Skip to content

Commit

Permalink
feat(udf): store WASM UDF in meta store (#15269)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Feb 27, 2024
1 parent c8a61fc commit e0f9c68
Show file tree
Hide file tree
Showing 23 changed files with 77 additions and 129 deletions.
5 changes: 4 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion e2e_test/batch/catalog/pg_settings.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ internal data_directory
internal parallel_compact_size_mb
internal sstable_size_mb
internal state_store
internal wasm_storage_url
postmaster backup_storage_directory
postmaster backup_storage_url
postmaster barrier_interval_ms
Expand Down
3 changes: 3 additions & 0 deletions proto/catalog.proto
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,10 @@ message Function {
string language = 7;
optional string link = 8;
optional string identifier = 10;
// The source code of the function.
optional string body = 14;
// The zstd-compressed binary of the function.
optional bytes compressed_binary = 17;
bool always_retry_on_network_error = 16;

oneof kind {
Expand Down
16 changes: 9 additions & 7 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -517,16 +517,17 @@ message UserDefinedFunction {
repeated string arg_names = 8;
repeated data.DataType arg_types = 3;
string language = 4;
// For external UDF: the link to the external function service.
// For WASM UDF: the link to the wasm binary file.
// The link to the external function service.
optional string link = 5;
// An unique identifier for the function.
// For external UDF, it's the name of the function in the external function service.
// For WASM UDF, it's the name of the function in the wasm binary file.
// For JavaScript UDF, it's the name of the function.
// An unique identifier to the function.
// - If `link` is not empty, the name of the function in the external function service.
// - If `language` is `rust` or `wasm`, the name of the function in the wasm binary file.
// - If `language` is `javascript`, the name of the function.
optional string identifier = 6;
// For JavaScript UDF, it's the body of the function.
// - If `language` is `javascript`, the source code of the function.
optional string body = 7;
// - If `language` is `rust` or `wasm`, the zstd-compressed wasm binary.
optional bytes compressed_binary = 10;
bool always_retry_on_network_error = 9;
}

Expand All @@ -538,4 +539,5 @@ message UserDefinedTableFunction {
optional string link = 5;
optional string identifier = 6;
optional string body = 7;
optional bytes compressed_binary = 10;
}
2 changes: 1 addition & 1 deletion proto/meta.proto
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ message SystemParams {
optional uint32 parallel_compact_size_mb = 11;
optional uint32 max_concurrent_creating_streaming_jobs = 12;
optional bool pause_on_next_bootstrap = 13;
optional string wasm_storage_url = 14;
optional string wasm_storage_url = 14 [deprecated = true];
optional bool enable_tracing = 15;
}

Expand Down
2 changes: 0 additions & 2 deletions src/common/src/system_param/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ macro_rules! for_all_params {
{ backup_storage_directory, String, None, true, "Remote directory for storing snapshots.", },
{ max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true, "Max number of concurrent creating streaming jobs.", },
{ pause_on_next_bootstrap, bool, Some(false), true, "Whether to pause all data sources on next bootstrap.", },
{ wasm_storage_url, String, Some("fs://.risingwave/data".to_string()), false, "", },
{ enable_tracing, bool, Some(false), true, "Whether to enable distributed tracing.", },
}
};
Expand Down Expand Up @@ -440,7 +439,6 @@ mod tests {
(BACKUP_STORAGE_DIRECTORY_KEY, "a"),
(MAX_CONCURRENT_CREATING_STREAMING_JOBS_KEY, "1"),
(PAUSE_ON_NEXT_BOOTSTRAP_KEY, "false"),
(WASM_STORAGE_URL_KEY, "a"),
(ENABLE_TRACING_KEY, "true"),
("a_deprecated_param", "foo"),
];
Expand Down
7 changes: 0 additions & 7 deletions src/common/src/system_param/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,4 @@ where
.enable_tracing
.unwrap_or_else(default::enable_tracing)
}

fn wasm_storage_url(&self) -> &str {
self.inner()
.wasm_storage_url
.as_ref()
.unwrap_or(&default::WASM_STORAGE_URL)
}
}
1 change: 0 additions & 1 deletion src/config/docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,3 @@ This page is automatically generated by `./risedev generate-example-config`
| pause_on_next_bootstrap | Whether to pause all data sources on next bootstrap. | false |
| sstable_size_mb | Target size of the Sstable. | 256 |
| state_store | | |
| wasm_storage_url | | "fs://.risingwave/data" |
1 change: 0 additions & 1 deletion src/config/example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -195,5 +195,4 @@ block_size_kb = 64
bloom_false_positive = 0.001
max_concurrent_creating_streaming_jobs = 1
pause_on_next_bootstrap = false
wasm_storage_url = "fs://.risingwave/data"
enable_tracing = false
5 changes: 3 additions & 2 deletions src/expr/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ futures-async-stream = { workspace = true }
futures-util = "0.3"
itertools = "0.12"
linkme = { version = "0.3", features = ["used_linker"] }
moka = { version = "0.12", features = ["future"] }
md5 = "0.7"
moka = { version = "0.12", features = ["sync"] }
num-traits = "0.2"
openssl = { version = "0.10", features = ["vendored"] }
parse-display = "0.9"
paste = "1"
risingwave_common = { workspace = true }
risingwave_expr_macro = { path = "../macro" }
risingwave_object_store = { workspace = true }
risingwave_pb = { workspace = true }
risingwave_udf = { workspace = true }
smallvec = "1"
Expand All @@ -57,6 +57,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [
"macros",
] }
tracing = "0.1"
zstd = { version = "0.13", default-features = false }

[target.'cfg(not(madsim))'.dependencies]
workspace-hack = { path = "../../workspace-hack" }
Expand Down
47 changes: 13 additions & 34 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,10 @@ use arrow_udf_js::{CallMode, Runtime as JsRuntime};
use arrow_udf_wasm::Runtime as WasmRuntime;
use await_tree::InstrumentAwait;
use cfg_or_panic::cfg_or_panic;
use moka::future::Cache;
use moka::sync::Cache;
use risingwave_common::array::{ArrayError, ArrayRef, DataChunk};
use risingwave_common::config::ObjectStoreConfig;
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};
use risingwave_object_store::object::build_remote_object_store;
use risingwave_object_store::object::object_metrics::ObjectStoreMetrics;
use risingwave_pb::expr::ExprNode;
use risingwave_udf::ArrowFlightUdfClient;
use thiserror_ext::AsReport;
Expand Down Expand Up @@ -188,12 +185,11 @@ impl Build for UserDefinedFunction {
let imp = match udf.language.as_str() {
#[cfg(not(madsim))]
"wasm" | "rust" => {
let link = udf.get_link()?;
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(get_or_create_wasm_runtime(link))
})?)
let compressed_wasm_binary = udf.get_compressed_binary()?;
let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice())
.context("failed to decompress wasm binary")?;
let runtime = get_or_create_wasm_runtime(&wasm_binary)?;
UdfImpl::Wasm(runtime)
}
"javascript" => {
let mut rt = JsRuntime::new()?;
Expand Down Expand Up @@ -271,38 +267,21 @@ pub(crate) fn get_or_create_flight_client(link: &str) -> Result<Arc<ArrowFlightU
/// Get or create a wasm runtime.
///
/// Runtimes returned by this function are cached inside for at least 60 seconds.
/// Later calls with the same link will reuse the same runtime.
/// Later calls with the same binary will reuse the same runtime.
#[cfg_or_panic(not(madsim))]
pub async fn get_or_create_wasm_runtime(link: &str) -> Result<Arc<WasmRuntime>> {
static RUNTIMES: LazyLock<Cache<String, Arc<WasmRuntime>>> = LazyLock::new(|| {
pub fn get_or_create_wasm_runtime(binary: &[u8]) -> Result<Arc<WasmRuntime>> {
static RUNTIMES: LazyLock<Cache<md5::Digest, Arc<WasmRuntime>>> = LazyLock::new(|| {
Cache::builder()
.time_to_idle(Duration::from_secs(60))
.build()
});

if let Some(runtime) = RUNTIMES.get(link).await {
let md5 = md5::compute(binary);
if let Some(runtime) = RUNTIMES.get(&md5) {
return Ok(runtime.clone());
}

// create new runtime
let (wasm_storage_url, object_name) = link
.rsplit_once('/')
.context("invalid link for wasm function")?;

// load wasm binary from object store
let object_store = build_remote_object_store(
wasm_storage_url,
Arc::new(ObjectStoreMetrics::unused()),
"Wasm Engine",
ObjectStoreConfig::default(),
)
.await;
let binary = object_store
.read(object_name, ..)
.await
.context("failed to load wasm binary from object storage")?;

let runtime = Arc::new(arrow_udf_wasm::Runtime::new(&binary)?);
RUNTIMES.insert(link.into(), runtime.clone()).await;
let runtime = Arc::new(arrow_udf_wasm::Runtime::new(binary)?);
RUNTIMES.insert(md5, runtime.clone());
Ok(runtime)
}
15 changes: 7 additions & 8 deletions src/expr/core/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::sync::Arc;

use anyhow::Context;
use arrow_array::RecordBatch;
use arrow_schema::{Field, Fields, Schema, SchemaRef};
use arrow_udf_js::{CallMode, Runtime as JsRuntime};
Expand Down Expand Up @@ -188,14 +189,12 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<Bo
let return_type = DataType::from(prost.get_return_type()?);

let client = match udtf.language.as_str() {
"wasm" => {
let link = udtf.get_link()?;
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(crate::expr::expr_udf::get_or_create_wasm_runtime(link))
})?)
"wasm" | "rust" => {
let compressed_wasm_binary = udtf.get_compressed_binary()?;
let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice())
.context("failed to decompress wasm binary")?;
let runtime = crate::expr::expr_udf::get_or_create_wasm_runtime(&wasm_binary)?;
UdfImpl::Wasm(runtime)
}
"javascript" => {
let mut rt = JsRuntime::new()?;
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ tokio-stream = "0.1"
tonic = { workspace = true }
tracing = "0.1"
uuid = "1"
zstd = { version = "0.13", default-features = false }

[target.'cfg(not(madsim))'.dependencies]
workspace-hack = { path = "../workspace-hack" }
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/catalog/function_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct FunctionCatalog {
pub identifier: Option<String>,
pub body: Option<String>,
pub link: Option<String>,
pub compressed_binary: Option<Vec<u8>>,
pub always_retry_on_network_error: bool,
}

Expand Down Expand Up @@ -69,6 +70,7 @@ impl From<&PbFunction> for FunctionCatalog {
identifier: prost.identifier.clone(),
body: prost.body.clone(),
link: prost.link.clone(),
compressed_binary: prost.compressed_binary.clone(),
always_retry_on_network_error: prost.always_retry_on_network_error,
}
}
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/expr/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl TableFunction {
link: c.link.clone(),
identifier: c.identifier.clone(),
body: c.body.clone(),
compressed_binary: c.compressed_binary.clone(),
}),
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/expr/user_defined_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl UserDefinedFunction {
identifier: udf.identifier.clone(),
body: udf.body.clone(),
link: udf.link.clone(),
compressed_binary: udf.compressed_binary.clone(),
always_retry_on_network_error: udf.always_retry_on_network_error,
};

Expand Down Expand Up @@ -93,6 +94,7 @@ impl Expr for UserDefinedFunction {
identifier: self.catalog.identifier.clone(),
link: self.catalog.link.clone(),
body: self.catalog.body.clone(),
compressed_binary: self.catalog.compressed_binary.clone(),
always_retry_on_network_error: self.catalog.always_retry_on_network_error,
})),
}
Expand Down
Loading

0 comments on commit e0f9c68

Please sign in to comment.