Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(udf): store WASM UDF in meta store (#15269) #15341

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions proto/catalog.proto
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,10 @@ message Function {
string language = 7;
optional string link = 8;
optional string identifier = 10;
// The source code of the function.
optional string body = 14;
// The zstd-compressed binary of the function.
optional bytes compressed_binary = 17;
bool always_retry_on_network_error = 16;

oneof kind {
Expand Down
16 changes: 9 additions & 7 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -482,16 +482,17 @@ message UserDefinedFunction {
repeated string arg_names = 8;
repeated data.DataType arg_types = 3;
string language = 4;
// For external UDF: the link to the external function service.
// For WASM UDF: the link to the wasm binary file.
// The link to the external function service.
optional string link = 5;
// An unique identifier for the function.
// For external UDF, it's the name of the function in the external function service.
// For WASM UDF, it's the name of the function in the wasm binary file.
// For JavaScript UDF, it's the name of the function.
// An unique identifier to the function.
// - If `link` is not empty, the name of the function in the external function service.
// - If `language` is `rust` or `wasm`, the name of the function in the wasm binary file.
// - If `language` is `javascript`, the name of the function.
optional string identifier = 6;
// For JavaScript UDF, it's the body of the function.
// - If `language` is `javascript`, the source code of the function.
optional string body = 7;
// - If `language` is `rust` or `wasm`, the zstd-compressed wasm binary.
optional bytes compressed_binary = 10;
bool always_retry_on_network_error = 9;
}

Expand All @@ -503,4 +504,5 @@ message UserDefinedTableFunction {
optional string link = 5;
optional string identifier = 6;
optional string body = 7;
optional bytes compressed_binary = 10;
}
2 changes: 1 addition & 1 deletion proto/meta.proto
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ message SystemParams {
optional uint32 parallel_compact_size_mb = 11;
optional uint32 max_concurrent_creating_streaming_jobs = 12;
optional bool pause_on_next_bootstrap = 13;
optional string wasm_storage_url = 14;
optional string wasm_storage_url = 14 [deprecated = true];
optional bool enable_tracing = 15;
}

Expand Down
2 changes: 0 additions & 2 deletions src/common/src/system_param/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ macro_rules! for_all_params {
{ backup_storage_directory, String, None, true, "Remote directory for storing snapshots.", },
{ max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true, "Max number of concurrent creating streaming jobs.", },
{ pause_on_next_bootstrap, bool, Some(false), true, "Whether to pause all data sources on next bootstrap.", },
{ wasm_storage_url, String, Some("fs://.risingwave/data".to_string()), false, "", },
{ enable_tracing, bool, Some(false), true, "Whether to enable distributed tracing.", },
}
};
Expand Down Expand Up @@ -440,7 +439,6 @@ mod tests {
(BACKUP_STORAGE_DIRECTORY_KEY, "a"),
(MAX_CONCURRENT_CREATING_STREAMING_JOBS_KEY, "1"),
(PAUSE_ON_NEXT_BOOTSTRAP_KEY, "false"),
(WASM_STORAGE_URL_KEY, "a"),
(ENABLE_TRACING_KEY, "true"),
("a_deprecated_param", "foo"),
];
Expand Down
7 changes: 0 additions & 7 deletions src/common/src/system_param/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,4 @@ where
.enable_tracing
.unwrap_or_else(default::enable_tracing)
}

fn wasm_storage_url(&self) -> &str {
self.inner()
.wasm_storage_url
.as_ref()
.unwrap_or(&default::WASM_STORAGE_URL)
}
}
1 change: 0 additions & 1 deletion src/config/example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -195,5 +195,4 @@ block_size_kb = 64
bloom_false_positive = 0.001
max_concurrent_creating_streaming_jobs = 1
pause_on_next_bootstrap = false
wasm_storage_url = "fs://.risingwave/data"
enable_tracing = false
5 changes: 3 additions & 2 deletions src/expr/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ futures-async-stream = { workspace = true }
futures-util = "0.3"
itertools = "0.12"
linkme = { version = "0.3", features = ["used_linker"] }
moka = { version = "0.12", features = ["future"] }
md5 = "0.7"
moka = { version = "0.12", features = ["sync"] }
num-traits = "0.2"
openssl = { version = "0.10", features = ["vendored"] }
parse-display = "0.8"
paste = "1"
risingwave_common = { workspace = true }
risingwave_expr_macro = { path = "../macro" }
risingwave_object_store = { workspace = true }
risingwave_pb = { workspace = true }
risingwave_udf = { workspace = true }
smallvec = "1"
Expand All @@ -56,6 +56,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [
"macros",
] }
tracing = "0.1"
zstd = { version = "0.13", default-features = false }

[target.'cfg(not(madsim))'.dependencies]
workspace-hack = { path = "../../workspace-hack" }
Expand Down
48 changes: 14 additions & 34 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,10 @@ use arrow_udf_js::{CallMode, Runtime as JsRuntime};
use arrow_udf_wasm::Runtime as WasmRuntime;
use await_tree::InstrumentAwait;
use cfg_or_panic::cfg_or_panic;
use moka::future::Cache;
use moka::sync::Cache;
use risingwave_common::array::{ArrayError, ArrayRef, DataChunk};
use risingwave_common::config::ObjectStoreConfig;
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};
use risingwave_object_store::object::build_remote_object_store;
use risingwave_object_store::object::object_metrics::ObjectStoreMetrics;
use risingwave_pb::expr::ExprNode;
use risingwave_udf::ArrowFlightUdfClient;
use thiserror_ext::AsReport;
Expand Down Expand Up @@ -189,13 +186,13 @@ impl Build for UserDefinedFunction {

let identifier = udf.get_identifier()?;
let imp = match udf.language.as_str() {
#[cfg(not(madsim))]
"wasm" => {
let link = udf.get_link()?;
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(get_or_create_wasm_runtime(link))
})?)
let compressed_wasm_binary = udf.get_compressed_binary()?;
let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice())
.context("failed to decompress wasm binary")?;
let runtime = get_or_create_wasm_runtime(&wasm_binary)?;
UdfImpl::Wasm(runtime)
}
"javascript" => {
let mut rt = JsRuntime::new()?;
Expand Down Expand Up @@ -270,38 +267,21 @@ pub(crate) fn get_or_create_flight_client(link: &str) -> Result<Arc<ArrowFlightU
/// Get or create a wasm runtime.
///
/// Runtimes returned by this function are cached inside for at least 60 seconds.
/// Later calls with the same link will reuse the same runtime.
/// Later calls with the same binary will reuse the same runtime.
#[cfg_or_panic(not(madsim))]
pub async fn get_or_create_wasm_runtime(link: &str) -> Result<Arc<WasmRuntime>> {
static RUNTIMES: LazyLock<Cache<String, Arc<WasmRuntime>>> = LazyLock::new(|| {
pub fn get_or_create_wasm_runtime(binary: &[u8]) -> Result<Arc<WasmRuntime>> {
static RUNTIMES: LazyLock<Cache<md5::Digest, Arc<WasmRuntime>>> = LazyLock::new(|| {
Cache::builder()
.time_to_idle(Duration::from_secs(60))
.build()
});

if let Some(runtime) = RUNTIMES.get(link).await {
let md5 = md5::compute(binary);
if let Some(runtime) = RUNTIMES.get(&md5) {
return Ok(runtime.clone());
}

// create new runtime
let (wasm_storage_url, object_name) = link
.rsplit_once('/')
.context("invalid link for wasm function")?;

// load wasm binary from object store
let object_store = build_remote_object_store(
wasm_storage_url,
Arc::new(ObjectStoreMetrics::unused()),
"Wasm Engine",
ObjectStoreConfig::default(),
)
.await;
let binary = object_store
.read(object_name, ..)
.await
.context("failed to load wasm binary from object storage")?;

let runtime = Arc::new(arrow_udf_wasm::Runtime::new(&binary)?);
RUNTIMES.insert(link.into(), runtime.clone()).await;
let runtime = Arc::new(arrow_udf_wasm::Runtime::new(binary)?);
RUNTIMES.insert(md5, runtime.clone());
Ok(runtime)
}
15 changes: 7 additions & 8 deletions src/expr/core/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::sync::Arc;

use anyhow::Context;
use arrow_array::RecordBatch;
use arrow_schema::{Field, Fields, Schema, SchemaRef};
use arrow_udf_js::{CallMode, Runtime as JsRuntime};
Expand Down Expand Up @@ -188,14 +189,12 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<Bo
let return_type = DataType::from(prost.get_return_type()?);

let client = match udtf.language.as_str() {
"wasm" => {
let link = udtf.get_link()?;
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(crate::expr::expr_udf::get_or_create_wasm_runtime(link))
})?)
"wasm" | "rust" => {
let compressed_wasm_binary = udtf.get_compressed_binary()?;
let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice())
.context("failed to decompress wasm binary")?;
let runtime = crate::expr::expr_udf::get_or_create_wasm_runtime(&wasm_binary)?;
UdfImpl::Wasm(runtime)
}
"javascript" => {
let mut rt = JsRuntime::new()?;
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ tokio-stream = "0.1"
tonic = { workspace = true }
tracing = "0.1"
uuid = "1"
zstd = { version = "0.13", default-features = false }

[target.'cfg(not(madsim))'.dependencies]
workspace-hack = { path = "../workspace-hack" }
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/catalog/function_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct FunctionCatalog {
pub identifier: Option<String>,
pub body: Option<String>,
pub link: Option<String>,
pub compressed_binary: Option<Vec<u8>>,
pub always_retry_on_network_error: bool,
}

Expand Down Expand Up @@ -69,6 +70,7 @@ impl From<&PbFunction> for FunctionCatalog {
identifier: prost.identifier.clone(),
body: prost.body.clone(),
link: prost.link.clone(),
compressed_binary: prost.compressed_binary.clone(),
always_retry_on_network_error: prost.always_retry_on_network_error,
}
}
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/expr/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl TableFunction {
link: c.link.clone(),
identifier: c.identifier.clone(),
body: c.body.clone(),
compressed_binary: c.compressed_binary.clone(),
}),
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/expr/user_defined_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl UserDefinedFunction {
identifier: udf.identifier.clone(),
body: udf.body.clone(),
link: udf.link.clone(),
compressed_binary: udf.compressed_binary.clone(),
always_retry_on_network_error: udf.always_retry_on_network_error,
};

Expand Down Expand Up @@ -93,6 +94,7 @@ impl Expr for UserDefinedFunction {
identifier: self.catalog.identifier.clone(),
link: self.catalog.link.clone(),
body: self.catalog.body.clone(),
compressed_binary: self.catalog.compressed_binary.clone(),
always_retry_on_network_error: self.catalog.always_retry_on_network_error,
})),
}
Expand Down
Loading
Loading