diff --git a/Cargo.lock b/Cargo.lock index 0ff3d21d8835c..b20149ab4875d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10150,7 +10150,6 @@ dependencies = [ "risingwave_compactor", "risingwave_compute", "risingwave_ctl", - "risingwave_expr", "risingwave_expr_impl", "risingwave_frontend", "risingwave_meta_node", @@ -10696,13 +10695,7 @@ version = "1.9.0-alpha" dependencies = [ "anyhow", "arrow-array 50.0.0", - "arrow-flight", "arrow-schema 50.0.0", - "arrow-udf-flight", - "arrow-udf-js", - "arrow-udf-js-deno", - "arrow-udf-python", - "arrow-udf-wasm", "async-trait", "auto_impl", "await-tree", @@ -10718,12 +10711,9 @@ dependencies = [ "futures", "futures-async-stream", "futures-util", - "ginepro", "itertools 0.12.1", "linkme", "madsim-tokio", - "md5", - "moka", "num-traits", "openssl", "parse-display", @@ -10737,10 +10727,8 @@ dependencies = [ "static_assertions", "thiserror", "thiserror-ext", - "tonic 0.10.2", "tracing", "workspace-hack", - "zstd 0.13.0", ] [[package]] @@ -10749,7 +10737,14 @@ version = "1.9.0-alpha" dependencies = [ "aho-corasick", "anyhow", + "arrow-array 50.0.0", + "arrow-flight", "arrow-schema 50.0.0", + "arrow-udf-flight", + "arrow-udf-js", + "arrow-udf-js-deno", + "arrow-udf-python", + "arrow-udf-wasm", "async-trait", "auto_enums", "chrono", @@ -10759,6 +10754,7 @@ dependencies = [ "fancy-regex", "futures-async-stream", "futures-util", + "ginepro", "hex", "icelake", "itertools 0.12.1", @@ -10766,6 +10762,7 @@ dependencies = [ "linkme", "madsim-tokio", "md5", + "moka", "num-traits", "openssl", "regex", @@ -10782,8 +10779,10 @@ dependencies = [ "sql-json-path", "thiserror", "thiserror-ext", + "tonic 0.10.2", "tracing", "workspace-hack", + "zstd 0.13.0", ] [[package]] @@ -10803,8 +10802,6 @@ dependencies = [ "anyhow", "arc-swap", "arrow-schema 50.0.0", - "arrow-udf-flight", - "arrow-udf-wasm", "assert_matches", "async-recursion", "async-trait", diff --git a/Makefile.toml b/Makefile.toml index 7f681626ac16f..b0c8e2c4b993b 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -35,6 +35,9 @@ is_release = get_env ENABLE_RELEASE_PROFILE is_not_release = not ${is_release} is_dynamic_linking = get_env ENABLE_DYNAMIC_LINKING is_hummock_trace = get_env ENABLE_HUMMOCK_TRACE +is_external_udf_enabled = get_env ENABLE_EXTERNAL_UDF +is_wasm_udf_enabled = get_env ENABLE_WASM_UDF +is_js_udf_enabled = get_env ENABLE_JS_UDF is_deno_udf_enabled = get_env ENABLE_DENO_UDF is_python_udf_enabled = get_env ENABLE_PYTHON_UDF @@ -59,14 +62,29 @@ else set_env RISINGWAVE_FEATURE_FLAGS "--features rw-static-link" end +if ${is_external_udf_enabled} + flags = get_env RISINGWAVE_FEATURE_FLAGS + set_env RISINGWAVE_FEATURE_FLAGS "${flags} --features external-udf" +end + +if ${is_wasm_udf_enabled} + flags = get_env RISINGWAVE_FEATURE_FLAGS + set_env RISINGWAVE_FEATURE_FLAGS "${flags} --features wasm-udf" +end + +if ${is_js_udf_enabled} + flags = get_env RISINGWAVE_FEATURE_FLAGS + set_env RISINGWAVE_FEATURE_FLAGS "${flags} --features js-udf" +end + if ${is_deno_udf_enabled} flags = get_env RISINGWAVE_FEATURE_FLAGS - set_env RISINGWAVE_FEATURE_FLAGS "${flags} --features embedded-deno-udf" + set_env RISINGWAVE_FEATURE_FLAGS "${flags} --features deno-udf" end if ${is_python_udf_enabled} flags = get_env RISINGWAVE_FEATURE_FLAGS - set_env RISINGWAVE_FEATURE_FLAGS "${flags} --features embedded-python-udf" + set_env RISINGWAVE_FEATURE_FLAGS "${flags} --features python-udf" end if ${is_hummock_trace} diff --git a/ci/scripts/build.sh b/ci/scripts/build.sh index 026352908cf2f..bf074f0083925 100755 --- a/ci/scripts/build.sh +++ b/ci/scripts/build.sh @@ -54,8 +54,7 @@ cargo build \ -p risingwave_compaction_test \ -p risingwave_e2e_extended_mode_test \ "${RISINGWAVE_FEATURE_FLAGS[@]}" \ - --features embedded-deno-udf \ - --features embedded-python-udf \ + --features all-udf \ --profile "$profile" \ --timings diff --git a/ci/scripts/release.sh b/ci/scripts/release.sh index ac24e43712c79..77c3b993c0939 100755 --- a/ci/scripts/release.sh +++ b/ci/scripts/release.sh @@ -71,8 +71,8 @@ if [ "${ARCH}" == "aarch64" ]; then # see https://github.com/tikv/jemallocator/blob/802969384ae0c581255f3375ee2ba774c8d2a754/jemalloc-sys/build.rs#L218 export JEMALLOC_SYS_WITH_LG_PAGE=16 fi -cargo build -p risingwave_cmd_all --features "rw-static-link" --profile release -cargo build -p risingwave_cmd --bin risectl --features "rw-static-link" --profile release +cargo build -p risingwave_cmd_all --features "rw-static-link" --features all-udf --profile release +cargo build -p risingwave_cmd --bin risectl --features "rw-static-link" --features all-udf --profile release cd target/release && chmod +x risingwave risectl echo "--- Upload nightly binary to s3" diff --git a/docker/Dockerfile b/docker/Dockerfile index de5276eac9cf8..d022a3032fc66 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -69,7 +69,7 @@ WORKDIR /risingwave ENV ENABLE_BUILD_DASHBOARD=1 RUN cargo fetch && \ - cargo build -p risingwave_cmd_all --release --features "rw-static-link" --features embedded-deno-udf --features embedded-python-udf && \ + cargo build -p risingwave_cmd_all --release --features "rw-static-link" --features all-udf && \ mkdir -p /risingwave/bin && \ mv /risingwave/target/release/risingwave /risingwave/bin/ && \ mv /risingwave/target/release/risingwave.dwp /risingwave/bin/ && \ diff --git a/docker/Dockerfile.hdfs b/docker/Dockerfile.hdfs index 5f6a9c4af1ff4..4558b94b176ab 100644 --- a/docker/Dockerfile.hdfs +++ b/docker/Dockerfile.hdfs @@ -98,7 +98,7 @@ ENV JAVA_HOME ${JAVA_HOME_PATH} ENV LD_LIBRARY_PATH ${JAVA_HOME_PATH}/lib/server:${LD_LIBRARY_PATH} RUN cargo fetch && \ - cargo build -p risingwave_cmd_all --release -p risingwave_object_store --features hdfs-backend --features "rw-static-link" --features embedded-deno-udf --features embedded-python-udf && \ + cargo build -p risingwave_cmd_all --release -p risingwave_object_store --features hdfs-backend --features "rw-static-link" --features all-udf && \ mkdir -p /risingwave/bin && \ mv /risingwave/target/release/risingwave /risingwave/bin/ && \ mv /risingwave/target/release/risingwave.dwp /risingwave/bin/ && \ diff --git a/e2e_test/error_ui/simple/main.slt b/e2e_test/error_ui/simple/main.slt index 6bcbbde608cf8..f77aa3dd9dd6d 100644 --- a/e2e_test/error_ui/simple/main.slt +++ b/e2e_test/error_ui/simple/main.slt @@ -14,9 +14,8 @@ create function int_42() returns int as int_42 using link '555.0.0.1:8815'; db error: ERROR: Failed to run the query Caused by these errors (recent errors listed first): - 1: Expr error - 2: UDF error - 3: Flight service error: invalid address: 555.0.0.1:8815, err: failed to parse address: http://555.0.0.1:8815: invalid IPv4 address + 1: failed to parse address: http://555.0.0.1:8815 + 2: invalid IPv4 address statement error diff --git a/src/cmd_all/Cargo.toml b/src/cmd_all/Cargo.toml index 5ed92b65609cc..d3341b3137d43 100644 --- a/src/cmd_all/Cargo.toml +++ b/src/cmd_all/Cargo.toml @@ -8,11 +8,15 @@ license = { workspace = true } repository = { workspace = true } [features] +default = ["rw-static-link"] rw-static-link = ["workspace-config/rw-static-link"] rw-dynamic-link = ["workspace-config/rw-dynamic-link"] -embedded-deno-udf = ["risingwave_expr/embedded-deno-udf"] -embedded-python-udf = ["risingwave_expr/embedded-python-udf"] -default = ["rw-static-link"] +all-udf = ["external-udf", "wasm-udf", "js-udf", "deno-udf", "python-udf"] +external-udf = ["risingwave_expr_impl/external-udf"] +wasm-udf = ["risingwave_expr_impl/wasm-udf"] +js-udf = ["risingwave_expr_impl/js-udf"] +deno-udf = ["risingwave_expr_impl/deno-udf"] +python-udf = ["risingwave_expr_impl/python-udf"] [package.metadata.cargo-machete] ignored = ["workspace-hack", "workspace-config", "task_stats_alloc"] @@ -32,7 +36,6 @@ risingwave_common = { workspace = true } risingwave_compactor = { workspace = true } risingwave_compute = { workspace = true } risingwave_ctl = { workspace = true } -risingwave_expr = { workspace = true } risingwave_expr_impl = { workspace = true } risingwave_frontend = { workspace = true } risingwave_meta_node = { workspace = true } diff --git a/src/expr/core/Cargo.toml b/src/expr/core/Cargo.toml index c811d81b34658..fe08d74e56065 100644 --- a/src/expr/core/Cargo.toml +++ b/src/expr/core/Cargo.toml @@ -15,20 +15,10 @@ ignored = ["workspace-hack", "ctor"] [package.metadata.cargo-udeps.ignore] normal = ["workspace-hack", "ctor"] -[features] -embedded-deno-udf = ["arrow-udf-js-deno"] -embedded-python-udf = ["arrow-udf-python"] - [dependencies] anyhow = "1" arrow-array = { workspace = true } -arrow-flight = "50" arrow-schema = { workspace = true } -arrow-udf-flight = { workspace = true } -arrow-udf-js = { workspace = true } -arrow-udf-js-deno = { workspace = true, optional = true } -arrow-udf-python = { workspace = true, optional = true } -arrow-udf-wasm = { workspace = true } async-trait = "0.1" auto_impl = "1" await-tree = { workspace = true } @@ -46,11 +36,8 @@ enum-as-inner = "0.6" futures = "0.3" futures-async-stream = { workspace = true } futures-util = "0.3" -ginepro = "0.7" itertools = { workspace = true } linkme = { version = "0.3", features = ["used_linker"] } -md5 = "0.7" -moka = { version = "0.12", features = ["sync"] } num-traits = "0.2" openssl = { version = "0.10", features = ["vendored"] } parse-display = "0.9" @@ -68,9 +55,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [ "rt-multi-thread", "macros", ] } -tonic = "0.10" tracing = "0.1" -zstd = { version = "0.13", default-features = false } [target.'cfg(not(madsim))'.dependencies] workspace-hack = { path = "../../workspace-hack" } diff --git a/src/expr/core/src/error.rs b/src/expr/core/src/error.rs index 08562b3a973b7..268fcb5e9753b 100644 --- a/src/expr/core/src/error.rs +++ b/src/expr/core/src/error.rs @@ -95,13 +95,6 @@ pub enum ExprError { anyhow::Error, ), - #[error("UDF error: {0}")] - Udf( - #[from] - #[backtrace] - Box, - ), - #[error("not a constant")] NotConstant, @@ -156,12 +149,6 @@ impl From for ExprError { } } -impl From for ExprError { - fn from(err: arrow_udf_flight::Error) -> Self { - Self::Udf(Box::new(err)) - } -} - /// A collection of multiple errors. #[derive(Error, Debug)] pub struct MultiExprError(Box<[ExprError]>); diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 54d3006dc3033..3aa6d5cce5d00 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -12,41 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; -use std::sync::atomic::{AtomicU8, Ordering}; -use std::sync::{Arc, LazyLock, Weak}; -use std::time::Duration; +use std::sync::{Arc, LazyLock}; -use anyhow::{Context, Error}; -use arrow_array::RecordBatch; +use anyhow::Context; use arrow_schema::{Fields, Schema, SchemaRef}; -use arrow_udf_flight::Client as FlightClient; -use arrow_udf_js::{CallMode as JsCallMode, Runtime as JsRuntime}; -#[cfg(feature = "embedded-deno-udf")] -use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; -#[cfg(feature = "embedded-python-udf")] -use arrow_udf_python::{CallMode as PythonCallMode, Runtime as PythonRuntime}; -use arrow_udf_wasm::Runtime as WasmRuntime; use await_tree::InstrumentAwait; -use cfg_or_panic::cfg_or_panic; -use ginepro::{LoadBalancedChannel, ResolutionStrategy}; -use moka::sync::Cache; use prometheus::{ exponential_buckets, register_histogram_vec_with_registry, - register_int_counter_vec_with_registry, HistogramVec, IntCounter, IntCounterVec, Registry, + register_int_counter_vec_with_registry, Histogram, HistogramVec, IntCounter, IntCounterVec, + Registry, }; use risingwave_common::array::arrow::{FromArrow, ToArrow, UdfArrowConvert}; use risingwave_common::array::{Array, ArrayRef, DataChunk}; use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; -use risingwave_common::util::addr::HostAddr; use risingwave_expr::expr_context::FRAGMENT_ID; use risingwave_pb::expr::ExprNode; -use thiserror_ext::AsReport; use super::{BoxedExpression, Build}; use crate::expr::Expression; +use crate::sig::{UdfImpl, UdfOptions}; use crate::{bail, ExprError, Result}; #[derive(Debug)] @@ -55,36 +41,10 @@ pub struct UserDefinedFunction { arg_types: Vec, return_type: DataType, arg_schema: SchemaRef, - imp: UdfImpl, - identifier: String, - link: Option, + runtime: Box, arrow_convert: UdfArrowConvert, span: await_tree::Span, - /// Number of remaining successful calls until retry is enabled. - /// This parameter is designed to prevent continuous retry on every call, which would increase delay. - /// Logic: - /// It resets to `INITIAL_RETRY_COUNT` after a single failure and then decrements with each call, enabling retry when it reaches zero. - /// If non-zero, we will not retry on connection errors to prevent blocking the stream. - /// On each connection error, the count will be reset to `INITIAL_RETRY_COUNT`. - /// On each successful call, the count will be decreased by 1. - /// Link: - /// See . - disable_retry_count: AtomicU8, - /// Always retry. Overrides `disable_retry_count`. - always_retry_on_network_error: bool, -} - -const INITIAL_RETRY_COUNT: u8 = 16; - -#[derive(Debug)] -pub enum UdfImpl { - External(Arc), - Wasm(Arc), - JavaScript(JsRuntime), - #[cfg(feature = "embedded-python-udf")] - Python(PythonRuntime), - #[cfg(feature = "embedded-deno-udf")] - Deno(Arc), + metrics: Metrics, } #[async_trait::async_trait] @@ -130,99 +90,29 @@ impl UserDefinedFunction { .to_record_batch(self.arg_schema.clone(), input)?; // metrics - let metrics = &*GLOBAL_METRICS; - // batch query does not have a fragment_id - let fragment_id = FRAGMENT_ID::try_with(ToOwned::to_owned) - .unwrap_or(0) - .to_string(); - let language = match &self.imp { - UdfImpl::Wasm(_) => "wasm", - UdfImpl::JavaScript(_) => "javascript(quickjs)", - #[cfg(feature = "embedded-python-udf")] - UdfImpl::Python(_) => "python", - #[cfg(feature = "embedded-deno-udf")] - UdfImpl::Deno(_) => "javascript(deno)", - UdfImpl::External(_) => "external", - }; - let labels: &[&str; 4] = &[ - self.link.as_deref().unwrap_or(""), - language, - &self.identifier, - fragment_id.as_str(), - ]; - metrics - .udf_input_chunk_rows - .with_label_values(labels) + self.metrics + .input_chunk_rows .observe(arrow_input.num_rows() as f64); - metrics - .udf_input_rows - .with_label_values(labels) + self.metrics + .input_rows .inc_by(arrow_input.num_rows() as u64); - metrics - .udf_input_bytes - .with_label_values(labels) + self.metrics + .input_bytes .inc_by(arrow_input.get_array_memory_size() as u64); - let timer = metrics.udf_latency.with_label_values(labels).start_timer(); + let timer = self.metrics.latency.start_timer(); + + let arrow_output_result = self + .runtime + .call(&arrow_input) + .instrument_await(self.span.clone()) + .await; - let arrow_output_result: Result = match &self.imp { - UdfImpl::Wasm(runtime) => runtime.call(&self.identifier, &arrow_input), - UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &arrow_input), - #[cfg(feature = "embedded-python-udf")] - UdfImpl::Python(runtime) => runtime.call(&self.identifier, &arrow_input), - #[cfg(feature = "embedded-deno-udf")] - UdfImpl::Deno(runtime) => tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(runtime.call(&self.identifier, arrow_input)) - }), - UdfImpl::External(client) => { - let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); - let result = if self.always_retry_on_network_error { - call_with_always_retry_on_network_error( - client, - &self.identifier, - &arrow_input, - &metrics.udf_retry_count.with_label_values(labels), - ) - .instrument_await(self.span.clone()) - .await - } else { - let result = if disable_retry_count != 0 { - client - .call(&self.identifier, &arrow_input) - .instrument_await(self.span.clone()) - .await - } else { - call_with_retry(client, &self.identifier, &arrow_input) - .instrument_await(self.span.clone()) - .await - }; - let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); - let connection_error = matches!(&result, Err(e) if is_connection_error(e)); - if connection_error && disable_retry_count != INITIAL_RETRY_COUNT { - // reset count on connection error - self.disable_retry_count - .store(INITIAL_RETRY_COUNT, Ordering::Relaxed); - } else if !connection_error && disable_retry_count != 0 { - // decrease count on success, ignore if exchange failed - _ = self.disable_retry_count.compare_exchange( - disable_retry_count, - disable_retry_count - 1, - Ordering::Relaxed, - Ordering::Relaxed, - ); - } - result - }; - result.map_err(|e| e.into()) - } - }; timer.stop_and_record(); if arrow_output_result.is_ok() { - &metrics.udf_success_count + &self.metrics.success_count } else { - &metrics.udf_failure_count + &self.metrics.failure_count } - .with_label_values(labels) .inc(); let arrow_output = arrow_output_result?; @@ -269,52 +159,6 @@ impl UserDefinedFunction { } } -/// Call a function, retry up to 5 times / 3s if connection is broken. -async fn call_with_retry( - client: &FlightClient, - id: &str, - input: &RecordBatch, -) -> Result { - let mut backoff = Duration::from_millis(100); - for i in 0..5 { - match client.call(id, input).await { - Err(err) if is_connection_error(&err) && i != 4 => { - tracing::error!(error = %err.as_report(), "UDF connection error. retry..."); - } - ret => return ret, - } - tokio::time::sleep(backoff).await; - backoff *= 2; - } - unreachable!() -} - -/// Always retry on connection error -async fn call_with_always_retry_on_network_error( - client: &FlightClient, - id: &str, - input: &RecordBatch, - retry_count: &IntCounter, -) -> Result { - let mut backoff = Duration::from_millis(100); - loop { - match client.call(id, input).await { - Err(err) if is_tonic_error(&err) => { - tracing::error!(error = %err.as_report(), "UDF tonic error. retry..."); - } - ret => { - if ret.is_err() { - tracing::error!(error = %ret.as_ref().unwrap_err().as_report(), "UDF error. exiting..."); - } - return ret; - } - } - retry_count.inc(); - tokio::time::sleep(backoff).await; - backoff *= 2; - } -} - impl Build for UserDefinedFunction { fn build( prost: &ExprNode, @@ -322,115 +166,29 @@ impl Build for UserDefinedFunction { ) -> Result { let return_type = DataType::from(prost.get_return_type().unwrap()); let udf = prost.get_rex_node().unwrap().as_udf().unwrap(); - let mut arrow_convert = UdfArrowConvert::default(); - - #[cfg(not(feature = "embedded-deno-udf"))] - let runtime = "quickjs"; - - #[cfg(feature = "embedded-deno-udf")] - let runtime = match udf.runtime.as_deref() { - Some("deno") => "deno", - _ => "quickjs", - }; - let identifier = udf.get_identifier()?; - let imp = match udf.language.as_str() { - #[cfg(not(madsim))] - "wasm" | "rust" => { - let compressed_wasm_binary = udf.get_compressed_binary()?; - let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice()) - .context("failed to decompress wasm binary")?; - let runtime = get_or_create_wasm_runtime(&wasm_binary)?; - // backward compatibility - // see for details - if runtime.abi_version().0 <= 2 { - arrow_convert = UdfArrowConvert { legacy: true }; - } - UdfImpl::Wasm(runtime) - } - "javascript" if runtime != "deno" => { - let mut rt = JsRuntime::new()?; - let body = format!( - "export function {}({}) {{ {} }}", - identifier, - udf.arg_names.join(","), - udf.get_body()? - ); - rt.add_function( - identifier, - arrow_convert.to_arrow_field("", &return_type)?, - JsCallMode::CalledOnNullInput, - &body, - )?; - UdfImpl::JavaScript(rt) - } - #[cfg(feature = "embedded-deno-udf")] - "javascript" if runtime == "deno" => { - let rt = DenoRuntime::new(); - let body = match udf.get_body() { - Ok(body) => body.clone(), - Err(_) => match udf.get_compressed_binary() { - Ok(compressed_binary) => { - let binary = zstd::stream::decode_all(compressed_binary.as_slice()) - .context("failed to decompress binary")?; - String::from_utf8(binary).context("failed to decode binary")? - } - Err(_) => { - bail!("UDF body or compressed binary is required for deno UDF"); - } - }, - }; - - let body = if matches!(udf.function_type.as_deref(), Some("async")) { - format!( - "export async function {}({}) {{ {} }}", - identifier, - udf.arg_names.join(","), - body - ) - } else { - format!( - "export function {}({}) {{ {} }}", - identifier, - udf.arg_names.join(","), - body - ) - }; - futures::executor::block_on(rt.add_function( - identifier, - arrow_convert.to_arrow_field("", &return_type)?, - DenoCallMode::CalledOnNullInput, - &body, - ))?; + let language = udf.language.as_str(); + let runtime = udf.runtime.as_deref(); + let link = udf.link.as_deref(); + + // lookup UDF builder + let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn; + let runtime = build_fn(UdfOptions { + table_function: false, + body: udf.body.as_deref(), + compressed_binary: udf.compressed_binary.as_deref(), + link: udf.link.as_deref(), + identifier, + arg_names: &udf.arg_names, + return_type: &return_type, + always_retry_on_network_error: udf.always_retry_on_network_error, + function_type: udf.function_type.as_deref(), + }) + .context("failed to build UDF runtime")?; - UdfImpl::Deno(rt) - } - #[cfg(feature = "embedded-python-udf")] - "python" if udf.body.is_some() => { - let mut rt = PythonRuntime::builder().sandboxed(true).build()?; - let body = udf.get_body()?; - rt.add_function( - identifier, - arrow_convert.to_arrow_field("", &return_type)?, - PythonCallMode::CalledOnNullInput, - body, - )?; - UdfImpl::Python(rt) - } - #[cfg(not(madsim))] - _ => { - let link = udf.get_link()?; - let client = get_or_create_flight_client(link)?; - // backward compatibility - // see for details - if client.protocol_version() == 1 { - arrow_convert = UdfArrowConvert { legacy: true }; - } - UdfImpl::External(client) - } - #[cfg(madsim)] - l => panic!("UDF language {l:?} is not supported on madsim"), + let arrow_convert = UdfArrowConvert { + legacy: runtime.is_legacy(), }; let arg_schema = Arc::new(Schema::new( @@ -440,173 +198,98 @@ impl Build for UserDefinedFunction { .try_collect::()?, )); + // batch query does not have a fragment_id + let fragment_id = FRAGMENT_ID::try_with(ToOwned::to_owned) + .unwrap_or(0) + .to_string(); + let labels: &[&str; 4] = &[ + udf.link.as_deref().unwrap_or(""), + language, + identifier, + fragment_id.as_str(), + ]; + Ok(Self { children: udf.children.iter().map(build_child).try_collect()?, arg_types: udf.arg_types.iter().map(|t| t.into()).collect(), return_type, arg_schema, - imp, - identifier: identifier.clone(), - link: udf.link.clone(), + runtime, arrow_convert, span: format!("udf_call({})", identifier).into(), - disable_retry_count: AtomicU8::new(0), - always_retry_on_network_error: udf.always_retry_on_network_error, + metrics: GLOBAL_METRICS.with_label_values(labels), }) } } -#[cfg_or_panic(not(madsim))] -/// Get or create a client for the given UDF service. -/// -/// There is a global cache for clients, so that we can reuse the same client for the same service. -pub fn get_or_create_flight_client(link: &str) -> Result> { - static CLIENTS: LazyLock>>> = - LazyLock::new(Default::default); - let mut clients = CLIENTS.lock().unwrap(); - if let Some(client) = clients.get(link).and_then(|c| c.upgrade()) { - // reuse existing client - Ok(client) - } else { - // create new client - let client = Arc::new(tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(async { - let channel = connect_tonic(link).await?; - Ok(FlightClient::new(channel).await?) as Result<_> - }) - })?); - clients.insert(link.to_owned(), Arc::downgrade(&client)); - Ok(client) - } -} - -/// Connect to a UDF service and return a tonic `Channel`. -async fn connect_tonic(mut addr: &str) -> Result { - // Interval between two successive probes of the UDF DNS. - const DNS_PROBE_INTERVAL_SECS: u64 = 5; - // Timeout duration for performing an eager DNS resolution. - const EAGER_DNS_RESOLVE_TIMEOUT_SECS: u64 = 5; - const REQUEST_TIMEOUT_SECS: u64 = 5; - const CONNECT_TIMEOUT_SECS: u64 = 5; - - if let Some(s) = addr.strip_prefix("http://") { - addr = s; - } - if let Some(s) = addr.strip_prefix("https://") { - addr = s; - } - let host_addr = addr.parse::().map_err(|e| { - arrow_udf_flight::Error::Service(format!( - "invalid address: {}, err: {}", - addr, - e.as_report() - )) - })?; - let channel = LoadBalancedChannel::builder((host_addr.host.clone(), host_addr.port)) - .dns_probe_interval(std::time::Duration::from_secs(DNS_PROBE_INTERVAL_SECS)) - .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS)) - .connect_timeout(Duration::from_secs(CONNECT_TIMEOUT_SECS)) - .resolution_strategy(ResolutionStrategy::Eager { - timeout: tokio::time::Duration::from_secs(EAGER_DNS_RESOLVE_TIMEOUT_SECS), - }) - .channel() - .await - .map_err(|e| { - arrow_udf_flight::Error::Service(format!( - "failed to create LoadBalancedChannel, address: {}, err: {}", - host_addr, - e.as_report() - )) - })?; - Ok(channel.into()) -} - -/// Get or create a wasm runtime. -/// -/// Runtimes returned by this function are cached inside for at least 60 seconds. -/// Later calls with the same binary will reuse the same runtime. -#[cfg_or_panic(not(madsim))] -pub fn get_or_create_wasm_runtime(binary: &[u8]) -> Result> { - static RUNTIMES: LazyLock>> = LazyLock::new(|| { - Cache::builder() - .time_to_idle(Duration::from_secs(60)) - .build() - }); - - let md5 = md5::compute(binary); - if let Some(runtime) = RUNTIMES.get(&md5) { - return Ok(runtime.clone()); - } - - let runtime = Arc::new(arrow_udf_wasm::Runtime::new(binary)?); - RUNTIMES.insert(md5, runtime.clone()); - Ok(runtime) -} - -/// Returns true if the arrow flight error is caused by a connection error. -fn is_connection_error(err: &arrow_udf_flight::Error) -> bool { - match err { - // Connection refused - arrow_udf_flight::Error::Tonic(status) if status.code() == tonic::Code::Unavailable => true, - _ => false, - } -} - -fn is_tonic_error(err: &arrow_udf_flight::Error) -> bool { - matches!( - err, - arrow_udf_flight::Error::Tonic(_) - | arrow_udf_flight::Error::Flight(arrow_flight::error::FlightError::Tonic(_)) - ) +/// Monitor metrics for UDF. +#[derive(Debug, Clone)] +struct MetricsVec { + /// Number of successful UDF calls. + success_count: IntCounterVec, + /// Number of failed UDF calls. + failure_count: IntCounterVec, + /// Total number of retried UDF calls. + retry_count: IntCounterVec, + /// Input chunk rows of UDF calls. + input_chunk_rows: HistogramVec, + /// The latency of UDF calls in seconds. + latency: HistogramVec, + /// Total number of input rows of UDF calls. + input_rows: IntCounterVec, + /// Total number of input bytes of UDF calls. + input_bytes: IntCounterVec, } /// Monitor metrics for UDF. #[derive(Debug, Clone)] struct Metrics { /// Number of successful UDF calls. - udf_success_count: IntCounterVec, + success_count: IntCounter, /// Number of failed UDF calls. - udf_failure_count: IntCounterVec, + failure_count: IntCounter, /// Total number of retried UDF calls. - udf_retry_count: IntCounterVec, + #[allow(dead_code)] + retry_count: IntCounter, /// Input chunk rows of UDF calls. - udf_input_chunk_rows: HistogramVec, + input_chunk_rows: Histogram, /// The latency of UDF calls in seconds. - udf_latency: HistogramVec, + latency: Histogram, /// Total number of input rows of UDF calls. - udf_input_rows: IntCounterVec, + input_rows: IntCounter, /// Total number of input bytes of UDF calls. - udf_input_bytes: IntCounterVec, + input_bytes: IntCounter, } /// Global UDF metrics. -static GLOBAL_METRICS: LazyLock = LazyLock::new(|| Metrics::new(&GLOBAL_METRICS_REGISTRY)); +static GLOBAL_METRICS: LazyLock = + LazyLock::new(|| MetricsVec::new(&GLOBAL_METRICS_REGISTRY)); -impl Metrics { +impl MetricsVec { fn new(registry: &Registry) -> Self { let labels = &["link", "language", "name", "fragment_id"]; - let udf_success_count = register_int_counter_vec_with_registry!( + let success_count = register_int_counter_vec_with_registry!( "udf_success_count", "Total number of successful UDF calls", labels, registry ) .unwrap(); - let udf_failure_count = register_int_counter_vec_with_registry!( + let failure_count = register_int_counter_vec_with_registry!( "udf_failure_count", "Total number of failed UDF calls", labels, registry ) .unwrap(); - let udf_retry_count = register_int_counter_vec_with_registry!( + let retry_count = register_int_counter_vec_with_registry!( "udf_retry_count", "Total number of retried UDF calls", labels, registry ) .unwrap(); - let udf_input_chunk_rows = register_histogram_vec_with_registry!( + let input_chunk_rows = register_histogram_vec_with_registry!( "udf_input_chunk_rows", "Input chunk rows of UDF calls", labels, @@ -614,7 +297,7 @@ impl Metrics { registry ) .unwrap(); - let udf_latency = register_histogram_vec_with_registry!( + let latency = register_histogram_vec_with_registry!( "udf_latency", "The latency(s) of UDF calls", labels, @@ -622,14 +305,14 @@ impl Metrics { registry ) .unwrap(); - let udf_input_rows = register_int_counter_vec_with_registry!( + let input_rows = register_int_counter_vec_with_registry!( "udf_input_rows", "Total number of input rows of UDF calls", labels, registry ) .unwrap(); - let udf_input_bytes = register_int_counter_vec_with_registry!( + let input_bytes = register_int_counter_vec_with_registry!( "udf_input_bytes", "Total number of input bytes of UDF calls", labels, @@ -637,14 +320,26 @@ impl Metrics { ) .unwrap(); + MetricsVec { + success_count, + failure_count, + retry_count, + input_chunk_rows, + latency, + input_rows, + input_bytes, + } + } + + fn with_label_values(&self, values: &[&str; 4]) -> Metrics { Metrics { - udf_success_count, - udf_failure_count, - udf_retry_count, - udf_input_chunk_rows, - udf_latency, - udf_input_rows, - udf_input_bytes, + success_count: self.success_count.with_label_values(values), + failure_count: self.failure_count.with_label_values(values), + retry_count: self.retry_count.with_label_values(values), + input_chunk_rows: self.input_chunk_rows.with_label_values(values), + latency: self.latency.with_label_values(values), + input_rows: self.input_rows.with_label_values(values), + input_bytes: self.input_bytes.with_label_values(values), } } } diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 9188ced21d111..3f4b07d86bedf 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -51,7 +51,6 @@ use risingwave_common::types::{DataType, Datum}; pub use self::build::*; pub use self::expr_input_ref::InputRefExpression; pub use self::expr_literal::LiteralExpression; -pub use self::expr_udf::{get_or_create_flight_client, get_or_create_wasm_runtime}; pub use self::value::{ValueImpl, ValueRef}; pub use self::wrapper::*; pub use super::{ExprError, Result}; diff --git a/src/expr/core/src/sig/mod.rs b/src/expr/core/src/sig/mod.rs index 747779f81e990..124a002f6519e 100644 --- a/src/expr/core/src/sig/mod.rs +++ b/src/expr/core/src/sig/mod.rs @@ -30,6 +30,10 @@ use crate::expr::BoxedExpression; use crate::table_function::BoxedTableFunction; use crate::ExprError; +mod udf; + +pub use self::udf::*; + /// The global registry of all function signatures. pub static FUNCTION_REGISTRY: LazyLock = LazyLock::new(|| { let mut map = FunctionRegistry::default(); diff --git a/src/expr/core/src/sig/udf.rs b/src/expr/core/src/sig/udf.rs new file mode 100644 index 0000000000000..efc3c42b87614 --- /dev/null +++ b/src/expr/core/src/sig/udf.rs @@ -0,0 +1,130 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! UDF implementation interface. +//! +//! To support a new language or runtime for UDF, implement the interface in this module. +//! +//! See expr/impl/src/udf for the implementations. + +use anyhow::{bail, Context, Result}; +use arrow_array::RecordBatch; +use futures::stream::BoxStream; +use risingwave_common::types::DataType; + +/// The global registry of UDF implementations. +/// +/// To register a new UDF implementation: +/// +/// ```ignore +/// #[linkme::distributed_slice(UDF_IMPLS)] +/// static MY_UDF_LANGUAGE: UdfImplDescriptor = UdfImplDescriptor {...}; +/// ``` +#[linkme::distributed_slice] +pub static UDF_IMPLS: [UdfImplDescriptor]; + +/// Find a UDF implementation by language. +pub fn find_udf_impl( + language: &str, + runtime: Option<&str>, + link: Option<&str>, +) -> Result<&'static UdfImplDescriptor> { + let mut impls = UDF_IMPLS + .iter() + .filter(|desc| (desc.match_fn)(language, runtime, link)); + let impl_ = impls.next().context( + "language not found.\nHINT: UDF feature flag may not be enabled during compilation", + )?; + if impls.next().is_some() { + bail!("multiple UDF implementations found for language: {language}"); + } + Ok(impl_) +} + +/// UDF implementation descriptor. +/// +/// Every UDF implementation should provide 3 functions: +pub struct UdfImplDescriptor { + /// Returns if a function matches the implementation. + /// + /// This function is used to determine which implementation to use for a UDF. + pub match_fn: fn(language: &str, runtime: Option<&str>, link: Option<&str>) -> bool, + + /// Creates a function from options. + /// + /// This function will be called when `create function` statement is executed on the frontend. + pub create_fn: fn(opts: CreateFunctionOptions<'_>) -> Result, + + /// Builds UDF runtime from verified options. + /// + /// This function will be called before the UDF is executed on the backend. + pub build_fn: fn(opts: UdfOptions<'_>) -> Result>, +} + +/// Options for creating a function. +/// +/// These information are parsed from `CREATE FUNCTION` statement. +/// Implementations should verify the options and return a `CreateFunctionOutput` in `create_fn`. +pub struct CreateFunctionOptions<'a> { + pub name: &'a str, + pub arg_names: &'a [String], + pub arg_types: &'a [DataType], + pub return_type: &'a DataType, + pub is_table_function: bool, + pub as_: Option<&'a str>, + pub using_link: Option<&'a str>, + pub using_base64_decoded: Option<&'a [u8]>, +} + +/// Output of creating a function. +pub struct CreateFunctionOutput { + pub identifier: String, + pub body: Option, + pub compressed_binary: Option>, +} + +/// Options for building a UDF runtime. +pub struct UdfOptions<'a> { + pub table_function: bool, + pub body: Option<&'a str>, + pub compressed_binary: Option<&'a [u8]>, + pub link: Option<&'a str>, + pub identifier: &'a str, + pub arg_names: &'a [String], + pub return_type: &'a DataType, + pub always_retry_on_network_error: bool, + pub function_type: Option<&'a str>, +} + +/// UDF implementation. +#[async_trait::async_trait] +pub trait UdfImpl: std::fmt::Debug + Send + Sync { + /// Call the scalar function. + async fn call(&self, input: &RecordBatch) -> Result; + + /// Call the table function. + async fn call_table_function<'a>( + &'a self, + input: &'a RecordBatch, + ) -> Result>>; + + /// Whether the UDF talks in legacy mode. + /// + /// If true, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types. + /// Otherwise, they are mapped to Arrow extension types. + /// See . + fn is_legacy(&self) -> bool { + false + } +} diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index 79b14a126f10d..9df12a0aa8d8e 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -15,28 +15,21 @@ use std::sync::Arc; use anyhow::Context; -use arrow_array::RecordBatch; use arrow_schema::{Fields, Schema, SchemaRef}; -use arrow_udf_js::{CallMode as JsCallMode, Runtime as JsRuntime}; -#[cfg(feature = "embedded-deno-udf")] -use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; -#[cfg(feature = "embedded-python-udf")] -use arrow_udf_python::{CallMode as PythonCallMode, Runtime as PythonRuntime}; use cfg_or_panic::cfg_or_panic; use risingwave_common::array::arrow::{FromArrow, ToArrow, UdfArrowConvert}; use risingwave_common::array::I32Array; use risingwave_common::bail; use super::*; -use crate::expr::expr_udf::UdfImpl; +use crate::sig::{UdfImpl, UdfOptions}; #[derive(Debug)] pub struct UserDefinedTableFunction { children: Vec, arg_schema: SchemaRef, return_type: DataType, - client: UdfImpl, - identifier: String, + runtime: Box, arrow_convert: UdfArrowConvert, #[allow(dead_code)] chunk_size: usize, @@ -54,44 +47,6 @@ impl TableFunction for UserDefinedTableFunction { } } -#[cfg(not(madsim))] -impl UdfImpl { - #[try_stream(ok = RecordBatch, error = ExprError)] - async fn call_table_function<'a>(&'a self, identifier: &'a str, input: RecordBatch) { - match self { - UdfImpl::External(client) => { - #[for_await] - for res in client.call_table_function(identifier, &input).await? { - yield res?; - } - } - UdfImpl::JavaScript(runtime) => { - for res in runtime.call_table_function(identifier, &input, 1024)? { - yield res?; - } - } - #[cfg(feature = "embedded-python-udf")] - UdfImpl::Python(runtime) => { - for res in runtime.call_table_function(identifier, &input, 1024)? { - yield res?; - } - } - #[cfg(feature = "embedded-deno-udf")] - UdfImpl::Deno(runtime) => { - let mut iter = runtime.call_table_function(identifier, input, 1024).await?; - while let Some(res) = iter.next().await { - yield res?; - } - } - UdfImpl::Wasm(runtime) => { - for res in runtime.call_table_function(identifier, &input)? { - yield res?; - } - } - } - } -} - #[cfg(not(madsim))] impl UserDefinedTableFunction { #[try_stream(boxed, ok = DataChunk, error = ExprError)] @@ -113,10 +68,7 @@ impl UserDefinedTableFunction { // call UDTF #[for_await] - for res in self - .client - .call_table_function(&self.identifier, arrow_input) - { + for res in self.runtime.call_table_function(&arrow_input).await? { let output = self.arrow_convert.from_record_batch(&res?)?; self.check_output(&output)?; @@ -180,108 +132,27 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result "deno", - _ => "quickjs", - }; - - let mut arrow_convert = UdfArrowConvert::default(); - - let client = match udtf.language.as_str() { - "wasm" | "rust" => { - let compressed_wasm_binary = udtf.get_compressed_binary()?; - let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice()) - .context("failed to decompress wasm binary")?; - let runtime = crate::expr::expr_udf::get_or_create_wasm_runtime(&wasm_binary)?; - // backward compatibility - if runtime.abi_version().0 <= 2 { - arrow_convert = UdfArrowConvert { legacy: true }; - } - UdfImpl::Wasm(runtime) - } - "javascript" if runtime != "deno" => { - let mut rt = JsRuntime::new()?; - let body = format!( - "export function* {}({}) {{ {} }}", - identifier, - udtf.arg_names.join(","), - udtf.get_body()? - ); - rt.add_function( - identifier, - arrow_convert.to_arrow_field("", &return_type)?, - JsCallMode::CalledOnNullInput, - &body, - )?; - UdfImpl::JavaScript(rt) - } - #[cfg(feature = "embedded-deno-udf")] - "javascript" if runtime == "deno" => { - let rt = DenoRuntime::new(); - let body = match udtf.get_body() { - Ok(body) => body.clone(), - Err(_) => match udtf.get_compressed_binary() { - Ok(compressed_binary) => { - let binary = zstd::stream::decode_all(compressed_binary.as_slice()) - .context("failed to decompress binary")?; - String::from_utf8(binary).context("failed to decode binary")? - } - Err(_) => { - bail!("UDF body or compressed binary is required for deno UDF"); - } - }, - }; - - let body = format!( - "export {} {}({}) {{ {} }}", - match udtf.function_type.as_deref() { - Some("async") => "async function", - Some("async_generator") => "async function*", - Some("sync") => "function", - _ => "function*", - }, - identifier, - udtf.arg_names.join(","), - body - ); - - futures::executor::block_on(rt.add_function( - identifier, - arrow_convert.to_arrow_field("", &return_type)?, - DenoCallMode::CalledOnNullInput, - &body, - ))?; - UdfImpl::Deno(rt) - } - #[cfg(feature = "embedded-python-udf")] - "python" if udtf.body.is_some() => { - let mut rt = PythonRuntime::builder().sandboxed(true).build()?; - let body = udtf.get_body()?; - rt.add_function( - identifier, - arrow_convert.to_arrow_field("", &return_type)?, - PythonCallMode::CalledOnNullInput, - body, - )?; - UdfImpl::Python(rt) - } - // connect to UDF service - _ => { - let link = udtf.get_link()?; - let client = crate::expr::expr_udf::get_or_create_flight_client(link)?; - // backward compatibility - // see for details - if client.protocol_version() == 1 { - arrow_convert = UdfArrowConvert { legacy: true }; - } - UdfImpl::External(client) - } + let language = udtf.language.as_str(); + let runtime = udtf.runtime.as_deref(); + let link = udtf.link.as_deref(); + + let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn; + let runtime = build_fn(UdfOptions { + table_function: true, + body: udtf.body.as_deref(), + compressed_binary: udtf.compressed_binary.as_deref(), + link: udtf.link.as_deref(), + identifier, + arg_names: &udtf.arg_names, + return_type: &return_type, + always_retry_on_network_error: false, + function_type: udtf.function_type.as_deref(), + }) + .context("failed to build UDF runtime")?; + + let arrow_convert = UdfArrowConvert { + legacy: runtime.is_legacy(), }; - let arg_schema = Arc::new(Schema::new( udtf.arg_types .iter() @@ -293,8 +164,7 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result body = Some(as_.to_string()), + (Some(link), None, None) => { + let bytes = read_file_from_link(link)?; + compressed_binary = Some(zstd::stream::encode_all(bytes.as_slice(), 0)?); + } + (None, Some(bytes), None) => { + compressed_binary = Some(zstd::stream::encode_all(bytes, 0)?); + } + (None, None, None) => bail!("Either USING or AS must be specified"), + _ => bail!("Both USING and AS cannot be specified"), + } + Ok(CreateFunctionOutput { + identifier, + body, + compressed_binary, + }) + }, + build_fn: |opts| { + let runtime = Runtime::new(); + let body = match (opts.body, opts.compressed_binary) { + (Some(body), _) => body.to_string(), + (_, Some(compressed_binary)) => { + let binary = zstd::stream::decode_all(compressed_binary) + .context("failed to decompress binary")?; + String::from_utf8(binary).context("failed to decode binary")? + } + _ => bail!("UDF body or compressed binary is required for deno UDF"), + }; + + let body = format!( + "export {} {}({}) {{ {} }}", + match opts.function_type { + Some("sync") => "function", + Some("async") => "async function", + Some("generator") => "function*", + Some("async_generator") => "async function*", + _ if opts.table_function => "function*", + _ => "function", + }, + opts.identifier, + opts.arg_names.join(","), + body + ); + + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(runtime.add_function( + opts.identifier, + UdfArrowConvert::default().to_arrow_field("", opts.return_type)?, + CallMode::CalledOnNullInput, + &body, + )) + })?; + + Ok(Box::new(DenoFunction { + runtime, + identifier: opts.identifier.to_string(), + })) + }, +}; + +#[derive(Debug)] +struct DenoFunction { + runtime: Arc, + identifier: String, +} + +#[async_trait::async_trait] +impl UdfImpl for DenoFunction { + async fn call(&self, input: &RecordBatch) -> Result { + // FIXME(runji): make the future Send + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current() + .block_on(self.runtime.call(&self.identifier, input.clone())) + }) + } + + async fn call_table_function<'a>( + &'a self, + input: &'a RecordBatch, + ) -> Result>> { + Ok(self.call_table_function_inner(input.clone())) + } +} + +impl DenoFunction { + #[try_stream(boxed, ok = RecordBatch, error = anyhow::Error)] + async fn call_table_function_inner<'a>(&'a self, input: RecordBatch) { + let mut stream = self + .runtime + .call_table_function(&self.identifier, input, 1024) + .await?; + while let Some(batch) = stream.next().await { + yield batch?; + } + } +} diff --git a/src/expr/impl/src/udf/external.rs b/src/expr/impl/src/udf/external.rs new file mode 100644 index 0000000000000..53fe4feb4402f --- /dev/null +++ b/src/expr/impl/src/udf/external.rs @@ -0,0 +1,295 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::{Arc, LazyLock, Weak}; +use std::time::Duration; + +use anyhow::bail; +use arrow_schema::Fields; +use arrow_udf_flight::Client; +use futures_util::{StreamExt, TryStreamExt}; +use ginepro::{LoadBalancedChannel, ResolutionStrategy}; +use risingwave_common::array::arrow::{ToArrow, UdfArrowConvert}; +use risingwave_common::util::addr::HostAddr; +use thiserror_ext::AsReport; + +use super::*; + +#[linkme::distributed_slice(UDF_IMPLS)] +static EXTERNAL: UdfImplDescriptor = UdfImplDescriptor { + match_fn: |language, _runtime, link| { + link.is_some() && matches!(language, "python" | "java" | "") + }, + create_fn: |opts| { + let link = opts.using_link.context("USING LINK must be specified")?; + let identifier = opts.as_.context("AS must be specified")?.to_string(); + + // check UDF server + let client = get_or_create_flight_client(link)?; + let convert = UdfArrowConvert { + legacy: client.protocol_version() == 1, + }; + // A helper function to create a unnamed field from data type. + let to_field = |data_type| convert.to_arrow_field("", data_type); + let args = arrow_schema::Schema::new( + opts.arg_types + .iter() + .map(to_field) + .try_collect::()?, + ); + let returns = arrow_schema::Schema::new(if opts.is_table_function { + vec![ + arrow_schema::Field::new("row", arrow_schema::DataType::Int32, true), + to_field(opts.return_type)?, + ] + } else { + vec![to_field(opts.return_type)?] + }); + let function = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(client.get(&identifier)) + }) + .context("failed to check UDF signature")?; + if !data_types_match(&function.args, &args) { + bail!( + "argument type mismatch, expect: {:?}, actual: {:?}", + args, + function.args, + ); + } + if !data_types_match(&function.returns, &returns) { + bail!( + "return type mismatch, expect: {:?}, actual: {:?}", + returns, + function.returns, + ); + } + Ok(CreateFunctionOutput { + identifier, + body: None, + compressed_binary: None, + }) + }, + build_fn: |opts| { + let link = opts.link.context("link is required")?; + let client = get_or_create_flight_client(link)?; + Ok(Box::new(ExternalFunction { + identifier: opts.identifier.to_string(), + client, + disable_retry_count: AtomicU8::new(INITIAL_RETRY_COUNT), + always_retry_on_network_error: opts.always_retry_on_network_error, + })) + }, +}; + +#[derive(Debug)] +struct ExternalFunction { + identifier: String, + client: Arc, + /// Number of remaining successful calls until retry is enabled. + /// This parameter is designed to prevent continuous retry on every call, which would increase delay. + /// Logic: + /// It resets to `INITIAL_RETRY_COUNT` after a single failure and then decrements with each call, enabling retry when it reaches zero. + /// If non-zero, we will not retry on connection errors to prevent blocking the stream. + /// On each connection error, the count will be reset to `INITIAL_RETRY_COUNT`. + /// On each successful call, the count will be decreased by 1. + /// Link: + /// See . + disable_retry_count: AtomicU8, + /// Always retry. Overrides `disable_retry_count`. + always_retry_on_network_error: bool, +} + +const INITIAL_RETRY_COUNT: u8 = 16; + +#[async_trait::async_trait] +impl UdfImpl for ExternalFunction { + fn is_legacy(&self) -> bool { + // see for details + self.client.protocol_version() == 1 + } + + async fn call(&self, input: &RecordBatch) -> Result { + let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); + let result = if self.always_retry_on_network_error { + self.call_with_always_retry_on_network_error( + input, + // &metrics.udf_retry_count.with_label_values(labels), + ) + .await + } else { + let result = if disable_retry_count != 0 { + self.client.call(&self.identifier, input).await + } else { + self.call_with_retry(input).await + }; + let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); + let connection_error = matches!(&result, Err(e) if is_connection_error(e)); + if connection_error && disable_retry_count != INITIAL_RETRY_COUNT { + // reset count on connection error + self.disable_retry_count + .store(INITIAL_RETRY_COUNT, Ordering::Relaxed); + } else if !connection_error && disable_retry_count != 0 { + // decrease count on success, ignore if exchange failed + _ = self.disable_retry_count.compare_exchange( + disable_retry_count, + disable_retry_count - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ); + } + result + }; + result.map_err(|e| e.into()) + } + + async fn call_table_function<'a>( + &'a self, + input: &'a RecordBatch, + ) -> Result>> { + let stream = self + .client + .call_table_function(&self.identifier, input) + .await?; + Ok(stream.map_err(|e| e.into()).boxed()) + } +} + +/// Get or create a client for the given UDF service. +/// +/// There is a global cache for clients, so that we can reuse the same client for the same service. +fn get_or_create_flight_client(link: &str) -> Result> { + static CLIENTS: LazyLock>>> = + LazyLock::new(Default::default); + let mut clients = CLIENTS.lock().unwrap(); + if let Some(client) = clients.get(link).and_then(|c| c.upgrade()) { + // reuse existing client + Ok(client) + } else { + // create new client + let client = Arc::new(tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + let channel = connect_tonic(link).await?; + Ok(Client::new(channel).await?) as Result<_> + }) + })?); + clients.insert(link.to_owned(), Arc::downgrade(&client)); + Ok(client) + } +} + +/// Connect to a UDF service and return a tonic `Channel`. +async fn connect_tonic(mut addr: &str) -> Result { + // Interval between two successive probes of the UDF DNS. + const DNS_PROBE_INTERVAL_SECS: u64 = 5; + // Timeout duration for performing an eager DNS resolution. + const EAGER_DNS_RESOLVE_TIMEOUT_SECS: u64 = 5; + const REQUEST_TIMEOUT_SECS: u64 = 5; + const CONNECT_TIMEOUT_SECS: u64 = 5; + + if let Some(s) = addr.strip_prefix("http://") { + addr = s; + } + if let Some(s) = addr.strip_prefix("https://") { + addr = s; + } + let host_addr = addr.parse::()?; + let channel = LoadBalancedChannel::builder((host_addr.host.clone(), host_addr.port)) + .dns_probe_interval(std::time::Duration::from_secs(DNS_PROBE_INTERVAL_SECS)) + .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS)) + .connect_timeout(Duration::from_secs(CONNECT_TIMEOUT_SECS)) + .resolution_strategy(ResolutionStrategy::Eager { + timeout: tokio::time::Duration::from_secs(EAGER_DNS_RESOLVE_TIMEOUT_SECS), + }) + .channel() + .await + .with_context(|| format!("failed to create LoadBalancedChannel, address: {host_addr}"))?; + Ok(channel.into()) +} + +impl ExternalFunction { + /// Call a function, retry up to 5 times / 3s if connection is broken. + async fn call_with_retry( + &self, + input: &RecordBatch, + ) -> Result { + let mut backoff = Duration::from_millis(100); + for i in 0..5 { + match self.client.call(&self.identifier, input).await { + Err(err) if is_connection_error(&err) && i != 4 => { + tracing::error!(error = %err.as_report(), "UDF connection error. retry..."); + } + ret => return ret, + } + tokio::time::sleep(backoff).await; + backoff *= 2; + } + unreachable!() + } + + /// Always retry on connection error + async fn call_with_always_retry_on_network_error( + &self, + input: &RecordBatch, + // retry_count: &IntCounter, + ) -> Result { + let mut backoff = Duration::from_millis(100); + loop { + match self.client.call(&self.identifier, input).await { + Err(err) if is_tonic_error(&err) => { + tracing::error!(error = %err.as_report(), "UDF tonic error. retry..."); + } + ret => { + if ret.is_err() { + tracing::error!(error = %ret.as_ref().unwrap_err().as_report(), "UDF error. exiting..."); + } + return ret; + } + } + // retry_count.inc(); + tokio::time::sleep(backoff).await; + backoff *= 2; + } + } +} + +/// Returns true if the arrow flight error is caused by a connection error. +fn is_connection_error(err: &arrow_udf_flight::Error) -> bool { + match err { + // Connection refused + arrow_udf_flight::Error::Tonic(status) if status.code() == tonic::Code::Unavailable => true, + _ => false, + } +} + +fn is_tonic_error(err: &arrow_udf_flight::Error) -> bool { + matches!( + err, + arrow_udf_flight::Error::Tonic(_) + | arrow_udf_flight::Error::Flight(arrow_flight::error::FlightError::Tonic(_)) + ) +} + +/// Check if two list of data types match, ignoring field names. +fn data_types_match(a: &arrow_schema::Schema, b: &arrow_schema::Schema) -> bool { + if a.fields().len() != b.fields().len() { + return false; + } + #[allow(clippy::disallowed_methods)] + a.fields() + .iter() + .zip(b.fields()) + .all(|(a, b)| a.data_type().equals_datatype(b.data_type())) +} diff --git a/src/expr/impl/src/udf/mod.rs b/src/expr/impl/src/udf/mod.rs new file mode 100644 index 0000000000000..df2892244d8fa --- /dev/null +++ b/src/expr/impl/src/udf/mod.rs @@ -0,0 +1,46 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![allow(dead_code, unused_imports)] + +// common imports for submodules +use anyhow::{Context as _, Result}; +use arrow_array::RecordBatch; +use futures_util::stream::BoxStream; +use risingwave_expr::sig::{ + CreateFunctionOptions, CreateFunctionOutput, UdfImpl, UdfImplDescriptor, UDF_IMPLS, +}; + +#[cfg(feature = "deno-udf")] +mod deno; +#[cfg(feature = "external-udf")] +#[cfg(not(madsim))] +mod external; +#[cfg(feature = "python-udf")] +mod python; +#[cfg(feature = "js-udf")] +mod quickjs; +#[cfg(feature = "wasm-udf")] +mod wasm; + +/// Download wasm binary from a link. +fn read_file_from_link(link: &str) -> Result> { + // currently only local file system is supported + let path = link + .strip_prefix("fs://") + .context("only 'fs://' is supported")?; + let content = + std::fs::read(path).context("failed to read wasm binary from local file system")?; + Ok(content) +} diff --git a/src/expr/impl/src/udf/python.rs b/src/expr/impl/src/udf/python.rs new file mode 100644 index 0000000000000..cc358585bfa78 --- /dev/null +++ b/src/expr/impl/src/udf/python.rs @@ -0,0 +1,66 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow_udf_python::{CallMode, Runtime}; +use futures_util::StreamExt; +use risingwave_common::array::arrow::{ToArrow, UdfArrowConvert}; + +use super::*; + +#[linkme::distributed_slice(UDF_IMPLS)] +static PYTHON: UdfImplDescriptor = UdfImplDescriptor { + match_fn: |language, _runtime, link| language == "python" && link.is_none(), + create_fn: |opts| { + Ok(CreateFunctionOutput { + identifier: opts.name.to_string(), + body: Some(opts.as_.context("AS must be specified")?.to_string()), + compressed_binary: None, + }) + }, + build_fn: |opts| { + let mut runtime = Runtime::builder().sandboxed(true).build()?; + runtime.add_function( + opts.identifier, + UdfArrowConvert::default().to_arrow_field("", opts.return_type)?, + CallMode::CalledOnNullInput, + opts.body.context("body is required")?, + )?; + Ok(Box::new(PythonFunction { + runtime, + identifier: opts.identifier.to_string(), + })) + }, +}; + +#[derive(Debug)] +struct PythonFunction { + runtime: Runtime, + identifier: String, +} + +#[async_trait::async_trait] +impl UdfImpl for PythonFunction { + async fn call(&self, input: &RecordBatch) -> Result { + self.runtime.call(&self.identifier, input) + } + + async fn call_table_function<'a>( + &'a self, + input: &'a RecordBatch, + ) -> Result>> { + self.runtime + .call_table_function(&self.identifier, input, 1024) + .map(|s| futures_util::stream::iter(s).boxed()) + } +} diff --git a/src/expr/impl/src/udf/quickjs.rs b/src/expr/impl/src/udf/quickjs.rs new file mode 100644 index 0000000000000..9d0a58ec90d90 --- /dev/null +++ b/src/expr/impl/src/udf/quickjs.rs @@ -0,0 +1,75 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow_udf_js::{CallMode, Runtime}; +use futures_util::StreamExt; +use risingwave_common::array::arrow::{ToArrow, UdfArrowConvert}; + +use super::*; + +#[linkme::distributed_slice(UDF_IMPLS)] +static QUICKJS: UdfImplDescriptor = UdfImplDescriptor { + match_fn: |language, runtime, _link| { + language == "javascript" && matches!(runtime, None | Some("quickjs")) + }, + create_fn: |opts| { + Ok(CreateFunctionOutput { + identifier: opts.name.to_string(), + body: Some(opts.as_.context("AS must be specified")?.to_string()), + compressed_binary: None, + }) + }, + build_fn: |opts| { + let body = format!( + "export function{} {}({}) {{ {} }}", + if opts.table_function { "*" } else { "" }, + opts.identifier, + opts.arg_names.join(","), + opts.body.context("body is required")?, + ); + let mut runtime = Runtime::new()?; + runtime.add_function( + opts.identifier, + UdfArrowConvert::default().to_arrow_field("", opts.return_type)?, + CallMode::CalledOnNullInput, + &body, + )?; + Ok(Box::new(QuickJsFunction { + runtime, + identifier: opts.identifier.to_string(), + })) + }, +}; + +#[derive(Debug)] +struct QuickJsFunction { + runtime: Runtime, + identifier: String, +} + +#[async_trait::async_trait] +impl UdfImpl for QuickJsFunction { + async fn call(&self, input: &RecordBatch) -> Result { + self.runtime.call(&self.identifier, input) + } + + async fn call_table_function<'a>( + &'a self, + input: &'a RecordBatch, + ) -> Result>> { + self.runtime + .call_table_function(&self.identifier, input, 1024) + .map(|s| futures_util::stream::iter(s).boxed()) + } +} diff --git a/src/expr/impl/src/udf/wasm.rs b/src/expr/impl/src/udf/wasm.rs new file mode 100644 index 0000000000000..d444d87d83fd9 --- /dev/null +++ b/src/expr/impl/src/udf/wasm.rs @@ -0,0 +1,283 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::sync::{Arc, LazyLock}; +use std::time::Duration; + +use anyhow::{anyhow, bail}; +use arrow_udf_wasm::Runtime; +use futures_util::StreamExt; +use itertools::Itertools; +use risingwave_common::types::DataType; +use risingwave_expr::sig::UdfOptions; + +use super::*; + +#[linkme::distributed_slice(UDF_IMPLS)] +static WASM: UdfImplDescriptor = UdfImplDescriptor { + match_fn: |language, _runtime, _link| language == "wasm", + create_fn: create_wasm, + build_fn: build, +}; + +#[linkme::distributed_slice(UDF_IMPLS)] +static RUST: UdfImplDescriptor = UdfImplDescriptor { + match_fn: |language, _runtime, _link| language == "rust", + create_fn: create_rust, + build_fn: build, +}; + +fn create_wasm(opts: CreateFunctionOptions<'_>) -> Result { + let wasm_binary: Cow<'_, [u8]> = if let Some(link) = opts.using_link { + read_file_from_link(link)?.into() + } else if let Some(bytes) = opts.using_base64_decoded { + bytes.into() + } else { + bail!("USING must be specified") + }; + let runtime = get_or_create_wasm_runtime(&wasm_binary)?; + if runtime.abi_version().0 <= 2 { + bail!("legacy arrow-udf is no longer supported. please update arrow-udf to 0.3+"); + } + let identifier_v1 = wasm_identifier_v1( + opts.name, + opts.arg_types, + opts.return_type, + opts.is_table_function, + ); + let identifier = find_wasm_identifier_v2(&runtime, &identifier_v1)?; + let compressed_binary = Some(zstd::stream::encode_all(&*wasm_binary, 0)?); + Ok(CreateFunctionOutput { + identifier, + body: None, + compressed_binary, + }) +} + +fn create_rust(opts: CreateFunctionOptions<'_>) -> Result { + if opts.using_link.is_some() { + bail!("USING is not supported for rust function"); + } + let identifier_v1 = wasm_identifier_v1( + opts.name, + opts.arg_types, + opts.return_type, + opts.is_table_function, + ); + // if the function returns a struct, users need to add `#[function]` macro by themselves. + // otherwise, we add it automatically. the code should start with `fn ...`. + let function_macro = if opts.return_type.is_struct() { + String::new() + } else { + format!("#[function(\"{}\")]", identifier_v1) + }; + let script = format!( + "use arrow_udf::{{function, types::*}};\n{}\n{}", + function_macro, + opts.as_.context("AS must be specified")? + ); + let body = Some(script.clone()); + + let wasm_binary = std::thread::spawn(move || { + let mut opts = arrow_udf_wasm::build::BuildOpts::default(); + opts.arrow_udf_version = Some("0.3".to_string()); + opts.script = script; + // use a fixed tempdir to reuse the build cache + opts.tempdir = Some(std::env::temp_dir().join("risingwave-rust-udf")); + + arrow_udf_wasm::build::build_with(&opts) + }) + .join() + .unwrap() + .context("failed to build rust function")?; + + let runtime = get_or_create_wasm_runtime(&wasm_binary)?; + let identifier = find_wasm_identifier_v2(&runtime, &identifier_v1)?; + + let compressed_binary = Some(zstd::stream::encode_all(wasm_binary.as_slice(), 0)?); + Ok(CreateFunctionOutput { + identifier, + body, + compressed_binary, + }) +} + +fn build(opts: UdfOptions<'_>) -> Result> { + let compressed_binary = opts + .compressed_binary + .context("compressed binary is required")?; + let wasm_binary = + zstd::stream::decode_all(compressed_binary).context("failed to decompress wasm binary")?; + let runtime = get_or_create_wasm_runtime(&wasm_binary)?; + Ok(Box::new(WasmFunction { + runtime, + identifier: opts.identifier.to_string(), + })) +} + +#[derive(Debug)] +struct WasmFunction { + runtime: Arc, + identifier: String, +} + +#[async_trait::async_trait] +impl UdfImpl for WasmFunction { + async fn call(&self, input: &RecordBatch) -> Result { + self.runtime.call(&self.identifier, input) + } + + async fn call_table_function<'a>( + &'a self, + input: &'a RecordBatch, + ) -> Result>> { + self.runtime + .call_table_function(&self.identifier, input) + .map(|s| futures_util::stream::iter(s).boxed()) + } + + fn is_legacy(&self) -> bool { + // see for details + self.runtime.abi_version().0 <= 2 + } +} + +/// Get or create a wasm runtime. +/// +/// Runtimes returned by this function are cached inside for at least 60 seconds. +/// Later calls with the same binary will reuse the same runtime. +fn get_or_create_wasm_runtime(binary: &[u8]) -> Result> { + static RUNTIMES: LazyLock>> = LazyLock::new(|| { + moka::sync::Cache::builder() + .time_to_idle(Duration::from_secs(60)) + .build() + }); + + let md5 = md5::compute(binary); + if let Some(runtime) = RUNTIMES.get(&md5) { + return Ok(runtime.clone()); + } + + let runtime = Arc::new(arrow_udf_wasm::Runtime::new(binary)?); + RUNTIMES.insert(md5, runtime.clone()); + Ok(runtime) +} + +/// Convert a v0.1 function identifier to v0.2 format. +/// +/// In arrow-udf v0.1 format, struct type is inline in the identifier. e.g. +/// +/// ```text +/// keyvalue(varchar,varchar)->struct +/// ``` +/// +/// However, since arrow-udf v0.2, struct type is no longer inline. +/// The above identifier is divided into a function and a type. +/// +/// ```text +/// keyvalue(varchar,varchar)->struct KeyValue +/// KeyValue=key:varchar,value:varchar +/// ``` +/// +/// For compatibility, we should call `find_wasm_identifier_v2` to +/// convert v0.1 identifiers to v0.2 format before looking up the function. +fn find_wasm_identifier_v2( + runtime: &arrow_udf_wasm::Runtime, + inlined_signature: &str, +) -> Result { + // Inline types in function signature. + // + // # Example + // + // ```text + // types = { "KeyValue": "key:varchar,value:varchar" } + // input = "keyvalue(varchar, varchar) -> struct KeyValue" + // output = "keyvalue(varchar, varchar) -> struct" + // ``` + let inline_types = |s: &str| -> String { + let mut inlined = s.to_string(); + // iteratively replace `struct Xxx` with `struct<...>` until no replacement is made. + loop { + let replaced = inlined.clone(); + for (k, v) in runtime.types() { + inlined = inlined.replace(&format!("struct {k}"), &format!("struct<{v}>")); + } + if replaced == inlined { + return inlined; + } + } + }; + // Function signature in arrow-udf is case sensitive. + // However, SQL identifiers are usually case insensitive and stored in lowercase. + // So we should convert the signature to lowercase before comparison. + let identifier = runtime + .functions() + .find(|f| inline_types(f).to_lowercase() == inlined_signature) + .ok_or_else(|| { + anyhow!( + "function not found in wasm binary: \"{}\"\nHINT: available functions:\n {}\navailable types:\n {}", + inlined_signature, + runtime.functions().join("\n "), + runtime.types().map(|(k, v)| format!("{k}: {v}")).join("\n "), + ) + })?; + Ok(identifier.into()) +} + +/// Generate a function identifier in v0.1 format from the function signature. +fn wasm_identifier_v1( + name: &str, + args: &[DataType], + ret: &DataType, + table_function: bool, +) -> String { + format!( + "{}({}){}{}", + name, + args.iter().map(datatype_name).join(","), + if table_function { "->>" } else { "->" }, + datatype_name(ret) + ) +} + +/// Convert a data type to string used in identifier. +fn datatype_name(ty: &DataType) -> String { + match ty { + DataType::Boolean => "boolean".to_string(), + DataType::Int16 => "int16".to_string(), + DataType::Int32 => "int32".to_string(), + DataType::Int64 => "int64".to_string(), + DataType::Float32 => "float32".to_string(), + DataType::Float64 => "float64".to_string(), + DataType::Date => "date32".to_string(), + DataType::Time => "time64".to_string(), + DataType::Timestamp => "timestamp".to_string(), + DataType::Timestamptz => "timestamptz".to_string(), + DataType::Interval => "interval".to_string(), + DataType::Decimal => "decimal".to_string(), + DataType::Jsonb => "json".to_string(), + DataType::Serial => "serial".to_string(), + DataType::Int256 => "int256".to_string(), + DataType::Bytea => "binary".to_string(), + DataType::Varchar => "string".to_string(), + DataType::List(inner) => format!("{}[]", datatype_name(inner)), + DataType::Struct(s) => format!( + "struct<{}>", + s.iter() + .map(|(name, ty)| format!("{}:{}", name, datatype_name(ty))) + .join(",") + ), + } +} diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index 3decb8535ca67..4adeaae423ca6 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -18,8 +18,6 @@ normal = ["workspace-hack"] anyhow = "1" arc-swap = "1" arrow-schema = { workspace = true } -arrow-udf-flight = { workspace = true } -arrow-udf-wasm = { workspace = true } async-recursion = "1.1.0" async-trait = "0.1" auto_enums = { workspace = true } diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index c6cafdea37fa7..573e1b96920c2 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -13,12 +13,9 @@ // limitations under the License. use anyhow::Context; -use arrow_schema::Fields; -use bytes::Bytes; -use risingwave_common::array::arrow::{ToArrow, UdfArrowConvert}; use risingwave_common::catalog::FunctionId; use risingwave_common::types::DataType; -use risingwave_expr::expr::{get_or_create_flight_client, get_or_create_wasm_runtime}; +use risingwave_expr::sig::CreateFunctionOptions; use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; use risingwave_pb::catalog::Function; @@ -62,10 +59,10 @@ pub async fn handle_create_function( None => "".to_string(), }; - let rt = match params.runtime { + let runtime = match params.runtime { Some(runtime) => { - if language.as_str() == "javascript" { - runtime.to_string() + if language == "javascript" { + Some(runtime.real_value()) } else { return Err(ErrorCode::InvalidParameterValue( "runtime is only supported for javascript".to_string(), @@ -73,7 +70,7 @@ pub async fn handle_create_function( .into()); } } - None => "".to_string(), + None => None, }; let return_type; @@ -130,214 +127,40 @@ pub async fn handle_create_function( return Err(CatalogError::Duplicated("function", name).into()); } - let identifier; - let mut link = None; - let mut body = None; - let mut compressed_binary = None; - let mut function_type = None; - let mut runtime = None; - - match language.as_str() { - "python" if params.using.is_none() => { - identifier = function_name.to_string(); - body = Some( - params - .as_ - .ok_or_else(|| ErrorCode::InvalidParameterValue("AS must be specified".into()))? - .into_string(), - ); - } - "python" | "java" | "" => { - let Some(CreateFunctionUsing::Link(l)) = params.using else { - return Err(ErrorCode::InvalidParameterValue( - "USING LINK must be specified".to_string(), - ) - .into()); - }; - let Some(as_) = params.as_ else { - return Err( - ErrorCode::InvalidParameterValue("AS must be specified".to_string()).into(), - ); - }; - identifier = as_.into_string(); - - // check UDF server - { - let client = get_or_create_flight_client(&l)?; - let convert = UdfArrowConvert { - legacy: client.protocol_version() == 1, - }; - // A helper function to create a unnamed field from data type. - let to_field = |data_type| convert.to_arrow_field("", data_type); - let args = arrow_schema::Schema::new( - arg_types - .iter() - .map(to_field) - .try_collect::<_, Fields, _>()?, - ); - let returns = arrow_schema::Schema::new(match kind { - Kind::Scalar(_) => vec![to_field(&return_type)?], - Kind::Table(_) => vec![ - arrow_schema::Field::new("row", arrow_schema::DataType::Int32, true), - to_field(&return_type)?, - ], - _ => unreachable!(), - }); - let function = client - .get(&identifier) - .await - .context("failed to check UDF signature")?; - if !data_types_match(&function.args, &args) { - return Err(ErrorCode::InvalidParameterValue(format!( - "argument type mismatch, expect: {:?}, actual: {:?}", - args, function.args, - )) - .into()); - } - if !data_types_match(&function.returns, &returns) { - return Err(ErrorCode::InvalidParameterValue(format!( - "return type mismatch, expect: {:?}, actual: {:?}", - returns, function.returns, - )) - .into()); - } - } - link = Some(l); - } - "javascript" if rt.as_str() != "deno" => { - identifier = function_name.to_string(); - body = Some( - params - .as_ - .ok_or_else(|| ErrorCode::InvalidParameterValue("AS must be specified".into()))? - .into_string(), - ); - runtime = Some("quickjs".to_string()); - } - "javascript" if rt.as_str() == "deno" => { - identifier = function_name.to_string(); - match (params.using, params.as_) { - (None, None) => { - return Err(ErrorCode::InvalidParameterValue( - "Either USING or AS must be specified".into(), - ) - .into()) - } - (None, Some(_as)) => body = Some(_as.into_string()), - (Some(CreateFunctionUsing::Link(link)), None) => { - let bytes = download_code_from_link(&link).await?; - compressed_binary = Some(zstd::stream::encode_all(bytes.as_slice(), 0)?); - } - (Some(CreateFunctionUsing::Base64(encoded)), None) => { - use base64::prelude::{Engine, BASE64_STANDARD}; - let bytes = BASE64_STANDARD - .decode(encoded) - .context("invalid base64 encoding")?; - compressed_binary = Some(zstd::stream::encode_all(bytes.as_slice(), 0)?); - } - (Some(_), Some(_)) => { - return Err(ErrorCode::InvalidParameterValue( - "Both USING and AS cannot be specified".into(), - ) - .into()) - } - }; - - function_type = match params.function_type { - Some(CreateFunctionType::Sync) => Some("sync".to_string()), - Some(CreateFunctionType::Async) => Some("async".to_string()), - Some(CreateFunctionType::Generator) => Some("generator".to_string()), - Some(CreateFunctionType::AsyncGenerator) => Some("async_generator".to_string()), - None => None, - }; - - runtime = Some("deno".to_string()); - } - "rust" => { - if params.using.is_some() { - return Err(ErrorCode::InvalidParameterValue( - "USING is not supported for rust function".to_string(), - ) - .into()); - } - let identifier_v1 = wasm_identifier_v1( - &function_name, - &arg_types, - &return_type, - matches!(kind, Kind::Table(_)), - ); - // if the function returns a struct, users need to add `#[function]` macro by themselves. - // otherwise, we add it automatically. the code should start with `fn ...`. - let function_macro = if return_type.is_struct() { - String::new() - } else { - format!("#[function(\"{}\")]", identifier_v1) - }; - let script = params - .as_ - .ok_or_else(|| ErrorCode::InvalidParameterValue("AS must be specified".into()))? - .into_string(); - let script = format!( - "use arrow_udf::{{function, types::*}};\n{}\n{}", - function_macro, script - ); - body = Some(script.clone()); - - let wasm_binary = tokio::task::spawn_blocking(move || { - let mut opts = arrow_udf_wasm::build::BuildOpts::default(); - opts.arrow_udf_version = Some("0.3".to_string()); - opts.script = script; - // use a fixed tempdir to reuse the build cache - opts.tempdir = Some(std::env::temp_dir().join("risingwave-rust-udf")); - - arrow_udf_wasm::build::build_with(&opts) - }) - .await? - .context("failed to build rust function")?; - - let runtime = get_or_create_wasm_runtime(&wasm_binary)?; - identifier = find_wasm_identifier_v2(&runtime, &identifier_v1)?; - - compressed_binary = Some(zstd::stream::encode_all(wasm_binary.as_slice(), 0)?); - } - "wasm" => { - let Some(using) = params.using else { - return Err(ErrorCode::InvalidParameterValue( - "USING must be specified".to_string(), - ) - .into()); - }; - let wasm_binary = match using { - CreateFunctionUsing::Link(link) => download_binary_from_link(&link).await?, - CreateFunctionUsing::Base64(encoded) => { - // decode wasm binary from base64 - use base64::prelude::{Engine, BASE64_STANDARD}; - BASE64_STANDARD - .decode(encoded) - .context("invalid base64 encoding")? - .into() - } - }; - let runtime = get_or_create_wasm_runtime(&wasm_binary)?; - if runtime.abi_version().0 <= 2 { - return Err(ErrorCode::InvalidParameterValue( - "legacy arrow-udf is no longer supported. please update arrow-udf to 0.3+" - .to_string(), - ) - .into()); - } - let identifier_v1 = wasm_identifier_v1( - &function_name, - &arg_types, - &return_type, - matches!(kind, Kind::Table(_)), - ); - identifier = find_wasm_identifier_v2(&runtime, &identifier_v1)?; - - compressed_binary = Some(zstd::stream::encode_all(wasm_binary.as_ref(), 0)?); + let link = match ¶ms.using { + Some(CreateFunctionUsing::Link(l)) => Some(l.as_str()), + _ => None, + }; + let base64_decoded = match ¶ms.using { + Some(CreateFunctionUsing::Base64(encoded)) => { + use base64::prelude::{Engine, BASE64_STANDARD}; + let bytes = BASE64_STANDARD + .decode(encoded) + .context("invalid base64 encoding")?; + Some(bytes) } - _ => unreachable!("invalid language: {language}"), + _ => None, }; + let function_type = match params.function_type { + Some(CreateFunctionType::Sync) => Some("sync".to_string()), + Some(CreateFunctionType::Async) => Some("async".to_string()), + Some(CreateFunctionType::Generator) => Some("generator".to_string()), + Some(CreateFunctionType::AsyncGenerator) => Some("async_generator".to_string()), + None => None, + }; + + let create_fn = + risingwave_expr::sig::find_udf_impl(&language, runtime.as_deref(), link)?.create_fn; + let output = create_fn(CreateFunctionOptions { + name: &function_name, + arg_names: &arg_names, + arg_types: &arg_types, + return_type: &return_type, + is_table_function: matches!(kind, Kind::Table(_)), + as_: params.as_.as_ref().map(|s| s.as_str()), + using_link: link, + using_base64_decoded: base64_decoded.as_deref(), + })?; let function = Function { id: FunctionId::placeholder().0, @@ -349,10 +172,10 @@ pub async fn handle_create_function( arg_types: arg_types.into_iter().map(|t| t.into()).collect(), return_type: Some(return_type.into()), language, - identifier: Some(identifier), - link, - body, - compressed_binary, + identifier: Some(output.identifier), + link: link.map(|s| s.to_string()), + body: output.body, + compressed_binary: output.compressed_binary, owner: session.user_id(), always_retry_on_network_error: with_options .always_retry_on_network_error @@ -366,148 +189,3 @@ pub async fn handle_create_function( Ok(PgResponse::empty_result(StatementType::CREATE_FUNCTION)) } - -/// Download wasm binary from a link. -#[allow(clippy::unused_async)] -async fn download_binary_from_link(link: &str) -> Result { - // currently only local file system is supported - if let Some(path) = link.strip_prefix("fs://") { - let content = - std::fs::read(path).context("failed to read wasm binary from local file system")?; - Ok(content.into()) - } else { - Err(ErrorCode::InvalidParameterValue("only 'fs://' is supported".to_string()).into()) - } -} - -/// Convert a v0.1 function identifier to v0.2 format. -/// -/// In arrow-udf v0.1 format, struct type is inline in the identifier. e.g. -/// -/// ```text -/// keyvalue(varchar,varchar)->struct -/// ``` -/// -/// However, since arrow-udf v0.2, struct type is no longer inline. -/// The above identifier is divided into a function and a type. -/// -/// ```text -/// keyvalue(varchar,varchar)->struct KeyValue -/// KeyValue=key:varchar,value:varchar -/// ``` -/// -/// For compatibility, we should call `find_wasm_identifier_v2` to -/// convert v0.1 identifiers to v0.2 format before looking up the function. -fn find_wasm_identifier_v2( - runtime: &arrow_udf_wasm::Runtime, - inlined_signature: &str, -) -> Result { - // Inline types in function signature. - // - // # Example - // - // ```text - // types = { "KeyValue": "key:varchar,value:varchar" } - // input = "keyvalue(varchar, varchar) -> struct KeyValue" - // output = "keyvalue(varchar, varchar) -> struct" - // ``` - let inline_types = |s: &str| -> String { - let mut inlined = s.to_string(); - // iteratively replace `struct Xxx` with `struct<...>` until no replacement is made. - loop { - let replaced = inlined.clone(); - for (k, v) in runtime.types() { - inlined = inlined.replace(&format!("struct {k}"), &format!("struct<{v}>")); - } - if replaced == inlined { - return inlined; - } - } - }; - // Function signature in arrow-udf is case sensitive. - // However, SQL identifiers are usually case insensitive and stored in lowercase. - // So we should convert the signature to lowercase before comparison. - let identifier = runtime - .functions() - .find(|f| inline_types(f).to_lowercase() == inlined_signature) - .ok_or_else(|| { - ErrorCode::InvalidParameterValue(format!( - "function not found in wasm binary: \"{}\"\nHINT: available functions:\n {}\navailable types:\n {}", - inlined_signature, - runtime.functions().join("\n "), - runtime.types().map(|(k, v)| format!("{k}: {v}")).join("\n "), - )) - })?; - Ok(identifier.into()) -} - -/// Download wasm binary from a link. -#[allow(clippy::unused_async)] -async fn download_code_from_link(link: &str) -> Result> { - // currently only local file system is supported - if let Some(path) = link.strip_prefix("fs://") { - let content = - std::fs::read(path).context("failed to read the code from local file system")?; - Ok(content) - } else { - Err(ErrorCode::InvalidParameterValue("only 'fs://' is supported".to_string()).into()) - } -} - -/// Generate a function identifier in v0.1 format from the function signature. -fn wasm_identifier_v1( - name: &str, - args: &[DataType], - ret: &DataType, - table_function: bool, -) -> String { - format!( - "{}({}){}{}", - name, - args.iter().map(datatype_name).join(","), - if table_function { "->>" } else { "->" }, - datatype_name(ret) - ) -} - -/// Convert a data type to string used in identifier. -fn datatype_name(ty: &DataType) -> String { - match ty { - DataType::Boolean => "boolean".to_string(), - DataType::Int16 => "int16".to_string(), - DataType::Int32 => "int32".to_string(), - DataType::Int64 => "int64".to_string(), - DataType::Float32 => "float32".to_string(), - DataType::Float64 => "float64".to_string(), - DataType::Date => "date32".to_string(), - DataType::Time => "time64".to_string(), - DataType::Timestamp => "timestamp".to_string(), - DataType::Timestamptz => "timestamptz".to_string(), - DataType::Interval => "interval".to_string(), - DataType::Decimal => "decimal".to_string(), - DataType::Jsonb => "json".to_string(), - DataType::Serial => "serial".to_string(), - DataType::Int256 => "int256".to_string(), - DataType::Bytea => "binary".to_string(), - DataType::Varchar => "string".to_string(), - DataType::List(inner) => format!("{}[]", datatype_name(inner)), - DataType::Struct(s) => format!( - "struct<{}>", - s.iter() - .map(|(name, ty)| format!("{}:{}", name, datatype_name(ty))) - .join(",") - ), - } -} - -/// Check if two list of data types match, ignoring field names. -fn data_types_match(a: &arrow_schema::Schema, b: &arrow_schema::Schema) -> bool { - if a.fields().len() != b.fields().len() { - return false; - } - #[allow(clippy::disallowed_methods)] - a.fields() - .iter() - .zip(b.fields()) - .all(|(a, b)| a.data_type().equals_datatype(b.data_type())) -} diff --git a/src/risedevtool/config/src/main.rs b/src/risedevtool/config/src/main.rs index 98416c5691890..d69aad43f2dac 100644 --- a/src/risedevtool/config/src/main.rs +++ b/src/risedevtool/config/src/main.rs @@ -75,6 +75,11 @@ pub enum Components { HummockTrace, Coredump, NoBacktrace, + ExternalUdf, + WasmUdf, + JsUdf, + DenoUdf, + PythonUdf, } impl Components { @@ -97,6 +102,11 @@ impl Components { Self::HummockTrace => "[Build] Hummock Trace", Self::Coredump => "[Runtime] Enable coredump", Self::NoBacktrace => "[Runtime] Disable backtrace", + Self::ExternalUdf => "[Build] Enable external UDF", + Self::WasmUdf => "[Build] Enable Wasm UDF", + Self::JsUdf => "[Build] Enable JS UDF", + Self::DenoUdf => "[Build] Enable Deno UDF", + Self::PythonUdf => "[Build] Enable Python UDF", } .into() } @@ -194,11 +204,16 @@ the binaries will also be codesigned with `get-task-allow` enabled. As a result, RisingWave will dump the core on panics. " } - Components::NoBacktrace => { + Self::NoBacktrace => { " With this option enabled, RiseDev will not set `RUST_BACKTRACE` when launching nodes. " } + Self::ExternalUdf => "Required if you want to support external UDF.", + Self::WasmUdf => "Required if you want to support WASM UDF.", + Self::JsUdf => "Required if you want to support JS UDF.", + Self::DenoUdf => "Required if you want to support Deno UDF.", + Self::PythonUdf => "Required if you want to support Python UDF.", } .into() } @@ -222,6 +237,11 @@ With this option enabled, RiseDev will not set `RUST_BACKTRACE` when launching n "ENABLE_HUMMOCK_TRACE" => Some(Self::HummockTrace), "ENABLE_COREDUMP" => Some(Self::Coredump), "DISABLE_BACKTRACE" => Some(Self::NoBacktrace), + "ENABLE_EXTERNAL_UDF" => Some(Self::ExternalUdf), + "ENABLE_WASM_UDF" => Some(Self::WasmUdf), + "ENABLE_JS_UDF" => Some(Self::JsUdf), + "ENABLE_DENO_UDF" => Some(Self::DenoUdf), + "ENABLE_PYTHON_UDF" => Some(Self::PythonUdf), _ => None, } } @@ -245,6 +265,11 @@ With this option enabled, RiseDev will not set `RUST_BACKTRACE` when launching n Self::HummockTrace => "ENABLE_HUMMOCK_TRACE", Self::Coredump => "ENABLE_COREDUMP", Self::NoBacktrace => "DISABLE_BACKTRACE", + Self::ExternalUdf => "ENABLE_EXTERNAL_UDF", + Self::WasmUdf => "ENABLE_WASM_UDF", + Self::JsUdf => "ENABLE_JS_UDF", + Self::DenoUdf => "ENABLE_DENO_UDF", + Self::PythonUdf => "ENABLE_PYTHON_UDF", } .into() } diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index 919d87356a3f2..eed9f274705b7 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -2887,8 +2887,8 @@ impl fmt::Display for TableColumnDef { pub struct CreateFunctionBody { /// LANGUAGE lang_name pub language: Option, - - pub runtime: Option, + /// RUNTIME runtime_name + pub runtime: Option, /// IMMUTABLE | STABLE | VOLATILE pub behavior: Option, @@ -2909,11 +2909,9 @@ impl fmt::Display for CreateFunctionBody { if let Some(language) = &self.language { write!(f, " LANGUAGE {language}")?; } - if let Some(runtime) = &self.runtime { write!(f, " RUNTIME {runtime}")?; } - if let Some(behavior) = &self.behavior { write!(f, " {behavior}")?; } @@ -3003,22 +3001,6 @@ impl fmt::Display for CreateFunctionUsing { } } -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum FunctionRuntime { - QuickJs, - Deno, -} - -impl fmt::Display for FunctionRuntime { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - FunctionRuntime::QuickJs => write!(f, "quickjs"), - FunctionRuntime::Deno => write!(f, "deno"), - } - } -} - #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum CreateFunctionType { @@ -3320,7 +3302,7 @@ mod tests { as_: Some(FunctionDefinition::SingleQuotedDef("SELECT 1".to_string())), return_: None, using: None, - runtime: Some(FunctionRuntime::Deno), + runtime: Some(Ident::new_unchecked("deno")), function_type: Some(CreateFunctionType::AsyncGenerator), }, with_options: CreateFunctionWithOptions { diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 84b1e1d97808d..5b2351073b55d 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -2410,7 +2410,7 @@ impl Parser { body.language = Some(self.parse_identifier()?); } else if self.parse_keyword(Keyword::RUNTIME) { ensure_not_set(&body.runtime, "RUNTIME")?; - body.runtime = Some(self.parse_function_runtime()?); + body.runtime = Some(self.parse_identifier()?); } else if self.parse_keyword(Keyword::IMMUTABLE) { ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?; body.behavior = Some(FunctionBehavior::Immutable); @@ -2457,17 +2457,6 @@ impl Parser { } } - fn parse_function_runtime(&mut self) -> Result { - let ident = self.parse_identifier()?; - match ident.value.to_lowercase().as_str() { - "deno" => Ok(FunctionRuntime::Deno), - "quickjs" => Ok(FunctionRuntime::QuickJs), - r => Err(ParserError::ParserError(format!( - "Unsupported runtime: {r}" - ))), - } - } - fn parse_function_type( &mut self, is_async: bool, diff --git a/src/tests/simulation/Cargo.toml b/src/tests/simulation/Cargo.toml index 31736ea848423..57768643eb7dc 100644 --- a/src/tests/simulation/Cargo.toml +++ b/src/tests/simulation/Cargo.toml @@ -38,7 +38,7 @@ risingwave_compute = { workspace = true } risingwave_connector = { workspace = true } risingwave_ctl = { workspace = true } risingwave_e2e_extended_mode_test = { path = "../e2e_extended_mode" } -risingwave_expr_impl = { workspace = true } +risingwave_expr_impl = { workspace = true, features = ["js-udf"] } risingwave_frontend = { workspace = true } risingwave_hummock_sdk = { workspace = true } risingwave_meta_node = { workspace = true }