From ba292a6e9a0d0db072b613300fe30af692d1d6bf Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 28 Feb 2024 19:23:58 +0800 Subject: [PATCH] feat(udf): store WASM UDF in meta store (#15269) (#15341) --- Cargo.lock | 5 +- proto/catalog.proto | 3 + proto/expr.proto | 16 +++-- proto/meta.proto | 2 +- src/common/src/system_param/mod.rs | 2 - src/common/src/system_param/reader.rs | 7 -- src/config/example.toml | 1 - src/expr/core/Cargo.toml | 5 +- src/expr/core/src/expr/expr_udf.rs | 48 ++++--------- .../core/src/table_function/user_defined.rs | 15 ++-- src/frontend/Cargo.toml | 1 + src/frontend/src/catalog/function_catalog.rs | 2 + src/frontend/src/expr/table_function.rs | 1 + .../src/expr/user_defined_function.rs | 2 + src/frontend/src/handler/create_function.rs | 70 ++++++------------- .../src/handler/create_sql_function.rs | 1 + .../migration/src/m20230908_072257_init.rs | 2 + src/meta/model_v2/src/function.rs | 2 + src/meta/src/backup_restore/restore.rs | 1 - src/meta/src/controller/mod.rs | 1 + src/workspace-hack/Cargo.toml | 1 + 21 files changed, 76 insertions(+), 112 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 676be08f8703e..756e793d086e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9224,6 +9224,7 @@ dependencies = [ "itertools 0.12.0", "linkme", "madsim-tokio", + "md5", "moka", "num-traits", "openssl", @@ -9231,7 +9232,6 @@ dependencies = [ "paste", "risingwave_common", "risingwave_expr_macro", - "risingwave_object_store", "risingwave_pb", "risingwave_udf", "smallvec", @@ -9240,6 +9240,7 @@ dependencies = [ "thiserror-ext", "tracing", "workspace-hack", + "zstd 0.13.0", ] [[package]] @@ -9368,6 +9369,7 @@ dependencies = [ "tracing", "uuid", "workspace-hack", + "zstd 0.13.0", ] [[package]] @@ -13865,6 +13867,7 @@ dependencies = [ "madsim-tokio", "md-5", "mio", + "moka", "nom", "num-bigint", "num-integer", diff --git a/proto/catalog.proto b/proto/catalog.proto index cfc2c37976151..2e8a2b19c86c1 100644 --- a/proto/catalog.proto +++ b/proto/catalog.proto @@ -219,7 +219,10 @@ message Function { string language = 7; optional string link = 8; optional string identifier = 10; + // The source code of the function. optional string body = 14; + // The zstd-compressed binary of the function. + optional bytes compressed_binary = 17; bool always_retry_on_network_error = 16; oneof kind { diff --git a/proto/expr.proto b/proto/expr.proto index c58c935d84c6c..bd75dcd8b27a4 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -482,16 +482,17 @@ message UserDefinedFunction { repeated string arg_names = 8; repeated data.DataType arg_types = 3; string language = 4; - // For external UDF: the link to the external function service. - // For WASM UDF: the link to the wasm binary file. + // The link to the external function service. optional string link = 5; - // An unique identifier for the function. - // For external UDF, it's the name of the function in the external function service. - // For WASM UDF, it's the name of the function in the wasm binary file. - // For JavaScript UDF, it's the name of the function. + // An unique identifier to the function. + // - If `link` is not empty, the name of the function in the external function service. + // - If `language` is `rust` or `wasm`, the name of the function in the wasm binary file. + // - If `language` is `javascript`, the name of the function. optional string identifier = 6; - // For JavaScript UDF, it's the body of the function. + // - If `language` is `javascript`, the source code of the function. optional string body = 7; + // - If `language` is `rust` or `wasm`, the zstd-compressed wasm binary. + optional bytes compressed_binary = 10; bool always_retry_on_network_error = 9; } @@ -503,4 +504,5 @@ message UserDefinedTableFunction { optional string link = 5; optional string identifier = 6; optional string body = 7; + optional bytes compressed_binary = 10; } diff --git a/proto/meta.proto b/proto/meta.proto index 01492cc0c4fff..1db290af7b308 100644 --- a/proto/meta.proto +++ b/proto/meta.proto @@ -555,7 +555,7 @@ message SystemParams { optional uint32 parallel_compact_size_mb = 11; optional uint32 max_concurrent_creating_streaming_jobs = 12; optional bool pause_on_next_bootstrap = 13; - optional string wasm_storage_url = 14; + optional string wasm_storage_url = 14 [deprecated = true]; optional bool enable_tracing = 15; } diff --git a/src/common/src/system_param/mod.rs b/src/common/src/system_param/mod.rs index 278390887dd51..335aabb264f32 100644 --- a/src/common/src/system_param/mod.rs +++ b/src/common/src/system_param/mod.rs @@ -86,7 +86,6 @@ macro_rules! for_all_params { { backup_storage_directory, String, None, true, "Remote directory for storing snapshots.", }, { max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true, "Max number of concurrent creating streaming jobs.", }, { pause_on_next_bootstrap, bool, Some(false), true, "Whether to pause all data sources on next bootstrap.", }, - { wasm_storage_url, String, Some("fs://.risingwave/data".to_string()), false, "", }, { enable_tracing, bool, Some(false), true, "Whether to enable distributed tracing.", }, } }; @@ -440,7 +439,6 @@ mod tests { (BACKUP_STORAGE_DIRECTORY_KEY, "a"), (MAX_CONCURRENT_CREATING_STREAMING_JOBS_KEY, "1"), (PAUSE_ON_NEXT_BOOTSTRAP_KEY, "false"), - (WASM_STORAGE_URL_KEY, "a"), (ENABLE_TRACING_KEY, "true"), ("a_deprecated_param", "foo"), ]; diff --git a/src/common/src/system_param/reader.rs b/src/common/src/system_param/reader.rs index c6b8d8c5af6aa..54bfcfa9e9404 100644 --- a/src/common/src/system_param/reader.rs +++ b/src/common/src/system_param/reader.rs @@ -142,11 +142,4 @@ where .enable_tracing .unwrap_or_else(default::enable_tracing) } - - fn wasm_storage_url(&self) -> &str { - self.inner() - .wasm_storage_url - .as_ref() - .unwrap_or(&default::WASM_STORAGE_URL) - } } diff --git a/src/config/example.toml b/src/config/example.toml index 59c68aff3c7c0..6a6314d7832d2 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -195,5 +195,4 @@ block_size_kb = 64 bloom_false_positive = 0.001 max_concurrent_creating_streaming_jobs = 1 pause_on_next_bootstrap = false -wasm_storage_url = "fs://.risingwave/data" enable_tracing = false diff --git a/src/expr/core/Cargo.toml b/src/expr/core/Cargo.toml index 51972f282826c..89fcc846de5c8 100644 --- a/src/expr/core/Cargo.toml +++ b/src/expr/core/Cargo.toml @@ -37,14 +37,14 @@ futures-async-stream = { workspace = true } futures-util = "0.3" itertools = "0.12" linkme = { version = "0.3", features = ["used_linker"] } -moka = { version = "0.12", features = ["future"] } +md5 = "0.7" +moka = { version = "0.12", features = ["sync"] } num-traits = "0.2" openssl = { version = "0.10", features = ["vendored"] } parse-display = "0.8" paste = "1" risingwave_common = { workspace = true } risingwave_expr_macro = { path = "../macro" } -risingwave_object_store = { workspace = true } risingwave_pb = { workspace = true } risingwave_udf = { workspace = true } smallvec = "1" @@ -56,6 +56,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [ "macros", ] } tracing = "0.1" +zstd = { version = "0.13", default-features = false } [target.'cfg(not(madsim))'.dependencies] workspace-hack = { path = "../../workspace-hack" } diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 1f7e15a591b42..838faa848795d 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -24,13 +24,10 @@ use arrow_udf_js::{CallMode, Runtime as JsRuntime}; use arrow_udf_wasm::Runtime as WasmRuntime; use await_tree::InstrumentAwait; use cfg_or_panic::cfg_or_panic; -use moka::future::Cache; +use moka::sync::Cache; use risingwave_common::array::{ArrayError, ArrayRef, DataChunk}; -use risingwave_common::config::ObjectStoreConfig; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; -use risingwave_object_store::object::build_remote_object_store; -use risingwave_object_store::object::object_metrics::ObjectStoreMetrics; use risingwave_pb::expr::ExprNode; use risingwave_udf::ArrowFlightUdfClient; use thiserror_ext::AsReport; @@ -189,13 +186,13 @@ impl Build for UserDefinedFunction { let identifier = udf.get_identifier()?; let imp = match udf.language.as_str() { + #[cfg(not(madsim))] "wasm" => { - let link = udf.get_link()?; - // Use `block_in_place` as an escape hatch to run async code here in sync context. - // Calling `block_on` directly will panic. - UdfImpl::Wasm(tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(get_or_create_wasm_runtime(link)) - })?) + let compressed_wasm_binary = udf.get_compressed_binary()?; + let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice()) + .context("failed to decompress wasm binary")?; + let runtime = get_or_create_wasm_runtime(&wasm_binary)?; + UdfImpl::Wasm(runtime) } "javascript" => { let mut rt = JsRuntime::new()?; @@ -270,38 +267,21 @@ pub(crate) fn get_or_create_flight_client(link: &str) -> Result Result> { - static RUNTIMES: LazyLock>> = LazyLock::new(|| { +pub fn get_or_create_wasm_runtime(binary: &[u8]) -> Result> { + static RUNTIMES: LazyLock>> = LazyLock::new(|| { Cache::builder() .time_to_idle(Duration::from_secs(60)) .build() }); - if let Some(runtime) = RUNTIMES.get(link).await { + let md5 = md5::compute(binary); + if let Some(runtime) = RUNTIMES.get(&md5) { return Ok(runtime.clone()); } - // create new runtime - let (wasm_storage_url, object_name) = link - .rsplit_once('/') - .context("invalid link for wasm function")?; - - // load wasm binary from object store - let object_store = build_remote_object_store( - wasm_storage_url, - Arc::new(ObjectStoreMetrics::unused()), - "Wasm Engine", - ObjectStoreConfig::default(), - ) - .await; - let binary = object_store - .read(object_name, ..) - .await - .context("failed to load wasm binary from object storage")?; - - let runtime = Arc::new(arrow_udf_wasm::Runtime::new(&binary)?); - RUNTIMES.insert(link.into(), runtime.clone()).await; + let runtime = Arc::new(arrow_udf_wasm::Runtime::new(binary)?); + RUNTIMES.insert(md5, runtime.clone()); Ok(runtime) } diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index 06383543ceb7b..ad9ba03943662 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -14,6 +14,7 @@ use std::sync::Arc; +use anyhow::Context; use arrow_array::RecordBatch; use arrow_schema::{Field, Fields, Schema, SchemaRef}; use arrow_udf_js::{CallMode, Runtime as JsRuntime}; @@ -188,14 +189,12 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result { - let link = udtf.get_link()?; - // Use `block_in_place` as an escape hatch to run async code here in sync context. - // Calling `block_on` directly will panic. - UdfImpl::Wasm(tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(crate::expr::expr_udf::get_or_create_wasm_runtime(link)) - })?) + "wasm" | "rust" => { + let compressed_wasm_binary = udtf.get_compressed_binary()?; + let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice()) + .context("failed to decompress wasm binary")?; + let runtime = crate::expr::expr_udf::get_or_create_wasm_runtime(&wasm_binary)?; + UdfImpl::Wasm(runtime) } "javascript" => { let mut rt = JsRuntime::new()?; diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index bf4dd37c233b6..a503b4efd084c 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -90,6 +90,7 @@ tokio-stream = "0.1" tonic = { workspace = true } tracing = "0.1" uuid = "1" +zstd = { version = "0.13", default-features = false } [target.'cfg(not(madsim))'.dependencies] workspace-hack = { path = "../workspace-hack" } diff --git a/src/frontend/src/catalog/function_catalog.rs b/src/frontend/src/catalog/function_catalog.rs index e60a3a758b7b5..142bc222a59b8 100644 --- a/src/frontend/src/catalog/function_catalog.rs +++ b/src/frontend/src/catalog/function_catalog.rs @@ -33,6 +33,7 @@ pub struct FunctionCatalog { pub identifier: Option, pub body: Option, pub link: Option, + pub compressed_binary: Option>, pub always_retry_on_network_error: bool, } @@ -69,6 +70,7 @@ impl From<&PbFunction> for FunctionCatalog { identifier: prost.identifier.clone(), body: prost.body.clone(), link: prost.link.clone(), + compressed_binary: prost.compressed_binary.clone(), always_retry_on_network_error: prost.always_retry_on_network_error, } } diff --git a/src/frontend/src/expr/table_function.rs b/src/frontend/src/expr/table_function.rs index e3000d0c245ab..c72c207c53783 100644 --- a/src/frontend/src/expr/table_function.rs +++ b/src/frontend/src/expr/table_function.rs @@ -79,6 +79,7 @@ impl TableFunction { link: c.link.clone(), identifier: c.identifier.clone(), body: c.body.clone(), + compressed_binary: c.compressed_binary.clone(), }), } } diff --git a/src/frontend/src/expr/user_defined_function.rs b/src/frontend/src/expr/user_defined_function.rs index 16dd3b0d65634..4231919a1a4dc 100644 --- a/src/frontend/src/expr/user_defined_function.rs +++ b/src/frontend/src/expr/user_defined_function.rs @@ -58,6 +58,7 @@ impl UserDefinedFunction { identifier: udf.identifier.clone(), body: udf.body.clone(), link: udf.link.clone(), + compressed_binary: udf.compressed_binary.clone(), always_retry_on_network_error: udf.always_retry_on_network_error, }; @@ -93,6 +94,7 @@ impl Expr for UserDefinedFunction { identifier: self.catalog.identifier.clone(), link: self.catalog.link.clone(), body: self.catalog.body.clone(), + compressed_binary: self.catalog.compressed_binary.clone(), always_retry_on_network_error: self.catalog.always_retry_on_network_error, })), } diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 63262e869d8f1..7daccb0f292b8 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -18,16 +18,13 @@ use bytes::Bytes; use itertools::Itertools; use pgwire::pg_response::StatementType; use risingwave_common::catalog::FunctionId; -use risingwave_common::system_param::reader::SystemParamsRead; use risingwave_common::types::DataType; use risingwave_expr::expr::get_or_create_wasm_runtime; -use risingwave_object_store::object::{build_remote_object_store, ObjectStoreConfig}; use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; use risingwave_pb::catalog::Function; use risingwave_sqlparser::ast::{ CreateFunctionBody, FunctionDefinition, ObjectName, OperateFunctionArg, }; -use risingwave_storage::monitor::ObjectStoreMetrics; use risingwave_udf::ArrowFlightUdfClient; use super::*; @@ -126,6 +123,7 @@ pub async fn handle_create_function( let identifier; let mut link = None; let mut body = None; + let mut compressed_binary = None; match language.as_str() { "python" | "java" | "" => { @@ -198,38 +196,21 @@ pub async fn handle_create_function( ) .into()); }; - link = match using { - CreateFunctionUsing::Link(link) => { - let runtime = get_or_create_wasm_runtime(&link).await?; - check_wasm_function(&runtime, &identifier)?; - Some(link) - } + let wasm_binary = match using { + CreateFunctionUsing::Link(link) => download_binary_from_link(&link).await?, CreateFunctionUsing::Base64(encoded) => { // decode wasm binary from base64 use base64::prelude::{Engine, BASE64_STANDARD}; - let wasm_binary = BASE64_STANDARD + BASE64_STANDARD .decode(encoded) - .context("invalid base64 encoding")?; - - let runtime = arrow_udf_wasm::Runtime::new(&wasm_binary)?; - check_wasm_function(&runtime, &identifier)?; - - let system_params = session.env().meta_client().get_system_params().await?; - let object_name = format!("{:?}.wasm", md5::compute(&wasm_binary)); - upload_wasm_binary( - system_params.wasm_storage_url(), - &object_name, - wasm_binary.into(), - ) - .await?; - - Some(format!( - "{}/{}", - system_params.wasm_storage_url(), - object_name - )) + .context("invalid base64 encoding")? + .into() } }; + let runtime = get_or_create_wasm_runtime(&wasm_binary)?; + check_wasm_function(&runtime, &identifier)?; + + compressed_binary = Some(zstd::stream::encode_all(wasm_binary.as_ref(), 0)?); } _ => unreachable!("invalid language: {language}"), }; @@ -247,6 +228,7 @@ pub async fn handle_create_function( identifier: Some(identifier), link, body, + compressed_binary, owner: session.user_id(), always_retry_on_network_error: with_options .always_retry_on_network_error @@ -259,25 +241,17 @@ pub async fn handle_create_function( Ok(PgResponse::empty_result(StatementType::CREATE_FUNCTION)) } -/// Upload wasm binary to object store. -async fn upload_wasm_binary( - wasm_storage_url: &str, - object_name: &str, - wasm_binary: Bytes, -) -> Result<()> { - // Note: it will panic if the url is invalid. We did a validation on meta startup. - let object_store = build_remote_object_store( - wasm_storage_url, - Arc::new(ObjectStoreMetrics::unused()), - "Wasm Engine", - ObjectStoreConfig::default(), - ) - .await; - object_store - .upload(object_name, wasm_binary) - .await - .context("failed to upload wasm binary to object store")?; - Ok(()) +/// Download wasm binary from a link. +#[allow(clippy::unused_async)] +async fn download_binary_from_link(link: &str) -> Result { + // currently only local file system is supported + if let Some(path) = link.strip_prefix("fs://") { + let content = + std::fs::read(path).context("failed to read wasm binary from local file system")?; + Ok(content.into()) + } else { + Err(ErrorCode::InvalidParameterValue("only 'fs://' is supported".to_string()).into()) + } } /// Check if the function exists in the wasm binary. diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index 45a9804b407ec..2caa5f813dbde 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -234,6 +234,7 @@ pub async fn handle_create_sql_function( language, identifier: None, body: Some(body), + compressed_binary: None, link: None, owner: session.user_id(), always_retry_on_network_error: false, diff --git a/src/meta/model_v2/migration/src/m20230908_072257_init.rs b/src/meta/model_v2/migration/src/m20230908_072257_init.rs index 661b1b6055e73..04230b1ef79e5 100644 --- a/src/meta/model_v2/migration/src/m20230908_072257_init.rs +++ b/src/meta/model_v2/migration/src/m20230908_072257_init.rs @@ -714,6 +714,7 @@ impl MigrationTrait for Migration { .col(ColumnDef::new(Function::Link).string()) .col(ColumnDef::new(Function::Identifier).string()) .col(ColumnDef::new(Function::Body).string()) + .col(ColumnDef::new(Function::CompressedBinary).string()) .col(ColumnDef::new(Function::Kind).string().not_null()) .col( ColumnDef::new(Function::AlwaysRetryOnNetworkError) @@ -1117,6 +1118,7 @@ enum Function { Link, Identifier, Body, + CompressedBinary, Kind, AlwaysRetryOnNetworkError, } diff --git a/src/meta/model_v2/src/function.rs b/src/meta/model_v2/src/function.rs index 1976cee4f867a..eaf368aa15d1e 100644 --- a/src/meta/model_v2/src/function.rs +++ b/src/meta/model_v2/src/function.rs @@ -44,6 +44,7 @@ pub struct Model { pub link: Option, pub identifier: Option, pub body: Option, + pub compressed_binary: Option>, pub kind: FunctionKind, pub always_retry_on_network_error: bool, } @@ -100,6 +101,7 @@ impl From for ActiveModel { link: Set(function.link), identifier: Set(function.identifier), body: Set(function.body), + compressed_binary: Set(function.compressed_binary), kind: Set(function.kind.unwrap().into()), always_retry_on_network_error: Set(function.always_retry_on_network_error), } diff --git a/src/meta/src/backup_restore/restore.rs b/src/meta/src/backup_restore/restore.rs index cc544a0f589aa..1afe17b15bebc 100644 --- a/src/meta/src/backup_restore/restore.rs +++ b/src/meta/src/backup_restore/restore.rs @@ -244,7 +244,6 @@ mod tests { data_directory: Some("data_directory".into()), backup_storage_url: Some("backup_storage_url".into()), backup_storage_directory: Some("backup_storage_directory".into()), - wasm_storage_url: Some("wasm_storage_url".into()), ..SystemConfig::default().into_init_system_params() } } diff --git a/src/meta/src/controller/mod.rs b/src/meta/src/controller/mod.rs index 9d973b710ab28..e7d4a6f1dcd13 100644 --- a/src/meta/src/controller/mod.rs +++ b/src/meta/src/controller/mod.rs @@ -286,6 +286,7 @@ impl From> for PbFunction { link: value.0.link, identifier: value.0.identifier, body: value.0.body, + compressed_binary: value.0.compressed_binary, kind: Some(value.0.kind.into()), always_retry_on_network_error: value.0.always_retry_on_network_error, } diff --git a/src/workspace-hack/Cargo.toml b/src/workspace-hack/Cargo.toml index f822ddea5dcd0..4862e9224c6a2 100644 --- a/src/workspace-hack/Cargo.toml +++ b/src/workspace-hack/Cargo.toml @@ -81,6 +81,7 @@ madsim-rdkafka = { version = "0.3", features = ["cmake-build", "gssapi", "ssl-ve madsim-tokio = { version = "0.2", default-features = false, features = ["fs", "io-util", "macros", "net", "process", "rt", "rt-multi-thread", "signal", "sync", "time", "tracing"] } md-5 = { version = "0.10" } mio = { version = "0.8", features = ["net", "os-ext"] } +moka = { version = "0.12", features = ["future", "sync"] } nom = { version = "7" } num-bigint = { version = "0.4" } num-integer = { version = "0.1", features = ["i128"] }