From 45c9e2b06de3d1921cb9cc33dc03de6f34b0e7bc Mon Sep 17 00:00:00 2001 From: Dylan Date: Fri, 12 Jul 2024 15:55:31 +0800 Subject: [PATCH 01/70] feat(batch): support batch s3 parquet frontend part (#17625) --- Cargo.lock | 1 + proto/expr.proto | 2 + src/batch/src/executor/s3_file_scan.rs | 72 ++--------- src/connector/Cargo.toml | 1 + src/connector/src/source/iceberg/mod.rs | 3 + .../src/source/iceberg/parquet_file_reader.rs | 85 +++++++++++++ src/frontend/src/binder/expr/function.rs | 5 + src/frontend/src/expr/table_function.rs | 112 ++++++++++++++++- src/frontend/src/lib.rs | 1 + .../src/optimizer/logical_optimization.rs | 10 ++ .../optimizer/plan_node/batch_file_scan.rs | 84 +++++++++++++ .../optimizer/plan_node/generic/file_scan.rs | 64 ++++++++++ .../src/optimizer/plan_node/generic/mod.rs | 3 + .../optimizer/plan_node/logical_file_scan.rs | 118 ++++++++++++++++++ src/frontend/src/optimizer/plan_node/mod.rs | 8 ++ src/frontend/src/optimizer/rule/mod.rs | 3 + .../rule/table_function_to_file_scan_rule.rs | 90 +++++++++++++ 17 files changed, 601 insertions(+), 61 deletions(-) create mode 100644 src/connector/src/source/iceberg/parquet_file_reader.rs create mode 100644 src/frontend/src/optimizer/plan_node/batch_file_scan.rs create mode 100644 src/frontend/src/optimizer/plan_node/generic/file_scan.rs create mode 100644 src/frontend/src/optimizer/plan_node/logical_file_scan.rs create mode 100644 src/frontend/src/optimizer/rule/table_function_to_file_scan_rule.rs diff --git a/Cargo.lock b/Cargo.lock index 9400e803bbbf..c08b302494d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11173,6 +11173,7 @@ dependencies = [ "opendal", "openssl", "parking_lot 0.12.1", + "parquet 52.0.0", "paste", "pg_bigdecimal", "postgres-openssl", diff --git a/proto/expr.proto b/proto/expr.proto index 0dc1a96d7861..dedfa3f3cd3b 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -345,6 +345,8 @@ message TableFunction { JSONB_PATH_QUERY = 15; JSONB_POPULATE_RECORDSET = 16; JSONB_TO_RECORDSET = 17; + // file scan + FILE_SCAN = 19; // User defined table function USER_DEFINED = 100; } diff --git a/src/batch/src/executor/s3_file_scan.rs b/src/batch/src/executor/s3_file_scan.rs index 61fc2d5b5dd3..7c56788f85ae 100644 --- a/src/batch/src/executor/s3_file_scan.rs +++ b/src/batch/src/executor/s3_file_scan.rs @@ -12,24 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::ops::Range; -use std::sync::Arc; - use anyhow::anyhow; -use bytes::Bytes; use futures_async_stream::try_stream; -use futures_util::future::BoxFuture; use futures_util::stream::StreamExt; -use futures_util::TryFutureExt; -use hashbrown::HashMap; -use iceberg::io::{ - FileIOBuilder, FileMetadata, FileRead, S3_ACCESS_KEY_ID, S3_REGION, S3_SECRET_ACCESS_KEY, -}; -use parquet::arrow::async_reader::{AsyncFileReader, MetadataLoader}; -use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; -use parquet::file::metadata::ParquetMetaData; +use parquet::arrow::ProjectionMask; use risingwave_common::array::arrow::IcebergArrowConvert; use risingwave_common::catalog::Schema; +use risingwave_connector::source::iceberg::parquet_file_reader::create_parquet_stream_builder; use crate::error::BatchError; use crate::executor::{DataChunk, Executor}; @@ -93,22 +82,13 @@ impl S3FileScanExecutor { async fn do_execute(self: Box) { assert_eq!(self.file_format, FileFormat::Parquet); - let mut props = HashMap::new(); - props.insert(S3_REGION, self.s3_region.clone()); - props.insert(S3_ACCESS_KEY_ID, self.s3_access_key.clone()); - props.insert(S3_SECRET_ACCESS_KEY, self.s3_secret_key.clone()); - - let file_io_builder = FileIOBuilder::new("s3"); - let file_io = file_io_builder.with_props(props.into_iter()).build()?; - let parquet_file = file_io.new_input(&self.location)?; - - let parquet_metadata = parquet_file.metadata().await?; - let parquet_reader = parquet_file.reader().await?; - let arrow_file_reader = ArrowFileReader::new(parquet_metadata, parquet_reader); - - let mut batch_stream_builder = ParquetRecordBatchStreamBuilder::new(arrow_file_reader) - .await - .map_err(|e| anyhow!(e))?; + let mut batch_stream_builder = create_parquet_stream_builder( + self.s3_region.clone(), + self.s3_access_key.clone(), + self.s3_secret_key.clone(), + self.location.clone(), + ) + .await?; let arrow_schema = batch_stream_builder.schema(); assert_eq!(arrow_schema.fields.len(), self.schema.fields.len()); @@ -120,7 +100,9 @@ impl S3FileScanExecutor { batch_stream_builder = batch_stream_builder.with_batch_size(self.batch_size); - let record_batch_stream = batch_stream_builder.build().map_err(|e| anyhow!(e))?; + let record_batch_stream = batch_stream_builder + .build() + .map_err(|e| anyhow!(e).context("fail to build arrow stream builder"))?; #[for_await] for record_batch in record_batch_stream { @@ -131,33 +113,3 @@ impl S3FileScanExecutor { } } } - -struct ArrowFileReader { - meta: FileMetadata, - r: R, -} - -impl ArrowFileReader { - fn new(meta: FileMetadata, r: R) -> Self { - Self { meta, r } - } -} - -impl AsyncFileReader for ArrowFileReader { - fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, parquet::errors::Result> { - Box::pin( - self.r - .read(range.start as _..range.end as _) - .map_err(|err| parquet::errors::ParquetError::External(Box::new(err))), - ) - } - - fn get_metadata(&mut self) -> BoxFuture<'_, parquet::errors::Result>> { - Box::pin(async move { - let file_size = self.meta.size; - let mut loader = MetadataLoader::load(self, file_size as usize, None).await?; - loader.load_page_index(false, false).await?; - Ok(Arc::new(loader.finish())) - }) - } -} diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index bccd77faf890..34f876107345 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -93,6 +93,7 @@ opendal = { workspace = true, features = [ ] } openssl = "0.10" parking_lot = { workspace = true } +parquet = { workspace = true } paste = "1" pg_bigdecimal = { git = "https://github.com/risingwavelabs/rust-pg_bigdecimal", rev = "0b7893d88894ca082b4525f94f812da034486f7c" } postgres-openssl = "0.5.0" diff --git a/src/connector/src/source/iceberg/mod.rs b/src/connector/src/source/iceberg/mod.rs index 92880341b588..1b76a0ff3e86 100644 --- a/src/connector/src/source/iceberg/mod.rs +++ b/src/connector/src/source/iceberg/mod.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod parquet_file_reader; + use std::collections::HashMap; use anyhow::anyhow; @@ -19,6 +21,7 @@ use async_trait::async_trait; use futures::StreamExt; use iceberg::spec::{DataContentType, ManifestList}; use itertools::Itertools; +pub use parquet_file_reader::*; use risingwave_common::bail; use risingwave_common::types::JsonbVal; use serde::{Deserialize, Serialize}; diff --git a/src/connector/src/source/iceberg/parquet_file_reader.rs b/src/connector/src/source/iceberg/parquet_file_reader.rs new file mode 100644 index 000000000000..d61eab039da0 --- /dev/null +++ b/src/connector/src/source/iceberg/parquet_file_reader.rs @@ -0,0 +1,85 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::ops::Range; +use std::sync::Arc; + +use anyhow::anyhow; +use bytes::Bytes; +use futures::future::BoxFuture; +use futures::TryFutureExt; +use iceberg::io::{ + FileIOBuilder, FileMetadata, FileRead, S3_ACCESS_KEY_ID, S3_REGION, S3_SECRET_ACCESS_KEY, +}; +use parquet::arrow::async_reader::{AsyncFileReader, MetadataLoader}; +use parquet::arrow::ParquetRecordBatchStreamBuilder; +use parquet::file::metadata::ParquetMetaData; + +pub struct ParquetFileReader { + meta: FileMetadata, + r: R, +} + +impl ParquetFileReader { + pub fn new(meta: FileMetadata, r: R) -> Self { + Self { meta, r } + } +} + +impl AsyncFileReader for ParquetFileReader { + fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, parquet::errors::Result> { + Box::pin( + self.r + .read(range.start as _..range.end as _) + .map_err(|err| parquet::errors::ParquetError::External(Box::new(err))), + ) + } + + fn get_metadata(&mut self) -> BoxFuture<'_, parquet::errors::Result>> { + Box::pin(async move { + let file_size = self.meta.size; + let mut loader = MetadataLoader::load(self, file_size as usize, None).await?; + loader.load_page_index(false, false).await?; + Ok(Arc::new(loader.finish())) + }) + } +} + +pub async fn create_parquet_stream_builder( + s3_region: String, + s3_access_key: String, + s3_secret_key: String, + location: String, +) -> Result>, anyhow::Error> { + let mut props = HashMap::new(); + props.insert(S3_REGION, s3_region.clone()); + props.insert(S3_ACCESS_KEY_ID, s3_access_key.clone()); + props.insert(S3_SECRET_ACCESS_KEY, s3_secret_key.clone()); + + let file_io_builder = FileIOBuilder::new("s3"); + let file_io = file_io_builder + .with_props(props.into_iter()) + .build() + .map_err(|e| anyhow!(e))?; + let parquet_file = file_io.new_input(&location).map_err(|e| anyhow!(e))?; + + let parquet_metadata = parquet_file.metadata().await.map_err(|e| anyhow!(e))?; + let parquet_reader = parquet_file.reader().await.map_err(|e| anyhow!(e))?; + let parquet_file_reader = ParquetFileReader::new(parquet_metadata, parquet_reader); + + ParquetRecordBatchStreamBuilder::new(parquet_file_reader) + .await + .map_err(|e| anyhow!(e)) +} diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index fbe6b8930a68..897b43a2f366 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -266,6 +266,11 @@ impl Binder { ); } + // file_scan table function + if function_name.eq_ignore_ascii_case("file_scan") { + self.ensure_table_function_allowed()?; + return Ok(TableFunction::new_file_scan(inputs)?.into()); + } // table function if let Ok(function_type) = TableFunctionType::from_str(function_name.as_str()) { self.ensure_table_function_allowed()?; diff --git a/src/frontend/src/expr/table_function.rs b/src/frontend/src/expr/table_function.rs index 0db14d4736c2..f4dda8d8176b 100644 --- a/src/frontend/src/expr/table_function.rs +++ b/src/frontend/src/expr/table_function.rs @@ -15,12 +15,15 @@ use std::sync::Arc; use itertools::Itertools; -use risingwave_common::types::DataType; +use risingwave_common::array::arrow::IcebergArrowConvert; +use risingwave_common::types::{DataType, ScalarImpl, StructType}; +use risingwave_connector::source::iceberg::create_parquet_stream_builder; pub use risingwave_pb::expr::table_function::PbType as TableFunctionType; use risingwave_pb::expr::PbTableFunction; use super::{infer_type, Expr, ExprImpl, ExprRewriter, RwResult}; use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind}; +use crate::error::ErrorCode::BindError; /// A table function takes a row as input and returns a table. It is also known as Set-Returning /// Function. @@ -62,6 +65,113 @@ impl TableFunction { } } + /// A special table function which would be transformed into `LogicalFileScan` by `TableFunctionToFileScanRule` in the optimizer. + /// select * from `file_scan`('parquet', 's3', region, ak, sk, location) + pub fn new_file_scan(args: Vec) -> RwResult { + let return_type = { + // arguments: + // file format e.g. parquet + // storage type e.g. s3 + // s3 region + // s3 access key + // s3 secret key + // file location + if args.len() != 6 { + return Err(BindError("file_scan function only accepts 6 arguments: file_scan('parquet', 's3', s3 region, s3 access key, s3 secret key, file location)".to_string()).into()); + } + let mut eval_args: Vec = vec![]; + for arg in &args { + if arg.return_type() != DataType::Varchar { + return Err(BindError( + "file_scan function only accepts string arguments".to_string(), + ) + .into()); + } + match arg.try_fold_const() { + Some(Ok(value)) => { + if value.is_none() { + return Err(BindError( + "file_scan function does not accept null arguments".to_string(), + ) + .into()); + } + match value { + Some(ScalarImpl::Utf8(s)) => { + eval_args.push(s.to_string()); + } + _ => { + return Err(BindError( + "file_scan function only accepts string arguments".to_string(), + ) + .into()) + } + } + } + Some(Err(err)) => { + return Err(err); + } + None => { + return Err(BindError( + "file_scan function only accepts constant arguments".to_string(), + ) + .into()); + } + } + } + if !"parquet".eq_ignore_ascii_case(&eval_args[0]) { + return Err(BindError( + "file_scan function only accepts 'parquet' as file format".to_string(), + ) + .into()); + } + + if !"s3".eq_ignore_ascii_case(&eval_args[1]) { + return Err(BindError( + "file_scan function only accepts 's3' as storage type".to_string(), + ) + .into()); + } + + #[cfg(madsim)] + return Err(crate::error::ErrorCode::BindError( + "file_scan can't be used in the madsim mode".to_string(), + ) + .into()); + + #[cfg(not(madsim))] + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + let parquet_stream_builder = create_parquet_stream_builder( + eval_args[2].clone(), + eval_args[3].clone(), + eval_args[4].clone(), + eval_args[5].clone(), + ) + .await?; + + let mut rw_types = vec![]; + for field in parquet_stream_builder.schema().fields() { + rw_types.push(( + field.name().to_string(), + IcebergArrowConvert.type_from_field(field)?, + )); + } + + Ok::(DataType::Struct( + StructType::new(rw_types), + )) + }) + })? + }; + + Ok(TableFunction { + args, + return_type, + function_type: TableFunctionType::FileScan, + user_defined: None, + }) + } + pub fn to_protobuf(&self) -> PbTableFunction { PbTableFunction { function_type: self.function_type as i32, diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index dbc312ac463c..bb27c50053ad 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![feature(async_closure)] #![allow(clippy::derive_partial_eq_without_eq)] #![feature(map_try_insert)] #![feature(negative_impls)] diff --git a/src/frontend/src/optimizer/logical_optimization.rs b/src/frontend/src/optimizer/logical_optimization.rs index 931a645b3d68..4f95bde0b852 100644 --- a/src/frontend/src/optimizer/logical_optimization.rs +++ b/src/frontend/src/optimizer/logical_optimization.rs @@ -134,6 +134,14 @@ static TABLE_FUNCTION_TO_PROJECT_SET: LazyLock = LazyLock::ne ) }); +static TABLE_FUNCTION_TO_FILE_SCAN: LazyLock = LazyLock::new(|| { + OptimizationStage::new( + "Table Function To FileScan", + vec![TableFunctionToFileScanRule::create()], + ApplyOrder::TopDown, + ) +}); + static VALUES_EXTRACT_PROJECT: LazyLock = LazyLock::new(|| { OptimizationStage::new( "Values Extract Project", @@ -689,6 +697,8 @@ impl LogicalOptimizer { plan = plan.optimize_by_rules(&SET_OPERATION_MERGE); plan = plan.optimize_by_rules(&SET_OPERATION_TO_JOIN); plan = plan.optimize_by_rules(&ALWAYS_FALSE_FILTER); + // Table function should be converted into `file_scan` before `project_set`. + plan = plan.optimize_by_rules(&TABLE_FUNCTION_TO_FILE_SCAN); // In order to unnest a table function, we need to convert it into a `project_set` first. plan = plan.optimize_by_rules(&TABLE_FUNCTION_TO_PROJECT_SET); diff --git a/src/frontend/src/optimizer/plan_node/batch_file_scan.rs b/src/frontend/src/optimizer/plan_node/batch_file_scan.rs new file mode 100644 index 000000000000..826f39441294 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/batch_file_scan.rs @@ -0,0 +1,84 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use pretty_xmlish::XmlNode; +use risingwave_pb::batch_plan::plan_node::NodeBody; + +use super::batch::prelude::*; +use super::utils::{childless_record, column_names_pretty, Distill}; +use super::{ + generic, ExprRewritable, PlanBase, PlanRef, ToBatchPb, ToDistributedBatch, ToLocalBatch, +}; +use crate::error::Result; +use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::property::{Distribution, Order}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct BatchFileScan { + pub base: PlanBase, + pub core: generic::FileScan, +} + +impl BatchFileScan { + pub fn new(core: generic::FileScan) -> Self { + let base = PlanBase::new_batch_with_core(&core, Distribution::Single, Order::any()); + + Self { base, core } + } + + pub fn column_names(&self) -> Vec<&str> { + self.schema().names_str() + } + + pub fn clone_with_dist(&self) -> Self { + let base = self + .base + .clone_with_new_distribution(Distribution::SomeShard); + Self { + base, + core: self.core.clone(), + } + } +} + +impl_plan_tree_node_for_leaf! { BatchFileScan } + +impl Distill for BatchFileScan { + fn distill<'a>(&self) -> XmlNode<'a> { + let fields = vec![("columns", column_names_pretty(self.schema()))]; + childless_record("BatchFileScan", fields) + } +} + +impl ToLocalBatch for BatchFileScan { + fn to_local(&self) -> Result { + Ok(self.clone_with_dist().into()) + } +} + +impl ToDistributedBatch for BatchFileScan { + fn to_distributed(&self) -> Result { + Ok(self.clone_with_dist().into()) + } +} + +impl ToBatchPb for BatchFileScan { + fn to_batch_prost_body(&self) -> NodeBody { + todo!() + } +} + +impl ExprRewritable for BatchFileScan {} + +impl ExprVisitable for BatchFileScan {} diff --git a/src/frontend/src/optimizer/plan_node/generic/file_scan.rs b/src/frontend/src/optimizer/plan_node/generic/file_scan.rs new file mode 100644 index 000000000000..f8ed20c12072 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/generic/file_scan.rs @@ -0,0 +1,64 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use educe::Educe; +use risingwave_common::catalog::Schema; + +use super::GenericPlanNode; +use crate::optimizer::optimizer_context::OptimizerContextRef; +use crate::optimizer::property::FunctionalDependencySet; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum FileFormat { + Parquet, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum StorageType { + S3, +} + +#[derive(Debug, Clone, Educe)] +#[educe(PartialEq, Eq, Hash)] +pub struct FileScan { + pub schema: Schema, + pub file_format: FileFormat, + pub storage_type: StorageType, + pub s3_region: String, + pub s3_access_key: String, + pub s3_secret_key: String, + pub file_location: String, + + #[educe(PartialEq(ignore))] + #[educe(Hash(ignore))] + pub ctx: OptimizerContextRef, +} + +impl GenericPlanNode for FileScan { + fn schema(&self) -> Schema { + self.schema.clone() + } + + fn stream_key(&self) -> Option> { + None + } + + fn ctx(&self) -> OptimizerContextRef { + self.ctx.clone() + } + + fn functional_dependency(&self) -> FunctionalDependencySet { + FunctionalDependencySet::new(self.schema.len()) + } +} diff --git a/src/frontend/src/optimizer/plan_node/generic/mod.rs b/src/frontend/src/optimizer/plan_node/generic/mod.rs index 38efb6fe2a27..d83ab50d8923 100644 --- a/src/frontend/src/optimizer/plan_node/generic/mod.rs +++ b/src/frontend/src/optimizer/plan_node/generic/mod.rs @@ -83,6 +83,9 @@ pub use changelog::*; mod now; pub use now::*; +mod file_scan; +pub use file_scan::*; + pub trait DistillUnit { fn distill_with_name<'a>(&self, name: impl Into>) -> XmlNode<'a>; } diff --git a/src/frontend/src/optimizer/plan_node/logical_file_scan.rs b/src/frontend/src/optimizer/plan_node/logical_file_scan.rs new file mode 100644 index 000000000000..df41023bb484 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/logical_file_scan.rs @@ -0,0 +1,118 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use pretty_xmlish::XmlNode; +use risingwave_common::bail; +use risingwave_common::catalog::Schema; + +use super::generic::GenericPlanRef; +use super::utils::{childless_record, Distill}; +use super::{ + generic, BatchFileScan, ColPrunable, ExprRewritable, Logical, LogicalProject, PlanBase, + PlanRef, PredicatePushdown, ToBatch, ToStream, +}; +use crate::error::Result; +use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::plan_node::utils::column_names_pretty; +use crate::optimizer::plan_node::{ + ColumnPruningContext, LogicalFilter, PredicatePushdownContext, RewriteStreamContext, + ToStreamContext, +}; +use crate::utils::{ColIndexMapping, Condition}; +use crate::OptimizerContextRef; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LogicalFileScan { + pub base: PlanBase, + pub core: generic::FileScan, +} + +impl LogicalFileScan { + pub fn new( + ctx: OptimizerContextRef, + schema: Schema, + file_format: String, + storage_type: String, + s3_region: String, + s3_access_key: String, + s3_secret_key: String, + file_location: String, + ) -> Self { + assert!("parquet".eq_ignore_ascii_case(&file_format)); + assert!("s3".eq_ignore_ascii_case(&storage_type)); + + let core = generic::FileScan { + schema, + file_format: generic::FileFormat::Parquet, + storage_type: generic::StorageType::S3, + s3_region, + s3_access_key, + s3_secret_key, + file_location, + ctx, + }; + + let base = PlanBase::new_logical_with_core(&core); + + LogicalFileScan { base, core } + } +} + +impl_plan_tree_node_for_leaf! {LogicalFileScan} +impl Distill for LogicalFileScan { + fn distill<'a>(&self) -> XmlNode<'a> { + let fields = vec![("columns", column_names_pretty(self.schema()))]; + childless_record("LogicalFileScan", fields) + } +} + +impl ColPrunable for LogicalFileScan { + fn prune_col(&self, required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef { + LogicalProject::with_out_col_idx(self.clone().into(), required_cols.iter().cloned()).into() + } +} + +impl ExprRewritable for LogicalFileScan {} + +impl ExprVisitable for LogicalFileScan {} + +impl PredicatePushdown for LogicalFileScan { + fn predicate_pushdown( + &self, + predicate: Condition, + _ctx: &mut PredicatePushdownContext, + ) -> PlanRef { + // No pushdown. + LogicalFilter::create(self.clone().into(), predicate) + } +} + +impl ToBatch for LogicalFileScan { + fn to_batch(&self) -> Result { + Ok(BatchFileScan::new(self.core.clone()).into()) + } +} + +impl ToStream for LogicalFileScan { + fn to_stream(&self, _ctx: &mut ToStreamContext) -> Result { + bail!("FileScan is not supported in streaming mode") + } + + fn logical_rewrite_for_stream( + &self, + _ctx: &mut RewriteStreamContext, + ) -> Result<(PlanRef, ColIndexMapping)> { + bail!("FileScan is not supported in streaming mode") + } +} diff --git a/src/frontend/src/optimizer/plan_node/mod.rs b/src/frontend/src/optimizer/plan_node/mod.rs index 71c4c44fac8b..b5062398270e 100644 --- a/src/frontend/src/optimizer/plan_node/mod.rs +++ b/src/frontend/src/optimizer/plan_node/mod.rs @@ -910,9 +910,11 @@ mod stream_topn; mod stream_values; mod stream_watermark_filter; +mod batch_file_scan; mod batch_iceberg_scan; mod batch_kafka_scan; mod derive; +mod logical_file_scan; mod logical_iceberg_scan; mod stream_cdc_table_scan; mod stream_share; @@ -923,6 +925,7 @@ pub mod utils; pub use batch_delete::BatchDelete; pub use batch_exchange::BatchExchange; pub use batch_expand::BatchExpand; +pub use batch_file_scan::BatchFileScan; pub use batch_filter::BatchFilter; pub use batch_group_topn::BatchGroupTopN; pub use batch_hash_agg::BatchHashAgg; @@ -959,6 +962,7 @@ pub use logical_dedup::LogicalDedup; pub use logical_delete::LogicalDelete; pub use logical_except::LogicalExcept; pub use logical_expand::LogicalExpand; +pub use logical_file_scan::LogicalFileScan; pub use logical_filter::LogicalFilter; pub use logical_hop_window::LogicalHopWindow; pub use logical_iceberg_scan::LogicalIcebergScan; @@ -1076,6 +1080,7 @@ macro_rules! for_all_plan_nodes { , { Logical, RecursiveUnion } , { Logical, CteRef } , { Logical, ChangeLog } + , { Logical, FileScan } , { Batch, SimpleAgg } , { Batch, HashAgg } , { Batch, SortAgg } @@ -1106,6 +1111,7 @@ macro_rules! for_all_plan_nodes { , { Batch, MaxOneRow } , { Batch, KafkaScan } , { Batch, IcebergScan } + , { Batch, FileScan } , { Stream, Project } , { Stream, Filter } , { Stream, TableScan } @@ -1182,6 +1188,7 @@ macro_rules! for_logical_plan_nodes { , { Logical, RecursiveUnion } , { Logical, CteRef } , { Logical, ChangeLog } + , { Logical, FileScan } } }; } @@ -1221,6 +1228,7 @@ macro_rules! for_batch_plan_nodes { , { Batch, MaxOneRow } , { Batch, KafkaScan } , { Batch, IcebergScan } + , { Batch, FileScan } } }; } diff --git a/src/frontend/src/optimizer/rule/mod.rs b/src/frontend/src/optimizer/rule/mod.rs index fd06402c7497..180dafa0c79b 100644 --- a/src/frontend/src/optimizer/rule/mod.rs +++ b/src/frontend/src/optimizer/rule/mod.rs @@ -160,12 +160,14 @@ pub use agg_call_merge_rule::*; mod pull_up_correlated_predicate_agg_rule; mod source_to_iceberg_scan_rule; mod source_to_kafka_scan_rule; +mod table_function_to_file_scan_rule; mod values_extract_project_rule; pub use batch::batch_push_limit_to_scan_rule::*; pub use pull_up_correlated_predicate_agg_rule::*; pub use source_to_iceberg_scan_rule::*; pub use source_to_kafka_scan_rule::*; +pub use table_function_to_file_scan_rule::*; pub use values_extract_project_rule::*; #[macro_export] @@ -228,6 +230,7 @@ macro_rules! for_all_rules { , { CrossJoinEliminateRule } , { ApplyTopNTransposeRule } , { TableFunctionToProjectSetRule } + , { TableFunctionToFileScanRule } , { ApplyLimitTransposeRule } , { CommonSubExprExtractRule } , { BatchProjectMergeRule } diff --git a/src/frontend/src/optimizer/rule/table_function_to_file_scan_rule.rs b/src/frontend/src/optimizer/rule/table_function_to_file_scan_rule.rs new file mode 100644 index 000000000000..a0eb0f9ebd3c --- /dev/null +++ b/src/frontend/src/optimizer/rule/table_function_to_file_scan_rule.rs @@ -0,0 +1,90 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use itertools::Itertools; +use risingwave_common::catalog::{Field, Schema}; +use risingwave_common::types::{DataType, ScalarImpl}; +use risingwave_common::util::iter_util::ZipEqDebug; + +use super::{BoxedRule, Rule}; +use crate::expr::{Expr, TableFunctionType}; +use crate::optimizer::plan_node::generic::GenericPlanRef; +use crate::optimizer::plan_node::{LogicalFileScan, LogicalTableFunction}; +use crate::optimizer::PlanRef; + +/// Transform a special `TableFunction` (with `FILE_SCAN` table function type) into a `LogicalFileScan` +pub struct TableFunctionToFileScanRule {} +impl Rule for TableFunctionToFileScanRule { + fn apply(&self, plan: PlanRef) -> Option { + let logical_table_function: &LogicalTableFunction = plan.as_logical_table_function()?; + if logical_table_function.table_function.function_type != TableFunctionType::FileScan { + return None; + } + assert!(!logical_table_function.with_ordinality); + let table_function_return_type = logical_table_function.table_function().return_type(); + + if let DataType::Struct(st) = table_function_return_type.clone() { + let fields = st + .types() + .zip_eq_debug(st.names()) + .map(|(data_type, name)| Field::with_name(data_type.clone(), name.to_string())) + .collect_vec(); + + let schema = Schema::new(fields); + + let mut eval_args = vec![]; + for arg in &logical_table_function.table_function().args { + assert_eq!(arg.return_type(), DataType::Varchar); + let value = arg.try_fold_const().unwrap().unwrap(); + match value { + Some(ScalarImpl::Utf8(s)) => { + eval_args.push(s.to_string()); + } + _ => { + unreachable!("must be a varchar") + } + } + } + assert!(eval_args.len() == 6); + assert!("parquet".eq_ignore_ascii_case(&eval_args[0])); + assert!("s3".eq_ignore_ascii_case(&eval_args[1])); + let s3_region = eval_args[2].clone(); + let s3_access_key = eval_args[3].clone(); + let s3_secret_key = eval_args[4].clone(); + let file_location = eval_args[5].clone(); + + Some( + LogicalFileScan::new( + logical_table_function.ctx(), + schema, + "parquet".to_string(), + "s3".to_string(), + s3_region, + s3_access_key, + s3_secret_key, + file_location, + ) + .into(), + ) + } else { + unreachable!("TableFunction return type should be struct") + } + } +} + +impl TableFunctionToFileScanRule { + pub fn create() -> BoxedRule { + Box::new(TableFunctionToFileScanRule {}) + } +} From d06bb47f96c4b15401a8f6893217f0c55ad7dda1 Mon Sep 17 00:00:00 2001 From: zwang28 <70626450+zwang28@users.noreply.github.com> Date: Fri, 12 Jul 2024 16:32:33 +0800 Subject: [PATCH 02/70] feat(batch): support as of now() - interval for time travel (#17665) --- Cargo.lock | 1 + ci/scripts/e2e-sqlserver-sink-test.sh | 1 + ci/scripts/e2e-time-travel-test.sh | 43 +++++++++++ ci/workflows/main-cron.yml | 20 ++++++ ci/workflows/pull-request.yml | 17 +++++ e2e_test/time_travel/basic.slt | 38 ++++++++++ e2e_test/time_travel/index.slt | 40 +++++++++++ e2e_test/time_travel/join.slt | 54 ++++++++++++++ e2e_test/time_travel/lookup_join.slt | 54 ++++++++++++++ risedev.yml | 14 ++++ src/config/ci-time-travel.toml | 5 ++ src/frontend/Cargo.toml | 1 + src/frontend/src/binder/expr/value.rs | 2 +- src/frontend/src/optimizer/plan_node/utils.rs | 35 +++++++-- src/frontend/src/planner/relation.rs | 13 ++-- src/frontend/src/scheduler/plan_fragmenter.rs | 4 +- src/sqlparser/src/ast/mod.rs | 7 ++ src/sqlparser/src/parser.rs | 37 +++++++++- src/sqlparser/tests/testdata/as_of.yaml | 2 +- src/storage/Cargo.toml | 2 +- src/storage/src/hummock/mod.rs | 1 + .../src/hummock/store/hummock_storage.rs | 34 +++++---- .../src/hummock/time_travel_version_cache.rs | 72 +++++++++++++++++++ 23 files changed, 464 insertions(+), 33 deletions(-) create mode 100755 ci/scripts/e2e-time-travel-test.sh create mode 100644 e2e_test/time_travel/basic.slt create mode 100644 e2e_test/time_travel/index.slt create mode 100644 e2e_test/time_travel/join.slt create mode 100644 e2e_test/time_travel/lookup_join.slt create mode 100644 src/config/ci-time-travel.toml create mode 100644 src/storage/src/hummock/time_travel_version_cache.rs diff --git a/Cargo.lock b/Cargo.lock index c08b302494d9..8aed93149faa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11485,6 +11485,7 @@ dependencies = [ "base64 0.22.0", "bk-tree", "bytes", + "chrono", "clap", "downcast-rs", "dyn-clone", diff --git a/ci/scripts/e2e-sqlserver-sink-test.sh b/ci/scripts/e2e-sqlserver-sink-test.sh index f1f62941375c..30032c66a628 100755 --- a/ci/scripts/e2e-sqlserver-sink-test.sh +++ b/ci/scripts/e2e-sqlserver-sink-test.sh @@ -74,6 +74,7 @@ if [[ ${#actual[@]} -eq ${#expected[@]} && ${actual[@]} == ${expected[@]} ]]; th else cat ./query_result.txt echo "The output is not as expected." + exit 1 fi echo "--- Kill cluster" diff --git a/ci/scripts/e2e-time-travel-test.sh b/ci/scripts/e2e-time-travel-test.sh new file mode 100755 index 000000000000..d69556b3c2d7 --- /dev/null +++ b/ci/scripts/e2e-time-travel-test.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +# Exits as soon as any line fails. +set -euo pipefail + +source ci/scripts/common.sh + +while getopts 'p:' opt; do + case ${opt} in + p ) + profile=$OPTARG + ;; + \? ) + echo "Invalid Option: -$OPTARG" 1>&2 + exit 1 + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + ;; + esac +done +shift $((OPTIND -1)) + +sudo apt install sqlite3 -y +download_and_prepare_rw "$profile" common + +echo "--- starting risingwave cluster" +risedev ci-start ci-time-travel +sleep 1 + +sqllogictest -p 4566 -d dev './e2e_test/time_travel/*.slt' + +echo "--- verify time travel metadata" +sleep 30 # ensure another time travel version snapshot has been taken +version_snapshot_count=$(sqlite3 .risingwave/data/sqlite/metadata.db "select count(*) from hummock_time_travel_version;") +if [ "$version_snapshot_count" -le 1 ]; then + echo "test failed: too few version_snapshot_count, actual ${version_snapshot_count}" + exit 1 +fi + +echo "--- Kill cluster" +risedev ci-kill + diff --git a/ci/workflows/main-cron.yml b/ci/workflows/main-cron.yml index 13c98c6bff9f..cab2fb2858ef 100644 --- a/ci/workflows/main-cron.yml +++ b/ci/workflows/main-cron.yml @@ -868,6 +868,26 @@ steps: timeout_in_minutes: 10 retry: *auto-retry + - label: "end-to-end time travel test" + key: "e2e-time-travel-tests" + command: "ci/scripts/e2e-time-travel-test.sh -p ci-release" + if: | + !(build.pull_request.labels includes "ci/main-cron/run-selected") && build.env("CI_STEPS") == null + || build.pull_request.labels includes "ci/run-e2e-time-travel-tests" + || build.env("CI_STEPS") =~ /(^|,)e2e-time-travel-tests?(,|$$)/ + depends_on: + - "build" + - "build-other" + - "docslt" + plugins: + - docker-compose#v5.1.0: + run: rw-build-env + config: ci/docker-compose.yml + mount-buildkite-agent: true + - ./ci/plugins/upload-failure-logs + timeout_in_minutes: 10 + retry: *auto-retry + - label: "end-to-end sqlserver sink test" key: "e2e-sqlserver-sink-tests" command: "ci/scripts/e2e-sqlserver-sink-test.sh -p ci-release" diff --git a/ci/workflows/pull-request.yml b/ci/workflows/pull-request.yml index ae8db23bbeb8..fd3e44f131c3 100644 --- a/ci/workflows/pull-request.yml +++ b/ci/workflows/pull-request.yml @@ -333,6 +333,23 @@ steps: timeout_in_minutes: 10 retry: *auto-retry + - label: "end-to-end time travel test" + key: "e2e-time-travel-tests" + command: "ci/scripts/e2e-time-travel-test.sh -p ci-dev" + if: build.pull_request.labels includes "ci/run-e2e-time-travel-tests" || build.env("CI_STEPS") =~ /(^|,)e2e-time-travel-tests?(,|$$)/ + depends_on: + - "build" + - "build-other" + - "docslt" + plugins: + - docker-compose#v5.1.0: + run: rw-build-env + config: ci/docker-compose.yml + mount-buildkite-agent: true + - ./ci/plugins/upload-failure-logs + timeout_in_minutes: 15 + retry: *auto-retry + - label: "end-to-end sqlserver sink test" if: build.pull_request.labels includes "ci/run-e2e-sqlserver-sink-tests" || build.env("CI_STEPS") =~ /(^|,)e2e-sqlserver-sink-tests?(,|$$)/ command: "ci/scripts/e2e-sqlserver-sink-test.sh -p ci-dev" diff --git a/e2e_test/time_travel/basic.slt b/e2e_test/time_travel/basic.slt new file mode 100644 index 000000000000..962fd8e096a5 --- /dev/null +++ b/e2e_test/time_travel/basic.slt @@ -0,0 +1,38 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +CREATE TABLE t (k INT); + +query I +SELECT * FROM t; +---- + +sleep 5s + +statement ok +INSERT INTO t VALUES (1); + +query I +SELECT * FROM t; +---- +1 + +query I +SELECT * FROM t FOR SYSTEM_TIME AS OF now(); +---- +1 + +query I +SELECT * FROM t FOR SYSTEM_TIME AS OF now() - '5' second; +---- + +sleep 5s + +query I +SELECT * FROM t FOR SYSTEM_TIME AS OF now() - '5' second; +---- +1 + +statement ok +DROP TABLE t; \ No newline at end of file diff --git a/e2e_test/time_travel/index.slt b/e2e_test/time_travel/index.slt new file mode 100644 index 000000000000..4c81bc683837 --- /dev/null +++ b/e2e_test/time_travel/index.slt @@ -0,0 +1,40 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +CREATE TABLE t (k INT); + +statement ok +CREATE INDEX idx_t_k on t (k); + +sleep 5s + +statement ok +INSERT INTO t VALUES (1); + +query I +SELECT * FROM t WHERE k=1; +---- +1 + +query I +SELECT * FROM t FOR SYSTEM_TIME AS OF now() WHERE k=1; +---- +1 + +query I +SELECT * FROM t FOR SYSTEM_TIME AS OF now() - '5' second WHERE k=1; +---- + +sleep 5s + +query I +SELECT * FROM t FOR SYSTEM_TIME AS OF now() - '5' second WHERE k=1; +---- +1 + +statement ok +DROP INDEX idx_t_k; + +statement ok +DROP TABLE t; \ No newline at end of file diff --git a/e2e_test/time_travel/join.slt b/e2e_test/time_travel/join.slt new file mode 100644 index 000000000000..cfada20463f9 --- /dev/null +++ b/e2e_test/time_travel/join.slt @@ -0,0 +1,54 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +CREATE TABLE t1 (k1 INT); + +statement ok +CREATE TABLE t2 (k2 INT); + +sleep 5s + +statement ok +INSERT INTO t1 VALUES (1); + +statement ok +INSERT INTO t2 VALUES (1); + +query I +SELECT count(*) FROM t1 join t2 on t1.k1=t2.k2; +---- +1 + +query I +SELECT count(*) FROM t1 FOR SYSTEM_TIME AS OF now() join t2 FOR SYSTEM_TIME AS OF now() on t1.k1=t2.k2; +---- +1 + +query I +SELECT count(*) FROM t1 FOR SYSTEM_TIME AS OF now() - '5' second join t2 FOR SYSTEM_TIME AS OF now() on t1.k1=t2.k2; +---- +0 + +query I +SELECT count(*) FROM t1 FOR SYSTEM_TIME AS OF now() join t2 FOR SYSTEM_TIME AS OF now() - '5' second on t1.k1=t2.k2; +---- +0 + +sleep 5s + +query I +SELECT count(*) FROM t1 FOR SYSTEM_TIME AS OF now() - '5' second join t2 FOR SYSTEM_TIME AS OF now() on t1.k1=t2.k2; +---- +1 + +query I +SELECT count(*) FROM t1 FOR SYSTEM_TIME AS OF now() join t2 FOR SYSTEM_TIME AS OF now() - '5' second on t1.k1=t2.k2; +---- +1 + +statement ok +DROP TABLE t2; + +statement ok +DROP TABLE t1; \ No newline at end of file diff --git a/e2e_test/time_travel/lookup_join.slt b/e2e_test/time_travel/lookup_join.slt new file mode 100644 index 000000000000..272bc7a34ae4 --- /dev/null +++ b/e2e_test/time_travel/lookup_join.slt @@ -0,0 +1,54 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +CREATE TABLE t1 (k1 INT); + +statement ok +CREATE TABLE t2 (k2 INT PRIMARY KEY); + +sleep 5s + +statement ok +INSERT INTO t1 VALUES (1); + +statement ok +INSERT INTO t2 VALUES (1); + +query I +SELECT count(*) FROM t1 join t2 on t1.k1=t2.k2; +---- +1 + +query I +SELECT count(*) FROM t1 FOR SYSTEM_TIME AS OF now() join t2 FOR SYSTEM_TIME AS OF now() on t1.k1=t2.k2; +---- +1 + +query I +SELECT count(*) FROM t1 FOR SYSTEM_TIME AS OF now() - '5' second join t2 FOR SYSTEM_TIME AS OF now() on t1.k1=t2.k2; +---- +0 + +query I +SELECT count(*) FROM t1 FOR SYSTEM_TIME AS OF now() join t2 FOR SYSTEM_TIME AS OF now() - '5' second on t1.k1=t2.k2; +---- +0 + +sleep 5s + +query I +SELECT count(*) FROM t1 FOR SYSTEM_TIME AS OF now() - '5' second join t2 FOR SYSTEM_TIME AS OF now() on t1.k1=t2.k2; +---- +1 + +query I +SELECT count(*) FROM t1 FOR SYSTEM_TIME AS OF now() join t2 FOR SYSTEM_TIME AS OF now() - '5' second on t1.k1=t2.k2; +---- +1 + +statement ok +DROP TABLE t2; + +statement ok +DROP TABLE t1; \ No newline at end of file diff --git a/risedev.yml b/risedev.yml index 984d92ab33d8..b4b558face0e 100644 --- a/risedev.yml +++ b/risedev.yml @@ -354,6 +354,20 @@ profile: - use: compute-node - use: frontend + ci-time-travel: + config-path: src/config/ci-time-travel.toml + steps: + - use: minio + - use: sqlite + - use: meta-node + port: 5690 + dashboard-port: 5691 + exporter-port: 1250 + meta-backend: sqlite + - use: compactor + - use: compute-node + - use: frontend + meta-1cn-1fe-sqlite-with-recovery: config-path: src/config/ci-recovery.toml steps: diff --git a/src/config/ci-time-travel.toml b/src/config/ci-time-travel.toml new file mode 100644 index 000000000000..3f17e2aaf371 --- /dev/null +++ b/src/config/ci-time-travel.toml @@ -0,0 +1,5 @@ +[meta] +enable_hummock_time_travel = true +hummock_time_travel_retention_ms = 300000 +hummock_time_travel_snapshot_interval = 30 +min_sst_retention_time_sec = 1 \ No newline at end of file diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index 418ab86e4307..89d29e076a38 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -26,6 +26,7 @@ auto_impl = "1" base64 = "0.22" bk-tree = "0.5.0" bytes = "1" +chrono = { version = "0.4", default-features = false } clap = { workspace = true } downcast-rs = "1.2" dyn-clone = "1.0.14" diff --git a/src/frontend/src/binder/expr/value.rs b/src/frontend/src/binder/expr/value.rs index a0758a15d444..e1fc78e884e0 100644 --- a/src/frontend/src/binder/expr/value.rs +++ b/src/frontend/src/binder/expr/value.rs @@ -100,7 +100,7 @@ impl Binder { Ok(literal) } - fn bind_date_time_field(field: AstDateTimeField) -> DateTimeField { + pub(crate) fn bind_date_time_field(field: AstDateTimeField) -> DateTimeField { // This is a binder function rather than `impl From for DateTimeField`, // so that the `sqlparser` crate and the `common` crate are kept independent. match field { diff --git a/src/frontend/src/optimizer/plan_node/utils.rs b/src/frontend/src/optimizer/plan_node/utils.rs index b9b175ed6f09..afae9cf64ca0 100644 --- a/src/frontend/src/optimizer/plan_node/utils.rs +++ b/src/frontend/src/optimizer/plan_node/utils.rs @@ -16,6 +16,7 @@ use std::collections::HashMap; use std::default::Default; use std::vec; +use anyhow::anyhow; use fixedbitset::FixedBitSet; use itertools::Itertools; use pretty_xmlish::{Pretty, Str, StrAssocArr, XmlNode}; @@ -324,7 +325,7 @@ macro_rules! plan_node_name { }; } pub(crate) use plan_node_name; -use risingwave_common::types::DataType; +use risingwave_common::types::{DataType, Interval}; use risingwave_expr::aggregate::AggKind; use risingwave_pb::plan_common::as_of::AsOfType; use risingwave_pb::plan_common::{as_of, PbAsOf}; @@ -397,12 +398,21 @@ pub fn to_pb_time_travel_as_of(a: &Option) -> Result> { return Ok(None); }; let as_of_type = match a { - AsOf::ProcessTime => AsOfType::ProcessTime(as_of::ProcessTime {}), + AsOf::ProcessTime => { + return Err(ErrorCode::NotSupported( + "do not support as of proctime".to_string(), + "please use as of timestamp".to_string(), + ) + .into()); + } AsOf::TimestampNum(ts) => AsOfType::Timestamp(as_of::Timestamp { timestamp: *ts }), - AsOf::TimestampString(ts) => AsOfType::Timestamp(as_of::Timestamp { - // should already have been validated by the parser - timestamp: ts.parse().unwrap(), - }), + AsOf::TimestampString(ts) => { + let date_time = speedate::DateTime::parse_str_rfc3339(ts) + .map_err(|_e| anyhow!("fail to parse timestamp"))?; + AsOfType::Timestamp(as_of::Timestamp { + timestamp: date_time.timestamp_tz(), + }) + } AsOf::VersionNum(_) | AsOf::VersionString(_) => { return Err(ErrorCode::NotSupported( "do not support as of version".to_string(), @@ -410,6 +420,19 @@ pub fn to_pb_time_travel_as_of(a: &Option) -> Result> { ) .into()); } + AsOf::ProcessTimeWithInterval((value, leading_field)) => { + let interval = Interval::parse_with_fields( + value, + Some(crate::Binder::bind_date_time_field(leading_field.clone())), + ) + .map_err(|_| anyhow!("fail to parse interval"))?; + let interval_sec = (interval.epoch_in_micros() / 1_000_000) as i64; + let timestamp = chrono::Utc::now() + .timestamp() + .checked_sub(interval_sec) + .ok_or_else(|| anyhow!("invalid timestamp"))?; + AsOfType::Timestamp(as_of::Timestamp { timestamp }) + } }; Ok(Some(PbAsOf { as_of_type: Some(as_of_type), diff --git a/src/frontend/src/planner/relation.rs b/src/frontend/src/planner/relation.rs index 98a124c72d97..b5f8d39276d2 100644 --- a/src/frontend/src/planner/relation.rs +++ b/src/frontend/src/planner/relation.rs @@ -74,12 +74,11 @@ impl Planner { pub(super) fn plan_base_table(&mut self, base_table: &BoundBaseTable) -> Result { let as_of = base_table.as_of.clone(); match as_of { - None | Some(AsOf::ProcessTime) | Some(AsOf::TimestampNum(_)) => {} - Some(AsOf::TimestampString(ref s)) => { - if s.parse::().is_err() { - return Err(ErrorCode::InvalidParameterValue(s.to_owned()).into()); - } - } + None + | Some(AsOf::ProcessTime) + | Some(AsOf::TimestampNum(_)) + | Some(AsOf::TimestampString(_)) + | Some(AsOf::ProcessTimeWithInterval(_)) => {} Some(AsOf::VersionNum(_)) | Some(AsOf::VersionString(_)) => { bail_not_implemented!("As Of Version is not supported yet.") } @@ -113,7 +112,7 @@ impl Planner { | Some(AsOf::VersionNum(_)) | Some(AsOf::TimestampString(_)) | Some(AsOf::TimestampNum(_)) => {} - Some(AsOf::ProcessTime) => { + Some(AsOf::ProcessTime) | Some(AsOf::ProcessTimeWithInterval(_)) => { bail_not_implemented!("As Of ProcessTime() is not supported yet.") } Some(AsOf::VersionString(_)) => { diff --git a/src/frontend/src/scheduler/plan_fragmenter.rs b/src/frontend/src/scheduler/plan_fragmenter.rs index 699a84fdc3e3..4dfa5f5cd915 100644 --- a/src/frontend/src/scheduler/plan_fragmenter.rs +++ b/src/frontend/src/scheduler/plan_fragmenter.rs @@ -351,7 +351,9 @@ impl SourceScanInfo { }) .map_err(|_e| anyhow!("fail to parse timestamp"))?, ), - Some(AsOf::ProcessTime) => unreachable!(), + Some(AsOf::ProcessTime) | Some(AsOf::ProcessTimeWithInterval(_)) => { + unreachable!() + } None => None, }; diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index db86c6fb8a47..ddab0e030678 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -3153,6 +3153,8 @@ impl fmt::Display for SetVariableValueSingle { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum AsOf { ProcessTime, + // used by time travel + ProcessTimeWithInterval((String, DateTimeField)), // the number of seconds that have elapsed since the Unix epoch, which is January 1, 1970 at 00:00:00 Coordinated Universal Time (UTC). TimestampNum(i64), TimestampString(String), @@ -3165,6 +3167,11 @@ impl fmt::Display for AsOf { use AsOf::*; match self { ProcessTime => write!(f, " FOR SYSTEM_TIME AS OF PROCTIME()"), + ProcessTimeWithInterval((value, leading_field)) => write!( + f, + " FOR SYSTEM_TIME AS OF NOW() - {} {}", + value, leading_field + ), TimestampNum(ts) => write!(f, " FOR SYSTEM_TIME AS OF {}", ts), TimestampString(ts) => write!(f, " FOR SYSTEM_TIME AS OF '{}'", ts), VersionNum(v) => write!(f, " FOR SYSTEM_VERSION AS OF {}", v), diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 60e9da014f10..3077d8b9a6c0 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -3649,10 +3649,41 @@ impl Parser<'_> { (Keyword::SYSTEM_TIME, Keyword::AS, Keyword::OF), cut_err( alt(( - ( - Self::parse_identifier.verify(|ident| { - ident.real_value() == "proctime" || ident.real_value() == "now" + preceded( + ( + Self::parse_identifier.verify(|ident| ident.real_value() == "now"), + cut_err(Token::LParen), + cut_err(Token::RParen), + Token::Minus, + ), + Self::parse_literal_interval.try_map(|e| match e { + Expr::Value(v) => match v { + Value::Interval { + value, + leading_field, + .. + } => { + let Some(leading_field) = leading_field else { + return Err(StrError("expect duration unit".into())); + }; + Ok(AsOf::ProcessTimeWithInterval((value, leading_field))) + } + _ => Err(StrError("expect Value::Interval".into())), + }, + _ => Err(StrError("expect Expr::Value".into())), }), + ), + ( + Self::parse_identifier.verify(|ident| ident.real_value() == "now"), + cut_err(Token::LParen), + cut_err(Token::RParen), + ) + .value(AsOf::ProcessTimeWithInterval(( + "0".to_owned(), + DateTimeField::Second, + ))), + ( + Self::parse_identifier.verify(|ident| ident.real_value() == "proctime"), cut_err(Token::LParen), cut_err(Token::RParen), ) diff --git a/src/sqlparser/tests/testdata/as_of.yaml b/src/sqlparser/tests/testdata/as_of.yaml index 2d7c81881988..903d2a37de74 100644 --- a/src/sqlparser/tests/testdata/as_of.yaml +++ b/src/sqlparser/tests/testdata/as_of.yaml @@ -2,7 +2,7 @@ - input: select * from t1 left join t2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2; formatted_sql: SELECT * FROM t1 LEFT JOIN t2 FOR SYSTEM_TIME AS OF PROCTIME() ON a1 = a2 - input: select * from t1 left join t2 FOR SYSTEM_TIME AS OF NOW() on a1 = a2; - formatted_sql: SELECT * FROM t1 LEFT JOIN t2 FOR SYSTEM_TIME AS OF PROCTIME() ON a1 = a2 + formatted_sql: SELECT * FROM t1 LEFT JOIN t2 FOR SYSTEM_TIME AS OF NOW() - 0 SECOND ON a1 = a2 - input: select * from t1 left join t2 FOR SYSTEM_TIME AS OF 1 on a1 = a2; formatted_sql: SELECT * FROM t1 LEFT JOIN t2 FOR SYSTEM_TIME AS OF 1 ON a1 = a2 - input: select * from t1 left join t2 FOR SYSTEM_TIME AS OF 'string' on a1 = a2; diff --git a/src/storage/Cargo.toml b/src/storage/Cargo.toml index 0dc4bd8b088f..b49a625111d3 100644 --- a/src/storage/Cargo.toml +++ b/src/storage/Cargo.toml @@ -36,6 +36,7 @@ libc = "0.2" lz4 = "1.25.0" memcomparable = "0.2" metrics-prometheus = "0.7" +moka = { version = "0.12", features = ["future", "sync"] } more-asserts = "0.3" num-integer = "0.1" parking_lot = { workspace = true } @@ -95,7 +96,6 @@ workspace-hack = { path = "../workspace-hack" } bincode = "1" criterion = { workspace = true, features = ["async_futures", "async_tokio"] } expect-test = "1" -moka = { version = "0.12", features = ["future"] } risingwave_hummock_sdk = { workspace = true } risingwave_test_runner = { workspace = true } uuid = { version = "1", features = ["v4"] } diff --git a/src/storage/src/hummock/mod.rs b/src/storage/src/hummock/mod.rs index 21eb4a13e8c3..3974af300621 100644 --- a/src/storage/src/hummock/mod.rs +++ b/src/storage/src/hummock/mod.rs @@ -54,6 +54,7 @@ pub mod recent_filter; pub use recent_filter::*; pub mod block_stream; +mod time_travel_version_cache; pub use error::*; pub use risingwave_common::cache::{CacheableEntry, LookupResult, LruCache}; diff --git a/src/storage/src/hummock/store/hummock_storage.rs b/src/storage/src/hummock/store/hummock_storage.rs index d3c13b77863a..fcc80a1e54e1 100644 --- a/src/storage/src/hummock/store/hummock_storage.rs +++ b/src/storage/src/hummock/store/hummock_storage.rs @@ -54,6 +54,7 @@ use crate::hummock::event_handler::{ use crate::hummock::iterator::change_log::ChangeLogIterator; use crate::hummock::local_version::pinned_version::{start_pinned_version_worker, PinnedVersion}; use crate::hummock::observer_manager::HummockObserverNode; +use crate::hummock::time_travel_version_cache::SimpleTimeTravelVersionCache; use crate::hummock::utils::{validate_safe_epoch, wait_for_epoch}; use crate::hummock::write_limiter::{WriteLimiter, WriteLimiterRef}; use crate::hummock::{ @@ -119,6 +120,8 @@ pub struct HummockStorage { compact_await_tree_reg: Option, hummock_meta_client: Arc, + + simple_time_travel_version_cache: Arc, } pub type ReadVersionTuple = (Vec, Vec, CommittedVersion); @@ -237,6 +240,7 @@ impl HummockStorage { write_limiter, compact_await_tree_reg: await_tree_reg, hummock_meta_client, + simple_time_travel_version_cache: Arc::new(SimpleTimeTravelVersionCache::new()), }; tokio::spawn(hummock_event_handler.start_hummock_event_handler_worker()); @@ -326,20 +330,24 @@ impl HummockStorage { table_id: TableId, key_range: TableKeyRange, ) -> StorageResult<(TableKeyRange, ReadVersionTuple)> { - let pb_version = self - .hummock_meta_client - .get_version_by_epoch(epoch) - .await - .inspect_err(|e| tracing::error!("{}", e.to_report_string())) - .map_err(|e| HummockError::meta_error(e.to_report_string()))?; - let version = HummockVersion::from_rpc_protobuf(&pb_version); - validate_safe_epoch(&version, table_id, epoch)?; - let (tx, _rx) = unbounded_channel(); + let fetch = async { + let pb_version = self + .hummock_meta_client + .get_version_by_epoch(epoch) + .await + .inspect_err(|e| tracing::error!("{}", e.to_report_string())) + .map_err(|e| HummockError::meta_error(e.to_report_string()))?; + let version = HummockVersion::from_rpc_protobuf(&pb_version); + validate_safe_epoch(&version, table_id, epoch)?; + let (tx, _rx) = unbounded_channel(); + Ok(PinnedVersion::new(version, tx)) + }; + let version = self + .simple_time_travel_version_cache + .get_or_insert(epoch, fetch) + .await?; Ok(get_committed_read_version_tuple( - PinnedVersion::new(version, tx), - table_id, - key_range, - epoch, + version, table_id, key_range, epoch, )) } diff --git a/src/storage/src/hummock/time_travel_version_cache.rs b/src/storage/src/hummock/time_travel_version_cache.rs new file mode 100644 index 000000000000..08ad70ab44fa --- /dev/null +++ b/src/storage/src/hummock/time_travel_version_cache.rs @@ -0,0 +1,72 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; + +use moka::sync::Cache; +use risingwave_hummock_sdk::HummockEpoch; +use tokio::sync::Mutex; + +use crate::hummock::local_version::pinned_version::PinnedVersion; +use crate::hummock::HummockResult; + +/// A naive cache to reduce number of RPC sent to meta node. +pub struct SimpleTimeTravelVersionCache { + inner: Mutex, +} + +impl SimpleTimeTravelVersionCache { + pub fn new() -> Self { + Self { + inner: Mutex::new(SimpleTimeTravelVersionCacheInner::new()), + } + } + + pub async fn get_or_insert( + &self, + epoch: HummockEpoch, + fetch: impl Future>, + ) -> HummockResult { + let mut guard = self.inner.lock().await; + if let Some(v) = guard.get(&epoch) { + return Ok(v); + } + let version = fetch.await?; + guard.add(epoch, version); + Ok(guard.get(&epoch).unwrap()) + } +} + +struct SimpleTimeTravelVersionCacheInner { + cache: Cache, +} + +impl SimpleTimeTravelVersionCacheInner { + fn new() -> Self { + let capacity = std::env::var("RW_HUMMOCK_TIME_TRAVEL_CACHE_SIZE") + .unwrap_or_else(|_| "10".into()) + .parse() + .unwrap(); + let cache = Cache::builder().max_capacity(capacity).build(); + Self { cache } + } + + fn get(&self, epoch: &HummockEpoch) -> Option { + self.cache.get(epoch) + } + + fn add(&mut self, epoch: HummockEpoch, version: PinnedVersion) { + self.cache.insert(epoch, version); + } +} From 102a60d9407c17f2e606dd3e9737ed4905a65bc2 Mon Sep 17 00:00:00 2001 From: congyi wang <58715567+wcy-fdu@users.noreply.github.com> Date: Fri, 12 Jul 2024 18:26:39 +0800 Subject: [PATCH 03/70] feat(connector): introduce parquet file source (#17201) --- ci/scripts/s3-source-test.sh | 2 +- ci/workflows/main-cron.yml | 22 ++ e2e_test/s3/fs_parquet_source.py | 137 +++++++++++++ proto/plan_common.proto | 1 + src/common/src/array/arrow/arrow_impl.rs | 12 +- src/connector/src/error.rs | 6 +- src/connector/src/parser/mod.rs | 4 + src/connector/src/parser/parquet_parser.rs | 190 ++++++++++++++++++ src/connector/src/sink/catalog/mod.rs | 9 +- src/connector/src/source/base.rs | 4 + .../opendal_source/opendal_reader.rs | 78 ++++--- .../src/handler/alter_source_with_sr.rs | 1 + src/frontend/src/handler/create_sink.rs | 2 +- src/frontend/src/handler/create_source.rs | 5 +- src/sqlparser/src/ast/statement.rs | 5 +- .../src/executor/source/fetch_executor.rs | 1 - 16 files changed, 446 insertions(+), 33 deletions(-) create mode 100644 e2e_test/s3/fs_parquet_source.py create mode 100644 src/connector/src/parser/parquet_parser.rs diff --git a/ci/scripts/s3-source-test.sh b/ci/scripts/s3-source-test.sh index 532223693a21..4ae239719476 100755 --- a/ci/scripts/s3-source-test.sh +++ b/ci/scripts/s3-source-test.sh @@ -32,7 +32,7 @@ echo "--- starting risingwave cluster with connector node" risedev ci-start ci-1cn-1fe echo "--- Run test" -python3 -m pip install --break-system-packages minio psycopg2-binary opendal +python3 -m pip install --break-system-packages minio psycopg2-binary opendal pandas if [[ -v format_type ]]; then python3 e2e_test/s3/"$script" "$format_type" else diff --git a/ci/workflows/main-cron.yml b/ci/workflows/main-cron.yml index cab2fb2858ef..50cd3cc7f45d 100644 --- a/ci/workflows/main-cron.yml +++ b/ci/workflows/main-cron.yml @@ -478,6 +478,28 @@ steps: timeout_in_minutes: 25 retry: *auto-retry + - label: "S3_v2 source check on parquet file" + key: "s3-v2-source-check-parquet-file" + command: "ci/scripts/s3-source-test.sh -p ci-release -s fs_parquet_source.py" + if: | + !(build.pull_request.labels includes "ci/main-cron/skip-ci") && build.env("CI_STEPS") == null + || build.pull_request.labels includes "ci/run-s3-source-tests" + || build.env("CI_STEPS") =~ /(^|,)s3-source-tests?(,|$$)/ + depends_on: build + plugins: + - seek-oss/aws-sm#v2.3.1: + env: + S3_SOURCE_TEST_CONF: ci_s3_source_test_aws + - docker-compose#v5.1.0: + run: rw-build-env + config: ci/docker-compose.yml + mount-buildkite-agent: true + environment: + - S3_SOURCE_TEST_CONF + - ./ci/plugins/upload-failure-logs + timeout_in_minutes: 25 + retry: *auto-retry + - label: "S3_v2 source batch read on AWS (json parser)" key: "s3-v2-source-batch-read-check-aws-json-parser" command: "ci/scripts/s3-source-test.sh -p ci-release -s fs_source_batch.py -t json" diff --git a/e2e_test/s3/fs_parquet_source.py b/e2e_test/s3/fs_parquet_source.py new file mode 100644 index 000000000000..64060928c775 --- /dev/null +++ b/e2e_test/s3/fs_parquet_source.py @@ -0,0 +1,137 @@ +import os +import sys +import random +import psycopg2 +import json +import pyarrow as pa +import pyarrow.parquet as pq +import pandas as pd +from datetime import datetime, timezone +from time import sleep +from minio import Minio +from random import uniform + +def gen_data(file_num, item_num_per_file): + assert item_num_per_file % 2 == 0, \ + f'item_num_per_file should be even to ensure sum(mark) == 0: {item_num_per_file}' + return [ + [{ + 'id': file_id * item_num_per_file + item_id, + 'name': f'{file_id}_{item_id}_{file_id * item_num_per_file + item_id}', + 'sex': item_id % 2, + 'mark': (-1) ** (item_id % 2), + 'test_int': pa.scalar(1, type=pa.int32()), + 'test_real': pa.scalar(4.0, type=pa.float32()), + 'test_double_precision': pa.scalar(5.0, type=pa.float64()), + 'test_varchar': pa.scalar('7', type=pa.string()), + 'test_bytea': pa.scalar(b'\xDe00BeEf', type=pa.binary()), + 'test_date': pa.scalar(datetime.now().date(), type=pa.date32()), + 'test_time': pa.scalar(datetime.now().time(), type=pa.time64('us')), + 'test_timestamp': pa.scalar(datetime.now().timestamp() * 1000000, type=pa.timestamp('us')), + 'test_timestamptz': pa.scalar(datetime.now().timestamp() * 1000, type=pa.timestamp('us', tz='+00:00')), + } for item_id in range(item_num_per_file)] + for file_id in range(file_num) + ] + +def do_test(config, file_num, item_num_per_file, prefix): + conn = psycopg2.connect( + host="localhost", + port="4566", + user="root", + database="dev" + ) + + # Open a cursor to execute SQL statements + cur = conn.cursor() + + def _table(): + return 's3_test_parquet' + + # Execute a SELECT statement + cur.execute(f'''CREATE TABLE {_table()}( + id bigint primary key, + name TEXT, + sex bigint, + mark bigint, + test_int int, + test_real real, + test_double_precision double precision, + test_varchar varchar, + test_bytea bytea, + test_date date, + test_time time, + test_timestamp timestamp, + test_timestamptz timestamptz, + ) WITH ( + connector = 's3_v2', + match_pattern = '*.parquet', + s3.region_name = '{config['S3_REGION']}', + s3.bucket_name = '{config['S3_BUCKET']}', + s3.credentials.access = '{config['S3_ACCESS_KEY']}', + s3.credentials.secret = '{config['S3_SECRET_KEY']}', + s3.endpoint_url = 'https://{config['S3_ENDPOINT']}' + ) FORMAT PLAIN ENCODE PARQUET;''') + + total_rows = file_num * item_num_per_file + MAX_RETRIES = 40 + for retry_no in range(MAX_RETRIES): + cur.execute(f'select count(*) from {_table()}') + result = cur.fetchone() + if result[0] == total_rows: + break + print(f"[retry {retry_no}] Now got {result[0]} rows in table, {total_rows} expected, wait 10s") + sleep(10) + + stmt = f'select count(*), sum(id) from {_table()}' + print(f'Execute {stmt}') + cur.execute(stmt) + result = cur.fetchone() + + print('Got:', result) + + def _assert_eq(field, got, expect): + assert got == expect, f'{field} assertion failed: got {got}, expect {expect}.' + + _assert_eq('count(*)', result[0], total_rows) + _assert_eq('sum(id)', result[1], (total_rows - 1) * total_rows / 2) + + print('Test pass') + + cur.execute(f'drop table {_table()}') + cur.close() + conn.close() + + +if __name__ == "__main__": + FILE_NUM = 10 + ITEM_NUM_PER_FILE = 2000 + data = gen_data(FILE_NUM, ITEM_NUM_PER_FILE) + + config = json.loads(os.environ["S3_SOURCE_TEST_CONF"]) + client = Minio( + config["S3_ENDPOINT"], + access_key=config["S3_ACCESS_KEY"], + secret_key=config["S3_SECRET_KEY"], + secure=True, + ) + run_id = str(random.randint(1000, 9999)) + _local = lambda idx: f'data_{idx}.parquet' + _s3 = lambda idx: f"{run_id}_data_{idx}.parquet" + + # put s3 files + for idx, file_data in enumerate(data): + table = pa.Table.from_pandas(pd.DataFrame(file_data)) + pq.write_table(table, _local(idx)) + + client.fput_object( + config["S3_BUCKET"], + _s3(idx), + _local(idx) + ) + + # do test + do_test(config, FILE_NUM, ITEM_NUM_PER_FILE, run_id) + + # clean up s3 files + for idx, _ in enumerate(data): + client.remove_object(config["S3_BUCKET"], _s3(idx)) \ No newline at end of file diff --git a/proto/plan_common.proto b/proto/plan_common.proto index c019994b771c..bc2e60503f10 100644 --- a/proto/plan_common.proto +++ b/proto/plan_common.proto @@ -165,6 +165,7 @@ enum EncodeType { ENCODE_TYPE_TEMPLATE = 7; ENCODE_TYPE_NONE = 8; ENCODE_TYPE_TEXT = 9; + ENCODE_TYPE_PARQUET = 10; } enum RowFormatType { diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index 35057f62f774..1d5e5816efe0 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -524,9 +524,12 @@ pub trait FromArrow { Float64 => self.from_float64_array(array.as_any().downcast_ref().unwrap()), Date32 => self.from_date32_array(array.as_any().downcast_ref().unwrap()), Time64(Microsecond) => self.from_time64us_array(array.as_any().downcast_ref().unwrap()), - Timestamp(Microsecond, _) => { + Timestamp(Microsecond, None) => { self.from_timestampus_array(array.as_any().downcast_ref().unwrap()) } + Timestamp(Microsecond, Some(_)) => { + self.from_timestampus_some_array(array.as_any().downcast_ref().unwrap()) + } Interval(MonthDayNano) => { self.from_interval_array(array.as_any().downcast_ref().unwrap()) } @@ -628,6 +631,13 @@ pub trait FromArrow { Ok(ArrayImpl::Timestamp(array.into())) } + fn from_timestampus_some_array( + &self, + array: &arrow_array::TimestampMicrosecondArray, + ) -> Result { + Ok(ArrayImpl::Timestamptz(array.into())) + } + fn from_interval_array( &self, array: &arrow_array::IntervalMonthDayNanoArray, diff --git a/src/connector/src/error.rs b/src/connector/src/error.rs index 3a86062d18a0..78a39b735b41 100644 --- a/src/connector/src/error.rs +++ b/src/connector/src/error.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use risingwave_common::array::ArrayError; use risingwave_common::error::def_anyhow_newtype; use risingwave_pb::PbFieldNotFound; use risingwave_rpc_client::error::RpcError; @@ -41,17 +42,20 @@ def_anyhow_newtype! { url::ParseError => "failed to parse url", serde_json::Error => "failed to parse json", csv::Error => "failed to parse csv", + uuid::Error => transparent, // believed to be self-explanatory // Connector errors opendal::Error => transparent, // believed to be self-explanatory - + parquet::errors::ParquetError => transparent, + ArrayError => "Array error", sqlx::Error => transparent, // believed to be self-explanatory mysql_async::Error => "MySQL error", tokio_postgres::Error => "Postgres error", apache_avro::Error => "Avro error", rdkafka::error::KafkaError => "Kafka error", pulsar::Error => "Pulsar error", + async_nats::jetstream::consumer::StreamError => "Nats error", async_nats::jetstream::consumer::pull::MessagesError => "Nats error", async_nats::jetstream::context::CreateStreamError => "Nats error", diff --git a/src/connector/src/parser/mod.rs b/src/connector/src/parser/mod.rs index fee737fce110..a0a612a812f8 100644 --- a/src/connector/src/parser/mod.rs +++ b/src/connector/src/parser/mod.rs @@ -24,6 +24,7 @@ pub use debezium::*; use futures::{Future, TryFutureExt}; use futures_async_stream::try_stream; pub use json_parser::*; +pub use parquet_parser::ParquetParser; pub use protobuf::*; use risingwave_common::array::{ArrayBuilderImpl, Op, StreamChunk}; use risingwave_common::bail; @@ -76,6 +77,7 @@ mod debezium; mod json_parser; mod maxwell; mod mysql; +pub mod parquet_parser; pub mod plain_parser; mod postgres; @@ -1117,6 +1119,7 @@ pub enum EncodingProperties { Json(JsonProperties), MongoJson, Bytes(BytesProperties), + Parquet, Native, /// Encoding can't be specified because the source will determines it. Now only used in Iceberg. None, @@ -1170,6 +1173,7 @@ impl SpecificParserConfig { delimiter: info.csv_delimiter as u8, has_header: info.csv_has_header, }), + (SourceFormat::Plain, SourceEncode::Parquet) => EncodingProperties::Parquet, (SourceFormat::Plain, SourceEncode::Avro) | (SourceFormat::Upsert, SourceEncode::Avro) => { let mut config = AvroProperties { diff --git a/src/connector/src/parser/parquet_parser.rs b/src/connector/src/parser/parquet_parser.rs new file mode 100644 index 000000000000..eeefaaf2d82b --- /dev/null +++ b/src/connector/src/parser/parquet_parser.rs @@ -0,0 +1,190 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +use std::sync::Arc; + +use arrow_array_iceberg::RecordBatch; +use futures_async_stream::try_stream; +use risingwave_common::array::arrow::IcebergArrowConvert; +use risingwave_common::array::{ArrayBuilderImpl, DataChunk, StreamChunk}; +use risingwave_common::types::{Datum, ScalarImpl}; + +use crate::parser::ConnectorResult; +use crate::source::SourceColumnDesc; +/// `ParquetParser` is responsible for converting the incoming `record_batch_stream` +/// into a `streamChunk`. +#[derive(Debug)] +pub struct ParquetParser { + rw_columns: Vec, + file_name: String, + offset: usize, +} + +impl ParquetParser { + pub fn new( + rw_columns: Vec, + file_name: String, + offset: usize, + ) -> ConnectorResult { + Ok(Self { + rw_columns, + file_name, + offset, + }) + } + + #[try_stream(boxed, ok = StreamChunk, error = crate::error::ConnectorError)] + pub async fn into_stream( + mut self, + record_batch_stream: parquet::arrow::async_reader::ParquetRecordBatchStream< + tokio_util::compat::Compat, + >, + ) { + #[for_await] + for record_batch in record_batch_stream { + let record_batch: RecordBatch = record_batch?; + // Convert each record batch into a stream chunk according to user defined schema. + let chunk: StreamChunk = self.convert_record_batch_to_stream_chunk(record_batch)?; + + yield chunk; + } + } + + fn inc_offset(&mut self) { + self.offset += 1; + } + + /// The function `convert_record_batch_to_stream_chunk` is designed to transform the given `RecordBatch` into a `StreamChunk`. + /// + /// For each column in the source column: + /// - If the column's schema matches a column in the `RecordBatch` (both the data type and column name are the same), + /// the corresponding records are converted into a column of the `StreamChunk`. + /// - If the column's schema does not match, null values are inserted. + /// - Hidden columns are handled separately by filling in the appropriate fields to ensure the data chunk maintains the correct format. + /// - If a column in the Parquet file does not exist in the source schema, it is skipped. + /// + /// # Arguments + /// + /// * `record_batch` - The `RecordBatch` to be converted into a `StreamChunk`. + /// + /// # Returns + /// + /// A `StreamChunk` containing the converted data from the `RecordBatch`. + + // The hidden columns that must be included here are _rw_file and _rw_offset. + // Depending on whether the user specifies a primary key (pk), there may be an additional hidden column row_id. + // Therefore, the maximum number of hidden columns is three. + + fn convert_record_batch_to_stream_chunk( + &mut self, + record_batch: RecordBatch, + ) -> Result { + const MAX_HIDDEN_COLUMN_NUMS: usize = 3; + let column_size = self.rw_columns.len(); + let mut chunk_columns = Vec::with_capacity(self.rw_columns.len() + MAX_HIDDEN_COLUMN_NUMS); + for source_column in self.rw_columns.clone() { + match source_column.column_type { + crate::source::SourceColumnType::Normal => { + match source_column.is_hidden_addition_col { + false => { + let rw_data_type = &source_column.data_type; + let rw_column_name = &source_column.name; + if let Some(parquet_column) = + record_batch.column_by_name(rw_column_name) + { + let arrow_field = IcebergArrowConvert + .to_arrow_field(rw_column_name, rw_data_type)?; + let converted_arrow_data_type: &arrow_schema_iceberg::DataType = + arrow_field.data_type(); + if converted_arrow_data_type == parquet_column.data_type() { + let array_impl = IcebergArrowConvert + .array_from_arrow_array(&arrow_field, parquet_column)?; + let column = Arc::new(array_impl); + chunk_columns.push(column); + } else { + // data type mismatch, this column is set to null. + let mut array_builder = ArrayBuilderImpl::with_type( + column_size, + rw_data_type.clone(), + ); + + array_builder.append_n_null(record_batch.num_rows()); + let res = array_builder.finish(); + let column = Arc::new(res); + chunk_columns.push(column); + } + } else { + // For columns defined in the source schema but not present in the Parquet file, null values are filled in. + let mut array_builder = + ArrayBuilderImpl::with_type(column_size, rw_data_type.clone()); + + array_builder.append_n_null(record_batch.num_rows()); + let res = array_builder.finish(); + let column = Arc::new(res); + chunk_columns.push(column); + } + } + // handle hidden columns, for file source, the hidden columns are only `Offset` and `Filename` + true => { + if let Some(additional_column_type) = + &source_column.additional_column.column_type + { + match additional_column_type{ + risingwave_pb::plan_common::additional_column::ColumnType::Offset(_) =>{ + let mut array_builder = + ArrayBuilderImpl::with_type(column_size, source_column.data_type.clone()); + for _ in 0..record_batch.num_rows(){ + let datum: Datum = Some(ScalarImpl::Utf8((self.offset).to_string().into())); + self.inc_offset(); + array_builder.append(datum); + } + let res = array_builder.finish(); + let column = Arc::new(res); + chunk_columns.push(column); + + }, + risingwave_pb::plan_common::additional_column::ColumnType::Filename(_) => { + let mut array_builder = + ArrayBuilderImpl::with_type(column_size, source_column.data_type.clone()); + let datum: Datum = Some(ScalarImpl::Utf8(self.file_name.clone().into())); + array_builder.append_n(record_batch.num_rows(), datum); + let res = array_builder.finish(); + let column = Arc::new(res); + chunk_columns.push(column); + }, + _ => unreachable!() + } + } + } + } + } + crate::source::SourceColumnType::RowId => { + let mut array_builder = + ArrayBuilderImpl::with_type(column_size, source_column.data_type.clone()); + let datum: Datum = None; + array_builder.append_n(record_batch.num_rows(), datum); + let res = array_builder.finish(); + let column = Arc::new(res); + chunk_columns.push(column); + } + // The following fields is only used in CDC source + crate::source::SourceColumnType::Offset | crate::source::SourceColumnType::Meta => { + unreachable!() + } + } + } + + let data_chunk = DataChunk::new(chunk_columns.clone(), record_batch.num_rows()); + Ok(data_chunk.into()) + } +} diff --git a/src/connector/src/sink/catalog/mod.rs b/src/connector/src/sink/catalog/mod.rs index cb814154c822..7638c197f234 100644 --- a/src/connector/src/sink/catalog/mod.rs +++ b/src/connector/src/sink/catalog/mod.rs @@ -235,7 +235,13 @@ impl TryFrom for SinkFormatDesc { E::Protobuf => SinkEncode::Protobuf, E::Template => SinkEncode::Template, E::Avro => SinkEncode::Avro, - e @ (E::Unspecified | E::Native | E::Csv | E::Bytes | E::None | E::Text) => { + e @ (E::Unspecified + | E::Native + | E::Csv + | E::Bytes + | E::None + | E::Text + | E::Parquet) => { return Err(SinkError::Config(anyhow!( "sink encode unsupported: {}", e.as_str_name() @@ -252,6 +258,7 @@ impl TryFrom for SinkFormatDesc { | E::Protobuf | E::Template | E::Native + | E::Parquet | E::None) => { return Err(SinkError::Config(anyhow!( "unsupported {} as sink key encode", diff --git a/src/connector/src/source/base.rs b/src/connector/src/source/base.rs index 1f2444a0db4e..707678b66599 100644 --- a/src/connector/src/source/base.rs +++ b/src/connector/src/source/base.rs @@ -243,6 +243,7 @@ pub enum SourceEncode { Protobuf, Json, Bytes, + Parquet, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] @@ -295,6 +296,9 @@ pub fn extract_source_struct(info: &PbStreamSourceInfo) -> Result (PbFormatType::Maxwell, PbEncodeType::Json) => (SourceFormat::Maxwell, SourceEncode::Json), (PbFormatType::Canal, PbEncodeType::Json) => (SourceFormat::Canal, SourceEncode::Json), (PbFormatType::Plain, PbEncodeType::Csv) => (SourceFormat::Plain, SourceEncode::Csv), + (PbFormatType::Plain, PbEncodeType::Parquet) => { + (SourceFormat::Plain, SourceEncode::Parquet) + } (PbFormatType::Native, PbEncodeType::Native) => { (SourceFormat::Native, SourceEncode::Native) } diff --git a/src/connector/src/source/filesystem/opendal_source/opendal_reader.rs b/src/connector/src/source/filesystem/opendal_source/opendal_reader.rs index f3a6b96cf2d6..5757452d2b4c 100644 --- a/src/connector/src/source/filesystem/opendal_source/opendal_reader.rs +++ b/src/connector/src/source/filesystem/opendal_source/opendal_reader.rs @@ -20,20 +20,22 @@ use async_trait::async_trait; use futures::TryStreamExt; use futures_async_stream::try_stream; use opendal::Operator; +use parquet::arrow::ParquetRecordBatchStreamBuilder; use risingwave_common::array::StreamChunk; +use risingwave_common::util::tokio_util::compat::FuturesAsyncReadCompatExt; use tokio::io::{AsyncRead, BufReader}; use tokio_util::io::{ReaderStream, StreamReader}; use super::opendal_enumerator::OpendalEnumerator; use super::OpendalSource; use crate::error::ConnectorResult; -use crate::parser::ParserConfig; +use crate::parser::{ByteStreamSourceParserImpl, EncodingProperties, ParquetParser, ParserConfig}; use crate::source::filesystem::file_common::CompressionFormat; use crate::source::filesystem::nd_streaming::need_nd_streaming; use crate::source::filesystem::{nd_streaming, OpendalFsSplit}; use crate::source::{ - into_chunk_stream, BoxChunkSourceStream, Column, SourceContextRef, SourceMessage, SourceMeta, - SplitMetaData, SplitReader, + BoxChunkSourceStream, Column, SourceContextRef, SourceMessage, SourceMeta, SplitMetaData, + SplitReader, }; const STREAM_READER_CAPACITY: usize = 4096; @@ -76,24 +78,54 @@ impl OpendalReader { #[try_stream(boxed, ok = StreamChunk, error = crate::error::ConnectorError)] async fn into_stream_inner(self) { for split in self.splits { - let data_stream = Self::stream_read_object( - self.connector.op.clone(), - split, - self.source_ctx.clone(), - self.connector.compression_format.clone(), - ); - - let data_stream = if need_nd_streaming(&self.parser_config.specific.encoding_config) { - nd_streaming::split_stream(data_stream) + let source_ctx = self.source_ctx.clone(); + + let object_name = split.name.clone(); + + let msg_stream; + + if let EncodingProperties::Parquet = &self.parser_config.specific.encoding_config { + // // If the format is "parquet", use `ParquetParser` to convert `record_batch` into stream chunk. + let reader: tokio_util::compat::Compat = self + .connector + .op + .reader_with(&object_name) + .into_future() // Unlike `rustc`, `try_stream` seems require manual `into_future`. + .await? + .into_futures_async_read(..) + .await? + .compat(); + // For the Parquet format, we directly convert from a record batch to a stream chunk. + // Therefore, the offset of the Parquet file represents the current position in terms of the number of rows read from the file. + let record_batch_stream = ParquetRecordBatchStreamBuilder::new(reader) + .await? + .with_batch_size(self.source_ctx.source_ctrl_opts.chunk_size) + .with_offset(split.offset) + .build()?; + + let parquet_parser = ParquetParser::new( + self.parser_config.common.rw_columns.clone(), + object_name, + split.offset, + )?; + msg_stream = parquet_parser.into_stream(record_batch_stream); } else { - data_stream - }; - - let msg_stream = into_chunk_stream( - data_stream, - self.parser_config.clone(), - self.source_ctx.clone(), - ); + let data_stream = Self::stream_read_object( + self.connector.op.clone(), + split, + self.source_ctx.clone(), + self.connector.compression_format.clone(), + ); + + let parser = + ByteStreamSourceParserImpl::create(self.parser_config.clone(), source_ctx) + .await?; + msg_stream = if need_nd_streaming(&self.parser_config.specific.encoding_config) { + Box::pin(parser.into_stream(nd_streaming::split_stream(data_stream))) + } else { + Box::pin(parser.into_stream(data_stream)) + }; + } #[for_await] for msg in msg_stream { let msg = msg?; @@ -115,15 +147,12 @@ impl OpendalReader { let source_name = source_ctx.source_name.to_string(); let max_chunk_size = source_ctx.source_ctrl_opts.chunk_size; let split_id = split.id(); - let object_name = split.name.clone(); - let reader = op .read_with(&object_name) .range(split.offset as u64..) .into_future() // Unlike `rustc`, `try_stream` seems require manual `into_future`. .await?; - let stream_reader = StreamReader::new( reader.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)), ); @@ -144,11 +173,10 @@ impl OpendalReader { } }; - let stream = ReaderStream::with_capacity(buf_reader, STREAM_READER_CAPACITY); - let mut offset: usize = split.offset; let mut batch_size: usize = 0; let mut batch = Vec::new(); + let stream = ReaderStream::with_capacity(buf_reader, STREAM_READER_CAPACITY); #[for_await] for read in stream { let bytes = read?; diff --git a/src/frontend/src/handler/alter_source_with_sr.rs b/src/frontend/src/handler/alter_source_with_sr.rs index 070b26b6a25e..840205caeadd 100644 --- a/src/frontend/src/handler/alter_source_with_sr.rs +++ b/src/frontend/src/handler/alter_source_with_sr.rs @@ -63,6 +63,7 @@ fn encode_type_to_encode(from: EncodeType) -> Option { EncodeType::Json => Encode::Json, EncodeType::Bytes => Encode::Bytes, EncodeType::Template => Encode::Template, + EncodeType::Parquet => Encode::Parquet, EncodeType::None => Encode::None, EncodeType::Text => Encode::Text, }) diff --git a/src/frontend/src/handler/create_sink.rs b/src/frontend/src/handler/create_sink.rs index 1dcb4167f8c1..b2a65b8d0978 100644 --- a/src/frontend/src/handler/create_sink.rs +++ b/src/frontend/src/handler/create_sink.rs @@ -775,7 +775,7 @@ fn bind_sink_format_desc(value: ConnectorSchema) -> Result { E::Protobuf => SinkEncode::Protobuf, E::Avro => SinkEncode::Avro, E::Template => SinkEncode::Template, - e @ (E::Native | E::Csv | E::Bytes | E::None | E::Text) => { + e @ (E::Native | E::Csv | E::Bytes | E::None | E::Text | E::Parquet) => { return Err(ErrorCode::BindError(format!("sink encode unsupported: {e}")).into()); } }; diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index 89c5afca08a4..6263cc618fc0 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -442,6 +442,8 @@ pub(crate) async fn bind_columns_from_source( None } + // For parquet format, this step is implemented in parquet parser. + (Format::Plain, Encode::Parquet) => None, ( Format::Plain | Format::Upsert | Format::Maxwell | Format::Canal | Format::Debezium, Encode::Json, @@ -1017,7 +1019,7 @@ static CONNECTORS_COMPATIBLE_FORMATS: LazyLock vec![Encode::Csv, Encode::Json], ), OPENDAL_S3_CONNECTOR => hashmap!( - Format::Plain => vec![Encode::Csv, Encode::Json], + Format::Plain => vec![Encode::Csv, Encode::Json, Encode::Parquet], ), GCS_CONNECTOR => hashmap!( Format::Plain => vec![Encode::Csv, Encode::Json], @@ -1605,6 +1607,7 @@ fn row_encode_to_prost(row_encode: &Encode) -> EncodeType { Encode::Csv => EncodeType::Csv, Encode::Bytes => EncodeType::Bytes, Encode::Template => EncodeType::Template, + Encode::Parquet => EncodeType::Parquet, Encode::None => EncodeType::None, Encode::Text => EncodeType::Text, } diff --git a/src/sqlparser/src/ast/statement.rs b/src/sqlparser/src/ast/statement.rs index facedadeb5bc..732badfdd0a5 100644 --- a/src/sqlparser/src/ast/statement.rs +++ b/src/sqlparser/src/ast/statement.rs @@ -173,6 +173,7 @@ pub enum Encode { /// Used internally for schema change Native, Template, + Parquet, } // TODO: unify with `from_keyword` @@ -190,6 +191,7 @@ impl fmt::Display for Encode { Encode::Native => "NATIVE", Encode::Template => "TEMPLATE", Encode::None => "NONE", + Encode::Parquet => "PARQUET", Encode::Text => "TEXT", } ) @@ -208,8 +210,9 @@ impl Encode { "TEMPLATE" => Encode::Template, "NATIVE" => Encode::Native, "NONE" => Encode::None, + "PARQUET" => Encode::Parquet, _ => parser_err!( - "expected AVRO | BYTES | CSV | PROTOBUF | JSON | NATIVE | TEMPLATE | NONE after Encode" + "expected AVRO | BYTES | CSV | PROTOBUF | JSON | NATIVE | TEMPLATE | PARQUET | NONE after Encode" ), }) } diff --git a/src/stream/src/executor/source/fetch_executor.rs b/src/stream/src/executor/source/fetch_executor.rs index b4c006469e65..788a9a45662c 100644 --- a/src/stream/src/executor/source/fetch_executor.rs +++ b/src/stream/src/executor/source/fetch_executor.rs @@ -199,7 +199,6 @@ impl FsFetchExecutor { else { unreachable!("Partition and offset columns must be set."); }; - // Initialize state table. state_store_handler.init_epoch(barrier.epoch); From 5c52c7d752804e5e0d66b370b7da70a5cd2e8d88 Mon Sep 17 00:00:00 2001 From: xiangjinwu <17769960+xiangjinwu@users.noreply.github.com> Date: Fri, 12 Jul 2024 20:07:24 +0800 Subject: [PATCH 04/70] feat(source): Avro with AWS Glue Schema Registry (#17605) --- Cargo.lock | 23 ++ Cargo.toml | 1 + e2e_test/source_inline/kafka/avro/glue.slt | 121 ++++++++++ proto/catalog.proto | 5 + src/connector/Cargo.toml | 1 + ...hema_resolver.rs => confluent_resolver.rs} | 0 .../src/parser/avro/glue_resolver.rs | 220 ++++++++++++++++++ src/connector/src/parser/avro/mod.rs | 6 +- src/connector/src/parser/avro/parser.rs | 215 +++++++++++------ .../src/parser/debezium/avro_parser.rs | 16 +- src/connector/src/parser/mod.rs | 98 ++++++-- src/connector/src/schema/mod.rs | 1 + src/frontend/src/handler/create_source.rs | 64 ++--- 13 files changed, 637 insertions(+), 134 deletions(-) create mode 100644 e2e_test/source_inline/kafka/avro/glue.slt rename src/connector/src/parser/avro/{schema_resolver.rs => confluent_resolver.rs} (100%) create mode 100644 src/connector/src/parser/avro/glue_resolver.rs diff --git a/Cargo.lock b/Cargo.lock index 8aed93149faa..70f7d40c510e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1500,6 +1500,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "aws-sdk-glue" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b6c34f6f4b9e8f76274a9b309838d670b3bb69b4be6756394de54718aa2ca0a" +dependencies = [ + "aws-credential-types", + "aws-http", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "http 0.2.9", + "regex", + "tracing", +] + [[package]] name = "aws-sdk-kinesis" version = "1.3.0" @@ -11123,6 +11145,7 @@ dependencies = [ "aws-credential-types", "aws-msk-iam-sasl-signer", "aws-sdk-dynamodb", + "aws-sdk-glue", "aws-sdk-kinesis", "aws-sdk-s3", "aws-smithy-http", diff --git a/Cargo.toml b/Cargo.toml index dfa5d090121d..ae8bea531028 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -94,6 +94,7 @@ aws-config = { version = "1", default-features = false, features = [ aws-credential-types = { version = "1", default-features = false, features = [ "hardcoded-credentials", ] } +aws-sdk-glue = "1" aws-sdk-kinesis = { version = "1", default-features = false, features = [ "rt-tokio", "rustls", diff --git a/e2e_test/source_inline/kafka/avro/glue.slt b/e2e_test/source_inline/kafka/avro/glue.slt new file mode 100644 index 000000000000..13a1ae27ff3a --- /dev/null +++ b/e2e_test/source_inline/kafka/avro/glue.slt @@ -0,0 +1,121 @@ +control substitution on + +system ok +rpk topic delete 'glue-sample-my-event' + +system ok +rpk topic create 'glue-sample-my-event' + +system ok +rpk topic produce -f '%v{hex}\n' 'glue-sample-my-event' < ConnectorResult>; + /// Gets the latest schema by arn, which is used as *reader schema*. + async fn get_by_name(&self, schema_arn: &str) -> ConnectorResult>; +} + +#[derive(Debug)] +pub enum GlueSchemaCacheImpl { + Real(RealGlueSchemaCache), + Mock(MockGlueSchemaCache), +} + +impl GlueSchemaCacheImpl { + pub async fn new( + aws_auth_props: &AwsAuthProps, + mock_config: Option<&str>, + ) -> ConnectorResult { + if let Some(mock_config) = mock_config { + return Ok(Self::Mock(MockGlueSchemaCache::new(mock_config))); + } + Ok(Self::Real(RealGlueSchemaCache::new(aws_auth_props).await?)) + } +} + +impl GlueSchemaCache for GlueSchemaCacheImpl { + async fn get_by_id(&self, schema_version_id: uuid::Uuid) -> ConnectorResult> { + match self { + Self::Real(inner) => inner.get_by_id(schema_version_id).await, + Self::Mock(inner) => inner.get_by_id(schema_version_id).await, + } + } + + async fn get_by_name(&self, schema_arn: &str) -> ConnectorResult> { + match self { + Self::Real(inner) => inner.get_by_name(schema_arn).await, + Self::Mock(inner) => inner.get_by_name(schema_arn).await, + } + } +} + +#[derive(Debug)] +pub struct RealGlueSchemaCache { + writer_schemas: Cache>, + glue_client: Client, +} + +impl RealGlueSchemaCache { + /// Create a new `GlueSchemaCache` + pub async fn new(aws_auth_props: &AwsAuthProps) -> ConnectorResult { + let client = Client::new(&aws_auth_props.build_config().await?); + Ok(Self { + writer_schemas: Cache::new(u64::MAX), + glue_client: client, + }) + } + + async fn parse_and_cache_schema( + &self, + schema_version_id: uuid::Uuid, + content: &str, + ) -> ConnectorResult> { + let schema = Schema::parse_str(content).context("failed to parse avro schema")?; + let schema = Arc::new(schema); + self.writer_schemas + .insert(schema_version_id, Arc::clone(&schema)) + .await; + Ok(schema) + } +} + +impl GlueSchemaCache for RealGlueSchemaCache { + /// Gets the a specific schema by id, which is used as *writer schema*. + async fn get_by_id(&self, schema_version_id: uuid::Uuid) -> ConnectorResult> { + if let Some(schema) = self.writer_schemas.get(&schema_version_id).await { + return Ok(schema); + } + let res = self + .glue_client + .get_schema_version() + .schema_version_id(schema_version_id) + .send() + .await + .context("glue sdk error")?; + let definition = res + .schema_definition() + .context("glue sdk response without definition")?; + self.parse_and_cache_schema(schema_version_id, definition) + .await + } + + /// Gets the latest schema by arn, which is used as *reader schema*. + async fn get_by_name(&self, schema_arn: &str) -> ConnectorResult> { + let res = self + .glue_client + .get_schema_version() + .schema_id(SchemaId::builder().schema_arn(schema_arn).build()) + .schema_version_number(SchemaVersionNumber::builder().latest_version(true).build()) + .send() + .await + .context("glue sdk error")?; + let schema_version_id = res + .schema_version_id() + .context("glue sdk response without schema version id")? + .parse() + .context("glue sdk response invalid schema version id")?; + let definition = res + .schema_definition() + .context("glue sdk response without definition")?; + self.parse_and_cache_schema(schema_version_id, definition) + .await + } +} + +#[derive(Debug)] +pub struct MockGlueSchemaCache { + by_id: HashMap>, + arn_to_latest_id: HashMap, +} + +impl MockGlueSchemaCache { + pub fn new(mock_config: &str) -> Self { + // The `mock_config` accepted is a JSON that looks like: + // { + // "by_id": { + // "4dc80ccf-2d0c-4846-9325-7e1c9e928121": { + // "type": "record", + // "name": "MyEvent", + // "fields": [...] + // }, + // "3df022f4-b16d-4afe-bdf7-cf4baf8d01d3": { + // ... + // } + // }, + // "arn_to_latest_id": { + // "arn:aws:glue:ap-southeast-1:123456123456:schema/default-registry/MyEvent": "3df022f4-b16d-4afe-bdf7-cf4baf8d01d3" + // } + // } + // + // The format is not public and we can make breaking changes to it. + // Current format only supports avsc. + let parsed: serde_json::Value = + serde_json::from_str(mock_config).expect("mock config shall be valid json"); + let by_id = parsed + .get("by_id") + .unwrap() + .as_object() + .unwrap() + .iter() + .map(|(schema_version_id, schema)| { + let schema_version_id = schema_version_id.parse().unwrap(); + let schema = Schema::parse(schema).unwrap(); + (schema_version_id, Arc::new(schema)) + }) + .collect(); + let arn_to_latest_id = parsed + .get("arn_to_latest_id") + .unwrap() + .as_object() + .unwrap() + .iter() + .map(|(arn, latest_id)| (arn.clone(), latest_id.as_str().unwrap().parse().unwrap())) + .collect(); + Self { + by_id, + arn_to_latest_id, + } + } +} + +impl GlueSchemaCache for MockGlueSchemaCache { + async fn get_by_id(&self, schema_version_id: uuid::Uuid) -> ConnectorResult> { + Ok(self + .by_id + .get(&schema_version_id) + .context("schema version id not found in mock registry")? + .clone()) + } + + async fn get_by_name(&self, schema_arn: &str) -> ConnectorResult> { + let schema_version_id = self + .arn_to_latest_id + .get(schema_arn) + .context("schema arn not found in mock registry")?; + self.get_by_id(*schema_version_id).await + } +} diff --git a/src/connector/src/parser/avro/mod.rs b/src/connector/src/parser/avro/mod.rs index 19193035bd56..536c700efef8 100644 --- a/src/connector/src/parser/avro/mod.rs +++ b/src/connector/src/parser/avro/mod.rs @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod confluent_resolver; +mod glue_resolver; mod parser; -mod schema_resolver; +pub use confluent_resolver::ConfluentSchemaCache; +pub use glue_resolver::{GlueSchemaCache, GlueSchemaCacheImpl}; pub use parser::{AvroAccessBuilder, AvroParserConfig}; -pub use schema_resolver::ConfluentSchemaCache; diff --git a/src/connector/src/parser/avro/parser.rs b/src/connector/src/parser/avro/parser.rs index 5bd0038e3e83..68ac7d446846 100644 --- a/src/connector/src/parser/avro/parser.rs +++ b/src/connector/src/parser/avro/parser.rs @@ -24,11 +24,13 @@ use risingwave_connector_codec::decoder::avro::{ }; use risingwave_pb::plan_common::ColumnDesc; -use super::ConfluentSchemaCache; +use super::{ConfluentSchemaCache, GlueSchemaCache as _, GlueSchemaCacheImpl}; use crate::error::ConnectorResult; use crate::parser::unified::AccessImpl; use crate::parser::util::bytes_from_url; -use crate::parser::{AccessBuilder, AvroProperties, EncodingProperties, EncodingType, MapHandling}; +use crate::parser::{ + AccessBuilder, AvroProperties, EncodingProperties, EncodingType, MapHandling, SchemaLocation, +}; use crate::schema::schema_registry::{ extract_schema_id, get_subject_by_strategy, handle_sr_list, Client, }; @@ -38,7 +40,7 @@ use crate::schema::schema_registry::{ pub struct AvroAccessBuilder { schema: Arc, /// Refer to [`AvroParserConfig::writer_schema_cache`]. - pub writer_schema_cache: Option>, + writer_schema_cache: WriterSchemaCache, value: Option, } @@ -93,21 +95,51 @@ impl AvroAccessBuilder { async fn parse_avro_value(&self, payload: &[u8]) -> ConnectorResult> { // parse payload to avro value // if use confluent schema, get writer schema from confluent schema registry - if let Some(resolver) = &self.writer_schema_cache { - let (schema_id, mut raw_payload) = extract_schema_id(payload)?; - let writer_schema = resolver.get_by_id(schema_id).await?; - Ok(Some(from_avro_datum( - writer_schema.as_ref(), - &mut raw_payload, - Some(&self.schema.original_schema), - )?)) - } else { - // FIXME: we should not use `Reader` (file header) here. See comment above and https://github.com/risingwavelabs/risingwave/issues/12871 - let mut reader = Reader::with_schema(&self.schema.original_schema, payload)?; - match reader.next() { - Some(Ok(v)) => Ok(Some(v)), - Some(Err(e)) => Err(e)?, - None => bail!("avro parse unexpected eof"), + match &self.writer_schema_cache { + WriterSchemaCache::Confluent(resolver) => { + let (schema_id, mut raw_payload) = extract_schema_id(payload)?; + let writer_schema = resolver.get_by_id(schema_id).await?; + Ok(Some(from_avro_datum( + writer_schema.as_ref(), + &mut raw_payload, + Some(&self.schema.original_schema), + )?)) + } + WriterSchemaCache::File => { + // FIXME: we should not use `Reader` (file header) here. See comment above and https://github.com/risingwavelabs/risingwave/issues/12871 + let mut reader = Reader::with_schema(&self.schema.original_schema, payload)?; + match reader.next() { + Some(Ok(v)) => Ok(Some(v)), + Some(Err(e)) => Err(e)?, + None => bail!("avro parse unexpected eof"), + } + } + WriterSchemaCache::Glue(resolver) => { + // + // byte 0: header version = 3 + // byte 1: compression: 0 = no compression; 5 = zlib (unsupported) + // byte 2..=17: 16-byte UUID as schema version id + // byte 18..: raw avro payload + if payload.len() < 18 { + bail!("payload shorter than 18-byte glue header"); + } + if payload[0] != 3 { + bail!( + "Only support glue header version 3 but found {}", + payload[0] + ); + } + if payload[1] != 0 { + bail!("Non-zero compression {} not supported", payload[1]); + } + let schema_version_id = uuid::Uuid::from_slice(&payload[2..18]).unwrap(); + let writer_schema = resolver.get_by_id(schema_version_id).await?; + let mut raw_payload = &payload[18..]; + Ok(Some(from_avro_datum( + writer_schema.as_ref(), + &mut raw_payload, + Some(&self.schema.original_schema), + )?)) } } } @@ -115,83 +147,112 @@ impl AvroAccessBuilder { #[derive(Debug, Clone)] pub struct AvroParserConfig { - pub schema: Arc, - pub key_schema: Option>, + schema: Arc, + key_schema: Option>, /// Writer schema is the schema used to write the data. When parsing Avro data, the exactly same schema /// must be used to decode the message, and then convert it with the reader schema. - pub writer_schema_cache: Option>, + writer_schema_cache: WriterSchemaCache, + + map_handling: Option, +} - pub map_handling: Option, +#[derive(Debug, Clone)] +enum WriterSchemaCache { + Confluent(Arc), + Glue(Arc), + File, } impl AvroParserConfig { pub async fn new(encoding_properties: EncodingProperties) -> ConnectorResult { let AvroProperties { - use_schema_registry, - row_schema_location: schema_location, - client_config, - aws_auth_props, - topic, + schema_location, enable_upsert, record_name, key_record_name, - name_strategy, map_handling, } = try_match_expand!(encoding_properties, EncodingProperties::Avro)?; - let url = handle_sr_list(schema_location.as_str())?; - if use_schema_registry { - let client = Client::new(url, &client_config)?; - let resolver = ConfluentSchemaCache::new(client); + match schema_location { + SchemaLocation::Confluent { + urls: schema_location, + client_config, + name_strategy, + topic, + } => { + let url = handle_sr_list(schema_location.as_str())?; + let client = Client::new(url, &client_config)?; + let resolver = ConfluentSchemaCache::new(client); - let subject_key = if enable_upsert { - Some(get_subject_by_strategy( + let subject_key = if enable_upsert { + Some(get_subject_by_strategy( + &name_strategy, + topic.as_str(), + key_record_name.as_deref(), + true, + )?) + } else { + if let Some(name) = &key_record_name { + bail!("unused FORMAT ENCODE option: key.message='{name}'"); + } + None + }; + let subject_value = get_subject_by_strategy( &name_strategy, topic.as_str(), - key_record_name.as_deref(), - true, - )?) - } else { - if let Some(name) = &key_record_name { - bail!("unused FORMAT ENCODE option: key.message='{name}'"); + record_name.as_deref(), + false, + )?; + tracing::debug!("infer key subject {subject_key:?}, value subject {subject_value}"); + + Ok(Self { + schema: Arc::new(ResolvedAvroSchema::create( + resolver.get_by_subject(&subject_value).await?, + )?), + key_schema: if let Some(subject_key) = subject_key { + Some(Arc::new(ResolvedAvroSchema::create( + resolver.get_by_subject(&subject_key).await?, + )?)) + } else { + None + }, + writer_schema_cache: WriterSchemaCache::Confluent(Arc::new(resolver)), + map_handling, + }) + } + SchemaLocation::File { + url: schema_location, + aws_auth_props, + } => { + let url = handle_sr_list(schema_location.as_str())?; + if enable_upsert { + bail!("avro upsert without schema registry is not supported"); } - None - }; - let subject_value = get_subject_by_strategy( - &name_strategy, - topic.as_str(), - record_name.as_deref(), - false, - )?; - tracing::debug!("infer key subject {subject_key:?}, value subject {subject_value}"); - - Ok(Self { - schema: Arc::new(ResolvedAvroSchema::create( - resolver.get_by_subject(&subject_value).await?, - )?), - key_schema: if let Some(subject_key) = subject_key { - Some(Arc::new(ResolvedAvroSchema::create( - resolver.get_by_subject(&subject_key).await?, - )?)) - } else { - None - }, - writer_schema_cache: Some(Arc::new(resolver)), - map_handling, - }) - } else { - if enable_upsert { - bail!("avro upsert without schema registry is not supported"); + let url = url.first().unwrap(); + let schema_content = bytes_from_url(url, aws_auth_props.as_ref()).await?; + let schema = Schema::parse_reader(&mut schema_content.as_slice()) + .context("failed to parse avro schema")?; + Ok(Self { + schema: Arc::new(ResolvedAvroSchema::create(Arc::new(schema))?), + key_schema: None, + writer_schema_cache: WriterSchemaCache::File, + map_handling, + }) + } + SchemaLocation::Glue { + schema_arn, + aws_auth_props, + mock_config, + } => { + let resolver = + GlueSchemaCacheImpl::new(&aws_auth_props, mock_config.as_deref()).await?; + let schema = resolver.get_by_name(&schema_arn).await?; + Ok(Self { + schema: Arc::new(ResolvedAvroSchema::create(schema)?), + key_schema: None, + writer_schema_cache: WriterSchemaCache::Glue(Arc::new(resolver)), + map_handling, + }) } - let url = url.first().unwrap(); - let schema_content = bytes_from_url(url, aws_auth_props.as_ref()).await?; - let schema = Schema::parse_reader(&mut schema_content.as_slice()) - .context("failed to parse avro schema")?; - Ok(Self { - schema: Arc::new(ResolvedAvroSchema::create(Arc::new(schema))?), - key_schema: None, - writer_schema_cache: None, - map_handling, - }) } } diff --git a/src/connector/src/parser/debezium/avro_parser.rs b/src/connector/src/parser/debezium/avro_parser.rs index 04f80ebba1ca..2ddfc77073c8 100644 --- a/src/connector/src/parser/debezium/avro_parser.rs +++ b/src/connector/src/parser/debezium/avro_parser.rs @@ -22,13 +22,12 @@ use risingwave_connector_codec::decoder::avro::{ avro_extract_field_schema, avro_schema_skip_nullable_union, avro_schema_to_column_descs, AvroAccess, AvroParseOptions, ResolvedAvroSchema, }; -use risingwave_pb::catalog::PbSchemaRegistryNameStrategy; use risingwave_pb::plan_common::ColumnDesc; use crate::error::ConnectorResult; use crate::parser::avro::ConfluentSchemaCache; use crate::parser::unified::AccessImpl; -use crate::parser::{AccessBuilder, EncodingProperties, EncodingType}; +use crate::parser::{AccessBuilder, EncodingProperties, EncodingType, SchemaLocation}; use crate::schema::schema_registry::{ extract_schema_id, get_subject_by_strategy, handle_sr_list, Client, }; @@ -95,14 +94,19 @@ pub struct DebeziumAvroParserConfig { impl DebeziumAvroParserConfig { pub async fn new(encoding_config: EncodingProperties) -> ConnectorResult { let avro_config = try_match_expand!(encoding_config, EncodingProperties::Avro)?; - let schema_location = &avro_config.row_schema_location; - let client_config = &avro_config.client_config; - let kafka_topic = &avro_config.topic; + let SchemaLocation::Confluent { + urls: schema_location, + client_config, + name_strategy, + topic: kafka_topic, + } = &avro_config.schema_location + else { + unreachable!() + }; let url = handle_sr_list(schema_location)?; let client = Client::new(url, client_config)?; let resolver = ConfluentSchemaCache::new(client); - let name_strategy = &PbSchemaRegistryNameStrategy::Unspecified; let key_subject = get_subject_by_strategy(name_strategy, kafka_topic, None, true)?; let val_subject = get_subject_by_strategy(name_strategy, kafka_topic, None, false)?; let key_schema = resolver.get_by_subject(&key_subject).await?; diff --git a/src/connector/src/parser/mod.rs b/src/connector/src/parser/mod.rs index a0a612a812f8..e88998f00596 100644 --- a/src/connector/src/parser/mod.rs +++ b/src/connector/src/parser/mod.rs @@ -61,6 +61,7 @@ use crate::parser::util::{ extreact_timestamp_from_meta, }; use crate::schema::schema_registry::SchemaRegistryAuth; +use crate::schema::AWS_GLUE_SCHEMA_ARN_KEY; use crate::source::monitor::GLOBAL_SOURCE_METRICS; use crate::source::{ extract_source_struct, BoxSourceStream, ChunkSourceStream, SourceColumnDesc, SourceColumnType, @@ -1069,18 +1070,48 @@ impl SpecificParserConfig { #[derive(Debug, Default, Clone)] pub struct AvroProperties { - pub use_schema_registry: bool, - pub row_schema_location: String, - pub client_config: SchemaRegistryAuth, - pub aws_auth_props: Option, - pub topic: String, + pub schema_location: SchemaLocation, pub enable_upsert: bool, pub record_name: Option, pub key_record_name: Option, - pub name_strategy: PbSchemaRegistryNameStrategy, pub map_handling: Option, } +/// WIP: may cover protobuf and json schema later. +#[derive(Debug, Clone)] +pub enum SchemaLocation { + /// Avsc from `https://`, `s3://` or `file://`. + File { + url: String, + aws_auth_props: Option, // for s3 + }, + /// + Confluent { + urls: String, + client_config: SchemaRegistryAuth, + name_strategy: PbSchemaRegistryNameStrategy, + topic: String, + }, + /// + Glue { + schema_arn: String, + aws_auth_props: AwsAuthProps, + // When `Some(_)`, ignore AWS and load schemas from provided config + mock_config: Option, + }, +} + +// TODO: `SpecificParserConfig` shall not `impl`/`derive` a `Default` +impl Default for SchemaLocation { + fn default() -> Self { + // backward compatible but undesired + Self::File { + url: Default::default(), + aws_auth_props: None, + } + } +} + #[derive(Debug, Default, Clone)] pub struct ProtobufProperties { pub message_name: String, @@ -1183,27 +1214,46 @@ impl SpecificParserConfig { Some(info.proto_message_name.clone()) }, key_record_name: info.key_message_name.clone(), - name_strategy: PbSchemaRegistryNameStrategy::try_from(info.name_strategy) - .unwrap(), - use_schema_registry: info.use_schema_registry, - row_schema_location: info.row_schema_location.clone(), map_handling: MapHandling::from_options(&info.format_encode_options)?, ..Default::default() }; if format == SourceFormat::Upsert { config.enable_upsert = true; } - if info.use_schema_registry { - config.topic.clone_from(get_kafka_topic(with_properties)?); - config.client_config = SchemaRegistryAuth::from(&info.format_encode_options); - } else { - config.aws_auth_props = Some( - serde_json::from_value::( + config.schema_location = if let Some(schema_arn) = + info.format_encode_options.get(AWS_GLUE_SCHEMA_ARN_KEY) + { + SchemaLocation::Glue { + schema_arn: schema_arn.clone(), + aws_auth_props: serde_json::from_value::( serde_json::to_value(info.format_encode_options.clone()).unwrap(), ) .map_err(|e| anyhow::anyhow!(e))?, - ); - } + // The option `mock_config` is not public and we can break compatibility. + mock_config: info + .format_encode_options + .get("aws.glue.mock_config") + .cloned(), + } + } else if info.use_schema_registry { + SchemaLocation::Confluent { + urls: info.row_schema_location.clone(), + client_config: SchemaRegistryAuth::from(&info.format_encode_options), + name_strategy: PbSchemaRegistryNameStrategy::try_from(info.name_strategy) + .unwrap(), + topic: get_kafka_topic(with_properties)?.clone(), + } + } else { + SchemaLocation::File { + url: info.row_schema_location.clone(), + aws_auth_props: Some( + serde_json::from_value::( + serde_json::to_value(info.format_encode_options.clone()).unwrap(), + ) + .map_err(|e| anyhow::anyhow!(e))?, + ), + } + }; EncodingProperties::Avro(config) } (SourceFormat::Plain, SourceEncode::Protobuf) @@ -1243,12 +1293,14 @@ impl SpecificParserConfig { } else { Some(info.proto_message_name.clone()) }, - name_strategy: PbSchemaRegistryNameStrategy::try_from(info.name_strategy) - .unwrap(), key_record_name: info.key_message_name.clone(), - row_schema_location: info.row_schema_location.clone(), - topic: get_kafka_topic(with_properties).unwrap().clone(), - client_config: SchemaRegistryAuth::from(&info.format_encode_options), + schema_location: SchemaLocation::Confluent { + urls: info.row_schema_location.clone(), + client_config: SchemaRegistryAuth::from(&info.format_encode_options), + name_strategy: PbSchemaRegistryNameStrategy::try_from(info.name_strategy) + .unwrap(), + topic: get_kafka_topic(with_properties).unwrap().clone(), + }, ..Default::default() }) } diff --git a/src/connector/src/schema/mod.rs b/src/connector/src/schema/mod.rs index 585dd43fa8bf..9b3757e29c09 100644 --- a/src/connector/src/schema/mod.rs +++ b/src/connector/src/schema/mod.rs @@ -26,6 +26,7 @@ const KEY_MESSAGE_NAME_KEY: &str = "key.message"; const SCHEMA_LOCATION_KEY: &str = "schema.location"; const SCHEMA_REGISTRY_KEY: &str = "schema.registry"; const NAME_STRATEGY_KEY: &str = "schema.registry.name.strategy"; +pub const AWS_GLUE_SCHEMA_ARN_KEY: &str = "aws.glue.schema_arn"; #[derive(Debug, thiserror::Error, thiserror_ext::Macro)] #[error("Invalid option: {message}")] diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index 6263cc618fc0..3d3c32958e31 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -33,11 +33,13 @@ use risingwave_connector::parser::additional_columns::{ }; use risingwave_connector::parser::{ fetch_json_schema_and_map_to_columns, AvroParserConfig, DebeziumAvroParserConfig, - ProtobufParserConfig, SpecificParserConfig, TimestamptzHandling, DEBEZIUM_IGNORE_KEY, + ProtobufParserConfig, SchemaLocation, SpecificParserConfig, TimestamptzHandling, + DEBEZIUM_IGNORE_KEY, }; use risingwave_connector::schema::schema_registry::{ name_strategy_from_str, SchemaRegistryAuth, SCHEMA_REGISTRY_PASSWORD, SCHEMA_REGISTRY_USERNAME, }; +use risingwave_connector::schema::AWS_GLUE_SCHEMA_ARN_KEY; use risingwave_connector::sink::iceberg::IcebergConfig; use risingwave_connector::source::cdc::{ CDC_SHARING_MODE_KEY, CDC_SNAPSHOT_BACKFILL, CDC_SNAPSHOT_MODE_KEY, CDC_TRANSACTIONAL_KEY, @@ -158,7 +160,7 @@ async fn extract_avro_table_schema( } else { if let risingwave_connector::parser::EncodingProperties::Avro(avro_props) = &parser_config.encoding_config - && !avro_props.use_schema_registry + && matches!(avro_props.schema_location, SchemaLocation::File { .. }) && !format_encode_options .get("with_deprecated_file_header") .is_some_and(|v| v == "true") @@ -381,33 +383,43 @@ pub(crate) async fn bind_columns_from_source( ) } (format @ (Format::Plain | Format::Upsert | Format::Debezium), Encode::Avro) => { - let (row_schema_location, use_schema_registry) = - get_schema_location(&mut format_encode_options_to_consume)?; + if format_encode_options_to_consume + .remove(AWS_GLUE_SCHEMA_ARN_KEY) + .is_none() + { + // Legacy logic that assumes either `schema.location` or confluent `schema.registry`. + // The handling of newly added aws glue is centralized in `connector::parser`. + // TODO(xiangjinwu): move these option parsing to `connector::parser` as well. - if matches!(format, Format::Debezium) && !use_schema_registry { - return Err(RwError::from(ProtocolError( - "schema location for DEBEZIUM_AVRO row format is not supported".to_string(), - ))); - } + let (row_schema_location, use_schema_registry) = + get_schema_location(&mut format_encode_options_to_consume)?; - let message_name = try_consume_string_from_options( - &mut format_encode_options_to_consume, - MESSAGE_NAME_KEY, - ); - let name_strategy = get_sr_name_strategy_check( - &mut format_encode_options_to_consume, - use_schema_registry, - )?; + if matches!(format, Format::Debezium) && !use_schema_registry { + return Err(RwError::from(ProtocolError( + "schema location for DEBEZIUM_AVRO row format is not supported".to_string(), + ))); + } - stream_source_info.use_schema_registry = use_schema_registry; - stream_source_info - .row_schema_location - .clone_from(&row_schema_location.0); - stream_source_info.proto_message_name = message_name.unwrap_or(AstString("".into())).0; - stream_source_info.key_message_name = - get_key_message_name(&mut format_encode_options_to_consume); - stream_source_info.name_strategy = - name_strategy.unwrap_or(PbSchemaRegistryNameStrategy::Unspecified as i32); + let message_name = try_consume_string_from_options( + &mut format_encode_options_to_consume, + MESSAGE_NAME_KEY, + ); + let name_strategy = get_sr_name_strategy_check( + &mut format_encode_options_to_consume, + use_schema_registry, + )?; + + stream_source_info.use_schema_registry = use_schema_registry; + stream_source_info + .row_schema_location + .clone_from(&row_schema_location.0); + stream_source_info.proto_message_name = + message_name.unwrap_or(AstString("".into())).0; + stream_source_info.key_message_name = + get_key_message_name(&mut format_encode_options_to_consume); + stream_source_info.name_strategy = + name_strategy.unwrap_or(PbSchemaRegistryNameStrategy::Unspecified as i32); + } Some( extract_avro_table_schema( From 4c4ada19ce57645546ef97b75f4aa9837951c8c8 Mon Sep 17 00:00:00 2001 From: William Wen <44139337+wenym1@users.noreply.github.com> Date: Sat, 13 Jul 2024 19:19:08 +0800 Subject: [PATCH 05/70] feat(storage): pass epoch and table id before barrier (#17635) --- src/meta/src/barrier/mod.rs | 8 +- src/meta/src/barrier/recovery.rs | 4 +- src/meta/src/barrier/rpc.rs | 14 +- .../hummock_test/src/compactor_tests.rs | 40 +++-- .../hummock_test/src/hummock_storage_tests.rs | 76 ++++++++ .../hummock_test/src/snapshot_tests.rs | 5 + .../hummock_test/src/state_store_tests.rs | 19 +- .../event_handler/hummock_event_handler.rs | 17 ++ src/storage/src/hummock/event_handler/mod.rs | 9 + .../src/hummock/event_handler/uploader/mod.rs | 163 +++++++++++++----- .../src/hummock/store/hummock_storage.rs | 9 + .../common/log_store_impl/kv_log_store/mod.rs | 65 +++++++ .../src/common/table/test_state_table.rs | 94 ++++++++++ .../src/common/table/test_storage_table.rs | 35 ++++ src/stream/src/task/barrier_manager.rs | 3 - .../src/task/barrier_manager/managed_state.rs | 53 +++++- 16 files changed, 545 insertions(+), 69 deletions(-) diff --git a/src/meta/src/barrier/mod.rs b/src/meta/src/barrier/mod.rs index 60bbb4cee7fa..92fc4dc31a2e 100644 --- a/src/meta/src/barrier/mod.rs +++ b/src/meta/src/barrier/mod.rs @@ -767,10 +767,10 @@ impl GlobalBarrierManager { send_latency_timer.observe_duration(); - let node_to_collect = match self - .control_stream_manager - .inject_barrier(command_ctx.clone()) - { + let node_to_collect = match self.control_stream_manager.inject_barrier( + command_ctx.clone(), + self.state.inflight_actor_infos.existing_table_ids(), + ) { Ok(node_to_collect) => node_to_collect, Err(err) => { for notifier in notifiers { diff --git a/src/meta/src/barrier/recovery.rs b/src/meta/src/barrier/recovery.rs index 3bb51b3b5ef7..0ead9779e914 100644 --- a/src/meta/src/barrier/recovery.rs +++ b/src/meta/src/barrier/recovery.rs @@ -477,8 +477,8 @@ impl GlobalBarrierManager { tracing::Span::current(), // recovery span )); - let mut node_to_collect = - control_stream_manager.inject_barrier(command_ctx.clone())?; + let mut node_to_collect = control_stream_manager + .inject_barrier(command_ctx.clone(), info.existing_table_ids())?; while !node_to_collect.is_empty() { let (worker_id, prev_epoch, _) = control_stream_manager .next_complete_barrier_response() diff --git a/src/meta/src/barrier/rpc.rs b/src/meta/src/barrier/rpc.rs index 0a7e6d4e1e95..c1a337bde046 100644 --- a/src/meta/src/barrier/rpc.rs +++ b/src/meta/src/barrier/rpc.rs @@ -24,6 +24,7 @@ use futures::future::try_join_all; use futures::stream::{BoxStream, FuturesUnordered}; use futures::{pin_mut, FutureExt, StreamExt}; use itertools::Itertools; +use risingwave_common::catalog::TableId; use risingwave_common::hash::ActorId; use risingwave_common::util::tracing::TracingContext; use risingwave_pb::common::{ActorInfo, WorkerNode}; @@ -247,6 +248,7 @@ impl ControlStreamManager { pub(super) fn inject_barrier( &mut self, command_context: Arc, + table_ids_to_sync: HashSet, ) -> MetaResult> { fail_point!("inject_barrier_err", |_| risingwave_common::bail!( "inject_barrier_err" @@ -263,9 +265,13 @@ impl ControlStreamManager { if actor_ids_to_collect.is_empty() { // No need to send or collect barrier for this node. assert!(actor_ids_to_send.is_empty()); - Ok(()) - } else { + } + { let Some(node) = self.nodes.get_mut(node_id) else { + if actor_ids_to_collect.is_empty() { + // Worker node get disconnected but has no actor to collect. Simply skip it. + return Ok(()); + } return Err( anyhow!("unconnected worker node: {:?}", worker_node.host).into() ); @@ -294,9 +300,7 @@ impl ControlStreamManager { barrier: Some(barrier), actor_ids_to_send, actor_ids_to_collect, - table_ids_to_sync: command_context - .info - .existing_table_ids() + table_ids_to_sync: table_ids_to_sync .iter() .map(|table_id| table_id.table_id) .collect(), diff --git a/src/storage/hummock_test/src/compactor_tests.rs b/src/storage/hummock_test/src/compactor_tests.rs index 96f237704abf..9f862e3300dc 100644 --- a/src/storage/hummock_test/src/compactor_tests.rs +++ b/src/storage/hummock_test/src/compactor_tests.rs @@ -15,7 +15,7 @@ #[cfg(test)] pub(crate) mod tests { - use std::collections::{BTreeMap, BTreeSet, VecDeque}; + use std::collections::{BTreeMap, BTreeSet, HashSet, VecDeque}; use std::ops::Bound; use std::sync::Arc; @@ -156,6 +156,9 @@ pub(crate) mod tests { value_size: usize, epochs: Vec, ) { + for epoch in &epochs { + storage.start_epoch(*epoch, HashSet::from_iter([Default::default()])); + } let mut local = storage .new_local(NewLocalOptions::for_test(TableId::default())) .await; @@ -534,17 +537,16 @@ pub(crate) mod tests { existing_table_id: u32, keys_per_epoch: usize, ) { + let table_id = existing_table_id.into(); let kv_count: u16 = 128; let mut epoch = test_epoch(1); - let mut local = storage - .new_local(NewLocalOptions::for_test(existing_table_id.into())) - .await; + let mut local = storage.new_local(NewLocalOptions::for_test(table_id)).await; + + storage.start_epoch(epoch, HashSet::from_iter([table_id])); // 1. add sstables let val = Bytes::from(b"0"[..].repeat(1 << 10)); // 1024 Byte value for idx in 0..kv_count { - epoch.inc_epoch(); - if idx == 0 { local.init_for_test(epoch).await.unwrap(); } @@ -559,9 +561,11 @@ pub(crate) mod tests { } local.flush().await.unwrap(); let next_epoch = epoch.next_epoch(); + storage.start_epoch(next_epoch, HashSet::from_iter([table_id])); local.seal_current_epoch(next_epoch, SealCurrentEpochOptions::for_test()); flush_and_commit(&hummock_meta_client, storage, epoch).await; + epoch.inc_epoch(); } } @@ -727,9 +731,10 @@ pub(crate) mod tests { .await; let vnode = VirtualNode::from_index(1); + global_storage.start_epoch(epoch, HashSet::from_iter([1.into(), 2.into()])); for index in 0..kv_count { - epoch.inc_epoch(); let next_epoch = epoch.next_epoch(); + global_storage.start_epoch(next_epoch, HashSet::from_iter([1.into(), 2.into()])); if index == 0 { storage_1.init_for_test(epoch).await.unwrap(); storage_2.init_for_test(epoch).await.unwrap(); @@ -755,6 +760,7 @@ pub(crate) mod tests { let res = global_storage.seal_and_sync_epoch(epoch).await.unwrap(); hummock_meta_client.commit_epoch(epoch, res).await.unwrap(); + epoch.inc_epoch(); } // Mimic dropping table @@ -838,7 +844,6 @@ pub(crate) mod tests { .unwrap(); assert!(compact_task.is_none()); - epoch.inc_epoch(); // to update version for hummock_storage global_storage.wait_version(version).await; @@ -921,12 +926,14 @@ pub(crate) mod tests { let vnode = VirtualNode::from_index(1); let mut epoch_set = BTreeSet::new(); + storage.start_epoch(epoch, HashSet::from_iter([existing_table_id.into()])); + let mut local = storage .new_local(NewLocalOptions::for_test(existing_table_id.into())) .await; for i in 0..kv_count { - epoch += millisec_interval_epoch; let next_epoch = epoch + millisec_interval_epoch; + storage.start_epoch(next_epoch, HashSet::from_iter([existing_table_id.into()])); if i == 0 { local.init_for_test(epoch).await.unwrap(); } @@ -944,6 +951,7 @@ pub(crate) mod tests { let res = storage.seal_and_sync_epoch(epoch).await.unwrap(); hummock_meta_client.commit_epoch(epoch, res).await.unwrap(); + epoch += millisec_interval_epoch; } let manual_compcation_option = ManualCompactionOption { @@ -969,7 +977,10 @@ pub(crate) mod tests { retention_seconds: Some(retention_seconds_expire_second), }, )]); - compact_task.current_epoch_time = epoch; + compact_task.current_epoch_time = hummock_manager_ref + .get_current_version() + .await + .max_committed_epoch; // assert compact_task assert_eq!( @@ -1123,12 +1134,13 @@ pub(crate) mod tests { let mut local = storage .new_local(NewLocalOptions::for_test(existing_table_id.into())) .await; + storage.start_epoch(epoch, HashSet::from_iter([existing_table_id.into()])); for i in 0..kv_count { - epoch += millisec_interval_epoch; if i == 0 { local.init_for_test(epoch).await.unwrap(); } let next_epoch = epoch + millisec_interval_epoch; + storage.start_epoch(next_epoch, HashSet::from_iter([existing_table_id.into()])); epoch_set.insert(epoch); let ramdom_key = [key_prefix.as_ref(), &rand::thread_rng().gen::<[u8; 32]>()].concat(); @@ -1139,6 +1151,7 @@ pub(crate) mod tests { local.seal_current_epoch(next_epoch, SealCurrentEpochOptions::for_test()); let res = storage.seal_and_sync_epoch(epoch).await.unwrap(); hummock_meta_client.commit_epoch(epoch, res).await.unwrap(); + epoch += millisec_interval_epoch; } let manual_compcation_option = ManualCompactionOption { @@ -1166,7 +1179,10 @@ pub(crate) mod tests { let compaction_filter_flag = CompactionFilterFlag::STATE_CLEAN | CompactionFilterFlag::TTL; compact_task.compaction_filter_mask = compaction_filter_flag.bits(); - compact_task.current_epoch_time = epoch; + compact_task.current_epoch_time = hummock_manager_ref + .get_current_version() + .await + .max_committed_epoch; // 3. compact let (_tx, rx) = tokio::sync::oneshot::channel(); diff --git a/src/storage/hummock_test/src/hummock_storage_tests.rs b/src/storage/hummock_test/src/hummock_storage_tests.rs index 8d721e9e560c..b9e576b547d7 100644 --- a/src/storage/hummock_test/src/hummock_storage_tests.rs +++ b/src/storage/hummock_test/src/hummock_storage_tests.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; // Copyright 2024 RisingWave Labs // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -101,6 +102,9 @@ async fn test_storage_basic() { // epoch 0 is reserved by storage service let epoch1 = test_epoch(1); + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.init_for_test(epoch1).await.unwrap(); // Write the first batch. @@ -165,6 +169,9 @@ async fn test_storage_basic() { assert_eq!(value, None); let epoch2 = epoch1.next_epoch(); + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch2, SealCurrentEpochOptions::for_test()); hummock_storage .ingest_batch( @@ -197,6 +204,9 @@ async fn test_storage_basic() { // Write the third batch. let epoch3 = epoch2.next_epoch(); + test_env + .storage + .start_epoch(epoch3, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch3, SealCurrentEpochOptions::for_test()); hummock_storage .ingest_batch( @@ -457,6 +467,9 @@ async fn test_state_store_sync() { let base_epoch = read_version.read().committed().max_committed_epoch(); let epoch1 = test_epoch(base_epoch.next_epoch()); + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.init_for_test(epoch1).await.unwrap(); // ingest 16B batch @@ -511,6 +524,9 @@ async fn test_state_store_sync() { .unwrap(); let epoch2 = epoch1.next_epoch(); + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch2, SealCurrentEpochOptions::for_test()); // ingest more 8B then will trigger a sync behind the scene @@ -531,6 +547,9 @@ async fn test_state_store_sync() { .unwrap(); let epoch3 = epoch2.next_epoch(); + test_env + .storage + .start_epoch(epoch3, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch3, SealCurrentEpochOptions::for_test()); let res = test_env.storage.seal_and_sync_epoch(epoch1).await.unwrap(); @@ -809,6 +828,9 @@ async fn test_delete_get() { .max_committed_epoch(); let epoch1 = initial_epoch.next_epoch(); + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.init_for_test(epoch1).await.unwrap(); let batch1 = vec![ @@ -833,6 +855,9 @@ async fn test_delete_get() { .unwrap(); let epoch2 = epoch1.next_epoch(); + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch2, SealCurrentEpochOptions::for_test()); let res = test_env.storage.seal_and_sync_epoch(epoch1).await.unwrap(); test_env @@ -896,6 +921,9 @@ async fn test_multiple_epoch_sync() { .max_committed_epoch(); let epoch1 = initial_epoch.next_epoch(); + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.init_for_test(epoch1).await.unwrap(); let batch1 = vec![ ( @@ -919,6 +947,9 @@ async fn test_multiple_epoch_sync() { .unwrap(); let epoch2 = epoch1.next_epoch(); + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch2, SealCurrentEpochOptions::for_test()); let batch2 = vec![( gen_key_from_str(VirtualNode::ZERO, "bb"), @@ -936,6 +967,9 @@ async fn test_multiple_epoch_sync() { .unwrap(); let epoch3 = epoch2.next_epoch(); + test_env + .storage + .start_epoch(epoch3, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch3, SealCurrentEpochOptions::for_test()); let batch3 = vec![ ( @@ -1011,6 +1045,9 @@ async fn test_multiple_epoch_sync() { test_get().await; let epoch4 = epoch3.next_epoch(); + test_env + .storage + .start_epoch(epoch4, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch4, SealCurrentEpochOptions::for_test()); test_env.storage.seal_epoch(epoch1, false); let sync_result2 = test_env.storage.seal_and_sync_epoch(epoch2).await.unwrap(); @@ -1043,6 +1080,9 @@ async fn test_iter_with_min_epoch() { .await; let epoch1 = (31 * 1000) << 16; + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TEST_TABLE_ID])); let gen_key = |index: usize| -> TableKey { gen_key_from_str(VirtualNode::ZERO, format!("\0\0key_{}", index).as_str()) @@ -1069,6 +1109,9 @@ async fn test_iter_with_min_epoch() { .unwrap(); let epoch2 = (32 * 1000) << 16; + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch2, SealCurrentEpochOptions::for_test()); // epoch 2 write let batch_epoch2: Vec<(TableKey, StorageValue)> = (20..30) @@ -1087,6 +1130,9 @@ async fn test_iter_with_min_epoch() { .unwrap(); let epoch3 = (33 * 1000) << 16; + test_env + .storage + .start_epoch(epoch3, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch3, SealCurrentEpochOptions::for_test()); { @@ -1279,6 +1325,9 @@ async fn test_hummock_version_reader() { let hummock_version_reader = test_env.storage.version_reader(); let epoch1 = (31 * 1000) << 16; + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TEST_TABLE_ID])); let gen_key = |index: usize| -> TableKey { gen_key_from_str(VirtualNode::ZERO, format!("\0\0key_{}", index).as_str()) @@ -1292,12 +1341,18 @@ async fn test_hummock_version_reader() { .collect(); let epoch2 = (32 * 1000) << 16; + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TEST_TABLE_ID])); // epoch 2 write let batch_epoch2: Vec<(TableKey, StorageValue)> = (20..30) .map(|index| (gen_key(index), StorageValue::new_put(gen_val(index)))) .collect(); let epoch3 = (33 * 1000) << 16; + test_env + .storage + .start_epoch(epoch3, HashSet::from_iter([TEST_TABLE_ID])); // epoch 3 write let batch_epoch3: Vec<(TableKey, StorageValue)> = (40..50) .map(|index| (gen_key(index), StorageValue::new_put(gen_val(index)))) @@ -1340,6 +1395,9 @@ async fn test_hummock_version_reader() { .unwrap(); let epoch4 = (34 * 1000) << 16; + test_env + .storage + .start_epoch(epoch4, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch4, SealCurrentEpochOptions::for_test()); { @@ -1710,6 +1768,9 @@ async fn test_get_with_min_epoch() { .await; let epoch1 = (31 * 1000) << 16; + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.init_for_test(epoch1).await.unwrap(); let gen_key = |index: usize| -> TableKey { @@ -1735,6 +1796,9 @@ async fn test_get_with_min_epoch() { .unwrap(); let epoch2 = (32 * 1000) << 16; + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TEST_TABLE_ID])); hummock_storage.seal_current_epoch(epoch2, SealCurrentEpochOptions::for_test()); // epoch 2 write let batch_epoch2: Vec<(TableKey, StorageValue)> = (20..30) @@ -1969,6 +2033,9 @@ async fn test_table_watermark() { }); let epoch1 = (31 * 1000) << 16; + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TEST_TABLE_ID])); local1.init_for_test(epoch1).await.unwrap(); local1.update_vnode_bitmap(vnode_bitmap1.clone()); local2.init_for_test(epoch1).await.unwrap(); @@ -2057,6 +2124,9 @@ async fn test_table_watermark() { let watermark1 = 10; let epoch2 = (32 * 1000) << 16; + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TEST_TABLE_ID])); for (local, vnode_bitmap) in [ (&mut local1, vnode_bitmap1.clone()), (&mut local2, vnode_bitmap2.clone()), @@ -2159,6 +2229,9 @@ async fn test_table_watermark() { let batch2_epoch2 = gen_batch(vnode2, epoch2_indexes()); let epoch3 = (33 * 1000) << 16; + test_env + .storage + .start_epoch(epoch3, HashSet::from_iter([TEST_TABLE_ID])); for (local, batch) in [(&mut local1, batch1_epoch2), (&mut local2, batch2_epoch2)] { for (key, value) in batch { @@ -2372,6 +2445,9 @@ async fn test_table_watermark() { let (mut local1, mut local2) = test_after_epoch2(local1, local2).await; let epoch4 = (34 * 1000) << 16; + test_env + .storage + .start_epoch(epoch4, HashSet::from_iter([TEST_TABLE_ID])); for (local, vnode_bitmap) in [ (&mut local1, vnode_bitmap1.clone()), diff --git a/src/storage/hummock_test/src/snapshot_tests.rs b/src/storage/hummock_test/src/snapshot_tests.rs index 402952dd0968..bde3c046ed6c 100644 --- a/src/storage/hummock_test/src/snapshot_tests.rs +++ b/src/storage/hummock_test/src/snapshot_tests.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; // Copyright 2024 RisingWave Labs // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -113,6 +114,7 @@ async fn test_snapshot_inner( .await; let epoch1 = test_epoch(1); + hummock_storage.start_epoch(epoch1, HashSet::from_iter([Default::default()])); local.init_for_test(epoch1).await.unwrap(); local .ingest_batch( @@ -134,6 +136,7 @@ async fn test_snapshot_inner( .await .unwrap(); let epoch2 = epoch1.next_epoch(); + hummock_storage.start_epoch(epoch2, HashSet::from_iter([Default::default()])); local.seal_current_epoch(epoch2, SealCurrentEpochOptions::for_test()); if enable_sync { let res = hummock_storage.seal_and_sync_epoch(epoch1).await.unwrap(); @@ -174,6 +177,7 @@ async fn test_snapshot_inner( .await .unwrap(); let epoch3 = epoch2.next_epoch(); + hummock_storage.start_epoch(epoch3, HashSet::from_iter([Default::default()])); local.seal_current_epoch(epoch3, SealCurrentEpochOptions::for_test()); if enable_sync { let res = hummock_storage.seal_and_sync_epoch(epoch2).await.unwrap(); @@ -243,6 +247,7 @@ async fn test_snapshot_range_scan_inner( let mut local = hummock_storage .new_local(NewLocalOptions::for_test(Default::default())) .await; + hummock_storage.start_epoch(epoch, HashSet::from_iter([Default::default()])); local.init_for_test(epoch).await.unwrap(); local diff --git a/src/storage/hummock_test/src/state_store_tests.rs b/src/storage/hummock_test/src/state_store_tests.rs index 2ed1f4359aaa..5c13d73f07ec 100644 --- a/src/storage/hummock_test/src/state_store_tests.rs +++ b/src/storage/hummock_test/src/state_store_tests.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; use std::ops::Bound; use std::ops::Bound::{Excluded, Included, Unbounded}; use std::sync::Arc; @@ -25,7 +26,7 @@ use risingwave_common::bitmap::Bitmap; use risingwave_common::catalog::{TableId, TableOption}; use risingwave_common::hash::table_distribution::TableDistribution; use risingwave_common::hash::VirtualNode; -use risingwave_common::util::epoch::{test_epoch, EpochExt}; +use risingwave_common::util::epoch::{test_epoch, EpochExt, MAX_EPOCH}; use risingwave_hummock_sdk::key::{prefixed_range_with_vnode, TableKeyRange}; use risingwave_hummock_sdk::{ HummockReadEpoch, HummockSstableObjectId, LocalSstableInfo, SyncResult, @@ -133,6 +134,7 @@ async fn test_basic_v2() { // epoch 0 is reserved by storage service let epoch1 = test_epoch(1); + hummock_storage.start_epoch(epoch1, HashSet::from_iter([Default::default()])); local.init_for_test(epoch1).await.unwrap(); // try to write an empty batch, and hummock should write nothing @@ -162,6 +164,7 @@ async fn test_basic_v2() { .unwrap(); let epoch2 = epoch1.next_epoch(); + hummock_storage.start_epoch(epoch2, HashSet::from_iter([Default::default()])); local.seal_current_epoch(epoch2, SealCurrentEpochOptions::for_test()); // Get the value after flushing to remote. @@ -219,6 +222,7 @@ async fn test_basic_v2() { .unwrap(); let epoch3 = epoch2.next_epoch(); + hummock_storage.start_epoch(epoch3, HashSet::from_iter([Default::default()])); local.seal_current_epoch(epoch3, SealCurrentEpochOptions::for_test()); // Get the value after flushing to remote. @@ -432,6 +436,7 @@ async fn test_state_store_sync_v2() { let mut local = hummock_storage .new_local(NewLocalOptions::for_test(Default::default())) .await; + hummock_storage.start_epoch(epoch, HashSet::from_iter([Default::default()])); local.init_for_test(epoch).await.unwrap(); local .ingest_batch( @@ -481,6 +486,7 @@ async fn test_state_store_sync_v2() { // ); epoch.inc_epoch(); + hummock_storage.start_epoch(epoch, HashSet::from_iter([Default::default()])); local.seal_current_epoch(epoch, SealCurrentEpochOptions::for_test()); // ingest more 8B then will trigger a sync behind the scene @@ -1022,6 +1028,7 @@ async fn test_delete_get_v2() { let initial_epoch = hummock_storage.get_pinned_version().max_committed_epoch(); let epoch1 = initial_epoch.next_epoch(); + hummock_storage.start_epoch(epoch1, HashSet::from_iter([Default::default()])); let batch1 = vec![ ( gen_key_from_str(VirtualNode::ZERO, "aa"), @@ -1047,6 +1054,7 @@ async fn test_delete_get_v2() { .await .unwrap(); let epoch2 = epoch1.next_epoch(); + hummock_storage.start_epoch(epoch2, HashSet::from_iter([Default::default()])); local.seal_current_epoch(epoch2, SealCurrentEpochOptions::for_test()); let res = hummock_storage.seal_and_sync_epoch(epoch1).await.unwrap(); @@ -1107,6 +1115,7 @@ async fn test_multiple_epoch_sync_v2() { let mut local = hummock_storage .new_local(NewLocalOptions::for_test(TableId::default())) .await; + hummock_storage.start_epoch(epoch1, HashSet::from_iter([Default::default()])); local.init_for_test(epoch1).await.unwrap(); local .ingest_batch( @@ -1120,6 +1129,7 @@ async fn test_multiple_epoch_sync_v2() { .unwrap(); let epoch2 = epoch1.next_epoch(); + hummock_storage.start_epoch(epoch2, HashSet::from_iter([Default::default()])); local.seal_current_epoch(epoch2, SealCurrentEpochOptions::for_test()); let batch2 = vec![( gen_key_from_str(VirtualNode::ZERO, "bb"), @@ -1137,6 +1147,7 @@ async fn test_multiple_epoch_sync_v2() { .unwrap(); let epoch3 = epoch2.next_epoch(); + hummock_storage.start_epoch(epoch3, HashSet::from_iter([Default::default()])); let batch3 = vec![ ( gen_key_from_str(VirtualNode::ZERO, "aa"), @@ -1245,6 +1256,7 @@ async fn test_gc_watermark_and_clear_shared_buffer() { let initial_epoch = hummock_storage.get_pinned_version().max_committed_epoch(); let epoch1 = initial_epoch.next_epoch(); + hummock_storage.start_epoch(epoch1, HashSet::from_iter([Default::default()])); local_hummock_storage.init_for_test(epoch1).await.unwrap(); local_hummock_storage .insert( @@ -1270,6 +1282,7 @@ async fn test_gc_watermark_and_clear_shared_buffer() { ); let epoch2 = epoch1.next_epoch(); + hummock_storage.start_epoch(epoch2, HashSet::from_iter([Default::default()])); local_hummock_storage.seal_current_epoch(epoch2, SealCurrentEpochOptions::for_test()); local_hummock_storage .delete( @@ -1536,6 +1549,10 @@ async fn test_iter_log() { let key_count = 10000; let test_log_data = gen_test_data(epoch_count, key_count, 0.05, 0.2); + for (epoch, _) in &test_log_data { + hummock_storage.start_epoch(*epoch, HashSet::from_iter([table_id])); + } + hummock_storage.start_epoch(MAX_EPOCH, HashSet::from_iter([table_id])); let in_memory_state_store = MemoryStateStore::new(); let mut in_memory_local = in_memory_state_store diff --git a/src/storage/src/hummock/event_handler/hummock_event_handler.rs b/src/storage/src/hummock/event_handler/hummock_event_handler.rs index b126974c7c08..f4038f0d7d52 100644 --- a/src/storage/src/hummock/event_handler/hummock_event_handler.rs +++ b/src/storage/src/hummock/event_handler/hummock_event_handler.rs @@ -737,6 +737,9 @@ impl HummockEventHandler { HummockEvent::Shutdown => { unreachable!("shutdown is handled specially") } + HummockEvent::StartEpoch { epoch, table_ids } => { + self.uploader.start_epoch(epoch, table_ids); + } HummockEvent::InitEpoch { instance_id, init_epoch, @@ -1146,6 +1149,11 @@ mod tests { rx.await.unwrap() }; + send_event(HummockEvent::StartEpoch { + epoch: epoch1, + table_ids: HashSet::from_iter([TEST_TABLE_ID]), + }); + send_event(HummockEvent::InitEpoch { instance_id: guard.instance_id, init_epoch: epoch1, @@ -1161,6 +1169,11 @@ mod tests { imm: imm1, }); + send_event(HummockEvent::StartEpoch { + epoch: epoch2, + table_ids: HashSet::from_iter([TEST_TABLE_ID]), + }); + send_event(HummockEvent::LocalSealEpoch { instance_id: guard.instance_id, next_epoch: epoch2, @@ -1178,6 +1191,10 @@ mod tests { }); let epoch3 = epoch2.next_epoch(); + send_event(HummockEvent::StartEpoch { + epoch: epoch3, + table_ids: HashSet::from_iter([TEST_TABLE_ID]), + }); send_event(HummockEvent::LocalSealEpoch { instance_id: guard.instance_id, next_epoch: epoch3, diff --git a/src/storage/src/hummock/event_handler/mod.rs b/src/storage/src/hummock/event_handler/mod.rs index 996d5d6a6df7..74a69bfa7194 100644 --- a/src/storage/src/hummock/event_handler/mod.rs +++ b/src/storage/src/hummock/event_handler/mod.rs @@ -74,6 +74,11 @@ pub enum HummockEvent { imm: ImmutableMemtable, }, + StartEpoch { + epoch: HummockEpoch, + table_ids: HashSet, + }, + InitEpoch { instance_id: LocalInstanceId, init_epoch: HummockEpoch, @@ -117,6 +122,10 @@ impl HummockEvent { HummockEvent::Shutdown => "Shutdown".to_string(), + HummockEvent::StartEpoch { epoch, table_ids } => { + format!("StartEpoch {} {:?}", epoch, table_ids) + } + HummockEvent::InitEpoch { instance_id, init_epoch, diff --git a/src/storage/src/hummock/event_handler/uploader/mod.rs b/src/storage/src/hummock/event_handler/uploader/mod.rs index 06f0c3aff77a..f0a18aa9eca6 100644 --- a/src/storage/src/hummock/event_handler/uploader/mod.rs +++ b/src/storage/src/hummock/event_handler/uploader/mod.rs @@ -615,6 +615,11 @@ struct TableUnsyncData { BTreeMap, BitmapBuilder)>, )>, spill_tasks: BTreeMap>, + unsync_epochs: BTreeMap, + // Initialized to be `None`. Transform to `Some(_)` when called + // `local_seal_epoch` with a non-existing epoch, to mark that + // the fragment of the table has stopped. + stopped_next_epoch: Option, // newer epoch at the front syncing_epochs: VecDeque, max_synced_epoch: Option, @@ -627,11 +632,21 @@ impl TableUnsyncData { instance_data: Default::default(), table_watermarks: None, spill_tasks: Default::default(), + unsync_epochs: Default::default(), + stopped_next_epoch: None, syncing_epochs: Default::default(), max_synced_epoch: committed_epoch, } } + fn new_epoch(&mut self, epoch: HummockEpoch) { + debug!(table_id = ?self.table_id, epoch, "table new epoch"); + if let Some(latest_epoch) = self.max_epoch() { + assert_gt!(epoch, latest_epoch); + } + self.unsync_epochs.insert(epoch, ()); + } + fn sync( &mut self, epoch: HummockEpoch, @@ -646,6 +661,13 @@ impl TableUnsyncData { if let Some(prev_epoch) = self.max_sync_epoch() { assert_gt!(epoch, prev_epoch) } + let epochs = take_before_epoch(&mut self.unsync_epochs, epoch); + assert_eq!( + *epochs.last_key_value().expect("non-empty").0, + epoch, + "{epochs:?} {epoch} {:?}", + self.table_id + ); self.syncing_epochs.push_front(epoch); ( self.instance_data @@ -711,8 +733,17 @@ impl TableUnsyncData { .or(self.max_synced_epoch) } + fn max_epoch(&self) -> Option { + self.unsync_epochs + .first_key_value() + .map(|(epoch, _)| *epoch) + .or_else(|| self.max_sync_epoch()) + } + fn is_empty(&self) -> bool { - self.instance_data.is_empty() && self.syncing_epochs.is_empty() + self.instance_data.is_empty() + && self.syncing_epochs.is_empty() + && self.unsync_epochs.is_empty() } } @@ -727,7 +758,7 @@ struct UnsyncData { instance_table_id: HashMap, // TODO: this is only used in spill to get existing epochs and can be removed // when we support spill not based on epoch - epochs: BTreeMap, + epochs: BTreeMap>, } impl UnsyncData { @@ -736,27 +767,15 @@ impl UnsyncData { table_id: TableId, instance_id: LocalInstanceId, init_epoch: HummockEpoch, - context: &UploaderContext, ) { debug!( table_id = table_id.table_id, instance_id, init_epoch, "init epoch" ); - let table_data = self.table_data.entry(table_id).or_insert_with(|| { - TableUnsyncData::new( - table_id, - context - .pinned_version - .version() - .state_table_info - .info() - .get(&table_id) - .map(|info| info.committed_epoch), - ) - }); - if let Some(max_prev_epoch) = table_data.max_sync_epoch() { - assert_gt!(init_epoch, max_prev_epoch); - } + let table_data = self + .table_data + .get_mut(&table_id) + .unwrap_or_else(|| panic!("should exist. {table_id:?}")); assert!(table_data .instance_data .insert( @@ -768,7 +787,7 @@ impl UnsyncData { .instance_table_id .insert(instance_id, table_id) .is_none()); - self.epochs.insert(init_epoch, ()); + assert!(table_data.unsync_epochs.contains_key(&init_epoch)); } fn instance_data( @@ -807,7 +826,20 @@ impl UnsyncData { .get_mut(&instance_id) .expect("should exist"); let epoch = instance_data.local_seal_epoch(next_epoch); - self.epochs.insert(next_epoch, ()); + // When drop/cancel a streaming job, for the barrier to stop actor, the + // local instance will call `local_seal_epoch`, but the `next_epoch` won't be + // called `start_epoch` because we have stopped writing on it. + if !table_data.unsync_epochs.contains_key(&next_epoch) { + if let Some(stopped_next_epoch) = table_data.stopped_next_epoch { + assert_eq!(stopped_next_epoch, next_epoch); + } else { + if let Some(max_epoch) = table_data.max_epoch() { + assert_gt!(next_epoch, max_epoch); + } + debug!(?table_id, epoch, next_epoch, "table data has stopped"); + table_data.stopped_next_epoch = Some(next_epoch); + } + } if let Some((direction, table_watermarks)) = opts.table_watermarks { table_data.add_table_watermarks(epoch, table_watermarks, direction); } @@ -838,20 +870,29 @@ impl UploaderData { sync_result_sender: oneshot::Sender>, ) { // clean old epochs - let _epochs = take_before_epoch(&mut self.unsync_data.epochs, epoch); + let epochs = take_before_epoch(&mut self.unsync_data.epochs, epoch); + if cfg!(debug_assertions) { + for epoch_table_ids in epochs.into_values() { + assert_eq!(epoch_table_ids, table_ids); + } + } let mut all_table_watermarks = HashMap::new(); let mut uploading_tasks = HashSet::new(); let mut spilled_tasks = BTreeSet::new(); let mut flush_payload = HashMap::new(); - let mut table_ids_to_ack = HashSet::new(); - for (table_id, table_data) in &mut self.unsync_data.table_data { + for (table_id, table_data) in &self.unsync_data.table_data { if !table_ids.contains(table_id) { table_data.assert_after_epoch(epoch); - continue; } - table_ids_to_ack.insert(*table_id); + } + for table_id in &table_ids { + let table_data = self + .unsync_data + .table_data + .get_mut(table_id) + .expect("should exist"); let (unflushed_payload, table_watermarks, task_ids) = table_data.sync(epoch); for (instance_id, payload) in unflushed_payload { if !payload.is_empty() { @@ -898,10 +939,7 @@ impl UploaderData { .map(|task_id| { let (sst, spill_table_ids) = self.spilled_data.remove(task_id).expect("should exist"); - assert!( - spill_table_ids.is_subset(&table_ids), - "spill_table_ids: {spill_table_ids:?}, table_ids: {table_ids:?}" - ); + assert_eq!(spill_table_ids, table_ids); sst }) .collect(); @@ -911,7 +949,6 @@ impl UploaderData { SyncingData { sync_epoch: epoch, table_ids, - table_ids_to_ack, remaining_uploading_tasks: uploading_tasks, uploaded, table_watermarks: all_table_watermarks, @@ -937,8 +974,6 @@ impl UnsyncData { struct SyncingData { sync_epoch: HummockEpoch, table_ids: HashSet, - /// Subset of `table_ids` that has existing instance - table_ids_to_ack: HashSet, remaining_uploading_tasks: HashSet, // newer data at the front uploaded: VecDeque>, @@ -1079,7 +1114,7 @@ impl HummockUploader { return; }; data.unsync_data - .init_instance(table_id, instance_id, init_epoch, &self.context); + .init_instance(table_id, instance_id, init_epoch); } pub(super) fn local_seal_epoch( @@ -1095,6 +1130,32 @@ impl HummockUploader { .local_seal_epoch(instance_id, next_epoch, opts); } + pub(super) fn start_epoch(&mut self, epoch: HummockEpoch, table_ids: HashSet) { + let UploaderState::Working(data) = &mut self.state else { + return; + }; + for table_id in &table_ids { + let table_data = data + .unsync_data + .table_data + .entry(*table_id) + .or_insert_with(|| { + TableUnsyncData::new( + *table_id, + self.context + .pinned_version + .version() + .state_table_info + .info() + .get(table_id) + .map(|info| info.committed_epoch), + ) + }); + table_data.new_epoch(epoch); + } + data.unsync_data.epochs.insert(epoch, table_ids); + } + pub(super) fn start_sync_epoch( &mut self, epoch: HummockEpoch, @@ -1150,7 +1211,7 @@ impl HummockUploader { if self.context.buffer_tracker.need_flush() { let mut curr_batch_flush_size = 0; // iterate from older epoch to newer epoch - for epoch in &mut data.unsync_data.epochs.keys() { + for (epoch, table_ids) in &data.unsync_data.epochs { if !self .context .buffer_tracker @@ -1160,7 +1221,12 @@ impl HummockUploader { } let mut spilled_table_ids = HashSet::new(); let mut payload = HashMap::new(); - for (table_id, table_data) in &mut data.unsync_data.table_data { + for table_id in table_ids { + let table_data = data + .unsync_data + .table_data + .get_mut(table_id) + .expect("should exist"); for (instance_id, instance_data) in &mut table_data.instance_data { let instance_payload = instance_data.spill(*epoch); if !instance_payload.is_empty() { @@ -1240,8 +1306,7 @@ impl UploaderData { let (_, syncing_data) = self.syncing_data.pop_first().expect("non-empty"); let SyncingData { sync_epoch, - table_ids: _table_ids, - table_ids_to_ack, + table_ids, remaining_uploading_tasks: _, uploaded, table_watermarks, @@ -1252,7 +1317,7 @@ impl UploaderData { .uploader_syncing_epoch_count .set(self.syncing_data.len() as _); - for table_id in table_ids_to_ack { + for table_id in table_ids { if let Some(table_data) = self.unsync_data.table_data.get_mut(&table_id) { table_data.ack_synced(sync_epoch); if table_data.is_empty() { @@ -1632,6 +1697,18 @@ pub(crate) mod tests { SealCurrentEpochOptions::for_test(), ); } + + fn start_epochs_for_test(&mut self, epochs: impl IntoIterator) { + let mut last_epoch = None; + for epoch in epochs { + last_epoch = Some(epoch); + self.start_epoch(epoch, HashSet::from_iter([TEST_TABLE_ID])); + } + self.start_epoch( + last_epoch.unwrap().next_epoch(), + HashSet::from_iter([TEST_TABLE_ID]), + ); + } } #[tokio::test] @@ -1709,6 +1786,7 @@ pub(crate) mod tests { async fn test_uploader_basic() { let mut uploader = test_uploader(dummy_success_upload_future); let epoch1 = INITIAL_EPOCH.next_epoch(); + uploader.start_epochs_for_test([epoch1]); let imm = gen_imm(epoch1).await; uploader.init_instance(TEST_LOCAL_INSTANCE_ID, TEST_TABLE_ID, epoch1); uploader.add_imm(TEST_LOCAL_INSTANCE_ID, imm.clone()); @@ -1771,6 +1849,7 @@ pub(crate) mod tests { let epoch1 = INITIAL_EPOCH.next_epoch(); let (sync_tx, sync_rx) = oneshot::channel(); + uploader.start_epochs_for_test([epoch1]); uploader.init_instance(TEST_LOCAL_INSTANCE_ID, TEST_TABLE_ID, epoch1); uploader.local_seal_epoch_for_test(TEST_LOCAL_INSTANCE_ID, epoch1); uploader.start_sync_epoch(epoch1, sync_tx, HashSet::from_iter([TEST_TABLE_ID])); @@ -1799,6 +1878,7 @@ pub(crate) mod tests { let mut uploader = test_uploader(dummy_success_upload_future); let epoch1 = INITIAL_EPOCH.next_epoch(); let epoch2 = epoch1.next_epoch(); + uploader.start_epochs_for_test([epoch1, epoch2]); let imm = gen_imm(epoch2).await; // epoch1 is empty while epoch2 is not. Going to seal empty epoch1. uploader.init_instance(TEST_LOCAL_INSTANCE_ID, TEST_TABLE_ID, epoch1); @@ -1851,6 +1931,7 @@ pub(crate) mod tests { let version4 = initial_pinned_version.new_pin_version(test_hummock_version(epoch4)); let version5 = initial_pinned_version.new_pin_version(test_hummock_version(epoch5)); + uploader.start_epochs_for_test([epoch6]); uploader.init_instance(TEST_LOCAL_INSTANCE_ID, TEST_TABLE_ID, epoch6); uploader.update_pinned_version(version1); @@ -1980,6 +2061,9 @@ pub(crate) mod tests { let epoch1 = INITIAL_EPOCH.next_epoch(); let epoch2 = epoch1.next_epoch(); + let epoch3 = epoch2.next_epoch(); + let epoch4 = epoch3.next_epoch(); + uploader.start_epochs_for_test([epoch1, epoch2, epoch3, epoch4]); let memory_limiter = buffer_tracker.get_memory_limiter().clone(); let memory_limiter = Some(memory_limiter.deref()); @@ -2039,7 +2123,6 @@ pub(crate) mod tests { let (sync_tx1, mut sync_rx1) = oneshot::channel(); uploader.start_sync_epoch(epoch1, sync_tx1, HashSet::from_iter([TEST_TABLE_ID])); await_start1_4.await; - let epoch3 = epoch2.next_epoch(); uploader.local_seal_epoch_for_test(instance_id1, epoch2); uploader.local_seal_epoch_for_test(instance_id2, epoch2); @@ -2071,7 +2154,6 @@ pub(crate) mod tests { // sealed: uploaded sst([imm2]) // syncing: epoch1: uploading: [imm1_4], [imm1_3], uploaded: sst([imm1_2, imm1_1]) - let epoch4 = epoch3.next_epoch(); uploader.local_seal_epoch_for_test(instance_id1, epoch3); let imm4 = gen_imm_with_limiter(epoch4, memory_limiter).await; uploader.add_imm(instance_id1, imm4.clone()); @@ -2216,6 +2298,7 @@ pub(crate) mod tests { let epoch1 = INITIAL_EPOCH.next_epoch(); let epoch2 = epoch1.next_epoch(); + uploader.start_epochs_for_test([epoch1, epoch2]); let instance_id1 = 1; let instance_id2 = 2; let flush_threshold = buffer_tracker.flush_threshold(); diff --git a/src/storage/src/hummock/store/hummock_storage.rs b/src/storage/src/hummock/store/hummock_storage.rs index fcc80a1e54e1..511d9dd33814 100644 --- a/src/storage/src/hummock/store/hummock_storage.rs +++ b/src/storage/src/hummock/store/hummock_storage.rs @@ -499,6 +499,15 @@ impl HummockStorage { ) } + /// Declare the start of an epoch. This information is provided for spill so that the spill task won't + /// include data of two or more syncs. + // TODO: remove this method when we support spill task that can include data of more two or more syncs + pub fn start_epoch(&self, epoch: HummockEpoch, table_ids: HashSet) { + let _ = self + .hummock_event_sender + .send(HummockEvent::StartEpoch { epoch, table_ids }); + } + pub fn sstable_store(&self) -> SstableStoreRef { self.context.sstable_store.clone() } diff --git a/src/stream/src/common/log_store_impl/kv_log_store/mod.rs b/src/stream/src/common/log_store_impl/kv_log_store/mod.rs index 9f9d2e3abe37..05d8ae73ba99 100644 --- a/src/stream/src/common/log_store_impl/kv_log_store/mod.rs +++ b/src/stream/src/common/log_store_impl/kv_log_store/mod.rs @@ -424,6 +424,7 @@ impl LogStoreFactory for KvLogStoreFactory { #[cfg(test)] mod tests { + use std::collections::HashSet; use std::future::{poll_fn, Future}; use std::iter::empty; use std::pin::pin; @@ -433,6 +434,7 @@ mod tests { use itertools::Itertools; use risingwave_common::array::StreamChunk; use risingwave_common::bitmap::{Bitmap, BitmapBuilder}; + use risingwave_common::catalog::TableId; use risingwave_common::hash::VirtualNode; use risingwave_common::util::epoch::{EpochExt, EpochPair}; use risingwave_connector::sink::log_store::{ @@ -495,12 +497,18 @@ mod tests { .version() .max_committed_epoch .next_epoch(); + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new_test_epoch(epoch1), false) .await .unwrap(); writer.write_chunk(stream_chunk1.clone()).await.unwrap(); let epoch2 = epoch1.next_epoch(); + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TableId::new(table.id)])); writer.flush_current_epoch(epoch2, false).await.unwrap(); writer.write_chunk(stream_chunk2.clone()).await.unwrap(); let epoch3 = epoch2.next_epoch(); @@ -596,12 +604,18 @@ mod tests { .version() .max_committed_epoch .next_epoch(); + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new_test_epoch(epoch1), false) .await .unwrap(); writer.write_chunk(stream_chunk1.clone()).await.unwrap(); let epoch2 = epoch1.next_epoch(); + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TableId::new(table.id)])); writer.flush_current_epoch(epoch2, false).await.unwrap(); writer.write_chunk(stream_chunk2.clone()).await.unwrap(); let epoch3 = epoch2.next_epoch(); @@ -678,6 +692,9 @@ mod tests { pk_info, ); let (mut reader, mut writer) = factory.build().await; + test_env + .storage + .start_epoch(epoch3, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new_test_epoch(epoch3), false) .await @@ -776,6 +793,9 @@ mod tests { .version() .max_committed_epoch .next_epoch(); + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new_test_epoch(epoch1), false) .await @@ -783,6 +803,9 @@ mod tests { writer.write_chunk(stream_chunk1_1.clone()).await.unwrap(); writer.write_chunk(stream_chunk1_2.clone()).await.unwrap(); let epoch2 = epoch1.next_epoch(); + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TableId::new(table.id)])); writer.flush_current_epoch(epoch2, true).await.unwrap(); writer.write_chunk(stream_chunk2.clone()).await.unwrap(); @@ -883,6 +906,9 @@ mod tests { ); let (mut reader, mut writer) = factory.build().await; + test_env + .storage + .start_epoch(epoch3, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new_test_epoch(epoch3), false) .await @@ -993,6 +1019,9 @@ mod tests { .version() .max_committed_epoch .next_epoch(); + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TableId::new(table.id)])); writer1 .init(EpochPair::new_test_epoch(epoch1), false) .await @@ -1007,6 +1036,9 @@ mod tests { writer1.write_chunk(chunk1_1.clone()).await.unwrap(); writer2.write_chunk(chunk1_2.clone()).await.unwrap(); let epoch2 = epoch1.next_epoch(); + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TableId::new(table.id)])); writer1.flush_current_epoch(epoch2, false).await.unwrap(); writer2.flush_current_epoch(epoch2, false).await.unwrap(); let [chunk2_1, chunk2_2] = gen_multi_vnode_stream_chunks::<2>(200, 100, pk_info); @@ -1108,6 +1140,9 @@ mod tests { pk_info, ); let (mut reader, mut writer) = factory.build().await; + test_env + .storage + .start_epoch(epoch3, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new(epoch3, epoch2), false) .await @@ -1177,6 +1212,9 @@ mod tests { .version() .max_committed_epoch .next_epoch(); + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new_test_epoch(epoch1), false) .await @@ -1314,15 +1352,24 @@ mod tests { .version() .max_committed_epoch .next_epoch(); + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new_test_epoch(epoch1), false) .await .unwrap(); writer.write_chunk(stream_chunk1.clone()).await.unwrap(); let epoch2 = epoch1.next_epoch(); + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TableId::new(table.id)])); writer.flush_current_epoch(epoch2, true).await.unwrap(); writer.write_chunk(stream_chunk2.clone()).await.unwrap(); let epoch3 = epoch2.next_epoch(); + test_env + .storage + .start_epoch(epoch3, HashSet::from_iter([TableId::new(table.id)])); writer.flush_current_epoch(epoch3, true).await.unwrap(); writer.write_chunk(stream_chunk3.clone()).await.unwrap(); writer.flush_current_epoch(u64::MAX, true).await.unwrap(); @@ -1411,6 +1458,9 @@ mod tests { let (mut reader, mut writer) = factory.build().await; let epoch4 = epoch3.next_epoch(); + test_env + .storage + .start_epoch(epoch4, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new(epoch4, epoch3), false) .await @@ -1471,6 +1521,9 @@ mod tests { .unwrap(); writer.write_chunk(stream_chunk4.clone()).await.unwrap(); let epoch5 = epoch4 + 1; + test_env + .storage + .start_epoch(epoch5, HashSet::from_iter([TableId::new(table.id)])); writer.flush_current_epoch(epoch5, true).await.unwrap(); writer.write_chunk(stream_chunk5.clone()).await.unwrap(); @@ -1629,12 +1682,18 @@ mod tests { .version() .max_committed_epoch .next_epoch(); + test_env + .storage + .start_epoch(epoch1, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new_test_epoch(epoch1), false) .await .unwrap(); writer.write_chunk(stream_chunk1.clone()).await.unwrap(); let epoch2 = epoch1.next_epoch(); + test_env + .storage + .start_epoch(epoch2, HashSet::from_iter([TableId::new(table.id)])); writer.flush_current_epoch(epoch2, false).await.unwrap(); writer.write_chunk(stream_chunk2.clone()).await.unwrap(); let epoch3 = epoch2.next_epoch(); @@ -1693,6 +1752,9 @@ mod tests { pk_info, ); let (mut reader, mut writer) = factory.build().await; + test_env + .storage + .start_epoch(epoch3, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new_test_epoch(epoch3), false) .await @@ -1754,6 +1816,9 @@ mod tests { pk_info, ); let (mut reader, mut writer) = factory.build().await; + test_env + .storage + .start_epoch(epoch4, HashSet::from_iter([TableId::new(table.id)])); writer .init(EpochPair::new_test_epoch(epoch4), false) .await diff --git a/src/stream/src/common/table/test_state_table.rs b/src/stream/src/common/table/test_state_table.rs index 89944cdfc487..72ffa72479cf 100644 --- a/src/stream/src/common/table/test_state_table.rs +++ b/src/stream/src/common/table/test_state_table.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; use std::ops::Bound::{self, *}; use futures::{pin_mut, StreamExt}; @@ -61,6 +62,9 @@ async fn test_state_table_update_insert() { .await; let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); state_table.insert(OwnedRow::new(vec![ @@ -78,6 +82,9 @@ async fn test_state_table_update_insert() { ])); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); state_table.delete(OwnedRow::new(vec![ @@ -134,6 +141,9 @@ async fn test_state_table_update_insert() { ); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); let row6_commit = state_table @@ -171,6 +181,9 @@ async fn test_state_table_update_insert() { ])); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); // one epoch: delete (1, 2, 3, 4), insert (5, 6, 7, None), delete(5, 6, 7, None) @@ -200,6 +213,9 @@ async fn test_state_table_update_insert() { assert_eq!(row1, None); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); let row1_commit = state_table @@ -239,6 +255,9 @@ async fn test_state_table_iter_with_prefix() { .await; let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); state_table.insert(OwnedRow::new(vec![ @@ -265,6 +284,9 @@ async fn test_state_table_iter_with_prefix() { ])); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); state_table.insert(OwnedRow::new(vec![ @@ -368,6 +390,9 @@ async fn test_state_table_iter_with_pk_range() { .await; let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); state_table.insert(OwnedRow::new(vec![ @@ -394,6 +419,9 @@ async fn test_state_table_iter_with_pk_range() { ])); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); state_table.insert(OwnedRow::new(vec![ @@ -501,6 +529,9 @@ async fn test_mem_table_assertion() { StateTable::from_table_catalog(&table, test_env.storage.clone(), None).await; let epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); state_table.insert(OwnedRow::new(vec![ Some(1_i32.into()), @@ -544,6 +575,9 @@ async fn test_state_table_iter_with_value_indices() { .await; let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); state_table.insert(OwnedRow::new(vec![ @@ -599,6 +633,9 @@ async fn test_state_table_iter_with_value_indices() { } epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); // write [3, 33, 333], [4, 44, 444], [5, 55, 555], [7, 77, 777], [8, 88, 888]into mem_table, @@ -711,6 +748,9 @@ async fn test_state_table_iter_with_shuffle_value_indices() { .await; let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); state_table.insert(OwnedRow::new(vec![ @@ -787,6 +827,9 @@ async fn test_state_table_iter_with_shuffle_value_indices() { } epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); // write [3, 33, 333], [4, 44, 444], [5, 55, 555], [7, 77, 777], [8, 88, 888]into mem_table, @@ -952,6 +995,9 @@ async fn test_state_table_write_chunk() { .await; let epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); let chunk = StreamChunk::from_rows( @@ -1081,6 +1127,9 @@ async fn test_state_table_write_chunk_visibility() { .await; let epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); let chunk = StreamChunk::from_rows( @@ -1205,6 +1254,9 @@ async fn test_state_table_write_chunk_value_indices() { .await; let epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); let chunk = StreamChunk::from_rows( @@ -1299,6 +1351,9 @@ async fn test_state_table_watermark_cache_ignore_null() { WatermarkCacheStateTable::from_table_catalog(&table, test_env.storage.clone(), None).await; let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); let rows = vec![ @@ -1347,6 +1402,9 @@ async fn test_state_table_watermark_cache_ignore_null() { state_table.update_watermark(watermark, true); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); let cache = state_table.get_watermark_cache(); @@ -1419,6 +1477,9 @@ async fn test_state_table_watermark_cache_write_chunk() { WatermarkCacheStateTable::from_table_catalog(&table, test_env.storage.clone(), None).await; let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); let cache = state_table.get_watermark_cache(); @@ -1428,6 +1489,9 @@ async fn test_state_table_watermark_cache_write_chunk() { state_table.update_watermark(watermark, true); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); let inserts_1 = vec![ @@ -1537,6 +1601,9 @@ async fn test_state_table_watermark_cache_write_chunk() { state_table.update_watermark(watermark, true); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); // After sync, we should scan all rows into watermark cache. @@ -1585,6 +1652,9 @@ async fn test_state_table_watermark_cache_refill() { WatermarkCacheStateTable::from_table_catalog(&table, test_env.storage.clone(), None).await; let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); let rows = vec![ @@ -1634,6 +1704,9 @@ async fn test_state_table_watermark_cache_refill() { state_table.update_watermark(watermark, true); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); // After the first barrier, watermark cache won't be filled. @@ -1675,6 +1748,9 @@ async fn test_state_table_iter_prefix_and_sub_range() { StateTable::from_table_catalog_inconsistent_op(&table, test_env.storage.clone(), None) .await; let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); state_table.insert(OwnedRow::new(vec![ @@ -1700,6 +1776,9 @@ async fn test_state_table_iter_prefix_and_sub_range() { ])); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); let pk_prefix = OwnedRow::new(vec![Some(1_i32.into())]); @@ -1870,6 +1949,9 @@ async fn test_replicated_state_table_replication() { .await; let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.init_epoch(epoch); replicated_state_table.init_epoch(epoch).await.unwrap(); @@ -1881,6 +1963,12 @@ async fn test_replicated_state_table_replication() { ])); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); replicated_state_table.commit(epoch).await.unwrap(); test_env.commit_epoch(epoch.prev).await; @@ -1941,6 +2029,12 @@ async fn test_replicated_state_table_replication() { replicated_state_table.write_chunk(replicate_chunk); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state_table.commit(epoch).await.unwrap(); replicated_state_table.commit(epoch).await.unwrap(); diff --git a/src/stream/src/common/table/test_storage_table.rs b/src/stream/src/common/table/test_storage_table.rs index 098632192d02..1f130330e3be 100644 --- a/src/stream/src/common/table/test_storage_table.rs +++ b/src/stream/src/common/table/test_storage_table.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; + use futures::{pin_mut, StreamExt}; use itertools::Itertools; use risingwave_common::catalog::{ColumnDesc, ColumnId, TableId}; @@ -77,6 +79,9 @@ async fn test_storage_table_value_indices() { value_indices.into_iter().map(|v| v as usize).collect_vec(), ); let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state.init_epoch(epoch); state.insert(OwnedRow::new(vec![ @@ -110,6 +115,9 @@ async fn test_storage_table_value_indices() { ])); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state.commit(epoch).await.unwrap(); test_env.commit_epoch(epoch.prev).await; @@ -197,6 +205,9 @@ async fn test_shuffled_column_id_for_storage_table_get_row() { .await; let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state.init_epoch(epoch); let table = StorageTable::for_test( @@ -223,6 +234,9 @@ async fn test_shuffled_column_id_for_storage_table_get_row() { ])); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state.commit(epoch).await.unwrap(); test_env.commit_epoch(epoch.prev).await; @@ -310,6 +324,9 @@ async fn test_row_based_storage_table_point_get_in_batch_mode() { value_indices, ); let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state.init_epoch(epoch); state.insert(OwnedRow::new(vec![Some(1_i32.into()), None, None])); @@ -326,6 +343,9 @@ async fn test_row_based_storage_table_point_get_in_batch_mode() { Some(222_i32.into()), ])); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state.commit(epoch).await.unwrap(); test_env.commit_epoch(epoch.prev).await; @@ -415,6 +435,9 @@ async fn test_batch_scan_with_value_indices() { value_indices, ); let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state.init_epoch(epoch); state.insert(OwnedRow::new(vec![ @@ -437,6 +460,9 @@ async fn test_batch_scan_with_value_indices() { ])); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state.commit(epoch).await.unwrap(); test_env.commit_epoch(epoch.prev).await; @@ -513,6 +539,9 @@ async fn test_batch_scan_chunk_with_value_indices() { value_indices.clone(), ); let mut epoch = EpochPair::new_test_epoch(test_epoch(1)); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state.init_epoch(epoch); let gen_row = |i: i32, is_update: bool| { @@ -554,6 +583,12 @@ async fn test_batch_scan_chunk_with_value_indices() { .collect_vec(); epoch.inc_for_test(); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); + test_env + .storage + .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); state.commit(epoch).await.unwrap(); test_env.commit_epoch(epoch.prev).await; diff --git a/src/stream/src/task/barrier_manager.rs b/src/stream/src/task/barrier_manager.rs index bc33f434bf22..a1108d9e4627 100644 --- a/src/stream/src/task/barrier_manager.rs +++ b/src/stream/src/task/barrier_manager.rs @@ -701,9 +701,6 @@ impl LocalBarrierWorker { to_collect ); - // There must be some actors to collect from. - assert!(!to_collect.is_empty()); - for actor_id in &to_collect { if let Some(e) = self.failure_actors.get(actor_id) { // The failure actors could exit before the barrier is issued, while their diff --git a/src/stream/src/task/barrier_manager/managed_state.rs b/src/stream/src/task/barrier_manager/managed_state.rs index 40b47ee26e8f..f4a3fb31c03c 100644 --- a/src/stream/src/task/barrier_manager/managed_state.rs +++ b/src/stream/src/task/barrier_manager/managed_state.rs @@ -27,6 +27,7 @@ use futures::{FutureExt, StreamExt, TryFutureExt}; use prometheus::HistogramTimer; use risingwave_common::catalog::TableId; use risingwave_common::must_match; +use risingwave_common::util::epoch::EpochPair; use risingwave_hummock_sdk::SyncResult; use risingwave_pb::stream_plan::barrier::BarrierKind; use risingwave_pb::stream_service::barrier_complete_response::CreateMviewProgress; @@ -49,7 +50,8 @@ struct IssuedState { pub barrier_inflight_latency: HistogramTimer, - pub table_ids: HashSet, + /// Only be `Some(_)` when `kind` is `Checkpoint` + pub table_ids: Option>, pub kind: BarrierKind, } @@ -202,6 +204,8 @@ pub(super) struct ManagedBarrierState { mutation_subscribers: HashMap, + prev_barrier_table_ids: Option<(EpochPair, HashSet)>, + /// Record the progress updates of creating mviews for each epoch of concurrent checkpoints. pub(super) create_mview_progress: HashMap>, @@ -235,6 +239,7 @@ impl ManagedBarrierState { Self { epoch_barrier_state_map: BTreeMap::default(), mutation_subscribers: Default::default(), + prev_barrier_table_ids: None, create_mview_progress: Default::default(), state_store, streaming_metrics, @@ -403,7 +408,7 @@ impl ManagedBarrierState { state_store, &self.streaming_metrics, prev_epoch, - table_ids, + table_ids.expect("should be Some on BarrierKind::Checkpoint"), )) }) } @@ -527,6 +532,50 @@ impl ManagedBarrierState { .streaming_metrics .barrier_inflight_latency .start_timer(); + + if let Some(hummock) = self.state_store.as_hummock() { + hummock.start_epoch(barrier.epoch.curr, table_ids.clone()); + } + + let table_ids = match barrier.kind { + BarrierKind::Unspecified => { + unreachable!() + } + BarrierKind::Initial => { + assert!( + self.prev_barrier_table_ids.is_none(), + "non empty table_ids at initial barrier: {:?}", + self.prev_barrier_table_ids + ); + info!(epoch = ?barrier.epoch, "initialize at Initial barrier"); + self.prev_barrier_table_ids = Some((barrier.epoch, table_ids)); + None + } + BarrierKind::Barrier => { + if let Some((prev_epoch, prev_table_ids)) = self.prev_barrier_table_ids.as_mut() { + assert_eq!(prev_epoch.curr, barrier.epoch.prev); + assert_eq!(prev_table_ids, &table_ids); + *prev_epoch = barrier.epoch; + } else { + info!(epoch = ?barrier.epoch, "initialize at non-checkpoint barrier"); + self.prev_barrier_table_ids = Some((barrier.epoch, table_ids)); + } + None + } + BarrierKind::Checkpoint => Some( + if let Some((prev_epoch, prev_table_ids)) = self + .prev_barrier_table_ids + .replace((barrier.epoch, table_ids)) + { + assert_eq!(prev_epoch.curr, barrier.epoch.prev); + prev_table_ids + } else { + info!(epoch = ?barrier.epoch, "initialize at Checkpoint barrier"); + HashSet::new() + }, + ), + }; + if let Some(BarrierState { ref inner, .. }) = self.epoch_barrier_state_map.get_mut(&barrier.epoch.prev) { From 442a08771606397c7c75c557f9369c6e36f21394 Mon Sep 17 00:00:00 2001 From: William Wen <44139337+wenym1@users.noreply.github.com> Date: Sun, 14 Jul 2024 16:49:52 +0800 Subject: [PATCH 06/70] feat(storage): decouple spill task from epoch (#17539) --- .../event_handler/hummock_event_handler.rs | 2 +- .../src/hummock/event_handler/uploader/mod.rs | 473 ++++-------------- .../hummock/event_handler/uploader/spiller.rs | 427 ++++++++++++++++ .../event_handler/uploader/test_utils.rs | 346 +++++++++++++ .../src/common/table/test_state_table.rs | 6 - .../src/common/table/test_storage_table.rs | 3 - 6 files changed, 865 insertions(+), 392 deletions(-) create mode 100644 src/storage/src/hummock/event_handler/uploader/spiller.rs create mode 100644 src/storage/src/hummock/event_handler/uploader/test_utils.rs diff --git a/src/storage/src/hummock/event_handler/hummock_event_handler.rs b/src/storage/src/hummock/event_handler/hummock_event_handler.rs index f4038f0d7d52..addb2d08d5fc 100644 --- a/src/storage/src/hummock/event_handler/hummock_event_handler.rs +++ b/src/storage/src/hummock/event_handler/hummock_event_handler.rs @@ -936,7 +936,7 @@ mod tests { use tokio::sync::oneshot; use crate::hummock::event_handler::refiller::CacheRefiller; - use crate::hummock::event_handler::uploader::tests::{gen_imm, TEST_TABLE_ID}; + use crate::hummock::event_handler::uploader::test_utils::{gen_imm, TEST_TABLE_ID}; use crate::hummock::event_handler::uploader::UploadTaskOutput; use crate::hummock::event_handler::{HummockEvent, HummockEventHandler, HummockVersionUpdate}; use crate::hummock::iterator::test_utils::mock_sstable_store; diff --git a/src/storage/src/hummock/event_handler/uploader/mod.rs b/src/storage/src/hummock/event_handler/uploader/mod.rs index f0a18aa9eca6..101f54541fed 100644 --- a/src/storage/src/hummock/event_handler/uploader/mod.rs +++ b/src/storage/src/hummock/event_handler/uploader/mod.rs @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod spiller; mod task_manager; +pub(crate) mod test_utils; use std::cmp::Ordering; use std::collections::btree_map::Entry; @@ -43,6 +45,7 @@ use tokio::task::JoinHandle; use tracing::{debug, error, info, warn}; use crate::hummock::event_handler::hummock_event_handler::{send_sync_result, BufferTracker}; +use crate::hummock::event_handler::uploader::spiller::Spiller; use crate::hummock::event_handler::uploader::uploader_imm::UploaderImm; use crate::hummock::event_handler::LocalInstanceId; use crate::hummock::local_version::pinned_version::PinnedVersion; @@ -657,6 +660,7 @@ impl TableUnsyncData { impl Iterator)>, )>, impl Iterator, + BTreeMap, ) { if let Some(prev_epoch) = self.max_sync_epoch() { assert_gt!(epoch, prev_epoch) @@ -687,6 +691,7 @@ impl TableUnsyncData { take_before_epoch(&mut self.spill_tasks, epoch) .into_values() .flat_map(|tasks| tasks.into_iter()), + epochs, ) } @@ -735,7 +740,7 @@ impl TableUnsyncData { fn max_epoch(&self) -> Option { self.unsync_epochs - .first_key_value() + .last_key_value() .map(|(epoch, _)| *epoch) .or_else(|| self.max_sync_epoch()) } @@ -747,6 +752,22 @@ impl TableUnsyncData { } } +#[derive(Eq, Hash, PartialEq, Copy, Clone)] +struct UnsyncEpochId(HummockEpoch, TableId); + +impl UnsyncEpochId { + fn epoch(&self) -> HummockEpoch { + self.0 + } +} + +fn get_unsync_epoch_id(epoch: HummockEpoch, table_ids: &HashSet) -> Option { + table_ids + .iter() + .min() + .map(|table_id| UnsyncEpochId(epoch, *table_id)) +} + #[derive(Default)] /// Unsync data, can be either imm or spilled sst, and some aggregated epoch information. /// @@ -756,9 +777,7 @@ struct UnsyncData { table_data: HashMap, // An index as a mapping from instance id to its table id instance_table_id: HashMap, - // TODO: this is only used in spill to get existing epochs and can be removed - // when we support spill not based on epoch - epochs: BTreeMap>, + unsync_epochs: HashMap>, } impl UnsyncData { @@ -869,49 +888,56 @@ impl UploaderData { table_ids: HashSet, sync_result_sender: oneshot::Sender>, ) { - // clean old epochs - let epochs = take_before_epoch(&mut self.unsync_data.epochs, epoch); - if cfg!(debug_assertions) { - for epoch_table_ids in epochs.into_values() { - assert_eq!(epoch_table_ids, table_ids); - } - } - let mut all_table_watermarks = HashMap::new(); let mut uploading_tasks = HashSet::new(); let mut spilled_tasks = BTreeSet::new(); let mut flush_payload = HashMap::new(); - for (table_id, table_data) in &self.unsync_data.table_data { - if !table_ids.contains(table_id) { - table_data.assert_after_epoch(epoch); - } - } - for table_id in &table_ids { - let table_data = self + + if let Some(UnsyncEpochId(_, min_table_id)) = get_unsync_epoch_id(epoch, &table_ids) { + let min_table_id_data = self .unsync_data .table_data - .get_mut(table_id) + .get_mut(&min_table_id) .expect("should exist"); - let (unflushed_payload, table_watermarks, task_ids) = table_data.sync(epoch); - for (instance_id, payload) in unflushed_payload { - if !payload.is_empty() { - flush_payload.insert(instance_id, payload); - } - } - if let Some((direction, watermarks)) = table_watermarks { - Self::add_table_watermarks( - &mut all_table_watermarks, - *table_id, - direction, - watermarks, + let epochs = take_before_epoch(&mut min_table_id_data.unsync_epochs.clone(), epoch); + for epoch in epochs.keys() { + assert_eq!( + self.unsync_data + .unsync_epochs + .remove(&UnsyncEpochId(*epoch, min_table_id)) + .expect("should exist"), + table_ids ); } - for task_id in task_ids { - if self.spilled_data.contains_key(&task_id) { - spilled_tasks.insert(task_id); - } else { - uploading_tasks.insert(task_id); + for table_id in &table_ids { + let table_data = self + .unsync_data + .table_data + .get_mut(table_id) + .expect("should exist"); + let (unflushed_payload, table_watermarks, task_ids, table_unsync_epochs) = + table_data.sync(epoch); + assert_eq!(table_unsync_epochs, epochs); + for (instance_id, payload) in unflushed_payload { + if !payload.is_empty() { + flush_payload.insert(instance_id, payload); + } + } + if let Some((direction, watermarks)) = table_watermarks { + Self::add_table_watermarks( + &mut all_table_watermarks, + *table_id, + direction, + watermarks, + ); + } + for task_id in task_ids { + if self.spilled_data.contains_key(&task_id) { + spilled_tasks.insert(task_id); + } else { + uploading_tasks.insert(task_id); + } } } } @@ -1153,7 +1179,13 @@ impl HummockUploader { }); table_data.new_epoch(epoch); } - data.unsync_data.epochs.insert(epoch, table_ids); + if let Some(unsync_epoch_id) = get_unsync_epoch_id(epoch, &table_ids) { + assert!(data + .unsync_data + .unsync_epochs + .insert(unsync_epoch_id, table_ids) + .is_none()); + } } pub(super) fn start_sync_epoch( @@ -1209,43 +1241,28 @@ impl HummockUploader { return false; }; if self.context.buffer_tracker.need_flush() { + let mut spiller = Spiller::new(&mut data.unsync_data); let mut curr_batch_flush_size = 0; // iterate from older epoch to newer epoch - for (epoch, table_ids) in &data.unsync_data.epochs { - if !self - .context - .buffer_tracker - .need_more_flush(curr_batch_flush_size) + while self + .context + .buffer_tracker + .need_more_flush(curr_batch_flush_size) + && let Some((epoch, payload, spilled_table_ids)) = spiller.next_spilled_payload() + { + assert!(!payload.is_empty()); { - break; - } - let mut spilled_table_ids = HashSet::new(); - let mut payload = HashMap::new(); - for table_id in table_ids { - let table_data = data - .unsync_data - .table_data - .get_mut(table_id) - .expect("should exist"); - for (instance_id, instance_data) in &mut table_data.instance_data { - let instance_payload = instance_data.spill(*epoch); - if !instance_payload.is_empty() { - payload.insert(*instance_id, instance_payload); - spilled_table_ids.insert(*table_id); - } - } - } - if !payload.is_empty() { let (task_id, task_size, spilled_table_ids) = data.task_manager .spill(&self.context, spilled_table_ids, payload); for table_id in spilled_table_ids { - data.unsync_data + spiller + .unsync_data() .table_data .get_mut(table_id) .expect("should exist") .spill_tasks - .entry(*epoch) + .entry(epoch) .or_default() .push_front(task_id); } @@ -1459,257 +1476,25 @@ impl HummockUploader { #[cfg(test)] pub(crate) mod tests { - use std::collections::{HashMap, HashSet, VecDeque}; + use std::collections::{HashMap, HashSet}; use std::future::{poll_fn, Future}; use std::ops::Deref; use std::pin::pin; use std::sync::atomic::AtomicUsize; - use std::sync::atomic::Ordering::{Relaxed, SeqCst}; - use std::sync::{Arc, LazyLock}; + use std::sync::atomic::Ordering::SeqCst; + use std::sync::Arc; use std::task::Poll; - use bytes::Bytes; - use futures::future::BoxFuture; use futures::FutureExt; - use itertools::Itertools; - use prometheus::core::GenericGauge; - use risingwave_common::catalog::TableId; - use risingwave_common::must_match; - use risingwave_common::util::epoch::{test_epoch, EpochExt}; - use risingwave_hummock_sdk::compaction_group::StaticCompactionGroupId; - use risingwave_hummock_sdk::key::{FullKey, TableKey}; - use risingwave_hummock_sdk::version::HummockVersion; - use risingwave_hummock_sdk::{HummockEpoch, LocalSstableInfo}; - use risingwave_pb::hummock::{KeyRange, SstableInfo, StateTableInfoDelta}; - use spin::Mutex; - use tokio::spawn; - use tokio::sync::mpsc::unbounded_channel; + use risingwave_common::util::epoch::EpochExt; + use risingwave_hummock_sdk::HummockEpoch; use tokio::sync::oneshot; - use tokio::task::yield_now; - - use crate::hummock::event_handler::hummock_event_handler::BufferTracker; - use crate::hummock::event_handler::uploader::uploader_imm::UploaderImm; - use crate::hummock::event_handler::uploader::{ - get_payload_imm_ids, HummockUploader, SyncedData, TableUnsyncData, UploadTaskInfo, - UploadTaskOutput, UploadTaskPayload, UploaderContext, UploaderData, UploaderState, - UploadingTask, UploadingTaskId, - }; - use crate::hummock::event_handler::{LocalInstanceId, TEST_LOCAL_INSTANCE_ID}; - use crate::hummock::local_version::pinned_version::PinnedVersion; - use crate::hummock::shared_buffer::shared_buffer_batch::{ - SharedBufferBatch, SharedBufferBatchId, SharedBufferValue, - }; - use crate::hummock::{HummockError, HummockResult, MemoryLimiter}; - use crate::mem_table::{ImmId, ImmutableMemtable}; - use crate::monitor::HummockStateStoreMetrics; - use crate::opts::StorageOpts; - use crate::store::SealCurrentEpochOptions; - - const INITIAL_EPOCH: HummockEpoch = test_epoch(5); - pub(crate) const TEST_TABLE_ID: TableId = TableId { table_id: 233 }; - - pub trait UploadOutputFuture = - Future> + Send + 'static; - pub trait UploadFn = - Fn(UploadTaskPayload, UploadTaskInfo) -> Fut + Send + Sync + 'static; - - impl HummockUploader { - fn data(&self) -> &UploaderData { - must_match!(&self.state, UploaderState::Working(data) => data) - } - fn table_data(&self) -> &TableUnsyncData { - self.data() - .unsync_data - .table_data - .get(&TEST_TABLE_ID) - .expect("should exist") - } - - fn test_max_syncing_epoch(&self) -> HummockEpoch { - self.table_data().max_sync_epoch().unwrap() - } - - fn test_max_synced_epoch(&self) -> HummockEpoch { - self.table_data().max_synced_epoch.unwrap() - } - } - - fn test_hummock_version(epoch: HummockEpoch) -> HummockVersion { - let mut version = HummockVersion::default(); - version.id = epoch; - version.max_committed_epoch = epoch; - version.state_table_info.apply_delta( - &HashMap::from_iter([( - TEST_TABLE_ID, - StateTableInfoDelta { - committed_epoch: epoch, - safe_epoch: epoch, - compaction_group_id: StaticCompactionGroupId::StateDefault as _, - }, - )]), - &HashSet::new(), - ); - version - } - - fn initial_pinned_version() -> PinnedVersion { - PinnedVersion::new(test_hummock_version(INITIAL_EPOCH), unbounded_channel().0) - } - - fn dummy_table_key() -> Vec { - vec![b't', b'e', b's', b't'] - } - - async fn gen_imm_with_limiter( - epoch: HummockEpoch, - limiter: Option<&MemoryLimiter>, - ) -> ImmutableMemtable { - let sorted_items = vec![( - TableKey(Bytes::from(dummy_table_key())), - SharedBufferValue::Delete, - )]; - let size = SharedBufferBatch::measure_batch_size(&sorted_items, None).0; - let tracker = match limiter { - Some(limiter) => Some(limiter.require_memory(size as u64).await), - None => None, - }; - SharedBufferBatch::build_shared_buffer_batch( - epoch, - 0, - sorted_items, - None, - size, - TEST_TABLE_ID, - tracker, - ) - } - - pub(crate) async fn gen_imm(epoch: HummockEpoch) -> ImmutableMemtable { - gen_imm_with_limiter(epoch, None).await - } - - fn gen_sstable_info( - start_epoch: HummockEpoch, - end_epoch: HummockEpoch, - ) -> Vec { - let start_full_key = FullKey::new(TEST_TABLE_ID, TableKey(dummy_table_key()), start_epoch); - let end_full_key = FullKey::new(TEST_TABLE_ID, TableKey(dummy_table_key()), end_epoch); - let gen_sst_object_id = (start_epoch << 8) + end_epoch; - vec![LocalSstableInfo::for_test(SstableInfo { - object_id: gen_sst_object_id, - sst_id: gen_sst_object_id, - key_range: Some(KeyRange { - left: start_full_key.encode(), - right: end_full_key.encode(), - right_exclusive: true, - }), - table_ids: vec![TEST_TABLE_ID.table_id], - ..Default::default() - })] - } - - fn test_uploader_context(upload_fn: F) -> UploaderContext - where - Fut: UploadOutputFuture, - F: UploadFn, - { - let config = StorageOpts::default(); - UploaderContext::new( - initial_pinned_version(), - Arc::new(move |payload, task_info| spawn(upload_fn(payload, task_info))), - BufferTracker::for_test(), - &config, - Arc::new(HummockStateStoreMetrics::unused()), - ) - } - - fn test_uploader(upload_fn: F) -> HummockUploader - where - Fut: UploadOutputFuture, - F: UploadFn, - { - let config = StorageOpts { - ..Default::default() - }; - HummockUploader::new( - Arc::new(HummockStateStoreMetrics::unused()), - initial_pinned_version(), - Arc::new(move |payload, task_info| spawn(upload_fn(payload, task_info))), - BufferTracker::for_test(), - &config, - ) - } - - fn dummy_success_upload_output() -> UploadTaskOutput { - UploadTaskOutput { - new_value_ssts: gen_sstable_info(INITIAL_EPOCH, INITIAL_EPOCH), - old_value_ssts: vec![], - wait_poll_timer: None, - } - } - - #[allow(clippy::unused_async)] - async fn dummy_success_upload_future( - _: UploadTaskPayload, - _: UploadTaskInfo, - ) -> HummockResult { - Ok(dummy_success_upload_output()) - } - - #[allow(clippy::unused_async)] - async fn dummy_fail_upload_future( - _: UploadTaskPayload, - _: UploadTaskInfo, - ) -> HummockResult { - Err(HummockError::other("failed")) - } - - impl UploadingTask { - fn from_vec(imms: Vec, context: &UploaderContext) -> Self { - let input = HashMap::from_iter([( - TEST_LOCAL_INSTANCE_ID, - imms.into_iter().map(UploaderImm::for_test).collect_vec(), - )]); - static NEXT_TASK_ID: LazyLock = LazyLock::new(|| AtomicUsize::new(0)); - Self::new( - UploadingTaskId(NEXT_TASK_ID.fetch_add(1, Relaxed)), - input, - context, - ) - } - } - - fn get_imm_ids<'a>( - imms: impl IntoIterator, - ) -> HashMap> { - HashMap::from_iter([( - TEST_LOCAL_INSTANCE_ID, - imms.into_iter().map(|imm| imm.batch_id()).collect_vec(), - )]) - } - - impl HummockUploader { - fn local_seal_epoch_for_test(&mut self, instance_id: LocalInstanceId, epoch: HummockEpoch) { - self.local_seal_epoch( - instance_id, - epoch.next_epoch(), - SealCurrentEpochOptions::for_test(), - ); - } - - fn start_epochs_for_test(&mut self, epochs: impl IntoIterator) { - let mut last_epoch = None; - for epoch in epochs { - last_epoch = Some(epoch); - self.start_epoch(epoch, HashSet::from_iter([TEST_TABLE_ID])); - } - self.start_epoch( - last_epoch.unwrap().next_epoch(), - HashSet::from_iter([TEST_TABLE_ID]), - ); - } - } + use super::test_utils::*; + use crate::hummock::event_handler::uploader::{get_payload_imm_ids, SyncedData, UploadingTask}; + use crate::hummock::event_handler::TEST_LOCAL_INSTANCE_ID; + use crate::hummock::HummockError; + use crate::opts::StorageOpts; #[tokio::test] pub async fn test_uploading_task_future() { @@ -1973,82 +1758,6 @@ pub(crate) mod tests { assert_eq!(epoch6, uploader.test_max_syncing_epoch()); } - fn prepare_uploader_order_test( - config: &StorageOpts, - skip_schedule: bool, - ) -> ( - BufferTracker, - HummockUploader, - impl Fn(HashMap>) -> (BoxFuture<'static, ()>, oneshot::Sender<()>), - ) { - let gauge = GenericGauge::new("test", "test").unwrap(); - let buffer_tracker = BufferTracker::from_storage_opts(config, gauge); - // (the started task send the imm ids of payload, the started task wait for finish notify) - #[allow(clippy::type_complexity)] - let task_notifier_holder: Arc< - Mutex, oneshot::Receiver<()>)>>, - > = Arc::new(Mutex::new(VecDeque::new())); - - let new_task_notifier = { - let task_notifier_holder = task_notifier_holder.clone(); - move |imm_ids: HashMap>| { - let (start_tx, start_rx) = oneshot::channel(); - let (finish_tx, finish_rx) = oneshot::channel(); - task_notifier_holder - .lock() - .push_front((start_tx, finish_rx)); - let await_start_future = async move { - let task_info = start_rx.await.unwrap(); - assert_eq!(imm_ids, task_info.imm_ids); - } - .boxed(); - (await_start_future, finish_tx) - } - }; - - let config = StorageOpts::default(); - let uploader = HummockUploader::new( - Arc::new(HummockStateStoreMetrics::unused()), - initial_pinned_version(), - Arc::new({ - move |_, task_info: UploadTaskInfo| { - let task_notifier_holder = task_notifier_holder.clone(); - let task_item = task_notifier_holder.lock().pop_back(); - let start_epoch = *task_info.epochs.last().unwrap(); - let end_epoch = *task_info.epochs.first().unwrap(); - assert!(end_epoch >= start_epoch); - spawn(async move { - let ssts = gen_sstable_info(start_epoch, end_epoch); - if !skip_schedule { - let (start_tx, finish_rx) = task_item.unwrap(); - start_tx.send(task_info).unwrap(); - finish_rx.await.unwrap(); - } - Ok(UploadTaskOutput { - new_value_ssts: ssts, - old_value_ssts: vec![], - wait_poll_timer: None, - }) - }) - } - }), - buffer_tracker.clone(), - &config, - ); - (buffer_tracker, uploader, new_task_notifier) - } - - async fn assert_uploader_pending(uploader: &mut HummockUploader) { - for _ in 0..10 { - yield_now().await; - } - assert!( - poll_fn(|cx| Poll::Ready(uploader.next_uploaded_sst().poll_unpin(cx))) - .await - .is_pending() - ) - } - #[tokio::test] async fn test_uploader_finish_in_order() { let config = StorageOpts { diff --git a/src/storage/src/hummock/event_handler/uploader/spiller.rs b/src/storage/src/hummock/event_handler/uploader/spiller.rs new file mode 100644 index 000000000000..a4caa3c05fe3 --- /dev/null +++ b/src/storage/src/hummock/event_handler/uploader/spiller.rs @@ -0,0 +1,427 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{HashMap, HashSet}; + +use risingwave_common::catalog::TableId; +use risingwave_hummock_sdk::HummockEpoch; + +use crate::hummock::event_handler::uploader::{ + LocalInstanceUnsyncData, UnsyncData, UnsyncEpochId, UploadTaskInput, +}; +use crate::hummock::event_handler::LocalInstanceId; + +#[derive(Default)] +struct EpochSpillableDataInfo { + instance_ids: HashSet, + payload_size: usize, +} + +pub(super) struct Spiller<'a> { + unsync_data: &'a mut UnsyncData, + epoch_info: HashMap, + unsync_epoch_id_map: HashMap<(HummockEpoch, TableId), UnsyncEpochId>, +} + +impl<'a> Spiller<'a> { + pub(super) fn new(unsync_data: &'a mut UnsyncData) -> Self { + let unsync_epoch_id_map: HashMap<_, _> = unsync_data + .unsync_epochs + .iter() + .flat_map(|(unsync_epoch_id, table_ids)| { + let epoch = unsync_epoch_id.epoch(); + let unsync_epoch_id = *unsync_epoch_id; + table_ids + .iter() + .map(move |table_id| ((epoch, *table_id), unsync_epoch_id)) + }) + .collect(); + let mut epoch_info: HashMap<_, EpochSpillableDataInfo> = HashMap::new(); + for instance_data in unsync_data + .table_data + .values() + .flat_map(|table_data| table_data.instance_data.values()) + { + if let Some((epoch, spill_size)) = instance_data.spillable_data_info() { + let unsync_epoch_id = unsync_epoch_id_map + .get(&(epoch, instance_data.table_id)) + .expect("should exist"); + let epoch_info = epoch_info.entry(*unsync_epoch_id).or_default(); + assert!(epoch_info.instance_ids.insert(instance_data.instance_id)); + epoch_info.payload_size += spill_size; + } + } + Self { + unsync_data, + epoch_info, + unsync_epoch_id_map, + } + } + + pub(super) fn next_spilled_payload( + &mut self, + ) -> Option<(HummockEpoch, UploadTaskInput, HashSet)> { + if let Some(unsync_epoch_id) = self + .epoch_info + .iter() + .max_by_key(|(_, info)| info.payload_size) + .map(|(unsync_epoch_id, _)| *unsync_epoch_id) + { + let spill_epoch = unsync_epoch_id.epoch(); + let spill_info = self + .epoch_info + .remove(&unsync_epoch_id) + .expect("should exist"); + let epoch_info = &mut self.epoch_info; + let mut payload = HashMap::new(); + let mut spilled_table_ids = HashSet::new(); + for instance_id in spill_info.instance_ids { + let table_id = *self + .unsync_data + .instance_table_id + .get(&instance_id) + .expect("should exist"); + let instance_data = self + .unsync_data + .table_data + .get_mut(&table_id) + .expect("should exist") + .instance_data + .get_mut(&instance_id) + .expect("should exist"); + let instance_payload = instance_data.spill(spill_epoch); + assert!(!instance_payload.is_empty()); + payload.insert(instance_id, instance_payload); + spilled_table_ids.insert(table_id); + + // update the spill info + if let Some((new_spill_epoch, size)) = instance_data.spillable_data_info() { + let new_unsync_epoch_id = self + .unsync_epoch_id_map + .get(&(new_spill_epoch, instance_data.table_id)) + .expect("should exist"); + let info = epoch_info.entry(*new_unsync_epoch_id).or_default(); + assert!(info.instance_ids.insert(instance_id)); + info.payload_size += size; + } + } + Some((spill_epoch, payload, spilled_table_ids)) + } else { + None + } + } + + pub(super) fn unsync_data(&mut self) -> &mut UnsyncData { + self.unsync_data + } +} + +impl LocalInstanceUnsyncData { + fn spillable_data_info(&self) -> Option<(HummockEpoch, usize)> { + self.sealed_data + .back() + .or(self.current_epoch_data.as_ref()) + .and_then(|epoch_data| { + if !epoch_data.is_empty() { + Some(( + epoch_data.epoch, + epoch_data.imms.iter().map(|imm| imm.size()).sum(), + )) + } else { + None + } + }) + } +} + +#[cfg(test)] +mod tests { + use std::collections::{HashMap, HashSet}; + use std::ops::Deref; + + use futures::future::join_all; + use futures::FutureExt; + use itertools::Itertools; + use risingwave_common::catalog::TableId; + use risingwave_common::util::epoch::EpochExt; + use tokio::sync::oneshot; + + use crate::hummock::event_handler::uploader::test_utils::*; + use crate::opts::StorageOpts; + use crate::store::SealCurrentEpochOptions; + + #[tokio::test] + async fn test_spill_in_order() { + let config = StorageOpts { + shared_buffer_capacity_mb: 1024 * 1024, + shared_buffer_flush_ratio: 0.0, + ..Default::default() + }; + let (buffer_tracker, mut uploader, new_task_notifier) = + prepare_uploader_order_test(&config, false); + + let table_id1 = TableId::new(1); + let table_id2 = TableId::new(2); + + let instance_id1_1 = 1; + let instance_id1_2 = 2; + let instance_id2 = 3; + + let epoch1 = INITIAL_EPOCH.next_epoch(); + let epoch2 = epoch1.next_epoch(); + let epoch3 = epoch2.next_epoch(); + let epoch4 = epoch3.next_epoch(); + let memory_limiter = buffer_tracker.get_memory_limiter().clone(); + let memory_limiter = Some(memory_limiter.deref()); + + // epoch1 + uploader.start_epoch(epoch1, HashSet::from_iter([table_id1])); + uploader.start_epoch(epoch1, HashSet::from_iter([table_id2])); + + uploader.init_instance(instance_id1_1, table_id1, epoch1); + uploader.init_instance(instance_id1_2, table_id1, epoch1); + uploader.init_instance(instance_id2, table_id2, epoch1); + + // naming: imm__ + let imm1_1_1 = gen_imm_inner(table_id1, epoch1, 0, memory_limiter).await; + uploader.add_imm(instance_id1_1, imm1_1_1.clone()); + let imm1_2_1 = gen_imm_inner(table_id1, epoch1, 0, memory_limiter).await; + uploader.add_imm(instance_id1_2, imm1_2_1.clone()); + let imm2_1 = gen_imm_inner(table_id2, epoch1, 0, memory_limiter).await; + uploader.add_imm(instance_id2, imm2_1.clone()); + + // epoch2 + uploader.start_epoch(epoch2, HashSet::from_iter([table_id1])); + uploader.start_epoch(epoch2, HashSet::from_iter([table_id2])); + + uploader.local_seal_epoch(instance_id1_1, epoch2, SealCurrentEpochOptions::for_test()); + uploader.local_seal_epoch(instance_id1_2, epoch2, SealCurrentEpochOptions::for_test()); + uploader.local_seal_epoch(instance_id2, epoch2, SealCurrentEpochOptions::for_test()); + + let imms1_1_2 = join_all( + [0, 1, 2].map(|offset| gen_imm_inner(table_id1, epoch2, offset, memory_limiter)), + ) + .await; + for imm in imms1_1_2.clone() { + uploader.add_imm(instance_id1_1, imm); + } + + // epoch3 + uploader.start_epoch(epoch3, HashSet::from_iter([table_id1])); + uploader.start_epoch(epoch3, HashSet::from_iter([table_id2])); + + uploader.local_seal_epoch(instance_id1_1, epoch3, SealCurrentEpochOptions::for_test()); + uploader.local_seal_epoch(instance_id1_2, epoch3, SealCurrentEpochOptions::for_test()); + uploader.local_seal_epoch(instance_id2, epoch3, SealCurrentEpochOptions::for_test()); + + let imms1_2_3 = join_all( + [0, 1, 2, 3].map(|offset| gen_imm_inner(table_id1, epoch3, offset, memory_limiter)), + ) + .await; + for imm in imms1_2_3.clone() { + uploader.add_imm(instance_id1_2, imm); + } + + // epoch4 + uploader.start_epoch(epoch4, HashSet::from_iter([table_id1, table_id2])); + + uploader.local_seal_epoch(instance_id1_1, epoch4, SealCurrentEpochOptions::for_test()); + uploader.local_seal_epoch(instance_id1_2, epoch4, SealCurrentEpochOptions::for_test()); + uploader.local_seal_epoch(instance_id2, epoch4, SealCurrentEpochOptions::for_test()); + + let imm1_1_4 = gen_imm_inner(table_id1, epoch4, 0, memory_limiter).await; + uploader.add_imm(instance_id1_1, imm1_1_4.clone()); + let imm1_2_4 = gen_imm_inner(table_id1, epoch4, 0, memory_limiter).await; + uploader.add_imm(instance_id1_2, imm1_2_4.clone()); + let imm2_4_1 = gen_imm_inner(table_id2, epoch4, 0, memory_limiter).await; + uploader.add_imm(instance_id2, imm2_4_1.clone()); + + // uploader state: + // table_id1: + // instance_id1_1: instance_id1_2: instance_id2 + // epoch1 imm1_1_1 imm1_2_1 | imm2_1 | + // epoch2 imms1_1_2(size 3) | | + // epoch3 imms_1_2_3(size 4) | | + // epoch4 imm1_1_4 imm1_2_4 imm2_4_1 | + + let (await_start1_1, finish_tx1_1) = new_task_notifier(HashMap::from_iter([ + (instance_id1_1, vec![imm1_1_1.batch_id()]), + (instance_id1_2, vec![imm1_2_1.batch_id()]), + ])); + let (await_start3, finish_tx3) = new_task_notifier(HashMap::from_iter([( + instance_id1_2, + imms1_2_3 + .iter() + .rev() + .map(|imm| imm.batch_id()) + .collect_vec(), + )])); + let (await_start2, finish_tx2) = new_task_notifier(HashMap::from_iter([( + instance_id1_1, + imms1_1_2 + .iter() + .rev() + .map(|imm| imm.batch_id()) + .collect_vec(), + )])); + let (await_start1_4, finish_tx1_4) = new_task_notifier(HashMap::from_iter([ + (instance_id1_1, vec![imm1_1_4.batch_id()]), + (instance_id1_2, vec![imm1_2_4.batch_id()]), + ])); + let (await_start2_1, finish_tx2_1) = new_task_notifier(HashMap::from_iter([( + instance_id2, + vec![imm2_1.batch_id()], + )])); + let (await_start2_4_1, finish_tx2_4_1) = new_task_notifier(HashMap::from_iter([( + instance_id2, + vec![imm2_4_1.batch_id()], + )])); + + uploader.may_flush(); + await_start1_1.await; + await_start3.await; + await_start2.await; + await_start1_4.await; + await_start2_1.await; + await_start2_4_1.await; + + assert_uploader_pending(&mut uploader).await; + + let imm2_4_2 = gen_imm_inner(table_id2, epoch4, 1, memory_limiter).await; + uploader.add_imm(instance_id2, imm2_4_2.clone()); + + uploader.local_seal_epoch( + instance_id1_1, + u64::MAX, + SealCurrentEpochOptions::for_test(), + ); + uploader.local_seal_epoch( + instance_id1_2, + u64::MAX, + SealCurrentEpochOptions::for_test(), + ); + uploader.local_seal_epoch(instance_id2, u64::MAX, SealCurrentEpochOptions::for_test()); + + // uploader state: + // table_id1: + // instance_id1_1: instance_id1_2: instance_id2 + // epoch1 spill(imm1_1_1, imm1_2_1, size 2) | spill(imm2_1, size 1) | + // epoch2 spill(imms1_1_2, size 3) | | + // epoch3 spill(imms_1_2_3, size 4) | | + // epoch4 spill(imm1_1_4, imm1_2_4, size 2) | spill(imm2_4_1, size 1), imm2_4_2 | + + let (sync_tx1_1, sync_rx1_1) = oneshot::channel(); + uploader.start_sync_epoch(epoch1, sync_tx1_1, HashSet::from_iter([table_id1])); + let (sync_tx2_1, sync_rx2_1) = oneshot::channel(); + uploader.start_sync_epoch(epoch2, sync_tx2_1, HashSet::from_iter([table_id1])); + let (sync_tx3_1, sync_rx3_1) = oneshot::channel(); + uploader.start_sync_epoch(epoch3, sync_tx3_1, HashSet::from_iter([table_id1])); + let (sync_tx1_2, sync_rx1_2) = oneshot::channel(); + uploader.start_sync_epoch(epoch1, sync_tx1_2, HashSet::from_iter([table_id2])); + let (sync_tx2_2, sync_rx2_2) = oneshot::channel(); + uploader.start_sync_epoch(epoch2, sync_tx2_2, HashSet::from_iter([table_id2])); + let (sync_tx3_2, sync_rx3_2) = oneshot::channel(); + uploader.start_sync_epoch(epoch3, sync_tx3_2, HashSet::from_iter([table_id2])); + + let (await_start2_4_2, finish_tx2_4_2) = new_task_notifier(HashMap::from_iter([( + instance_id2, + vec![imm2_4_2.batch_id()], + )])); + + let (sync_tx4, mut sync_rx4) = oneshot::channel(); + uploader.start_sync_epoch(epoch4, sync_tx4, HashSet::from_iter([table_id1, table_id2])); + await_start2_4_2.await; + + finish_tx2_4_1.send(()).unwrap(); + finish_tx3.send(()).unwrap(); + finish_tx2.send(()).unwrap(); + finish_tx1_4.send(()).unwrap(); + assert_uploader_pending(&mut uploader).await; + + finish_tx1_1.send(()).unwrap(); + { + let imm_ids = HashMap::from_iter([ + (instance_id1_1, vec![imm1_1_1.batch_id()]), + (instance_id1_2, vec![imm1_2_1.batch_id()]), + ]); + let sst = uploader.next_uploaded_sst().await; + assert_eq!(&imm_ids, sst.imm_ids()); + let synced_data = sync_rx1_1.await.unwrap().unwrap(); + assert_eq!(synced_data.uploaded_ssts.len(), 1); + assert_eq!(&imm_ids, synced_data.uploaded_ssts[0].imm_ids()); + } + { + let imm_ids3 = HashMap::from_iter([( + instance_id1_2, + imms1_2_3 + .iter() + .rev() + .map(|imm| imm.batch_id()) + .collect_vec(), + )]); + let imm_ids2 = HashMap::from_iter([( + instance_id1_1, + imms1_1_2 + .iter() + .rev() + .map(|imm| imm.batch_id()) + .collect_vec(), + )]); + let sst = uploader.next_uploaded_sst().await; + assert_eq!(&imm_ids3, sst.imm_ids()); + let sst = uploader.next_uploaded_sst().await; + assert_eq!(&imm_ids2, sst.imm_ids()); + let synced_data = sync_rx2_1.await.unwrap().unwrap(); + assert_eq!(synced_data.uploaded_ssts.len(), 1); + assert_eq!(&imm_ids2, synced_data.uploaded_ssts[0].imm_ids()); + let synced_data = sync_rx3_1.await.unwrap().unwrap(); + assert_eq!(synced_data.uploaded_ssts.len(), 1); + assert_eq!(&imm_ids3, synced_data.uploaded_ssts[0].imm_ids()); + } + { + let imm_ids1_4 = HashMap::from_iter([ + (instance_id1_1, vec![imm1_1_4.batch_id()]), + (instance_id1_2, vec![imm1_2_4.batch_id()]), + ]); + let imm_ids2_1 = HashMap::from_iter([(instance_id2, vec![imm2_1.batch_id()])]); + let imm_ids2_4_1 = HashMap::from_iter([(instance_id2, vec![imm2_4_1.batch_id()])]); + finish_tx2_1.send(()).unwrap(); + let sst = uploader.next_uploaded_sst().await; + assert_eq!(&imm_ids1_4, sst.imm_ids()); + let sst = uploader.next_uploaded_sst().await; + assert_eq!(&imm_ids2_1, sst.imm_ids()); + let sst = uploader.next_uploaded_sst().await; + assert_eq!(&imm_ids2_4_1, sst.imm_ids()); + let synced_data = sync_rx1_2.await.unwrap().unwrap(); + assert_eq!(synced_data.uploaded_ssts.len(), 1); + assert_eq!(&imm_ids2_1, synced_data.uploaded_ssts[0].imm_ids()); + let synced_data = sync_rx2_2.await.unwrap().unwrap(); + assert!(synced_data.uploaded_ssts.is_empty()); + let synced_data = sync_rx3_2.await.unwrap().unwrap(); + assert!(synced_data.uploaded_ssts.is_empty()); + + let imm_ids2_4_2 = HashMap::from_iter([(instance_id2, vec![imm2_4_2.batch_id()])]); + + assert!((&mut sync_rx4).now_or_never().is_none()); + finish_tx2_4_2.send(()).unwrap(); + let sst = uploader.next_uploaded_sst().await; + assert_eq!(&imm_ids2_4_2, sst.imm_ids()); + let synced_data = sync_rx4.await.unwrap().unwrap(); + assert_eq!(synced_data.uploaded_ssts.len(), 3); + assert_eq!(&imm_ids2_4_2, synced_data.uploaded_ssts[0].imm_ids()); + assert_eq!(&imm_ids2_4_1, synced_data.uploaded_ssts[1].imm_ids()); + assert_eq!(&imm_ids1_4, synced_data.uploaded_ssts[2].imm_ids()); + } + } +} diff --git a/src/storage/src/hummock/event_handler/uploader/test_utils.rs b/src/storage/src/hummock/event_handler/uploader/test_utils.rs new file mode 100644 index 000000000000..2fa574c72fc2 --- /dev/null +++ b/src/storage/src/hummock/event_handler/uploader/test_utils.rs @@ -0,0 +1,346 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![cfg(test)] + +use std::collections::{HashMap, HashSet, VecDeque}; +use std::future::{poll_fn, Future}; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::Relaxed; +use std::sync::{Arc, LazyLock}; +use std::task::Poll; + +use bytes::Bytes; +use futures::future::BoxFuture; +use futures::FutureExt; +use itertools::Itertools; +use prometheus::core::GenericGauge; +use risingwave_common::catalog::TableId; +use risingwave_common::must_match; +use risingwave_common::util::epoch::{test_epoch, EpochExt}; +use risingwave_hummock_sdk::compaction_group::StaticCompactionGroupId; +use risingwave_hummock_sdk::key::{FullKey, TableKey}; +use risingwave_hummock_sdk::version::HummockVersion; +use risingwave_hummock_sdk::{HummockEpoch, LocalSstableInfo}; +use risingwave_pb::hummock::{KeyRange, SstableInfo, StateTableInfoDelta}; +use spin::Mutex; +use tokio::spawn; +use tokio::sync::mpsc::unbounded_channel; +use tokio::sync::oneshot; +use tokio::task::yield_now; + +use crate::hummock::event_handler::hummock_event_handler::BufferTracker; +use crate::hummock::event_handler::uploader::uploader_imm::UploaderImm; +use crate::hummock::event_handler::uploader::{ + HummockUploader, TableUnsyncData, UploadTaskInfo, UploadTaskOutput, UploadTaskPayload, + UploaderContext, UploaderData, UploaderState, UploadingTask, UploadingTaskId, +}; +use crate::hummock::event_handler::{LocalInstanceId, TEST_LOCAL_INSTANCE_ID}; +use crate::hummock::local_version::pinned_version::PinnedVersion; +use crate::hummock::shared_buffer::shared_buffer_batch::{ + SharedBufferBatch, SharedBufferBatchId, SharedBufferValue, +}; +use crate::hummock::{HummockError, HummockResult, MemoryLimiter}; +use crate::mem_table::{ImmId, ImmutableMemtable}; +use crate::monitor::HummockStateStoreMetrics; +use crate::opts::StorageOpts; +use crate::store::SealCurrentEpochOptions; + +pub(crate) const INITIAL_EPOCH: HummockEpoch = test_epoch(5); +pub(crate) const TEST_TABLE_ID: TableId = TableId { table_id: 233 }; + +pub trait UploadOutputFuture = Future> + Send + 'static; +pub trait UploadFn = + Fn(UploadTaskPayload, UploadTaskInfo) -> Fut + Send + Sync + 'static; + +impl HummockUploader { + pub(super) fn data(&self) -> &UploaderData { + must_match!(&self.state, UploaderState::Working(data) => data) + } + + pub(super) fn table_data(&self) -> &TableUnsyncData { + self.data() + .unsync_data + .table_data + .get(&TEST_TABLE_ID) + .expect("should exist") + } + + pub(super) fn test_max_syncing_epoch(&self) -> HummockEpoch { + self.table_data().max_sync_epoch().unwrap() + } + + pub(super) fn test_max_synced_epoch(&self) -> HummockEpoch { + self.table_data().max_synced_epoch.unwrap() + } +} + +pub(super) fn test_hummock_version(epoch: HummockEpoch) -> HummockVersion { + let mut version = HummockVersion::default(); + version.id = epoch; + version.max_committed_epoch = epoch; + version.state_table_info.apply_delta( + &HashMap::from_iter([( + TEST_TABLE_ID, + StateTableInfoDelta { + committed_epoch: epoch, + safe_epoch: epoch, + compaction_group_id: StaticCompactionGroupId::StateDefault as _, + }, + )]), + &HashSet::new(), + ); + version +} + +pub(super) fn initial_pinned_version() -> PinnedVersion { + PinnedVersion::new(test_hummock_version(INITIAL_EPOCH), unbounded_channel().0) +} + +pub(super) fn dummy_table_key() -> Vec { + vec![b't', b'e', b's', b't'] +} + +pub(super) async fn gen_imm_with_limiter( + epoch: HummockEpoch, + limiter: Option<&MemoryLimiter>, +) -> ImmutableMemtable { + gen_imm_inner(TEST_TABLE_ID, epoch, 0, limiter).await +} + +pub(super) async fn gen_imm_inner( + table_id: TableId, + epoch: HummockEpoch, + spill_offset: u16, + limiter: Option<&MemoryLimiter>, +) -> ImmutableMemtable { + let sorted_items = vec![( + TableKey(Bytes::from(dummy_table_key())), + SharedBufferValue::Delete, + )]; + let size = SharedBufferBatch::measure_batch_size(&sorted_items, None).0; + let tracker = match limiter { + Some(limiter) => Some(limiter.require_memory(size as u64).await), + None => None, + }; + SharedBufferBatch::build_shared_buffer_batch( + epoch, + spill_offset, + sorted_items, + None, + size, + table_id, + tracker, + ) +} + +pub(crate) async fn gen_imm(epoch: HummockEpoch) -> ImmutableMemtable { + gen_imm_with_limiter(epoch, None).await +} + +pub(super) fn gen_sstable_info( + start_epoch: HummockEpoch, + end_epoch: HummockEpoch, +) -> Vec { + let start_full_key = FullKey::new(TEST_TABLE_ID, TableKey(dummy_table_key()), start_epoch); + let end_full_key = FullKey::new(TEST_TABLE_ID, TableKey(dummy_table_key()), end_epoch); + let gen_sst_object_id = (start_epoch << 8) + end_epoch; + vec![LocalSstableInfo::for_test(SstableInfo { + object_id: gen_sst_object_id, + sst_id: gen_sst_object_id, + key_range: Some(KeyRange { + left: start_full_key.encode(), + right: end_full_key.encode(), + right_exclusive: true, + }), + table_ids: vec![TEST_TABLE_ID.table_id], + ..Default::default() + })] +} + +pub(super) fn test_uploader_context(upload_fn: F) -> UploaderContext +where + Fut: UploadOutputFuture, + F: UploadFn, +{ + let config = StorageOpts::default(); + UploaderContext::new( + initial_pinned_version(), + Arc::new(move |payload, task_info| spawn(upload_fn(payload, task_info))), + BufferTracker::for_test(), + &config, + Arc::new(HummockStateStoreMetrics::unused()), + ) +} + +pub(super) fn test_uploader(upload_fn: F) -> HummockUploader +where + Fut: UploadOutputFuture, + F: UploadFn, +{ + let config = StorageOpts { + ..Default::default() + }; + HummockUploader::new( + Arc::new(HummockStateStoreMetrics::unused()), + initial_pinned_version(), + Arc::new(move |payload, task_info| spawn(upload_fn(payload, task_info))), + BufferTracker::for_test(), + &config, + ) +} + +pub(super) fn dummy_success_upload_output() -> UploadTaskOutput { + UploadTaskOutput { + new_value_ssts: gen_sstable_info(INITIAL_EPOCH, INITIAL_EPOCH), + old_value_ssts: vec![], + wait_poll_timer: None, + } +} + +#[allow(clippy::unused_async)] +pub(super) async fn dummy_success_upload_future( + _: UploadTaskPayload, + _: UploadTaskInfo, +) -> HummockResult { + Ok(dummy_success_upload_output()) +} + +#[allow(clippy::unused_async)] +pub(super) async fn dummy_fail_upload_future( + _: UploadTaskPayload, + _: UploadTaskInfo, +) -> HummockResult { + Err(HummockError::other("failed")) +} + +impl UploadingTask { + pub(super) fn from_vec(imms: Vec, context: &UploaderContext) -> Self { + let input = HashMap::from_iter([( + TEST_LOCAL_INSTANCE_ID, + imms.into_iter().map(UploaderImm::for_test).collect_vec(), + )]); + static NEXT_TASK_ID: LazyLock = LazyLock::new(|| AtomicUsize::new(0)); + Self::new( + UploadingTaskId(NEXT_TASK_ID.fetch_add(1, Relaxed)), + input, + context, + ) + } +} + +pub(super) fn get_imm_ids<'a>( + imms: impl IntoIterator, +) -> HashMap> { + HashMap::from_iter([( + TEST_LOCAL_INSTANCE_ID, + imms.into_iter().map(|imm| imm.batch_id()).collect_vec(), + )]) +} + +impl HummockUploader { + pub(super) fn local_seal_epoch_for_test( + &mut self, + instance_id: LocalInstanceId, + epoch: HummockEpoch, + ) { + self.local_seal_epoch( + instance_id, + epoch.next_epoch(), + SealCurrentEpochOptions::for_test(), + ); + } + + pub(super) fn start_epochs_for_test(&mut self, epochs: impl IntoIterator) { + for epoch in epochs { + self.start_epoch(epoch, HashSet::from_iter([TEST_TABLE_ID])); + } + } +} + +pub(crate) fn prepare_uploader_order_test( + config: &StorageOpts, + skip_schedule: bool, +) -> ( + BufferTracker, + HummockUploader, + impl Fn(HashMap>) -> (BoxFuture<'static, ()>, oneshot::Sender<()>), +) { + let gauge = GenericGauge::new("test", "test").unwrap(); + let buffer_tracker = BufferTracker::from_storage_opts(config, gauge); + // (the started task send the imm ids of payload, the started task wait for finish notify) + #[allow(clippy::type_complexity)] + let task_notifier_holder: Arc< + Mutex, oneshot::Receiver<()>)>>, + > = Arc::new(Mutex::new(VecDeque::new())); + + let new_task_notifier = { + let task_notifier_holder = task_notifier_holder.clone(); + move |imm_ids: HashMap>| { + let (start_tx, start_rx) = oneshot::channel(); + let (finish_tx, finish_rx) = oneshot::channel(); + task_notifier_holder + .lock() + .push_front((start_tx, finish_rx)); + let await_start_future = async move { + let task_info = start_rx.await.unwrap(); + assert_eq!(imm_ids, task_info.imm_ids); + } + .boxed(); + (await_start_future, finish_tx) + } + }; + + let config = StorageOpts::default(); + let uploader = HummockUploader::new( + Arc::new(HummockStateStoreMetrics::unused()), + initial_pinned_version(), + Arc::new({ + move |_, task_info: UploadTaskInfo| { + let task_notifier_holder = task_notifier_holder.clone(); + let task_item = task_notifier_holder.lock().pop_back(); + let start_epoch = *task_info.epochs.last().unwrap(); + let end_epoch = *task_info.epochs.first().unwrap(); + assert!(end_epoch >= start_epoch); + spawn(async move { + let ssts = gen_sstable_info(start_epoch, end_epoch); + if !skip_schedule { + let (start_tx, finish_rx) = task_item.unwrap(); + start_tx.send(task_info).unwrap(); + finish_rx.await.unwrap(); + } + Ok(UploadTaskOutput { + new_value_ssts: ssts, + old_value_ssts: vec![], + wait_poll_timer: None, + }) + }) + } + }), + buffer_tracker.clone(), + &config, + ); + (buffer_tracker, uploader, new_task_notifier) +} + +pub(crate) async fn assert_uploader_pending(uploader: &mut HummockUploader) { + for _ in 0..10 { + yield_now().await; + } + assert!( + poll_fn(|cx| Poll::Ready(uploader.next_uploaded_sst().poll_unpin(cx))) + .await + .is_pending() + ) +} diff --git a/src/stream/src/common/table/test_state_table.rs b/src/stream/src/common/table/test_state_table.rs index 72ffa72479cf..d6e9c9bed5b9 100644 --- a/src/stream/src/common/table/test_state_table.rs +++ b/src/stream/src/common/table/test_state_table.rs @@ -1963,9 +1963,6 @@ async fn test_replicated_state_table_replication() { ])); epoch.inc_for_test(); - test_env - .storage - .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); test_env .storage .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); @@ -2029,9 +2026,6 @@ async fn test_replicated_state_table_replication() { replicated_state_table.write_chunk(replicate_chunk); epoch.inc_for_test(); - test_env - .storage - .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); test_env .storage .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); diff --git a/src/stream/src/common/table/test_storage_table.rs b/src/stream/src/common/table/test_storage_table.rs index 1f130330e3be..1eb552271dce 100644 --- a/src/stream/src/common/table/test_storage_table.rs +++ b/src/stream/src/common/table/test_storage_table.rs @@ -583,9 +583,6 @@ async fn test_batch_scan_chunk_with_value_indices() { .collect_vec(); epoch.inc_for_test(); - test_env - .storage - .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); test_env .storage .start_epoch(epoch.curr, HashSet::from_iter([TEST_TABLE_ID])); From d149d99316dbb38b9f8128a07aa9b6ccf734ebbd Mon Sep 17 00:00:00 2001 From: zwang28 <70626450+zwang28@users.noreply.github.com> Date: Mon, 15 Jul 2024 13:25:26 +0800 Subject: [PATCH 07/70] fix(test): adapt to ctl change (#17682) --- src/storage/backup/integration_tests/common.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/storage/backup/integration_tests/common.sh b/src/storage/backup/integration_tests/common.sh index d28ab1b9a982..d72b47686953 100644 --- a/src/storage/backup/integration_tests/common.sh +++ b/src/storage/backup/integration_tests/common.sh @@ -111,7 +111,7 @@ function backup() { function delete_snapshot() { local snapshot_id snapshot_id=$1 - ${BACKUP_TEST_RW_ALL_IN_ONE} risectl meta delete-meta-snapshots "${snapshot_id}" + ${BACKUP_TEST_RW_ALL_IN_ONE} risectl meta delete-meta-snapshots --snapshot-ids "${snapshot_id}" } function restore() { From a011ac6775f9bfdb0b5fc4e44c34ff263ef93dee Mon Sep 17 00:00:00 2001 From: Dylan Date: Mon, 15 Jul 2024 13:39:14 +0800 Subject: [PATCH 08/70] fix(iceberg): fix iceerg source with rest catalog (#17684) --- src/connector/src/sink/iceberg/mod.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/connector/src/sink/iceberg/mod.rs b/src/connector/src/sink/iceberg/mod.rs index a63f5e49f0f8..4f4af102406c 100644 --- a/src/connector/src/sink/iceberg/mod.rs +++ b/src/connector/src/sink/iceberg/mod.rs @@ -514,10 +514,26 @@ impl IcebergConfig { Ok(Arc::new(catalog)) } "rest" => { + let mut iceberg_configs = HashMap::new(); + if let Some(region) = &self.region { + iceberg_configs.insert(S3_REGION.to_string(), region.clone().to_string()); + } + if let Some(endpoint) = &self.endpoint { + iceberg_configs.insert(S3_ENDPOINT.to_string(), endpoint.clone().to_string()); + } + iceberg_configs.insert( + S3_ACCESS_KEY_ID.to_string(), + self.access_key.clone().to_string(), + ); + iceberg_configs.insert( + S3_SECRET_ACCESS_KEY.to_string(), + self.secret_key.clone().to_string(), + ); let config = iceberg_catalog_rest::RestCatalogConfig::builder() .uri(self.uri.clone().ok_or_else(|| { SinkError::Iceberg(anyhow!("`catalog.uri` must be set in rest catalog")) })?) + .props(iceberg_configs) .build(); let catalog = iceberg_catalog_rest::RestCatalog::new(config).await?; Ok(Arc::new(catalog)) From c31be2d65221bb7dc26e1eb75a56fbe78d09d164 Mon Sep 17 00:00:00 2001 From: zwang28 <70626450+zwang28@users.noreply.github.com> Date: Mon, 15 Jul 2024 15:25:27 +0800 Subject: [PATCH 09/70] fix(batch): fix time travel issue (#17686) --- src/batch/src/executor/row_seq_scan.rs | 2 +- src/common/src/util/epoch.rs | 4 ++++ src/meta/src/hummock/manager/time_travel.rs | 2 -- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/batch/src/executor/row_seq_scan.rs b/src/batch/src/executor/row_seq_scan.rs index 455e81342fec..9e3710a7040b 100644 --- a/src/batch/src/executor/row_seq_scan.rs +++ b/src/batch/src/executor/row_seq_scan.rs @@ -499,7 +499,7 @@ impl RowSeqScanExecutor { pub fn unix_timestamp_sec_to_epoch(ts: i64) -> risingwave_common::util::epoch::Epoch { let ts = ts.checked_add(1).unwrap(); - risingwave_common::util::epoch::Epoch::from_unix_millis( + risingwave_common::util::epoch::Epoch::from_unix_millis_or_earliest( u64::try_from(ts).unwrap().checked_mul(1000).unwrap(), ) } diff --git a/src/common/src/util/epoch.rs b/src/common/src/util/epoch.rs index 56dbdf6c54da..cba00f4c5ddb 100644 --- a/src/common/src/util/epoch.rs +++ b/src/common/src/util/epoch.rs @@ -77,6 +77,10 @@ impl Epoch { Epoch((mi - UNIX_RISINGWAVE_DATE_SEC * 1000) << EPOCH_PHYSICAL_SHIFT_BITS) } + pub fn from_unix_millis_or_earliest(mi: u64) -> Self { + Epoch((mi.saturating_sub(UNIX_RISINGWAVE_DATE_SEC * 1000)) << EPOCH_PHYSICAL_SHIFT_BITS) + } + pub fn physical_now() -> u64 { UNIX_RISINGWAVE_DATE_EPOCH .elapsed() diff --git a/src/meta/src/hummock/manager/time_travel.rs b/src/meta/src/hummock/manager/time_travel.rs index eec78c70fab9..788ab693579c 100644 --- a/src/meta/src/hummock/manager/time_travel.rs +++ b/src/meta/src/hummock/manager/time_travel.rs @@ -111,8 +111,6 @@ impl HummockManager { res.rows_affected ); let earliest_valid_version = hummock_time_travel_version::Entity::find() - .select_only() - .column(hummock_time_travel_version::Column::VersionId) .filter( hummock_time_travel_version::Column::VersionId.lte(version_watermark.version_id), ) From c186a8fc26decb8fc8f37cf4ad9e63bb2afd81ed Mon Sep 17 00:00:00 2001 From: Bohan Zhang Date: Mon, 15 Jul 2024 16:06:55 +0800 Subject: [PATCH 10/70] fix: enable upsert protobuf combination (#17624) Signed-off-by: tabVersion --- .../source_inline/kafka/protobuf/basic.slt | 23 +++++++++++++++++++ e2e_test/source_inline/kafka/protobuf/pb.py | 2 ++ src/frontend/src/handler/create_source.rs | 7 ++++-- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/e2e_test/source_inline/kafka/protobuf/basic.slt b/e2e_test/source_inline/kafka/protobuf/basic.slt index 0eae891d04bc..44153949e79e 100644 --- a/e2e_test/source_inline/kafka/protobuf/basic.slt +++ b/e2e_test/source_inline/kafka/protobuf/basic.slt @@ -33,6 +33,22 @@ FORMAT plain ENCODE protobuf( message = 'test.User' ); + +# for upsert protobuf source +# NOTE: the key part is in json format and rw only read it as bytes +statement ok +create table sr_pb_upsert (primary key (rw_key)) +include + key as rw_key +with ( + ${RISEDEV_KAFKA_WITH_OPTIONS_COMMON}, + topic = 'sr_pb_test', + scan.startup.mode = 'earliest') +FORMAT plain ENCODE protobuf( + schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}', + message = 'test.User' + ); + # Wait for source sleep 2s @@ -50,9 +66,16 @@ select min(id), max(id), max((sc).file_name) from sr_pb_test; ---- 0 19 source/context_019.proto +query TT +select convert_from(min(rw_key), 'UTF-8'), convert_from(max(rw_key), 'UTF-8') from sr_pb_upsert; +---- +{"id": 0} {"id": 9} statement ok drop table sr_pb_test; statement ok drop table sr_pb_test_bk; + +statement ok +drop table sr_pb_upsert; diff --git a/e2e_test/source_inline/kafka/protobuf/pb.py b/e2e_test/source_inline/kafka/protobuf/pb.py index 4cab50f899e5..d78db1b536b9 100644 --- a/e2e_test/source_inline/kafka/protobuf/pb.py +++ b/e2e_test/source_inline/kafka/protobuf/pb.py @@ -1,4 +1,5 @@ import sys +import json import importlib from google.protobuf.source_context_pb2 import SourceContext from confluent_kafka import Producer @@ -55,6 +56,7 @@ def send_to_kafka( producer.produce( topic=topic, partition=0, + key=json.dumps({"id": i}), # RisingWave does not handle key schema, so we use json value=serializer(user, SerializationContext(topic, MessageField.VALUE)), on_delivery=delivery_report, ) diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index 3d3c32958e31..aac3649d808c 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -798,7 +798,10 @@ pub(crate) async fn bind_source_pk( // For all Upsert formats, we only accept one and only key column as primary key. // Additional KEY columns must be set in this case and must be primary key. - (Format::Upsert, encode @ Encode::Json | encode @ Encode::Avro) => { + ( + Format::Upsert, + encode @ Encode::Json | encode @ Encode::Avro | encode @ Encode::Protobuf, + ) => { if let Some(ref key_column_name) = include_key_column_name && sql_defined_pk { @@ -993,7 +996,7 @@ static CONNECTORS_COMPATIBLE_FORMATS: LazyLock hashmap!( Format::Plain => vec![Encode::Json, Encode::Protobuf, Encode::Avro, Encode::Bytes, Encode::Csv], - Format::Upsert => vec![Encode::Json, Encode::Avro], + Format::Upsert => vec![Encode::Json, Encode::Avro, Encode::Protobuf], Format::Debezium => vec![Encode::Json, Encode::Avro], Format::Maxwell => vec![Encode::Json], Format::Canal => vec![Encode::Json], From 4e724c07c36dc81afa5120ffb0f99378ba82be1c Mon Sep 17 00:00:00 2001 From: Yuhao Su <31772373+yuhao-su@users.noreply.github.com> Date: Mon, 15 Jul 2024 10:53:38 -0500 Subject: [PATCH 11/70] feat(ddl): allow alter table with generated columns (#17652) --- e2e_test/ddl/table/generated_columns.slt.part | 15 ++++ .../source_inline/kafka/avro/alter_table.slt | 81 +++++++++++++++++++ .../src/handler/alter_table_column.rs | 30 ++++--- .../src/handler/alter_table_with_sr.rs | 36 ++++++--- src/frontend/src/handler/create_table.rs | 7 +- 5 files changed, 147 insertions(+), 22 deletions(-) create mode 100644 e2e_test/source_inline/kafka/avro/alter_table.slt diff --git a/e2e_test/ddl/table/generated_columns.slt.part b/e2e_test/ddl/table/generated_columns.slt.part index 2271522a47fd..e1be2f26417e 100644 --- a/e2e_test/ddl/table/generated_columns.slt.part +++ b/e2e_test/ddl/table/generated_columns.slt.part @@ -58,6 +58,21 @@ select * from t2; 1 2 2 3 +statement error +alter table t2 drop column v1; +---- +db error: ERROR: Failed to run the query + +Caused by: + failed to drop column "v1" because it's referenced by a generated column "v2" + + +statement ok +alter table t2 drop column v2; + +statement ok +alter table t2 drop column v1; + statement ok drop table t2; diff --git a/e2e_test/source_inline/kafka/avro/alter_table.slt b/e2e_test/source_inline/kafka/avro/alter_table.slt new file mode 100644 index 000000000000..f33a5dc0d594 --- /dev/null +++ b/e2e_test/source_inline/kafka/avro/alter_table.slt @@ -0,0 +1,81 @@ +control substitution on + +# https://github.com/risingwavelabs/risingwave/issues/16486 + +# cleanup +system ok +rpk topic delete 'avro_alter_table_test' || true; \ +(rpk sr subject delete 'avro_alter_table_test-value' && rpk sr subject delete 'avro_alter_table_test-value' --permanent) || true; + +# create topic and sr subject +system ok +rpk topic create 'avro_alter_table_test' + +# create a schema and produce a message +system ok +echo '{"type":"record","name":"Root","fields":[{"name":"bar","type":"int","default":0},{"name":"foo","type":"string"}]}' | jq '{"schema": tojson}' \ +| curl -s -X POST -H 'content-type:application/json' -d @- "${RISEDEV_SCHEMA_REGISTRY_URL}/subjects/avro_alter_table_test-value/versions" + +system ok +echo '{"foo":"ABC", "bar":1}' | rpk topic produce --schema-id=topic avro_alter_table_test + +statement ok +create table t (*, gen_col int as bar + 1) +WITH ( + ${RISEDEV_KAFKA_WITH_OPTIONS_COMMON}, + topic = 'avro_alter_table_test' +) +FORMAT PLAIN ENCODE AVRO ( + schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}' +); + +sleep 4s + +query ? +select * from t +---- +1 ABC 2 + +# create a new version of schema that removed field bar +system ok +echo '{"type":"record","name":"Root","fields":[{"name":"foo","type":"string"}]}' | jq '{"schema": tojson}' \ +| curl -s -X POST -H 'content-type:application/json' -d @- "${RISEDEV_SCHEMA_REGISTRY_URL}/subjects/avro_alter_table_test-value/versions" + +# Refresh table schema should fail +statement error +ALTER TABLE t REFRESH SCHEMA; +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: failed to refresh schema because some of the columns to drop are referenced by a generated column "gen_col" + 2: fail to bind expression in generated column "gen_col" + 3: Failed to bind expression: bar + 1 + 4: Item not found: Invalid column: bar + + +# TODO: Add support for dropping generated columns for table with schema registry +statement error +ALTER TABLE t DROP COLUMN gen_col; +---- +db error: ERROR: Failed to run the query + +Caused by: + Not supported: alter table with schema registry +HINT: try `ALTER TABLE .. FORMAT .. ENCODE .. (...)` instead + + +# statement ok +# ALTER TABLE t DROP COLUMN gen_col; + +# # Refresh table schema +# statement ok +# ALTER TABLE t REFRESH SCHEMA; + +# query ? +# select * from t +# ---- +# ABC + +statement ok +drop table t; diff --git a/src/frontend/src/handler/alter_table_column.rs b/src/frontend/src/handler/alter_table_column.rs index 0ddeb2d1e3d3..0af3ace68b6c 100644 --- a/src/frontend/src/handler/alter_table_column.rs +++ b/src/frontend/src/handler/alter_table_column.rs @@ -17,8 +17,8 @@ use std::sync::Arc; use anyhow::Context; use itertools::Itertools; use pgwire::pg_response::{PgResponse, StatementType}; -use risingwave_common::bail_not_implemented; use risingwave_common::util::column_index_mapping::ColIndexMapping; +use risingwave_common::{bail, bail_not_implemented}; use risingwave_sqlparser::ast::{ AlterTableOperation, ColumnOption, ConnectorSchema, Encode, ObjectName, Statement, }; @@ -30,7 +30,8 @@ use super::util::SourceSchemaCompatExt; use super::{HandlerArgs, RwPgResponse}; use crate::catalog::root_catalog::SchemaPath; use crate::catalog::table_catalog::TableType; -use crate::error::{ErrorCode, Result, RwError}; +use crate::error::{ErrorCode, Result}; +use crate::expr::ExprImpl; use crate::session::SessionImpl; use crate::{Binder, TableCatalog, WithOptions}; @@ -113,13 +114,6 @@ pub async fn handle_alter_table_column( bail_not_implemented!("alter table with incoming sinks"); } - // TODO(yuhao): alter table with generated columns. - if original_catalog.has_generated_column() { - return Err(RwError::from(ErrorCode::BindError( - "Alter a table with generated column has not been implemented.".to_string(), - ))); - } - // Retrieve the original table definition and parse it to AST. let [mut definition]: [_; 1] = Parser::parse_sql(&original_catalog.definition) .context("unable to parse original table definition")? @@ -193,6 +187,24 @@ pub async fn handle_alter_table_column( bail_not_implemented!(issue = 6903, "drop column cascade"); } + // Check if the column to drop is referenced by any generated columns. + for column in original_catalog.columns() { + if let Some(expr) = column.generated_expr() { + let expr = ExprImpl::from_expr_proto(expr)?; + let refs = expr.collect_input_refs(original_catalog.columns().len()); + for idx in refs.ones() { + let refed_column = &original_catalog.columns()[idx]; + if refed_column.name() == column_name.real_value() { + bail!(format!( + "failed to drop column \"{}\" because it's referenced by a generated column \"{}\"", + column_name, + column.name() + )) + } + } + } + } + // Locate the column by name and remove it. let column_name = column_name.real_value(); let removed_column = columns diff --git a/src/frontend/src/handler/alter_table_with_sr.rs b/src/frontend/src/handler/alter_table_with_sr.rs index c1700c36a3c9..d932246759e2 100644 --- a/src/frontend/src/handler/alter_table_with_sr.rs +++ b/src/frontend/src/handler/alter_table_with_sr.rs @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::Context; +use anyhow::{anyhow, Context}; +use fancy_regex::Regex; use pgwire::pg_response::StatementType; use risingwave_common::bail_not_implemented; use risingwave_sqlparser::ast::{ConnectorSchema, ObjectName, Statement}; use risingwave_sqlparser::parser::Parser; +use thiserror_ext::AsReport; use super::alter_source_with_sr::alter_definition_format_encode; use super::alter_table_column::{ @@ -24,7 +26,7 @@ use super::alter_table_column::{ }; use super::util::SourceSchemaCompatExt; use super::{HandlerArgs, RwPgResponse}; -use crate::error::{ErrorCode, Result, RwError}; +use crate::error::{ErrorCode, Result}; use crate::TableCatalog; fn get_connector_schema_from_table(table: &TableCatalog) -> Result> { @@ -49,13 +51,6 @@ pub async fn handle_refresh_schema( bail_not_implemented!("alter table with incoming sinks"); } - // TODO(yuhao): alter table with generated columns. - if original_table.has_generated_column() { - return Err(RwError::from(ErrorCode::BindError( - "Alter a table with generated column has not been implemented.".to_string(), - ))); - } - let connector_schema = { let connector_schema = get_connector_schema_from_table(&original_table)?; if !connector_schema @@ -81,14 +76,31 @@ pub async fn handle_refresh_schema( .try_into() .unwrap(); - replace_table_with_definition( + let result = replace_table_with_definition( &session, table_name, definition, &original_table, Some(connector_schema), ) - .await?; + .await; - Ok(RwPgResponse::empty_result(StatementType::ALTER_TABLE)) + match result { + Ok(_) => Ok(RwPgResponse::empty_result(StatementType::ALTER_TABLE)), + Err(e) => { + let report = e.to_report_string(); + // This is a workaround for reporting errors when columns to drop is referenced by generated column. + // Finding the actual columns to drop requires generating `PbSource` from the sql definition + // and fetching schema from schema registry, which will cause a lot of unnecessary refactor. + // Here we match the error message to yield when failing to bind generated column exprs. + let re = Regex::new(r#"fail to bind expression in generated column "(.*?)""#).unwrap(); + let captures = re.captures(&report).map_err(anyhow::Error::from)?; + if let Some(gen_col_name) = captures.and_then(|captures| captures.get(1)) { + Err(anyhow!(e).context(format!("failed to refresh schema because some of the columns to drop are referenced by a generated column \"{}\"", + gen_col_name.as_str())).into()) + } else { + Err(e) + } + } + } } diff --git a/src/frontend/src/handler/create_table.rs b/src/frontend/src/handler/create_table.rs index 11d0d0ebd08d..ef8b57f8064d 100644 --- a/src/frontend/src/handler/create_table.rs +++ b/src/frontend/src/handler/create_table.rs @@ -288,7 +288,12 @@ pub fn bind_sql_column_constraints( binder.set_clause(Some(Clause::GeneratedColumn)); let idx = binder .get_column_binding_index(table_name.clone(), &column.name.real_value())?; - let expr_impl = binder.bind_expr(expr)?; + let expr_impl = binder.bind_expr(expr).with_context(|| { + format!( + "fail to bind expression in generated column \"{}\"", + column.name.real_value() + ) + })?; check_generated_column_constraints( &column.name.real_value(), From 96de66440d5f8238d920ac76ce04cbd6247bcb21 Mon Sep 17 00:00:00 2001 From: Yuhao Su <31772373+yuhao-su@users.noreply.github.com> Date: Mon, 15 Jul 2024 11:07:59 -0500 Subject: [PATCH 12/70] feat(secret): introduce secret management (#17456) --- Cargo.lock | 57 ++---- Makefile.toml | 3 + ci/scripts/common.sh | 2 + e2e_test/sink/deltalake_rust_sink.slt | 10 +- e2e_test/sink/iceberg_sink.slt | 18 +- e2e_test/source/cdc/cdc.share_stream.slt | 24 ++- proto/catalog.proto | 5 + proto/meta.proto | 1 + src/batch/src/executor/source.rs | 12 +- src/bench/sink_bench/main.rs | 1 + src/cmd_all/src/standalone.rs | 15 +- src/common/Cargo.toml | 1 + src/common/secret/Cargo.toml | 29 +++ src/common/secret/src/encryption.rs | 80 ++++++++ src/common/secret/src/error.rs | 45 +++++ src/common/secret/src/lib.rs | 24 +++ src/common/secret/src/secret_manager.rs | 180 ++++++++++++++++++ src/common/src/catalog/external_table.rs | 5 +- src/common/src/config.rs | 7 - src/common/src/lib.rs | 5 +- src/compute/src/lib.rs | 9 + src/compute/src/observer/observer_manager.rs | 37 ++-- src/compute/src/server.rs | 7 + src/config/docs.md | 1 - src/config/example.toml | 1 - src/connector/src/error.rs | 1 + src/connector/src/lib.rs | 2 +- .../src/parser/debezium/avro_parser.rs | 4 +- src/connector/src/parser/mod.rs | 44 +++-- src/connector/src/sink/catalog/desc.rs | 8 +- src/connector/src/sink/catalog/mod.rs | 9 +- src/connector/src/sink/mod.rs | 38 +++- src/connector/src/sink/redis.rs | 2 + src/connector/src/source/base.rs | 32 +++- src/connector/src/source/cdc/external/mod.rs | 9 +- src/connector/src/source/cdc/mod.rs | 5 +- src/connector/src/source/reader/desc.rs | 10 +- src/connector/src/source/reader/fs_reader.rs | 7 +- src/connector/src/source/reader/reader.rs | 5 +- src/connector/src/with_options.rs | 67 +++++++ src/frontend/planner_test/src/lib.rs | 5 +- src/frontend/src/catalog/secret_catalog.rs | 4 +- src/frontend/src/catalog/source_catalog.rs | 16 +- .../rw_catalog/rw_iceberg_files.rs | 3 +- .../rw_catalog/rw_iceberg_snapshots.rs | 3 +- src/frontend/src/catalog/view_catalog.rs | 2 +- src/frontend/src/error.rs | 7 + .../src/handler/alter_source_with_sr.rs | 16 +- src/frontend/src/handler/create_secret.rs | 16 +- src/frontend/src/handler/create_sink.rs | 109 ++++++----- src/frontend/src/handler/create_source.rs | 92 +++++---- src/frontend/src/handler/create_table.rs | 30 ++- src/frontend/src/handler/create_table_as.rs | 7 + src/frontend/src/handler/create_view.rs | 10 +- src/frontend/src/handler/drop_secret.rs | 6 +- src/frontend/src/handler/explain.rs | 7 +- src/frontend/src/lib.rs | 11 +- src/frontend/src/observer/observer_manager.rs | 54 ++++-- src/frontend/src/optimizer/mod.rs | 5 +- .../optimizer/plan_node/batch_iceberg_scan.rs | 5 +- .../optimizer/plan_node/batch_kafka_scan.rs | 5 +- .../src/optimizer/plan_node/batch_source.rs | 5 +- .../optimizer/plan_node/stream_fs_fetch.rs | 43 +++-- .../src/optimizer/plan_node/stream_sink.rs | 21 +- .../src/optimizer/plan_node/stream_source.rs | 42 ++-- .../optimizer/plan_node/stream_source_scan.rs | 5 +- .../plan_visitor/distributed_dml_visitor.rs | 5 +- src/frontend/src/scheduler/plan_fragmenter.rs | 18 +- src/frontend/src/session.rs | 32 ++++ src/frontend/src/user/user_authentication.rs | 3 +- src/frontend/src/utils/overwrite_options.rs | 6 +- src/frontend/src/utils/with_options.rs | 112 ++++++++--- src/meta/Cargo.toml | 1 - src/meta/node/Cargo.toml | 2 + src/meta/node/src/lib.rs | 24 ++- src/meta/node/src/server.rs | 47 +++-- src/meta/service/Cargo.toml | 1 + src/meta/service/src/cloud_service.rs | 6 +- src/meta/service/src/cluster_service.rs | 3 + src/meta/service/src/notification_service.rs | 76 +++++++- src/meta/src/controller/catalog.rs | 26 ++- src/meta/src/controller/cluster.rs | 5 + src/meta/src/controller/streaming_job.rs | 8 +- src/meta/src/manager/catalog/mod.rs | 67 +++++-- src/meta/src/manager/cluster.rs | 7 +- src/meta/src/manager/env.rs | 7 +- src/meta/src/manager/metadata.rs | 11 +- src/meta/src/manager/streaming_job.rs | 3 +- src/meta/src/rpc/ddl_controller.rs | 77 ++++---- src/meta/src/stream/sink.rs | 2 +- src/meta/src/stream/source_manager.rs | 10 +- src/prost/build.rs | 3 +- src/risedevtool/common.toml | 1 + src/rpc_client/src/meta_client.rs | 6 + src/sqlparser/src/ast/mod.rs | 4 +- src/sqlparser/src/ast/value.rs | 26 ++- src/sqlparser/src/keywords.rs | 1 + src/sqlparser/src/parser.rs | 12 ++ src/stream/src/error.rs | 8 +- src/stream/src/executor/sink.rs | 3 + src/stream/src/from_proto/sink.rs | 11 +- src/stream/src/from_proto/source/fs_fetch.rs | 7 +- .../src/from_proto/source/trad_source.rs | 13 +- src/stream/src/from_proto/source_backfill.rs | 5 +- src/stream/src/from_proto/stream_cdc_scan.rs | 17 +- src/stream/src/task/barrier_manager.rs | 3 +- src/tests/simulation/src/cluster.rs | 6 + 107 files changed, 1521 insertions(+), 515 deletions(-) create mode 100644 src/common/secret/Cargo.toml create mode 100644 src/common/secret/src/encryption.rs create mode 100644 src/common/secret/src/error.rs create mode 100644 src/common/secret/src/lib.rs create mode 100644 src/common/secret/src/secret_manager.rs diff --git a/Cargo.lock b/Cargo.lock index 70f7d40c510e..68a9060f529d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,22 +77,6 @@ dependencies = [ "aes", ] -[[package]] -name = "aes-siv" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e08d0cdb774acd1e4dac11478b1a0c0d203134b2aab0ba25eb430de9b18f8b9" -dependencies = [ - "aead", - "aes", - "cipher", - "cmac", - "ctr", - "dbl", - "digest", - "zeroize", -] - [[package]] name = "ahash" version = "0.7.8" @@ -2811,17 +2795,6 @@ dependencies = [ "cc", ] -[[package]] -name = "cmac" -version = "0.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8543454e3c3f5126effff9cd44d562af4e31fb8ce1cc0d3dcd8f084515dbc1aa" -dependencies = [ - "cipher", - "dbl", - "digest", -] - [[package]] name = "cmake" version = "0.1.50" @@ -3917,15 +3890,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "dbl" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd2735a791158376708f9347fe8faba9667589d82427ef3aed6794a8981de3d9" -dependencies = [ - "generic-array", -] - [[package]] name = "debugid" version = "0.8.0" @@ -10891,6 +10855,7 @@ dependencies = [ "risingwave_common_estimate_size", "risingwave_common_metrics", "risingwave_common_proc_macro", + "risingwave_common_secret", "risingwave_error", "risingwave_license", "risingwave_pb", @@ -11004,6 +10969,22 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "risingwave_common_secret" +version = "1.11.0-alpha" +dependencies = [ + "aes-gcm", + "anyhow", + "bincode 1.3.3", + "parking_lot 0.12.1", + "prost 0.12.1", + "risingwave_pb", + "serde", + "thiserror", + "thiserror-ext", + "tracing", +] + [[package]] name = "risingwave_common_service" version = "1.11.0-alpha" @@ -11747,7 +11728,6 @@ dependencies = [ name = "risingwave_meta" version = "1.11.0-alpha" dependencies = [ - "aes-siv", "anyhow", "arc-swap", "assert_matches", @@ -11870,8 +11850,10 @@ version = "1.11.0-alpha" dependencies = [ "anyhow", "clap", + "educe", "either", "futures", + "hex", "itertools 0.12.1", "madsim-etcd-client", "madsim-tokio", @@ -11907,6 +11889,7 @@ dependencies = [ "itertools 0.12.1", "madsim-tokio", "madsim-tonic", + "prost 0.12.1", "rand", "regex", "risingwave_common", diff --git a/Makefile.toml b/Makefile.toml index 3d3e1c7e20fc..aa3d38baf59b 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -28,6 +28,8 @@ env_scripts = [ set_env ENABLE_TELEMETRY "false" set_env RW_TELEMETRY_TYPE "test" + set_env RW_SECRET_STORE_PRIVATE_KEY_HEX "0123456789abcdef" + set_env RW_TEMP_SECRET_FILE_DIR "${PREFIX_SECRET}" is_sanitizer_enabled = get_env ENABLE_SANITIZER is_hdfs_backend = get_env ENABLE_HDFS @@ -144,6 +146,7 @@ rm -rf "${PREFIX_DATA}" rm -rf "${PREFIX_LOG}" rm -rf "${PREFIX_CONFIG}" rm -rf "${PREFIX_PROFILING}" +rm -rf "${PREFIX_SECRET}" ''' [tasks.reset-rw] diff --git a/ci/scripts/common.sh b/ci/scripts/common.sh index 3b31afeef253..ac64d1a7a89c 100755 --- a/ci/scripts/common.sh +++ b/ci/scripts/common.sh @@ -15,6 +15,8 @@ export MCLI_DOWNLOAD_BIN=https://rw-ci-deps-dist.s3.amazonaws.com/mc export GCLOUD_DOWNLOAD_TGZ=https://rw-ci-deps-dist.s3.amazonaws.com/google-cloud-cli-475.0.0-linux-x86_64.tar.gz export NEXTEST_HIDE_PROGRESS_BAR=true export RW_TELEMETRY_TYPE=test +export RW_SECRET_STORE_PRIVATE_KEY_HEX="0123456789abcdef" + unset LANG if [ -n "${BUILDKITE_COMMIT:-}" ]; then export GIT_SHA=$BUILDKITE_COMMIT diff --git a/e2e_test/sink/deltalake_rust_sink.slt b/e2e_test/sink/deltalake_rust_sink.slt index a658520cdeb4..74dca623a9d0 100644 --- a/e2e_test/sink/deltalake_rust_sink.slt +++ b/e2e_test/sink/deltalake_rust_sink.slt @@ -4,6 +4,11 @@ CREATE TABLE t6 (v1 int primary key, v2 smallint, v3 bigint, v4 real, v5 float, statement ok CREATE MATERIALIZED VIEW mv6 AS SELECT * FROM t6; +statement ok +CREATE SECRET deltalake_s3_secret_key WITH ( + backend = 'meta' +) as 'hummockadmin'; + statement ok create sink s6 as select * from mv6 with ( @@ -12,7 +17,7 @@ with ( force_append_only = 'true', location = 's3a://deltalake/deltalake-test', s3.access.key = 'hummockadmin', - s3.secret.key = 'hummockadmin', + s3.secret.key = secret deltalake_s3_secret_key, s3.endpoint = 'http://127.0.0.1:9301' ); @@ -25,6 +30,9 @@ FLUSH; statement ok DROP SINK s6; +statement ok +DROP SECRET deltalake_s3_secret_key; + statement ok DROP MATERIALIZED VIEW mv6; diff --git a/e2e_test/sink/iceberg_sink.slt b/e2e_test/sink/iceberg_sink.slt index a7f7567075e9..08fa23dd839c 100644 --- a/e2e_test/sink/iceberg_sink.slt +++ b/e2e_test/sink/iceberg_sink.slt @@ -4,6 +4,16 @@ CREATE TABLE t6 (v1 int primary key, v2 bigint, v3 varchar); statement ok CREATE MATERIALIZED VIEW mv6 AS SELECT * FROM t6; +statement ok +CREATE SECRET iceberg_s3_access_key WITH ( + backend = 'meta' +) as 'hummockadmin'; + +statement ok +CREATE SECRET iceberg_s3_secret_key WITH ( + backend = 'meta' +) as 'hummockadmin'; + statement ok CREATE SINK s6 AS select mv6.v1 as v1, mv6.v2 as v2, mv6.v3 as v3 from mv6 WITH ( connector = 'iceberg', @@ -11,8 +21,8 @@ CREATE SINK s6 AS select mv6.v1 as v1, mv6.v2 as v2, mv6.v3 as v3 from mv6 WITH primary_key = 'v1', warehouse.path = 's3a://iceberg', s3.endpoint = 'http://127.0.0.1:9301', - s3.access.key = 'hummockadmin', - s3.secret.key = 'hummockadmin', + s3.access.key = secret iceberg_s3_access_key, + s3.secret.key = secret iceberg_s3_secret_key, s3.region = 'us-east-1', catalog.name = 'demo', catalog.type = 'storage', @@ -25,8 +35,8 @@ CREATE SOURCE iceberg_demo_source WITH ( connector = 'iceberg', warehouse.path = 's3a://iceberg', s3.endpoint = 'http://127.0.0.1:9301', - s3.access.key = 'hummockadmin', - s3.secret.key = 'hummockadmin', + s3.access.key = secret iceberg_s3_access_key, + s3.secret.key = secret iceberg_s3_secret_key, s3.region = 'us-east-1', catalog.name = 'demo', catalog.type = 'storage', diff --git a/e2e_test/source/cdc/cdc.share_stream.slt b/e2e_test/source/cdc/cdc.share_stream.slt index 3dc26d98c628..33c9c16a776c 100644 --- a/e2e_test/source/cdc/cdc.share_stream.slt +++ b/e2e_test/source/cdc/cdc.share_stream.slt @@ -11,6 +11,11 @@ mysql --protocol=tcp -u root mytest < e2e_test/source/cdc/mysql_create.sql system ok mysql --protocol=tcp -u root mytest < e2e_test/source/cdc/mysql_init_data.sql +statement ok +create secret mysql_pwd with ( + backend = 'meta' +) as '${MYSQL_PWD:}'; + # create a cdc source job, which format fixed to `FORMAT PLAIN ENCODE JSON` statement ok create source mysql_mytest with ( @@ -18,7 +23,7 @@ create source mysql_mytest with ( hostname = '${MYSQL_HOST:localhost}', port = '${MYSQL_TCP_PORT:8306}', username = 'rwcdc', - password = '${MYSQL_PWD:}', + password = secret mysql_pwd, database.name = 'mytest', server.id = '5601' ); @@ -48,6 +53,9 @@ from mysql_mytest table 'mytest.products'; # sleep to ensure (default,'Milk','Milk is a white liquid food') is consumed from Debezium message instead of backfill. sleep 10s +statement error Permission denied +drop secret mysql_pwd; + system ok mysql --protocol=tcp -u root mytest -e "INSERT INTO products VALUES (default,'Milk','Milk is a white liquid food'); INSERT INTO orders VALUES (default, '2023-11-28 15:08:22', 'Bob', 10.52, 100, false);" @@ -190,13 +198,23 @@ SELECT c_tinyint, c_smallint, c_mediumint, c_integer, c_bigint, c_decimal, c_flo -128 -32767 -8388608 -2147483647 -9223372036854775807 -10 -10000 -10000 a b 1001-01-01 00:00:00 1998-01-01 00:00:00 1970-01-01 00:00:01+00:00 NULL NULL -8388608 -2147483647 9223372036854775806 -10 -10000 -10000 c d 1001-01-01 NULL 2000-01-01 00:00:00 NULL +statement ok +create secret pg_pwd with ( + backend = 'meta' +) as '${PGPASSWORD:}'; + +statement ok +create secret pg_username with ( + backend = 'meta' +) as '${PGUSER:$USER}'; + statement ok create source pg_source with ( connector = 'postgres-cdc', hostname = '${PGHOST:localhost}', port = '${PGPORT:5432}', - username = '${PGUSER:$USER}', - password = '${PGPASSWORD:}', + username = secret pg_username, + password = secret pg_pwd, database.name = '${PGDATABASE:postgres}', slot.name = 'pg_slot' ); diff --git a/proto/catalog.proto b/proto/catalog.proto index 550a7bdf044b..b18275e32d4c 100644 --- a/proto/catalog.proto +++ b/proto/catalog.proto @@ -463,3 +463,8 @@ message Secret { uint32 owner = 5; uint32 schema_id = 6; } + +message OptionsWithSecret { + map options = 1; + map secret_refs = 2; +} diff --git a/proto/meta.proto b/proto/meta.proto index 328488759018..9ceb2b0143fc 100644 --- a/proto/meta.proto +++ b/proto/meta.proto @@ -316,6 +316,7 @@ message AddWorkerNodeResponse { reserved "system_params"; common.Status status = 1; optional uint32 node_id = 2; + string cluster_id = 4; } message ActivateWorkerNodeRequest { diff --git a/src/batch/src/executor/source.rs b/src/batch/src/executor/source.rs index c71f94ae36b0..51d8da9d14d9 100644 --- a/src/batch/src/executor/source.rs +++ b/src/batch/src/executor/source.rs @@ -27,6 +27,7 @@ use risingwave_connector::source::reader::reader::SourceReader; use risingwave_connector::source::{ ConnectorProperties, SourceColumnDesc, SourceContext, SourceCtrlOpts, SplitImpl, SplitMetaData, }; +use risingwave_connector::WithOptionsSecResolved; use risingwave_pb::batch_plan::plan_node::NodeBody; use super::Executor; @@ -64,12 +65,15 @@ impl BoxedExecutorBuilder for SourceExecutor { )?; // prepare connector source - let source_props = source_node.with_properties.clone(); - let config = - ConnectorProperties::extract(source_props, false).map_err(BatchError::connector)?; + let options_with_secret = WithOptionsSecResolved::new( + source_node.with_properties.clone(), + source_node.secret_refs.clone(), + ); + let config = ConnectorProperties::extract(options_with_secret.clone(), false) + .map_err(BatchError::connector)?; let info = source_node.get_info().unwrap(); - let parser_config = SpecificParserConfig::new(info, &source_node.with_properties)?; + let parser_config = SpecificParserConfig::new(info, &options_with_secret)?; let columns: Vec<_> = source_node .columns diff --git a/src/bench/sink_bench/main.rs b/src/bench/sink_bench/main.rs index 9eea1c94655b..94c0427e5135 100644 --- a/src/bench/sink_bench/main.rs +++ b/src/bench/sink_bench/main.rs @@ -469,6 +469,7 @@ fn mock_from_legacy_type( format, encode: SinkEncode::Json, options: Default::default(), + secret_refs: Default::default(), key_encode: None, })) } else { diff --git a/src/cmd_all/src/standalone.rs b/src/cmd_all/src/standalone.rs index 325f2f8ff395..26d4aefeb56d 100644 --- a/src/cmd_all/src/standalone.rs +++ b/src/cmd_all/src/standalone.rs @@ -293,17 +293,17 @@ mod test { fn test_parse_opt_args() { // Test parsing into standalone-level opts. let raw_opts = " ---compute-opts=--listen-addr 127.0.0.1:8000 --total-memory-bytes 34359738368 --parallelism 10 ---meta-opts=--advertise-addr 127.0.0.1:9999 --data-directory \"some path with spaces\" --listen-addr 127.0.0.1:8001 --etcd-password 1234 ---frontend-opts=--config-path=src/config/original.toml +--compute-opts=--listen-addr 127.0.0.1:8000 --total-memory-bytes 34359738368 --parallelism 10 --temp-secret-file-dir ./compute/secrets/ +--meta-opts=--advertise-addr 127.0.0.1:9999 --data-directory \"some path with spaces\" --listen-addr 127.0.0.1:8001 --etcd-password 1234 --temp-secret-file-dir ./meta/secrets/ +--frontend-opts=--config-path=src/config/original.toml --temp-secret-file-dir ./frontend/secrets/ --prometheus-listener-addr=127.0.0.1:1234 --config-path=src/config/test.toml "; let actual = StandaloneOpts::parse_from(raw_opts.lines()); let opts = StandaloneOpts { - compute_opts: Some("--listen-addr 127.0.0.1:8000 --total-memory-bytes 34359738368 --parallelism 10".into()), - meta_opts: Some("--advertise-addr 127.0.0.1:9999 --data-directory \"some path with spaces\" --listen-addr 127.0.0.1:8001 --etcd-password 1234".into()), - frontend_opts: Some("--config-path=src/config/original.toml".into()), + compute_opts: Some("--listen-addr 127.0.0.1:8000 --total-memory-bytes 34359738368 --parallelism 10 --temp-secret-file-dir ./compute/secrets/".into()), + meta_opts: Some("--advertise-addr 127.0.0.1:9999 --data-directory \"some path with spaces\" --listen-addr 127.0.0.1:8001 --etcd-password 1234 --temp-secret-file-dir ./meta/secrets/".into()), + frontend_opts: Some("--config-path=src/config/original.toml --temp-secret-file-dir ./frontend/secrets/".into()), compactor_opts: None, prometheus_listener_addr: Some("127.0.0.1:1234".into()), config_path: Some("src/config/test.toml".into()), @@ -354,6 +354,7 @@ mod test { heap_profiling_dir: None, dangerous_max_idle_secs: None, connector_rpc_endpoint: None, + temp_secret_file_dir: "./meta/secrets/", }, ), compute_opts: Some( @@ -377,6 +378,7 @@ mod test { async_stack_trace: None, heap_profiling_dir: None, connector_rpc_endpoint: None, + temp_secret_file_dir: "./compute/secrets/", }, ), frontend_opts: Some( @@ -393,6 +395,7 @@ mod test { config_path: "src/config/test.toml", metrics_level: None, enable_barrier_read: None, + temp_secret_file_dir: "./frontend/secrets/", }, ), compactor_opts: None, diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index 3ae8fb38fcd5..86e229ddb7b9 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -86,6 +86,7 @@ risingwave-fields-derive = { path = "./fields-derive" } risingwave_common_estimate_size = { workspace = true } risingwave_common_metrics = { path = "./metrics" } risingwave_common_proc_macro = { workspace = true } +risingwave_common_secret = { path = "./secret" } risingwave_error = { workspace = true } risingwave_license = { workspace = true } risingwave_pb = { workspace = true } diff --git a/src/common/secret/Cargo.toml b/src/common/secret/Cargo.toml new file mode 100644 index 000000000000..4b698a737cb5 --- /dev/null +++ b/src/common/secret/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "risingwave_common_secret" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +keywords = { workspace = true } +license = { workspace = true } +repository = { workspace = true } + +[package.metadata.cargo-machete] +ignored = ["workspace-hack"] + +[package.metadata.cargo-udeps.ignore] +normal = ["workspace-hack"] + +[dependencies] +aes-gcm = "0.10" +anyhow = "1" +bincode = "1" +parking_lot = { workspace = true } +prost = { workspace = true } +risingwave_pb = { workspace = true } +serde = { version = "1" } +thiserror = "1" +thiserror-ext = { workspace = true } +tracing = "0.1" + +[lints] +workspace = true diff --git a/src/common/secret/src/encryption.rs b/src/common/secret/src/encryption.rs new file mode 100644 index 000000000000..a6c0253fb1f9 --- /dev/null +++ b/src/common/secret/src/encryption.rs @@ -0,0 +1,80 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use aes_gcm::aead::generic_array::GenericArray; +use aes_gcm::aead::{Aead, AeadCore, KeyInit, OsRng}; +use aes_gcm::Aes128Gcm; +use serde::{Deserialize, Serialize}; + +use super::{SecretError, SecretResult}; + +#[derive(Deserialize, Serialize)] +pub struct SecretEncryption { + nonce: [u8; 12], + ciphertext: Vec, +} + +impl SecretEncryption { + pub fn encrypt(key: &[u8], plaintext: &[u8]) -> SecretResult { + let encrypt_key = Self::fill_key(key); + let nonce_array = Aes128Gcm::generate_nonce(&mut OsRng); + let cipher = Aes128Gcm::new(encrypt_key.as_slice().into()); + let ciphertext = cipher + .encrypt(&nonce_array, plaintext) + .map_err(|_| SecretError::AesError)?; + Ok(Self { + nonce: nonce_array.into(), + ciphertext, + }) + } + + pub fn decrypt(&self, key: &[u8]) -> SecretResult> { + let decrypt_key = Self::fill_key(key); + let nonce_array = GenericArray::from_slice(&self.nonce); + let cipher = Aes128Gcm::new(decrypt_key.as_slice().into()); + let plaintext = cipher + .decrypt(nonce_array, self.ciphertext.as_slice()) + .map_err(|_| SecretError::AesError)?; + Ok(plaintext) + } + + fn fill_key(key: &[u8]) -> Vec { + let mut k = key[..(std::cmp::min(key.len(), 16))].to_vec(); + k.resize_with(16, || 0); + k + } + + pub fn serialize(&self) -> SecretResult> { + let res = bincode::serialize(&self)?; + Ok(res) + } + + pub fn deserialize(data: &[u8]) -> SecretResult { + let res = bincode::deserialize(data)?; + Ok(res) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_secret_encryption_decyption() { + let key = b"0123456789abcdef"; + let plaintext = "Hello, world!".as_bytes(); + let secret = SecretEncryption::encrypt(key, plaintext).unwrap(); + let decrypted = secret.decrypt(key).unwrap(); + assert_eq!(plaintext, decrypted.as_slice()); + } +} diff --git a/src/common/secret/src/error.rs b/src/common/secret/src/error.rs new file mode 100644 index 000000000000..6db7b7f927e5 --- /dev/null +++ b/src/common/secret/src/error.rs @@ -0,0 +1,45 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub use anyhow::anyhow; +use thiserror::Error; +use thiserror_ext::Construct; + +use super::SecretId; + +pub type SecretResult = Result; + +#[derive(Error, Debug, Construct)] +pub enum SecretError { + #[error("secret not found: {0}")] + ItemNotFound(SecretId), + + #[error("decode utf8 error: {0}")] + DecodeUtf8Error(#[from] std::string::FromUtf8Error), + + #[error("I/O error: {0}")] + IoError(#[from] std::io::Error), + + #[error("unspecified secret ref type: {0}")] + UnspecifiedRefType(SecretId), + + #[error("fail to encrypt/decrypt secret")] + AesError, + + #[error("ser/de proto message error: {0}")] + ProtoError(#[from] bincode::Error), + + #[error(transparent)] + Internal(#[from] anyhow::Error), +} diff --git a/src/common/secret/src/lib.rs b/src/common/secret/src/lib.rs new file mode 100644 index 000000000000..8ac065e5ea18 --- /dev/null +++ b/src/common/secret/src/lib.rs @@ -0,0 +1,24 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(lazy_cell)] + +type SecretId = u32; + +mod secret_manager; +pub use secret_manager::*; +mod encryption; +pub use encryption::*; +mod error; +pub use error::*; diff --git a/src/common/secret/src/secret_manager.rs b/src/common/secret/src/secret_manager.rs new file mode 100644 index 000000000000..7483ac4d2508 --- /dev/null +++ b/src/common/secret/src/secret_manager.rs @@ -0,0 +1,180 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{BTreeMap, HashMap}; +use std::fs::File; +use std::io::Write; +use std::path::PathBuf; + +use anyhow::{anyhow, Context}; +use parking_lot::RwLock; +use prost::Message; +use risingwave_pb::catalog::PbSecret; +use risingwave_pb::secret::secret_ref::RefAsType; +use risingwave_pb::secret::PbSecretRef; +use thiserror_ext::AsReport; + +use super::error::{SecretError, SecretResult}; +use super::SecretId; + +static INSTANCE: std::sync::OnceLock = std::sync::OnceLock::new(); + +#[derive(Debug)] +pub struct LocalSecretManager { + secrets: RwLock>>, + /// The local directory used to write secrets into file, so that it can be passed into some libararies + secret_file_dir: PathBuf, +} + +impl LocalSecretManager { + /// Initialize the secret manager with the given temp file path, cluster id, and encryption key. + /// # Panics + /// Panics if fail to create the secret file directory. + pub fn init(temp_file_dir: String, cluster_id: String, worker_id: u32) { + // use `get_or_init` to handle concurrent initialization in single node mode. + INSTANCE.get_or_init(|| { + let secret_file_dir = PathBuf::from(temp_file_dir) + .join(cluster_id) + .join(worker_id.to_string()); + std::fs::remove_dir_all(&secret_file_dir).ok(); + std::fs::create_dir_all(&secret_file_dir).unwrap(); + + Self { + secrets: RwLock::new(HashMap::new()), + secret_file_dir, + } + }); + } + + /// Get the global secret manager instance. + /// # Panics + /// Panics if the secret manager is not initialized. + pub fn global() -> &'static LocalSecretManager { + // Initialize the secret manager for unit tests. + #[cfg(debug_assertions)] + LocalSecretManager::init("./tmp".to_string(), "test_cluster".to_string(), 0); + + INSTANCE.get().unwrap() + } + + pub fn add_secret(&self, secret_id: SecretId, secret: Vec) { + let mut secret_guard = self.secrets.write(); + secret_guard.insert(secret_id, secret); + } + + pub fn init_secrets(&self, secrets: Vec) { + let mut secret_guard = self.secrets.write(); + // Reset the secrets + secret_guard.clear(); + // Error should only occurs when running simulation tests when we have multiple nodes + // in 1 process and can fail . + std::fs::remove_dir_all(&self.secret_file_dir) + .inspect_err(|e| { + tracing::error!( + error = %e.as_report(), + path = %self.secret_file_dir.to_string_lossy(), + "Failed to remove secret directory") + }) + .ok(); + std::fs::create_dir_all(&self.secret_file_dir).unwrap(); + for secret in secrets { + secret_guard.insert(secret.id, secret.value); + } + } + + pub fn get_secret(&self, secret_id: SecretId) -> Option> { + let secret_guard = self.secrets.read(); + secret_guard.get(&secret_id).cloned() + } + + pub fn remove_secret(&self, secret_id: SecretId) { + let mut secret_guard = self.secrets.write(); + secret_guard.remove(&secret_id); + self.remove_secret_file_if_exist(&secret_id); + } + + pub fn fill_secrets( + &self, + mut options: BTreeMap, + secret_refs: BTreeMap, + ) -> SecretResult> { + let secret_guard = self.secrets.read(); + for (option_key, secret_ref) in secret_refs { + let secret_id = secret_ref.secret_id; + let pb_secret_bytes = secret_guard + .get(&secret_id) + .ok_or(SecretError::ItemNotFound(secret_id))?; + let secret_value_bytes = Self::get_secret_value(pb_secret_bytes)?; + match secret_ref.ref_as() { + RefAsType::Text => { + // We converted the secret string from sql to bytes using `as_bytes` in frontend. + // So use `from_utf8` here to convert it back to string. + options.insert(option_key, String::from_utf8(secret_value_bytes.clone())?); + } + RefAsType::File => { + let path_str = + self.get_or_init_secret_file(secret_id, secret_value_bytes.clone())?; + options.insert(option_key, path_str); + } + RefAsType::Unspecified => { + return Err(SecretError::UnspecifiedRefType(secret_id)); + } + } + } + Ok(options) + } + + /// Get the secret file for the given secret id and return the path string. If the file does not exist, create it. + /// WARNING: This method should be called only when the secret manager is locked. + fn get_or_init_secret_file( + &self, + secret_id: SecretId, + secret_bytes: Vec, + ) -> SecretResult { + let path = self.secret_file_dir.join(secret_id.to_string()); + if !path.exists() { + let mut file = File::create(&path)?; + file.write_all(&secret_bytes)?; + file.sync_all()?; + } + Ok(path.to_string_lossy().to_string()) + } + + /// WARNING: This method should be called only when the secret manager is locked. + fn remove_secret_file_if_exist(&self, secret_id: &SecretId) { + let path = self.secret_file_dir.join(secret_id.to_string()); + if path.exists() { + std::fs::remove_file(&path) + .inspect_err(|e| { + tracing::error!( + error = %e.as_report(), + path = %path.to_string_lossy(), + "Failed to remove secret file") + }) + .ok(); + } + } + + fn get_secret_value(pb_secret_bytes: &[u8]) -> SecretResult> { + let pb_secret = risingwave_pb::secret::Secret::decode(pb_secret_bytes) + .context("failed to decode secret")?; + let secret_value = match pb_secret.get_secret_backend().unwrap() { + risingwave_pb::secret::secret::SecretBackend::Meta(backend) => backend.value.clone(), + risingwave_pb::secret::secret::SecretBackend::HashicorpVault(_) => { + return Err(anyhow!("hashicorp_vault backend is not implemented yet").into()) + } + }; + Ok(secret_value) + } +} diff --git a/src/common/src/catalog/external_table.rs b/src/common/src/catalog/external_table.rs index 797d789b85fc..adc49cac4da9 100644 --- a/src/common/src/catalog/external_table.rs +++ b/src/common/src/catalog/external_table.rs @@ -15,6 +15,7 @@ use std::collections::{BTreeMap, HashMap}; use risingwave_pb::plan_common::ExternalTableDesc; +use risingwave_pb::secret::PbSecretRef; use super::{ColumnDesc, ColumnId, TableId}; use crate::util::sort_util::ColumnOrder; @@ -42,6 +43,8 @@ pub struct CdcTableDesc { /// properties will be passed into the `StreamScanNode` pub connect_properties: BTreeMap, + /// Secret refs + pub secret_refs: BTreeMap, } impl CdcTableDesc { @@ -65,7 +68,7 @@ impl CdcTableDesc { table_name: self.external_table_name.clone(), stream_key: self.stream_key.iter().map(|k| *k as _).collect(), connect_properties: self.connect_properties.clone(), - secret_refs: Default::default(), + secret_refs: self.secret_refs.clone(), } } diff --git a/src/common/src/config.rs b/src/common/src/config.rs index 5c9de58d898a..8d6bd7f34b38 100644 --- a/src/common/src/config.rs +++ b/src/common/src/config.rs @@ -383,9 +383,6 @@ pub struct MetaConfig { /// Whether compactor should rewrite row to remove dropped column. #[serde(default = "default::meta::enable_dropped_column_reclaim")] pub enable_dropped_column_reclaim: bool, - - #[serde(default = "default::meta::secret_store_private_key")] - pub secret_store_private_key: Vec, } #[derive(Copy, Clone, Debug, Default)] @@ -1425,10 +1422,6 @@ pub mod default { pub fn enable_dropped_column_reclaim() -> bool { false } - - pub fn secret_store_private_key() -> Vec { - "demo-secret-private-key".as_bytes().to_vec() - } } pub mod server { diff --git a/src/common/src/lib.rs b/src/common/src/lib.rs index 0681554dbb99..467b313720f2 100644 --- a/src/common/src/lib.rs +++ b/src/common/src/lib.rs @@ -80,7 +80,10 @@ pub use risingwave_common_metrics::{ register_guarded_histogram_vec_with_registry, register_guarded_int_counter_vec_with_registry, register_guarded_int_gauge_vec_with_registry, }; -pub use {risingwave_common_metrics as metrics, risingwave_license as license}; +pub use { + risingwave_common_metrics as metrics, risingwave_common_secret as secret, + risingwave_license as license, +}; pub mod lru; pub mod opts; pub mod range; diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index 7a3d2be65d1d..cee364ceb26d 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -136,6 +136,15 @@ pub struct ComputeNodeOpts { #[deprecated = "connector node has been deprecated."] #[clap(long, hide = true, env = "RW_CONNECTOR_RPC_ENDPOINT")] pub connector_rpc_endpoint: Option, + + /// The path of the temp secret file directory. + #[clap( + long, + hide = true, + env = "RW_TEMP_SECRET_FILE_DIR", + default_value = "./secrets" + )] + pub temp_secret_file_dir: String, } impl risingwave_common::opts::Opts for ComputeNodeOpts { diff --git a/src/compute/src/observer/observer_manager.rs b/src/compute/src/observer/observer_manager.rs index 62e2d699668f..c028c1e85161 100644 --- a/src/compute/src/observer/observer_manager.rs +++ b/src/compute/src/observer/observer_manager.rs @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use risingwave_common::secret::LocalSecretManager; use risingwave_common::system_param::local_manager::LocalSystemParamsManagerRef; use risingwave_common_service::ObserverState; -use risingwave_pb::meta::subscribe_response::Info; +use risingwave_pb::meta::subscribe_response::{Info, Operation}; use risingwave_pb::meta::SubscribeResponse; pub struct ComputeObserverNode { @@ -27,19 +28,33 @@ impl ObserverState for ComputeObserverNode { } fn handle_notification(&mut self, resp: SubscribeResponse) { - let Some(info) = resp.info.as_ref() else { - return; - }; - - match info.to_owned() { - Info::SystemParams(p) => self.system_params_manager.try_set_params(p), - _ => { - panic!("error type notification"); + if let Some(info) = resp.info.as_ref() { + match info.to_owned() { + Info::SystemParams(p) => self.system_params_manager.try_set_params(p), + Info::Secret(s) => match resp.operation() { + Operation::Add => { + LocalSecretManager::global().add_secret(s.id, s.value); + } + Operation::Delete => { + LocalSecretManager::global().remove_secret(s.id); + } + _ => { + panic!("error type notification"); + } + }, + _ => { + panic!("error type notification"); + } } - } + }; } - fn handle_initialization_notification(&mut self, _resp: SubscribeResponse) {} + fn handle_initialization_notification(&mut self, resp: SubscribeResponse) { + let Some(Info::Snapshot(snapshot)) = resp.info else { + unreachable!(); + }; + LocalSecretManager::global().init_secrets(snapshot.secrets); + } } impl ComputeObserverNode { diff --git a/src/compute/src/server.rs b/src/compute/src/server.rs index d7dcbd5146c3..8b1ea60cc9b1 100644 --- a/src/compute/src/server.rs +++ b/src/compute/src/server.rs @@ -29,6 +29,7 @@ use risingwave_common::config::{ }; use risingwave_common::lru::init_global_sequencer_args; use risingwave_common::monitor::{RouterExt, TcpConfig}; +use risingwave_common::secret::LocalSecretManager; use risingwave_common::system_param::local_manager::LocalSystemParamsManager; use risingwave_common::system_param::reader::SystemParamsRead; use risingwave_common::telemetry::manager::TelemetryManager; @@ -215,6 +216,12 @@ pub async fn compute_node_serve( .await .unwrap(); + LocalSecretManager::init( + opts.temp_secret_file_dir, + meta_client.cluster_id().to_string(), + worker_id, + ); + // Initialize observer manager. let system_params_manager = Arc::new(LocalSystemParamsManager::new(system_params.clone())); let compute_observer_node = ComputeObserverNode::new(system_params_manager.clone()); diff --git a/src/config/docs.md b/src/config/docs.md index 5b0789325945..563d9c5e1cbb 100644 --- a/src/config/docs.md +++ b/src/config/docs.md @@ -60,7 +60,6 @@ This page is automatically generated by `./risedev generate-example-config` | periodic_split_compact_group_interval_sec | | 10 | | periodic_tombstone_reclaim_compaction_interval_sec | | 600 | | periodic_ttl_reclaim_compaction_interval_sec | Schedule `ttl_reclaim` compaction for all compaction groups with this interval. | 1800 | -| secret_store_private_key | | [100, 101, 109, 111, 45, 115, 101, 99, 114, 101, 116, 45, 112, 114, 105, 118, 97, 116, 101, 45, 107, 101, 121] | | split_group_size_limit | | 68719476736 | | table_write_throughput_threshold | The threshold of write throughput to trigger a group split. Increase this configuration value to avoid split too many groups with few data write. | 16777216 | | unrecognized | | | diff --git a/src/config/example.toml b/src/config/example.toml index e5c4b557d3a2..1623bc114ccd 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -57,7 +57,6 @@ compact_task_table_size_partition_threshold_high = 536870912 event_log_enabled = true event_log_channel_max_size = 10 enable_dropped_column_reclaim = false -secret_store_private_key = [100, 101, 109, 111, 45, 115, 101, 99, 114, 101, 116, 45, 112, 114, 105, 118, 97, 116, 101, 45, 107, 101, 121] [meta.compaction_config] max_bytes_for_level_base = 536870912 diff --git a/src/connector/src/error.rs b/src/connector/src/error.rs index 78a39b735b41..076163469316 100644 --- a/src/connector/src/error.rs +++ b/src/connector/src/error.rs @@ -72,6 +72,7 @@ def_anyhow_newtype! { mongodb::error::Error => "Mongodb error", openssl::error::ErrorStack => "OpenSSL error", + risingwave_common::secret::SecretError => "Secret error", } pub type ConnectorResult = std::result::Result; diff --git a/src/connector/src/lib.rs b/src/connector/src/lib.rs index 8c0ade401cb4..85bb7740ae9f 100644 --- a/src/connector/src/lib.rs +++ b/src/connector/src/lib.rs @@ -58,7 +58,7 @@ pub use paste::paste; pub use risingwave_jni_core::{call_method, call_static_method, jvm_runtime}; mod with_options; -pub use with_options::WithPropertiesExt; +pub use with_options::{WithOptionsSecResolved, WithPropertiesExt}; #[cfg(test)] mod with_options_test; diff --git a/src/connector/src/parser/debezium/avro_parser.rs b/src/connector/src/parser/debezium/avro_parser.rs index 2ddfc77073c8..dbf685d6b620 100644 --- a/src/connector/src/parser/debezium/avro_parser.rs +++ b/src/connector/src/parser/debezium/avro_parser.rs @@ -194,6 +194,7 @@ mod tests { use super::*; use crate::parser::{DebeziumParser, SourceStreamChunkBuilder, SpecificParserConfig}; use crate::source::{SourceColumnDesc, SourceContext}; + use crate::WithOptionsSecResolved; const DEBEZIUM_AVRO_DATA: &[u8] = b"\x00\x00\x00\x00\x06\x00\x02\xd2\x0f\x0a\x53\x61\x6c\x6c\x79\x0c\x54\x68\x6f\x6d\x61\x73\x2a\x73\x61\x6c\x6c\x79\x2e\x74\x68\x6f\x6d\x61\x73\x40\x61\x63\x6d\x65\x2e\x63\x6f\x6d\x16\x32\x2e\x31\x2e\x32\x2e\x46\x69\x6e\x61\x6c\x0a\x6d\x79\x73\x71\x6c\x12\x64\x62\x73\x65\x72\x76\x65\x72\x31\xc0\xb4\xe8\xb7\xc9\x61\x00\x30\x66\x69\x72\x73\x74\x5f\x69\x6e\x5f\x64\x61\x74\x61\x5f\x63\x6f\x6c\x6c\x65\x63\x74\x69\x6f\x6e\x12\x69\x6e\x76\x65\x6e\x74\x6f\x72\x79\x00\x02\x12\x63\x75\x73\x74\x6f\x6d\x65\x72\x73\x00\x00\x20\x6d\x79\x73\x71\x6c\x2d\x62\x69\x6e\x2e\x30\x30\x30\x30\x30\x33\x8c\x06\x00\x00\x00\x02\x72\x02\x92\xc3\xe8\xb7\xc9\x61\x00"; @@ -399,7 +400,8 @@ mod tests { row_encode: PbEncodeType::Avro.into(), ..Default::default() }; - let parser_config = SpecificParserConfig::new(&info, &props)?; + let parser_config = + SpecificParserConfig::new(&info, &WithOptionsSecResolved::without_secrets(props))?; let config = DebeziumAvroParserConfig::new(parser_config.clone().encoding_config).await?; let columns = config .map_to_columns()? diff --git a/src/connector/src/parser/mod.rs b/src/connector/src/parser/mod.rs index e88998f00596..fcf945184458 100644 --- a/src/connector/src/parser/mod.rs +++ b/src/connector/src/parser/mod.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; use std::fmt::Debug; use std::sync::LazyLock; @@ -31,6 +30,7 @@ use risingwave_common::bail; use risingwave_common::catalog::{KAFKA_TIMESTAMP_COLUMN_NAME, TABLE_NAME_COLUMN_NAME}; use risingwave_common::log::LogSuppresser; use risingwave_common::metrics::GLOBAL_ERROR_METRICS; +use risingwave_common::secret::LocalSecretManager; use risingwave_common::types::{Datum, DatumCow, DatumRef, ScalarRefImpl}; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::tracing::InstrumentStream; @@ -67,6 +67,7 @@ use crate::source::{ extract_source_struct, BoxSourceStream, ChunkSourceStream, SourceColumnDesc, SourceColumnType, SourceContext, SourceContextRef, SourceEncode, SourceFormat, SourceMessage, SourceMeta, }; +use crate::with_options::WithOptionsSecResolved; pub mod additional_columns; mod avro; @@ -1177,9 +1178,15 @@ impl SpecificParserConfig { // The validity of (format, encode) is ensured by `extract_format_encode` pub fn new( info: &StreamSourceInfo, - with_properties: &BTreeMap, + with_properties: &WithOptionsSecResolved, ) -> ConnectorResult { - let source_struct = extract_source_struct(info)?; + let info = info.clone(); + let source_struct = extract_source_struct(&info)?; + let format_encode_options_with_secret = LocalSecretManager::global() + .fill_secrets(info.format_encode_options, info.format_encode_secret_refs)?; + let (options, secret_refs) = with_properties.clone().into_parts(); + let options_with_secret = + LocalSecretManager::global().fill_secrets(options.clone(), secret_refs.clone())?; let format = source_struct.format; let encode = source_struct.encode; // this transformation is needed since there may be config for the protocol @@ -1188,7 +1195,7 @@ impl SpecificParserConfig { SourceFormat::Native => ProtocolProperties::Native, SourceFormat::None => ProtocolProperties::None, SourceFormat::Debezium => { - let debezium_props = DebeziumProps::from(&info.format_encode_options); + let debezium_props = DebeziumProps::from(&format_encode_options_with_secret); ProtocolProperties::Debezium(debezium_props) } SourceFormat::DebeziumMongo => ProtocolProperties::DebeziumMongo, @@ -1214,31 +1221,31 @@ impl SpecificParserConfig { Some(info.proto_message_name.clone()) }, key_record_name: info.key_message_name.clone(), - map_handling: MapHandling::from_options(&info.format_encode_options)?, + map_handling: MapHandling::from_options(&format_encode_options_with_secret)?, ..Default::default() }; if format == SourceFormat::Upsert { config.enable_upsert = true; } config.schema_location = if let Some(schema_arn) = - info.format_encode_options.get(AWS_GLUE_SCHEMA_ARN_KEY) + format_encode_options_with_secret.get(AWS_GLUE_SCHEMA_ARN_KEY) { SchemaLocation::Glue { schema_arn: schema_arn.clone(), aws_auth_props: serde_json::from_value::( - serde_json::to_value(info.format_encode_options.clone()).unwrap(), + serde_json::to_value(format_encode_options_with_secret.clone()) + .unwrap(), ) .map_err(|e| anyhow::anyhow!(e))?, // The option `mock_config` is not public and we can break compatibility. - mock_config: info - .format_encode_options + mock_config: format_encode_options_with_secret .get("aws.glue.mock_config") .cloned(), } } else if info.use_schema_registry { SchemaLocation::Confluent { urls: info.row_schema_location.clone(), - client_config: SchemaRegistryAuth::from(&info.format_encode_options), + client_config: SchemaRegistryAuth::from(&format_encode_options_with_secret), name_strategy: PbSchemaRegistryNameStrategy::try_from(info.name_strategy) .unwrap(), topic: get_kafka_topic(with_properties)?.clone(), @@ -1248,7 +1255,8 @@ impl SpecificParserConfig { url: info.row_schema_location.clone(), aws_auth_props: Some( serde_json::from_value::( - serde_json::to_value(info.format_encode_options.clone()).unwrap(), + serde_json::to_value(format_encode_options_with_secret.clone()) + .unwrap(), ) .map_err(|e| anyhow::anyhow!(e))?, ), @@ -1274,12 +1282,16 @@ impl SpecificParserConfig { config.enable_upsert = true; } if info.use_schema_registry { - config.topic.clone_from(get_kafka_topic(with_properties)?); - config.client_config = SchemaRegistryAuth::from(&info.format_encode_options); + config + .topic + .clone_from(get_kafka_topic(&options_with_secret)?); + config.client_config = + SchemaRegistryAuth::from(&format_encode_options_with_secret); } else { config.aws_auth_props = Some( serde_json::from_value::( - serde_json::to_value(info.format_encode_options.clone()).unwrap(), + serde_json::to_value(format_encode_options_with_secret.clone()) + .unwrap(), ) .map_err(|e| anyhow::anyhow!(e))?, ); @@ -1296,7 +1308,7 @@ impl SpecificParserConfig { key_record_name: info.key_message_name.clone(), schema_location: SchemaLocation::Confluent { urls: info.row_schema_location.clone(), - client_config: SchemaRegistryAuth::from(&info.format_encode_options), + client_config: SchemaRegistryAuth::from(&format_encode_options_with_secret), name_strategy: PbSchemaRegistryNameStrategy::try_from(info.name_strategy) .unwrap(), topic: get_kafka_topic(with_properties).unwrap().clone(), @@ -1314,7 +1326,7 @@ impl SpecificParserConfig { ) => EncodingProperties::Json(JsonProperties { use_schema_registry: info.use_schema_registry, timestamptz_handling: TimestamptzHandling::from_options( - &info.format_encode_options, + &format_encode_options_with_secret, )?, }), (SourceFormat::DebeziumMongo, SourceEncode::Json) => { diff --git a/src/connector/src/sink/catalog/desc.rs b/src/connector/src/sink/catalog/desc.rs index 64e618bdd933..614f5d09516e 100644 --- a/src/connector/src/sink/catalog/desc.rs +++ b/src/connector/src/sink/catalog/desc.rs @@ -51,6 +51,9 @@ pub struct SinkDesc { /// The properties of the sink. pub properties: BTreeMap, + /// Secret ref + pub secret_refs: BTreeMap, + // The append-only behavior of the physical sink connector. Frontend will determine `sink_type` // based on both its own derivation on the append-only attribute and other user-specified // options in `properties`. @@ -84,7 +87,6 @@ impl SinkDesc { owner: UserId, connection_id: Option, dependent_relations: Vec, - secret_ref: BTreeMap, ) -> SinkCatalog { SinkCatalog { id: self.id, @@ -99,7 +101,7 @@ impl SinkDesc { owner, dependent_relations, properties: self.properties, - secret_refs: secret_ref, + secret_refs: self.secret_refs, sink_type: self.sink_type, format_desc: self.format_desc, connection_id, @@ -134,7 +136,7 @@ impl SinkDesc { sink_from_name: self.sink_from_name.clone(), target_table: self.target_table.map(|table_id| table_id.table_id()), extra_partition_col_idx: self.extra_partition_col_idx.map(|idx| idx as u64), - secret_refs: Default::default(), + secret_refs: self.secret_refs.clone(), } } } diff --git a/src/connector/src/sink/catalog/mod.rs b/src/connector/src/sink/catalog/mod.rs index 7638c197f234..335f39e60cbc 100644 --- a/src/connector/src/sink/catalog/mod.rs +++ b/src/connector/src/sink/catalog/mod.rs @@ -120,7 +120,7 @@ pub struct SinkFormatDesc { pub format: SinkFormat, pub encode: SinkEncode, pub options: BTreeMap, - + pub secret_refs: BTreeMap, pub key_encode: Option, } @@ -170,6 +170,7 @@ impl SinkFormatDesc { format, encode, options: Default::default(), + secret_refs: Default::default(), key_encode: None, })) } @@ -203,7 +204,7 @@ impl SinkFormatDesc { encode: encode.into(), options, key_encode, - secret_refs: Default::default(), + secret_refs: self.secret_refs.clone(), } } } @@ -266,13 +267,13 @@ impl TryFrom for SinkFormatDesc { ))) } }; - let options = value.options.into_iter().collect(); Ok(Self { format, encode, - options, + options: value.options, key_encode, + secret_refs: value.secret_refs, }) } } diff --git a/src/connector/src/sink/mod.rs b/src/connector/src/sink/mod.rs index a1a993803568..7af7a13183aa 100644 --- a/src/connector/src/sink/mod.rs +++ b/src/connector/src/sink/mod.rs @@ -58,6 +58,7 @@ use risingwave_common::catalog::{ColumnDesc, Field, Schema}; use risingwave_common::metrics::{ LabelGuardedHistogram, LabelGuardedIntCounter, LabelGuardedIntGauge, }; +use risingwave_common::secret::{LocalSecretManager, SecretError}; use risingwave_common::session_config::sink_decouple::SinkDecouple; use risingwave_pb::catalog::PbSinkType; use risingwave_pb::connector_service::{PbSinkParam, SinkMetadata, TableSchema}; @@ -230,25 +231,42 @@ impl SinkParam { fields: self.columns.iter().map(Field::from).collect(), } } -} -impl From for SinkParam { - fn from(sink_catalog: SinkCatalog) -> Self { + // `SinkParams` should only be used when there is a secret context. + // FIXME: Use a new type for `SinkFormatDesc` with properties contain filled secrets. + pub fn fill_secret_for_format_desc( + format_desc: Option, + ) -> Result> { + match format_desc { + Some(mut format_desc) => { + format_desc.options = LocalSecretManager::global() + .fill_secrets(format_desc.options, format_desc.secret_refs.clone())?; + Ok(Some(format_desc)) + } + None => Ok(None), + } + } + + /// Try to convert a `SinkCatalog` to a `SinkParam` and fill the secrets to properties. + pub fn try_from_sink_catalog(sink_catalog: SinkCatalog) -> Result { let columns = sink_catalog .visible_columns() .map(|col| col.column_desc.clone()) .collect(); - Self { + let properties_with_secret = LocalSecretManager::global() + .fill_secrets(sink_catalog.properties, sink_catalog.secret_refs)?; + let format_desc_with_secret = Self::fill_secret_for_format_desc(sink_catalog.format_desc)?; + Ok(Self { sink_id: sink_catalog.id, sink_name: sink_catalog.name, - properties: sink_catalog.properties, + properties: properties_with_secret, columns, downstream_pk: sink_catalog.downstream_pk, sink_type: sink_catalog.sink_type, - format_desc: sink_catalog.format_desc, + format_desc: format_desc_with_secret, db_name: sink_catalog.db_name, sink_from_name: sink_catalog.sink_from_name, - } + }) } } @@ -597,6 +615,12 @@ pub enum SinkError { #[backtrace] ConnectorError, ), + #[error("Secret error: {0}")] + Secret( + #[from] + #[backtrace] + SecretError, + ), #[error("Mongodb error: {0}")] Mongodb( #[source] diff --git a/src/connector/src/sink/redis.rs b/src/connector/src/sink/redis.rs index 9d6a33d5131a..ad55d8990474 100644 --- a/src/connector/src/sink/redis.rs +++ b/src/connector/src/sink/redis.rs @@ -410,6 +410,7 @@ mod test { format: SinkFormat::AppendOnly, encode: SinkEncode::Json, options: BTreeMap::default(), + secret_refs: BTreeMap::default(), key_encode: None, }; @@ -487,6 +488,7 @@ mod test { format: SinkFormat::AppendOnly, encode: SinkEncode::Template, options: btree_map, + secret_refs: Default::default(), key_encode: None, }; diff --git a/src/connector/src/source/base.rs b/src/connector/src/source/base.rs index 707678b66599..4c44f9610bd1 100644 --- a/src/connector/src/source/base.rs +++ b/src/connector/src/source/base.rs @@ -26,6 +26,7 @@ use itertools::Itertools; use risingwave_common::array::StreamChunk; use risingwave_common::bail; use risingwave_common::catalog::TableId; +use risingwave_common::secret::LocalSecretManager; use risingwave_common::types::{JsonbVal, Scalar}; use risingwave_pb::catalog::{PbSource, PbStreamSourceInfo}; use risingwave_pb::plan_common::ExternalTableDesc; @@ -48,7 +49,7 @@ use crate::source::SplitImpl::{CitusCdc, MongodbCdc, MysqlCdc, PostgresCdc}; use crate::with_options::WithOptions; use crate::{ dispatch_source_prop, dispatch_split_impl, for_all_sources, impl_connector_properties, - impl_split, match_source_name_str, + impl_split, match_source_name_str, WithOptionsSecResolved, }; const SPLIT_TYPE_FIELD: &str = "split_type"; @@ -387,16 +388,19 @@ impl ConnectorProperties { /// `deny_unknown_fields`: Since `WITH` options are persisted in meta, we do not deny unknown fields when restoring from /// existing data to avoid breaking backwards compatibility. We only deny unknown fields when creating new sources. pub fn extract( - mut with_properties: BTreeMap, + with_properties: WithOptionsSecResolved, deny_unknown_fields: bool, ) -> Result { - let connector = with_properties + let (options, secret_refs) = with_properties.into_parts(); + let mut options_with_secret = + LocalSecretManager::global().fill_secrets(options, secret_refs)?; + let connector = options_with_secret .remove(UPSTREAM_SOURCE_KEY) .ok_or_else(|| anyhow!("Must specify 'connector' in WITH clause"))?; match_source_name_str!( connector.to_lowercase().as_str(), PropType, - PropType::try_from_btreemap(with_properties, deny_unknown_fields) + PropType::try_from_btreemap(options_with_secret, deny_unknown_fields) .map(ConnectorProperties::from), |other| bail!("connector '{}' is not supported", other) ) @@ -690,7 +694,9 @@ mod tests { "nexmark.split.num" => "1", )); - let props = ConnectorProperties::extract(props, true).unwrap(); + let props = + ConnectorProperties::extract(WithOptionsSecResolved::without_secrets(props), true) + .unwrap(); if let ConnectorProperties::Nexmark(props) = props { assert_eq!(props.table_type, Some(EventType::Person)); @@ -710,7 +716,9 @@ mod tests { "broker.rewrite.endpoints" => r#"{"b-1:9092":"dns-1", "b-2:9092":"dns-2"}"#, )); - let props = ConnectorProperties::extract(props, true).unwrap(); + let props = + ConnectorProperties::extract(WithOptionsSecResolved::without_secrets(props), true) + .unwrap(); if let ConnectorProperties::Kafka(k) = props { let btreemap = btreemap! { "b-1:9092".to_string() => "dns-1".to_string(), @@ -745,7 +753,11 @@ mod tests { "table.name" => "orders", )); - let conn_props = ConnectorProperties::extract(user_props_mysql, true).unwrap(); + let conn_props = ConnectorProperties::extract( + WithOptionsSecResolved::without_secrets(user_props_mysql), + true, + ) + .unwrap(); if let ConnectorProperties::MysqlCdc(c) = conn_props { assert_eq!(c.properties.get("database.hostname").unwrap(), "127.0.0.1"); assert_eq!(c.properties.get("database.port").unwrap(), "3306"); @@ -757,7 +769,11 @@ mod tests { panic!("extract cdc config failed"); } - let conn_props = ConnectorProperties::extract(user_props_postgres, true).unwrap(); + let conn_props = ConnectorProperties::extract( + WithOptionsSecResolved::without_secrets(user_props_postgres), + true, + ) + .unwrap(); if let ConnectorProperties::PostgresCdc(c) = conn_props { assert_eq!(c.properties.get("database.hostname").unwrap(), "127.0.0.1"); assert_eq!(c.properties.get("database.port").unwrap(), "5432"); diff --git a/src/connector/src/source/cdc/external/mod.rs b/src/connector/src/source/cdc/external/mod.rs index 528e14bd6022..9a7a4dcac247 100644 --- a/src/connector/src/source/cdc/external/mod.rs +++ b/src/connector/src/source/cdc/external/mod.rs @@ -29,6 +29,8 @@ use futures_async_stream::try_stream; use risingwave_common::bail; use risingwave_common::catalog::{ColumnDesc, Schema}; use risingwave_common::row::OwnedRow; +use risingwave_common::secret::LocalSecretManager; +use risingwave_pb::secret::PbSecretRef; use serde_derive::{Deserialize, Serialize}; use crate::error::{ConnectorError, ConnectorResult}; @@ -97,7 +99,7 @@ pub const SCHEMA_NAME_KEY: &str = "schema.name"; pub const DATABASE_NAME_KEY: &str = "database.name"; impl SchemaTableName { - pub fn from_properties(properties: &HashMap) -> Self { + pub fn from_properties(properties: &BTreeMap) -> Self { let table_type = CdcTableType::from_properties(properties); let table_name = properties.get(TABLE_NAME_KEY).cloned().unwrap_or_default(); @@ -215,8 +217,11 @@ pub struct ExternalTableConfig { impl ExternalTableConfig { pub fn try_from_btreemap( connect_properties: BTreeMap, + secret_refs: BTreeMap, ) -> ConnectorResult { - let json_value = serde_json::to_value(connect_properties)?; + let options_with_secret = + LocalSecretManager::global().fill_secrets(connect_properties, secret_refs)?; + let json_value = serde_json::to_value(options_with_secret)?; let config = serde_json::from_value::(json_value)?; Ok(config) } diff --git a/src/connector/src/source/cdc/mod.rs b/src/connector/src/source/cdc/mod.rs index 2a4200b5cc4a..c86450b59471 100644 --- a/src/connector/src/source/cdc/mod.rs +++ b/src/connector/src/source/cdc/mod.rs @@ -120,7 +120,7 @@ impl TryFromBTreeMap for CdcProperties { properties: BTreeMap, _deny_unknown_fields: bool, ) -> ConnectorResult { - let is_share_source = properties + let is_share_source: bool = properties .get(CDC_SHARING_MODE_KEY) .is_some_and(|v| v == "true"); Ok(CdcProperties { @@ -180,8 +180,6 @@ where } fn init_from_pb_cdc_table_desc(&mut self, table_desc: &ExternalTableDesc) { - let properties = table_desc.connect_properties.clone(); - let table_schema = TableSchema { columns: table_desc .columns @@ -197,7 +195,6 @@ where pk_indices: table_desc.stream_key.clone(), }; - self.properties = properties; self.table_schema = table_schema; self.is_cdc_source_job = false; self.is_backfill_table = true; diff --git a/src/connector/src/source/reader/desc.rs b/src/connector/src/source/reader/desc.rs index 44d2effe5160..af607d2537ea 100644 --- a/src/connector/src/source/reader/desc.rs +++ b/src/connector/src/source/reader/desc.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; use std::sync::Arc; use risingwave_common::bail; @@ -29,6 +28,7 @@ use crate::parser::additional_columns::source_add_partition_offset_cols; use crate::parser::{EncodingProperties, ProtocolProperties, SpecificParserConfig}; use crate::source::monitor::SourceMetrics; use crate::source::{SourceColumnDesc, SourceColumnType, UPSTREAM_SOURCE_KEY}; +use crate::WithOptionsSecResolved; pub const DEFAULT_CONNECTOR_MESSAGE_BUFFER_SIZE: usize = 16; @@ -55,7 +55,7 @@ pub struct SourceDescBuilder { columns: Vec, metrics: Arc, row_id_index: Option, - with_properties: BTreeMap, + with_properties: WithOptionsSecResolved, source_info: PbStreamSourceInfo, connector_message_buffer_size: usize, pk_indices: Vec, @@ -66,7 +66,7 @@ impl SourceDescBuilder { columns: Vec, metrics: Arc, row_id_index: Option, - with_properties: BTreeMap, + with_properties: WithOptionsSecResolved, source_info: PbStreamSourceInfo, connector_message_buffer_size: usize, pk_indices: Vec, @@ -199,11 +199,13 @@ pub mod test_utils { )) }) .collect(); + let options_with_secret = + crate::WithOptionsSecResolved::without_secrets(with_properties.clone()); SourceDescBuilder { columns, metrics: Default::default(), row_id_index, - with_properties, + with_properties: options_with_secret, source_info, connector_message_buffer_size: DEFAULT_CONNECTOR_MESSAGE_BUFFER_SIZE, pk_indices, diff --git a/src/connector/src/source/reader/fs_reader.rs b/src/connector/src/source/reader/fs_reader.rs index 5bd139e70983..ae05bc64ca1a 100644 --- a/src/connector/src/source/reader/fs_reader.rs +++ b/src/connector/src/source/reader/fs_reader.rs @@ -14,7 +14,6 @@ #![deprecated = "will be replaced by new fs source (list + fetch)"] -use std::collections::BTreeMap; use std::sync::Arc; use anyhow::Context; @@ -22,26 +21,26 @@ use futures::stream::pending; use futures::StreamExt; use risingwave_common::catalog::ColumnId; -use crate::dispatch_source_prop; use crate::error::ConnectorResult; use crate::parser::{CommonParserConfig, ParserConfig, SpecificParserConfig}; use crate::source::{ create_split_reader, BoxChunkSourceStream, ConnectorProperties, ConnectorState, SourceColumnDesc, SourceContext, SplitReader, }; +use crate::{dispatch_source_prop, WithOptionsSecResolved}; #[derive(Clone, Debug)] pub struct FsSourceReader { pub config: ConnectorProperties, pub columns: Vec, - pub properties: BTreeMap, + pub properties: WithOptionsSecResolved, pub parser_config: SpecificParserConfig, } impl FsSourceReader { #[allow(clippy::too_many_arguments)] pub fn new( - properties: BTreeMap, + properties: WithOptionsSecResolved, columns: Vec, parser_config: SpecificParserConfig, ) -> ConnectorResult { diff --git a/src/connector/src/source/reader/reader.rs b/src/connector/src/source/reader/reader.rs index 02012841c5a4..2258617f84b4 100644 --- a/src/connector/src/source/reader/reader.rs +++ b/src/connector/src/source/reader/reader.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; use std::sync::Arc; use anyhow::Context; @@ -26,7 +25,6 @@ use risingwave_common::catalog::ColumnId; use rw_futures_util::select_all; use thiserror_ext::AsReport as _; -use crate::dispatch_source_prop; use crate::error::ConnectorResult; use crate::parser::{CommonParserConfig, ParserConfig, SpecificParserConfig}; use crate::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator; @@ -38,6 +36,7 @@ use crate::source::{ create_split_reader, BoxChunkSourceStream, BoxTryStream, Column, ConnectorProperties, ConnectorState, SourceColumnDesc, SourceContext, SplitReader, WaitCheckpointTask, }; +use crate::{dispatch_source_prop, WithOptionsSecResolved}; #[derive(Clone, Debug)] pub struct SourceReader { @@ -49,7 +48,7 @@ pub struct SourceReader { impl SourceReader { pub fn new( - properties: BTreeMap, + properties: WithOptionsSecResolved, columns: Vec, connector_message_buffer_size: usize, parser_config: SpecificParserConfig, diff --git a/src/connector/src/with_options.rs b/src/connector/src/with_options.rs index 154586d77052..bb96a6298f5d 100644 --- a/src/connector/src/with_options.rs +++ b/src/connector/src/with_options.rs @@ -14,6 +14,9 @@ use std::collections::{BTreeMap, HashMap}; +use risingwave_pb::secret::PbSecretRef; + +use crate::sink::catalog::SinkFormatDesc; use crate::source::cdc::external::CdcTableType; use crate::source::iceberg::ICEBERG_CONNECTOR; use crate::source::{ @@ -135,3 +138,67 @@ pub trait WithPropertiesExt: Get + Sized { } impl WithPropertiesExt for T {} + +/// Options or properties extracted from the `WITH` clause of DDLs. +#[derive(Default, Clone, Debug, PartialEq, Eq, Hash)] +pub struct WithOptionsSecResolved { + inner: BTreeMap, + secret_ref: BTreeMap, +} + +impl std::ops::Deref for WithOptionsSecResolved { + type Target = BTreeMap; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl std::ops::DerefMut for WithOptionsSecResolved { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl WithOptionsSecResolved { + /// Create a new [`WithOptions`] from a option [`BTreeMap`] and resolved secret ref. + pub fn new(inner: BTreeMap, secret_ref: BTreeMap) -> Self { + Self { inner, secret_ref } + } + + /// Create a new [`WithOptions`] from a [`BTreeMap`]. + pub fn without_secrets(inner: BTreeMap) -> Self { + Self { + inner, + secret_ref: Default::default(), + } + } + + /// Take the value of the option map and secret refs. + pub fn into_parts(self) -> (BTreeMap, BTreeMap) { + (self.inner, self.secret_ref) + } + + pub fn value_eq_ignore_case(&self, key: &str, val: &str) -> bool { + if let Some(inner_val) = self.inner.get(key) { + if inner_val.eq_ignore_ascii_case(val) { + return true; + } + } + false + } +} + +/// For `planner_test` crate so that it does not depend directly on `connector` crate just for `SinkFormatDesc`. +impl TryFrom<&WithOptionsSecResolved> for Option { + type Error = crate::sink::SinkError; + + fn try_from(value: &WithOptionsSecResolved) -> std::result::Result { + let connector = value.get(crate::sink::CONNECTOR_TYPE_KEY); + let r#type = value.get(crate::sink::SINK_TYPE_OPTION); + match (connector, r#type) { + (Some(c), Some(t)) => SinkFormatDesc::from_legacy_type(c, t), + _ => Ok(None), + } + } +} diff --git a/src/frontend/planner_test/src/lib.rs b/src/frontend/planner_test/src/lib.rs index ee4f942f0988..abb291cc37d7 100644 --- a/src/frontend/planner_test/src/lib.rs +++ b/src/frontend/planner_test/src/lib.rs @@ -36,7 +36,7 @@ use risingwave_frontend::session::SessionImpl; use risingwave_frontend::test_utils::{create_proto_file, get_explain_output, LocalFrontend}; use risingwave_frontend::{ build_graph, explain_stream_graph, Binder, Explain, FrontendOpts, OptimizerContext, - OptimizerContextRef, PlanRef, Planner, WithOptions, + OptimizerContextRef, PlanRef, Planner, WithOptionsSecResolved, }; use risingwave_sqlparser::ast::{ AstOption, DropMode, EmitMode, ExplainOptions, ObjectName, Statement, @@ -836,7 +836,8 @@ impl TestCase { let mut options = BTreeMap::new(); options.insert("connector".to_string(), "blackhole".to_string()); options.insert("type".to_string(), "append-only".to_string()); - let options = WithOptions::new(options); + // let options = WithOptionsSecResolved::without_secrets(options); + let options = WithOptionsSecResolved::without_secrets(options); let format_desc = (&options).try_into().unwrap(); match plan_root.gen_sink_plan( sink_name.to_string(), diff --git a/src/frontend/src/catalog/secret_catalog.rs b/src/frontend/src/catalog/secret_catalog.rs index 5e9aaae7dec9..d1f9048baf0e 100644 --- a/src/frontend/src/catalog/secret_catalog.rs +++ b/src/frontend/src/catalog/secret_catalog.rs @@ -19,7 +19,7 @@ use crate::user::UserId; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct SecretCatalog { - pub secret_id: SecretId, + pub id: SecretId, pub name: String, pub database_id: DatabaseId, pub value: Vec, @@ -29,7 +29,7 @@ pub struct SecretCatalog { impl From<&PbSecret> for SecretCatalog { fn from(value: &PbSecret) -> Self { Self { - secret_id: SecretId::new(value.id), + id: SecretId::new(value.id), database_id: value.database_id, owner: value.owner, name: value.name.clone(), diff --git a/src/frontend/src/catalog/source_catalog.rs b/src/frontend/src/catalog/source_catalog.rs index 1ee095a918e5..f060fad72c78 100644 --- a/src/frontend/src/catalog/source_catalog.rs +++ b/src/frontend/src/catalog/source_catalog.rs @@ -12,11 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; - use risingwave_common::catalog::{ColumnCatalog, SourceVersionId}; use risingwave_common::util::epoch::Epoch; -use risingwave_connector::WithPropertiesExt; +use risingwave_connector::{WithOptionsSecResolved, WithPropertiesExt}; use risingwave_pb::catalog::source::OptionalAssociatedTableId; use risingwave_pb::catalog::{PbSource, StreamSourceInfo, WatermarkDesc}; @@ -36,7 +34,7 @@ pub struct SourceCatalog { pub owner: UserId, pub info: StreamSourceInfo, pub row_id_index: Option, - pub with_properties: BTreeMap, + pub with_properties: WithOptionsSecResolved, pub watermark_descs: Vec, pub associated_table_id: Option, pub definition: String, @@ -55,6 +53,7 @@ impl SourceCatalog { } pub fn to_prost(&self, schema_id: SchemaId, database_id: DatabaseId) -> PbSource { + let (with_properties, secret_refs) = self.with_properties.clone().into_parts(); PbSource { id: self.id, schema_id, @@ -63,7 +62,7 @@ impl SourceCatalog { row_id_index: self.row_id_index.map(|idx| idx as _), columns: self.columns.iter().map(|c| c.to_protobuf()).collect(), pk_column_ids: self.pk_col_ids.iter().map(Into::into).collect(), - with_properties: self.with_properties.clone().into_iter().collect(), + with_properties, owner: self.owner, info: Some(self.info.clone()), watermark_descs: self.watermark_descs.clone(), @@ -77,7 +76,7 @@ impl SourceCatalog { version: self.version, created_at_cluster_version: self.created_at_cluster_version.clone(), initialized_at_cluster_version: self.initialized_at_cluster_version.clone(), - secret_refs: Default::default(), + secret_refs, } } @@ -104,7 +103,8 @@ impl From<&PbSource> for SourceCatalog { .into_iter() .map(Into::into) .collect(); - let with_properties = prost.with_properties.clone().into_iter().collect(); + let options_with_secrets = + WithOptionsSecResolved::new(prost.with_properties.clone(), prost.secret_refs.clone()); let columns = prost_columns.into_iter().map(ColumnCatalog::from).collect(); let row_id_index = prost.row_id_index.map(|idx| idx as _); @@ -131,7 +131,7 @@ impl From<&PbSource> for SourceCatalog { owner, info: prost.info.clone().unwrap(), row_id_index, - with_properties, + with_properties: options_with_secrets, watermark_descs, associated_table_id: associated_table_id.map(|x| x.into()), definition: prost.definition.clone(), diff --git a/src/frontend/src/catalog/system_catalog/rw_catalog/rw_iceberg_files.rs b/src/frontend/src/catalog/system_catalog/rw_catalog/rw_iceberg_files.rs index aaeb8aaa064c..b025723857b1 100644 --- a/src/frontend/src/catalog/system_catalog/rw_catalog/rw_iceberg_files.rs +++ b/src/frontend/src/catalog/system_catalog/rw_catalog/rw_iceberg_files.rs @@ -79,8 +79,7 @@ async fn read(reader: &SysCatalogReaderImpl) -> Result> { let mut result = vec![]; for (schema_name, source) in iceberg_sources { - let source_props = source.with_properties.clone(); - let config = ConnectorProperties::extract(source_props, false)?; + let config = ConnectorProperties::extract(source.with_properties.clone(), false)?; if let ConnectorProperties::Iceberg(iceberg_properties) = config { let iceberg_config: IcebergConfig = iceberg_properties.to_iceberg_config(); let table: Table = iceberg_config.load_table_v2().await?; diff --git a/src/frontend/src/catalog/system_catalog/rw_catalog/rw_iceberg_snapshots.rs b/src/frontend/src/catalog/system_catalog/rw_catalog/rw_iceberg_snapshots.rs index d7491cfeb0ca..e2bbcb486b92 100644 --- a/src/frontend/src/catalog/system_catalog/rw_catalog/rw_iceberg_snapshots.rs +++ b/src/frontend/src/catalog/system_catalog/rw_catalog/rw_iceberg_snapshots.rs @@ -57,8 +57,7 @@ async fn read(reader: &SysCatalogReaderImpl) -> Result> let mut result = vec![]; for (schema_name, source) in iceberg_sources { - let source_props = source.with_properties.clone(); - let config = ConnectorProperties::extract(source_props, false)?; + let config = ConnectorProperties::extract(source.with_properties.clone(), false)?; if let ConnectorProperties::Iceberg(iceberg_properties) = config { let iceberg_config: IcebergConfig = iceberg_properties.to_iceberg_config(); let table: Table = iceberg_config.load_table_v2().await?; diff --git a/src/frontend/src/catalog/view_catalog.rs b/src/frontend/src/catalog/view_catalog.rs index a884eed3d0e2..331613be9415 100644 --- a/src/frontend/src/catalog/view_catalog.rs +++ b/src/frontend/src/catalog/view_catalog.rs @@ -36,7 +36,7 @@ impl From<&PbView> for ViewCatalog { id: view.id, name: view.name.clone(), owner: view.owner, - properties: WithOptions::new(view.properties.clone()), + properties: WithOptions::new_with_options(view.properties.clone()), sql: view.sql.clone(), columns: view.columns.iter().map(|f| f.into()).collect(), } diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 93b3e627ae85..3092c9bee91a 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -15,6 +15,7 @@ use risingwave_batch::error::BatchError; use risingwave_common::array::ArrayError; use risingwave_common::error::{BoxedError, NoFunction, NotImplemented}; +use risingwave_common::secret::SecretError; use risingwave_common::session_config::SessionConfigError; use risingwave_common::util::value_encoding::error::ValueEncodingError; use risingwave_connector::error::ConnectorError; @@ -164,6 +165,12 @@ pub enum ErrorCode { #[backtrace] SessionConfigError, ), + #[error("Secret error: {0}")] + SecretError( + #[from] + #[backtrace] + SecretError, + ), #[error("{0} has been deprecated, please use {1} instead.")] Deprecated(String, String), } diff --git a/src/frontend/src/handler/alter_source_with_sr.rs b/src/frontend/src/handler/alter_source_with_sr.rs index 840205caeadd..01e09b0c4f5f 100644 --- a/src/frontend/src/handler/alter_source_with_sr.rs +++ b/src/frontend/src/handler/alter_source_with_sr.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use anyhow::Context; +use either::Either; use itertools::Itertools; use pgwire::pg_response::StatementType; use risingwave_common::bail_not_implemented; @@ -37,6 +38,7 @@ use crate::catalog::source_catalog::SourceCatalog; use crate::catalog::{DatabaseId, SchemaId}; use crate::error::{ErrorCode, Result}; use crate::session::SessionImpl; +use crate::utils::resolve_secret_ref_in_with_options; use crate::{Binder, WithOptions}; fn format_type_to_format(from: FormatType) -> Option { @@ -162,7 +164,8 @@ pub async fn refresh_sr_and_get_columns_diff( } let (Some(columns_from_resolve_source), source_info) = - bind_columns_from_source(session, connector_schema, &with_properties).await? + bind_columns_from_source(session, connector_schema, Either::Right(&with_properties)) + .await? else { // Source without schema registry is rejected. unreachable!("source without schema registry is rejected") @@ -254,12 +257,21 @@ pub async fn handle_alter_source_with_sr( source.definition = alter_definition_format_encode(&source.definition, connector_schema.row_options.clone())?; - let format_encode_options = WithOptions::try_from(connector_schema.row_options())?.into_inner(); + let (format_encode_options, format_encode_secret_ref) = resolve_secret_ref_in_with_options( + WithOptions::try_from(connector_schema.row_options())?, + session.as_ref(), + )? + .into_parts(); source .info .format_encode_options .extend(format_encode_options); + source + .info + .format_encode_secret_refs + .extend(format_encode_secret_ref); + let mut pb_source = source.to_prost(schema_id, database_id); // update version diff --git a/src/frontend/src/handler/create_secret.rs b/src/frontend/src/handler/create_secret.rs index 2e99f26e97cb..8e3e56f324b4 100644 --- a/src/frontend/src/handler/create_secret.rs +++ b/src/frontend/src/handler/create_secret.rs @@ -45,15 +45,17 @@ pub async fn handle_create_secret( }; } + let secret = secret_to_str(&stmt.credential)?.as_bytes().to_vec(); + // check if the secret backend is supported let with_props = WithOptions::try_from(stmt.with_properties.0.as_ref() as &[SqlOption])?; let secret_payload: Vec = { - if let Some(backend) = with_props.inner().get(SECRET_BACKEND_KEY) { + if let Some(backend) = with_props.get(SECRET_BACKEND_KEY) { match backend.to_lowercase().as_ref() { SECRET_BACKEND_META => { let backend = risingwave_pb::secret::Secret { secret_backend: Some(risingwave_pb::secret::secret::SecretBackend::Meta( - risingwave_pb::secret::SecretMetaBackend { value: vec![] }, + risingwave_pb::secret::SecretMetaBackend { value: secret }, )), }; backend.encode_to_vec() @@ -100,3 +102,13 @@ pub async fn handle_create_secret( Ok(PgResponse::empty_result(StatementType::CREATE_SECRET)) } + +fn secret_to_str(value: &Value) -> Result { + match value { + Value::DoubleQuotedString(s) | Value::SingleQuotedString(s) => Ok(s.to_string()), + _ => Err(ErrorCode::InvalidInputSyntax( + "secret value should be quoted by ' or \" ".to_string(), + ) + .into()), + } +} diff --git a/src/frontend/src/handler/create_sink.rs b/src/frontend/src/handler/create_sink.rs index b2a65b8d0978..c44e6a2367bb 100644 --- a/src/frontend/src/handler/create_sink.rs +++ b/src/frontend/src/handler/create_sink.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::collections::{BTreeMap, BTreeSet, HashMap}; -use std::rc::Rc; use std::sync::{Arc, LazyLock}; use anyhow::Context; @@ -24,6 +23,7 @@ use maplit::{convert_args, hashmap}; use pgwire::pg_response::{PgResponse, StatementType}; use risingwave_common::array::arrow::IcebergArrowConvert; use risingwave_common::catalog::{ConnectionId, DatabaseId, Schema, SchemaId, TableId, UserId}; +use risingwave_common::secret::LocalSecretManager; use risingwave_common::types::DataType; use risingwave_common::{bail, catalog}; use risingwave_connector::sink::catalog::{SinkCatalog, SinkFormatDesc, SinkType}; @@ -37,7 +37,8 @@ use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism; use risingwave_pb::stream_plan::stream_node::NodeBody; use risingwave_pb::stream_plan::{DispatcherType, MergeNode, StreamFragmentGraph, StreamNode}; use risingwave_sqlparser::ast::{ - ConnectorSchema, CreateSink, CreateSinkStatement, EmitMode, Encode, Format, Query, Statement, + ConnectorSchema, CreateSink, CreateSinkStatement, EmitMode, Encode, ExplainOptions, Format, + Query, Statement, }; use risingwave_sqlparser::parser::Parser; @@ -60,12 +61,12 @@ use crate::handler::HandlerArgs; use crate::optimizer::plan_node::{ generic, IcebergPartitionInfo, LogicalSource, PartitionComputeInfo, StreamProject, }; -use crate::optimizer::{OptimizerContext, OptimizerContextRef, PlanRef, RelationCollectorVisitor}; +use crate::optimizer::{OptimizerContext, PlanRef, RelationCollectorVisitor}; use crate::scheduler::streaming_manager::CreatingStreamingJobInfo; use crate::session::SessionImpl; use crate::stream_fragmenter::build_graph; -use crate::utils::{resolve_privatelink_in_with_option, resolve_secret_in_with_options}; -use crate::{Explain, Planner, TableCatalog, WithOptions}; +use crate::utils::{resolve_privatelink_in_with_option, resolve_secret_ref_in_with_options}; +use crate::{Explain, Planner, TableCatalog, WithOptions, WithOptionsSecResolved}; // used to store result of `gen_sink_plan` pub struct SinkPlanContext { @@ -75,17 +76,36 @@ pub struct SinkPlanContext { pub target_table_catalog: Option>, } -pub fn gen_sink_plan( - session: &SessionImpl, - context: OptimizerContextRef, +pub async fn gen_sink_plan( + handler_args: HandlerArgs, stmt: CreateSinkStatement, - partition_info: Option, + explain_options: Option, ) -> Result { + let session = handler_args.session.clone(); + let session = session.as_ref(); let user_specified_columns = !stmt.columns.is_empty(); let db_name = session.database(); let (sink_schema_name, sink_table_name) = Binder::resolve_schema_qualified_name(db_name, stmt.sink_name.clone())?; + let mut with_options = handler_args.with_options.clone(); + + let connection_id = { + let conn_id = + resolve_privatelink_in_with_option(&mut with_options, &sink_schema_name, session)?; + conn_id.map(ConnectionId) + }; + + let mut resolved_with_options = resolve_secret_ref_in_with_options(with_options, session)?; + + let partition_info = get_partition_compute_info(&resolved_with_options).await?; + + let context = if let Some(explain_options) = explain_options { + OptimizerContext::new(handler_args.clone(), explain_options) + } else { + OptimizerContext::from_handler_args(handler_args.clone()) + }; + // Used for debezium's table name let sink_from_table_name; // `true` means that sink statement has the form: `CREATE SINK s1 FROM ...` @@ -109,8 +129,6 @@ pub fn gen_sink_plan( let (sink_database_id, sink_schema_id) = session.get_database_and_schema_id_for_create(sink_schema_name.clone())?; - let definition = context.normalized_sql().to_owned(); - let (dependent_relations, bound) = { let mut binder = Binder::new_for_stream(session); let bound = binder.bind_query(*query.clone())?; @@ -127,12 +145,9 @@ pub fn gen_sink_plan( get_column_names(&bound, session, stmt.columns)? }; - let mut with_options = context.with_options().clone(); - if sink_into_table_name.is_some() { - let prev = with_options - .inner_mut() - .insert(CONNECTOR_TYPE_KEY.to_string(), "table".to_string()); + let prev = + resolved_with_options.insert(CONNECTOR_TYPE_KEY.to_string(), "table".to_string()); if prev.is_some() { return Err(RwError::from(ErrorCode::BindError( @@ -141,19 +156,12 @@ pub fn gen_sink_plan( } } - let connection_id = { - let conn_id = - resolve_privatelink_in_with_option(&mut with_options, &sink_schema_name, session)?; - conn_id.map(ConnectionId) - }; - let secret_ref = resolve_secret_in_with_options(&mut with_options, session)?; - let emit_on_window_close = stmt.emit_mode == Some(EmitMode::OnWindowClose); if emit_on_window_close { context.warn_to_user("EMIT ON WINDOW CLOSE is currently an experimental feature. Please use it with caution."); } - let connector = with_options + let connector = resolved_with_options .get(CONNECTOR_TYPE_KEY) .cloned() .ok_or_else(|| ErrorCode::BindError(format!("missing field '{CONNECTOR_TYPE_KEY}'")))?; @@ -162,13 +170,13 @@ pub fn gen_sink_plan( // Case A: new syntax `format ... encode ...` Some(f) => { validate_compatibility(&connector, &f)?; - Some(bind_sink_format_desc(f)?) + Some(bind_sink_format_desc(session,f)?) } - None => match with_options.get(SINK_TYPE_OPTION) { + None => match resolved_with_options.get(SINK_TYPE_OPTION) { // Case B: old syntax `type = '...'` Some(t) => SinkFormatDesc::from_legacy_type(&connector, t)?.map(|mut f| { session.notice_to_user("Consider using the newer syntax `FORMAT ... ENCODE ...` instead of `type = '...'`."); - if let Some(v) = with_options.get(SINK_USER_FORCE_APPEND_ONLY_OPTION) { + if let Some(v) = resolved_with_options.get(SINK_USER_FORCE_APPEND_ONLY_OPTION) { f.options.insert(SINK_USER_FORCE_APPEND_ONLY_OPTION.into(), v.into()); } f @@ -178,12 +186,13 @@ pub fn gen_sink_plan( }, }; - let mut plan_root = Planner::new(context).plan_query(bound)?; + let definition = context.normalized_sql().to_owned(); + let mut plan_root = Planner::new(context.into()).plan_query(bound)?; if let Some(col_names) = &col_names { plan_root.set_out_names(col_names.clone())?; }; - let without_backfill = match with_options.remove(SINK_WITHOUT_BACKFILL) { + let without_backfill = match resolved_with_options.remove(SINK_WITHOUT_BACKFILL) { Some(flag) if flag.eq_ignore_ascii_case("false") => { if direct_sink { true @@ -227,7 +236,7 @@ pub fn gen_sink_plan( let sink_plan = plan_root.gen_sink_plan( sink_table_name, definition, - with_options, + resolved_with_options, emit_on_window_close, db_name.to_owned(), sink_from_table_name, @@ -236,6 +245,7 @@ pub fn gen_sink_plan( target_table, partition_info, )?; + let sink_desc = sink_plan.sink_desc().clone(); let mut sink_plan: PlanRef = sink_plan.into(); @@ -256,7 +266,6 @@ pub fn gen_sink_plan( UserId::new(session.user_id()), connection_id, dependent_relations.into_iter().collect_vec(), - secret_ref, ); if let Some(table_catalog) = &target_table_catalog { @@ -310,12 +319,13 @@ pub fn gen_sink_plan( // `Some(PartitionComputeInfo)` if the sink need to compute partition. // `None` if the sink does not need to compute partition. pub async fn get_partition_compute_info( - with_options: &WithOptions, + with_options: &WithOptionsSecResolved, ) -> Result> { - let properties = with_options.clone().into_inner(); - let Some(connector) = properties.get(UPSTREAM_SOURCE_KEY) else { + let (options, secret_refs) = with_options.clone().into_parts(); + let Some(connector) = options.get(UPSTREAM_SOURCE_KEY).cloned() else { return Ok(None); }; + let properties = LocalSecretManager::global().fill_secrets(options, secret_refs)?; match connector.as_str() { ICEBERG_SINK => { let iceberg_config = IcebergConfig::from_btreemap(properties)?; @@ -409,21 +419,17 @@ pub async fn handle_create_sink( return Ok(resp); } - let partition_info = get_partition_compute_info(&handle_args.with_options).await?; - let (sink, graph, target_table_catalog) = { - let context = Rc::new(OptimizerContext::from_handler_args(handle_args)); - let SinkPlanContext { query, sink_plan: plan, sink_catalog: sink, target_table_catalog, - } = gen_sink_plan(&session, context.clone(), stmt, partition_info)?; + } = gen_sink_plan(handle_args, stmt, None).await?; let has_order_by = !query.order_by.is_empty(); if has_order_by { - context.warn_to_user( + plan.ctx().warn_to_user( r#"The ORDER BY clause in the CREATE SINK statement has no effect at all."# .to_string(), ); @@ -757,7 +763,7 @@ fn derive_default_column_project_for_sink( /// Transforms the (format, encode, options) from sqlparser AST into an internal struct `SinkFormatDesc`. /// This is an analogy to (part of) [`crate::handler::create_source::bind_columns_from_source`] /// which transforms sqlparser AST `SourceSchemaV2` into `StreamSourceInfo`. -fn bind_sink_format_desc(value: ConnectorSchema) -> Result { +fn bind_sink_format_desc(session: &SessionImpl, value: ConnectorSchema) -> Result { use risingwave_connector::sink::catalog::{SinkEncode, SinkFormat}; use risingwave_connector::sink::encoder::TimestamptzHandlingMode; use risingwave_sqlparser::ast::{Encode as E, Format as F}; @@ -792,7 +798,11 @@ fn bind_sink_format_desc(value: ConnectorSchema) -> Result { } } - let mut options = WithOptions::try_from(value.row_options.as_slice())?.into_inner(); + let (mut options, secret_refs) = resolve_secret_ref_in_with_options( + WithOptions::try_from(value.row_options.as_slice())?, + session, + )? + .into_parts(); options .entry(TimestamptzHandlingMode::OPTION_KEY.to_owned()) @@ -802,6 +812,7 @@ fn bind_sink_format_desc(value: ConnectorSchema) -> Result { format, encode, options, + secret_refs, key_encode, }) } @@ -870,20 +881,6 @@ pub fn validate_compatibility(connector: &str, format_desc: &ConnectorSchema) -> Ok(()) } -/// For `planner_test` crate so that it does not depend directly on `connector` crate just for `SinkFormatDesc`. -impl TryFrom<&WithOptions> for Option { - type Error = risingwave_connector::sink::SinkError; - - fn try_from(value: &WithOptions) -> std::result::Result { - let connector = value.get(CONNECTOR_TYPE_KEY); - let r#type = value.get(SINK_TYPE_OPTION); - match (connector, r#type) { - (Some(c), Some(t)) => SinkFormatDesc::from_legacy_type(c, t), - _ => Ok(None), - } - } -} - #[cfg(test)] pub mod tests { use risingwave_common::catalog::{DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME}; diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index aac3649d808c..de838b7bebf9 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -27,6 +27,7 @@ use risingwave_common::catalog::{ debug_assert_column_ids_distinct, ColumnCatalog, ColumnDesc, ColumnId, Schema, TableId, INITIAL_SOURCE_VERSION_ID, KAFKA_TIMESTAMP_COLUMN_NAME, }; +use risingwave_common::secret::LocalSecretManager; use risingwave_common::types::DataType; use risingwave_connector::parser::additional_columns::{ build_additional_column_desc, get_supported_additional_columns, @@ -83,8 +84,8 @@ use crate::handler::HandlerArgs; use crate::optimizer::plan_node::generic::SourceNodeKind; use crate::optimizer::plan_node::{LogicalSource, ToStream, ToStreamContext}; use crate::session::SessionImpl; -use crate::utils::{resolve_privatelink_in_with_option, resolve_secret_in_with_options}; -use crate::{bind_data_type, build_graph, OptimizerContext, WithOptions}; +use crate::utils::{resolve_privatelink_in_with_option, resolve_secret_ref_in_with_options}; +use crate::{bind_data_type, build_graph, OptimizerContext, WithOptions, WithOptionsSecResolved}; pub(crate) const UPSTREAM_SOURCE_KEY: &str = "connector"; @@ -145,7 +146,7 @@ fn json_schema_infer_use_schema_registry(schema_config: &Option<(AstString, bool /// Map an Avro schema to a relational schema. async fn extract_avro_table_schema( info: &StreamSourceInfo, - with_properties: &BTreeMap, + with_properties: &WithOptionsSecResolved, format_encode_options: &mut BTreeMap, is_debezium: bool, ) -> Result> { @@ -181,7 +182,7 @@ async fn extract_avro_table_schema( async fn extract_debezium_avro_table_pk_columns( info: &StreamSourceInfo, - with_properties: &WithOptions, + with_properties: &WithOptionsSecResolved, ) -> Result> { let parser_config = SpecificParserConfig::new(info, with_properties)?; let conf = DebeziumAvroParserConfig::new(parser_config.encoding_config).await?; @@ -191,7 +192,7 @@ async fn extract_debezium_avro_table_pk_columns( /// Map a protobuf schema to a relational schema. async fn extract_protobuf_table_schema( schema: &ProtobufSchema, - with_properties: &BTreeMap, + with_properties: &WithOptionsSecResolved, format_encode_options: &mut BTreeMap, ) -> Result> { let info = StreamSourceInfo { @@ -299,15 +300,28 @@ fn get_name_strategy_or_default(name_strategy: Option) -> Result, + with_properties: Either<&WithOptions, &WithOptionsSecResolved>, ) -> Result<(Option>, StreamSourceInfo)> { const MESSAGE_NAME_KEY: &str = "message"; const KEY_MESSAGE_NAME_KEY: &str = "key.message"; const NAME_STRATEGY_KEY: &str = "schema.registry.name.strategy"; - let is_kafka: bool = with_properties.is_kafka_connector(); - let format_encode_options = WithOptions::try_from(source_schema.row_options())?.into_inner(); - let mut format_encode_options_to_consume = format_encode_options.clone(); + let options_with_secret = match with_properties { + Either::Left(options) => resolve_secret_ref_in_with_options(options.clone(), session)?, + Either::Right(options_with_secret) => options_with_secret.clone(), + }; + + let is_kafka: bool = options_with_secret.is_kafka_connector(); + let (format_encode_options, format_encode_secret_refs) = resolve_secret_ref_in_with_options( + WithOptions::try_from(source_schema.row_options())?, + session, + )? + .into_parts(); + // Need real secret to access the schema registry + let mut format_encode_options_to_consume = LocalSecretManager::global().fill_secrets( + format_encode_options.clone(), + format_encode_secret_refs.clone(), + )?; fn get_key_message_name(options: &mut BTreeMap) -> Option { consume_string_from_options(options, KEY_MESSAGE_NAME_KEY) @@ -334,6 +348,7 @@ pub(crate) async fn bind_columns_from_source( format: format_to_prost(&source_schema.format) as i32, row_encode: row_encode_to_prost(&source_schema.row_encode) as i32, format_encode_options, + format_encode_secret_refs, ..Default::default() }; @@ -376,7 +391,7 @@ pub(crate) async fn bind_columns_from_source( Some( extract_protobuf_table_schema( &protobuf_schema, - with_properties, + &options_with_secret, &mut format_encode_options_to_consume, ) .await?, @@ -424,7 +439,7 @@ pub(crate) async fn bind_columns_from_source( Some( extract_avro_table_schema( &stream_source_info, - with_properties, + &options_with_secret, &mut format_encode_options_to_consume, matches!(format, Format::Debezium), ) @@ -482,15 +497,15 @@ pub(crate) async fn bind_columns_from_source( extract_json_table_schema( &schema_config, - with_properties, + &options_with_secret, &mut format_encode_options_to_consume, ) .await? } (Format::None, Encode::None) => { - if with_properties.is_iceberg_connector() { + if options_with_secret.is_iceberg_connector() { Some( - extract_iceberg_columns(with_properties) + extract_iceberg_columns(&options_with_secret) .await .map_err(|err| ProtocolError(err.to_report_string()))?, ) @@ -526,8 +541,17 @@ fn bind_columns_from_source_for_cdc( session: &SessionImpl, source_schema: &ConnectorSchema, ) -> Result<(Option>, StreamSourceInfo)> { - let format_encode_options = WithOptions::try_from(source_schema.row_options())?.into_inner(); - let mut format_encode_options_to_consume = format_encode_options.clone(); + let (format_encode_options, format_encode_secret_refs) = resolve_secret_ref_in_with_options( + WithOptions::try_from(source_schema.row_options())?, + session, + )? + .into_parts(); + + // Need real secret to access the schema registry + let mut format_encode_options_to_consume = LocalSecretManager::global().fill_secrets( + format_encode_options.clone(), + format_encode_secret_refs.clone(), + )?; match (&source_schema.format, &source_schema.row_encode) { (Format::Plain, Encode::Json) => (), @@ -550,6 +574,7 @@ fn bind_columns_from_source_for_cdc( use_schema_registry: json_schema_infer_use_schema_registry(&schema_config), cdc_source_job: true, is_distributed: false, + format_encode_secret_refs, ..Default::default() }; if !format_encode_options_to_consume.is_empty() { @@ -763,7 +788,7 @@ pub(crate) async fn bind_source_pk( source_info: &StreamSourceInfo, columns: &mut [ColumnCatalog], sql_defined_pk_names: Vec, - with_properties: &WithOptions, + with_properties: &WithOptionsSecResolved, ) -> Result> { let sql_defined_pk = !sql_defined_pk_names.is_empty(); let include_key_column_name: Option = { @@ -1159,7 +1184,7 @@ pub fn validate_compatibility( /// One should only call this function after all properties of all columns are resolved, like /// generated column descriptors. pub(super) async fn check_source_schema( - props: &WithOptions, + props: &WithOptionsSecResolved, row_id_index: Option, columns: &[ColumnCatalog], ) -> Result<()> { @@ -1168,9 +1193,9 @@ pub(super) async fn check_source_schema( }; if connector == NEXMARK_CONNECTOR { - check_nexmark_schema(props.inner(), row_id_index, columns) + check_nexmark_schema(props, row_id_index, columns) } else if connector == ICEBERG_CONNECTOR { - Ok(check_iceberg_source(props.inner(), columns) + Ok(check_iceberg_source(props, columns) .await .map_err(|err| ProtocolError(err.to_report_string()))?) } else { @@ -1179,7 +1204,7 @@ pub(super) async fn check_source_schema( } pub(super) fn check_nexmark_schema( - props: &BTreeMap, + props: &WithOptionsSecResolved, row_id_index: Option, columns: &[ColumnCatalog], ) -> Result<()> { @@ -1233,7 +1258,7 @@ pub(super) fn check_nexmark_schema( } pub async fn extract_iceberg_columns( - with_properties: &BTreeMap, + with_properties: &WithOptionsSecResolved, ) -> anyhow::Result> { let props = ConnectorProperties::extract(with_properties.clone(), true)?; if let ConnectorProperties::Iceberg(properties) = props { @@ -1272,7 +1297,7 @@ pub async fn extract_iceberg_columns( } pub async fn check_iceberg_source( - props: &BTreeMap, + props: &WithOptionsSecResolved, columns: &[ColumnCatalog], ) -> anyhow::Result<()> { let props = ConnectorProperties::extract(props.clone(), true)?; @@ -1347,7 +1372,7 @@ pub fn bind_connector_props( .to_string(), ); } - Ok(WithOptions::new(with_properties)) + Ok(with_properties) } #[allow(clippy::too_many_arguments)] @@ -1421,6 +1446,13 @@ pub async fn bind_create_source_or_table_with_connector( check_and_add_timestamp_column(&with_properties, &mut columns); } + // resolve privatelink connection for Kafka + let mut with_properties = with_properties; + let connection_id = + resolve_privatelink_in_with_option(&mut with_properties, &schema_name, session)?; + + let with_properties = resolve_secret_ref_in_with_options(with_properties, session)?; + let pk_names = bind_source_pk( &source_schema, &source_info, @@ -1473,13 +1505,7 @@ pub async fn bind_create_source_or_table_with_connector( )?; check_source_schema(&with_properties, row_id_index, &columns).await?; - // resolve privatelink connection for Kafka - let mut with_properties = with_properties; - let connection_id = - resolve_privatelink_in_with_option(&mut with_properties, &schema_name, session)?; - let _secret_ref = resolve_secret_in_with_options(&mut with_properties, session)?; - - let definition: String = handler_args.normalized_sql.clone(); + let definition = handler_args.normalized_sql.clone(); let associated_table_id = if is_create_source { None @@ -1495,7 +1521,7 @@ pub async fn bind_create_source_or_table_with_connector( owner: session.user_id(), info: source_info, row_id_index, - with_properties: with_properties.into_inner().into_iter().collect(), + with_properties, watermark_descs, associated_table_id, definition, @@ -1539,7 +1565,7 @@ pub async fn handle_create_source( let (columns_from_resolve_source, mut source_info) = if create_cdc_source_job { bind_columns_from_source_for_cdc(&session, &source_schema)? } else { - bind_columns_from_source(&session, &source_schema, &with_properties).await? + bind_columns_from_source(&session, &source_schema, Either::Left(&with_properties)).await? }; if is_shared { // Note: this field should be called is_shared. Check field doc for more details. diff --git a/src/frontend/src/handler/create_table.rs b/src/frontend/src/handler/create_table.rs index ef8b57f8064d..5dce6028d0cc 100644 --- a/src/frontend/src/handler/create_table.rs +++ b/src/frontend/src/handler/create_table.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{BTreeMap, HashMap}; +use std::collections::HashMap; use std::rc::Rc; use std::sync::Arc; @@ -28,10 +28,10 @@ use risingwave_common::catalog::{ }; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::util::value_encoding::DatumToProtoExt; -use risingwave_connector::source; use risingwave_connector::source::cdc::external::{ ExternalTableConfig, ExternalTableImpl, DATABASE_NAME_KEY, SCHEMA_NAME_KEY, TABLE_NAME_KEY, }; +use risingwave_connector::{source, WithOptionsSecResolved}; use risingwave_pb::catalog::{PbSource, PbTable, Table, WatermarkDesc}; use risingwave_pb::ddl_service::TableJobType; use risingwave_pb::plan_common::column_desc::GeneratedOrDefaultColumn; @@ -486,7 +486,7 @@ pub(crate) async fn gen_create_table_plan_with_source( let with_properties = bind_connector_props(&handler_args, &source_schema, false)?; let (columns_from_resolve_source, source_info) = - bind_columns_from_source(session, &source_schema, &with_properties).await?; + bind_columns_from_source(session, &source_schema, Either::Left(&with_properties)).await?; let (source_catalog, database_id, schema_id) = bind_create_source_or_table_with_connector( handler_args.clone(), @@ -541,6 +541,12 @@ pub(crate) fn gen_create_table_plan( for c in &mut columns { c.column_desc.column_id = col_id_gen.generate(c.name()) } + + let (_, secret_refs) = context.with_options().clone().into_parts(); + if !secret_refs.is_empty() { + return Err(crate::error::ErrorCode::InvalidParameterValue("Secret reference is not allowed in options when creating table without external source".to_string()).into()); + } + gen_create_table_plan_without_source( context, table_name, @@ -740,7 +746,7 @@ pub(crate) fn gen_create_table_plan_for_cdc_table( external_table_name: String, mut columns: Vec, pk_names: Vec, - connect_properties: BTreeMap, + connect_properties: WithOptionsSecResolved, mut col_id_gen: ColumnIdGenerator, on_conflict: Option, with_version_column: Option, @@ -786,6 +792,8 @@ pub(crate) fn gen_create_table_plan_for_cdc_table( .map(|idx| ColumnOrder::new(*idx, OrderType::ascending())) .collect(); + let (options, secret_refs) = connect_properties.into_parts(); + let cdc_table_desc = CdcTableDesc { table_id, source_id: source.id.into(), // id of cdc source streaming job @@ -793,7 +801,8 @@ pub(crate) fn gen_create_table_plan_for_cdc_table( pk: table_pk, columns: columns.iter().map(|c| c.column_desc.clone()).collect(), stream_key: pk_column_indices, - connect_properties: connect_properties.into_iter().collect(), + connect_properties: options, + secret_refs, }; tracing::debug!(?cdc_table_desc, "create cdc table"); @@ -841,9 +850,9 @@ pub(crate) fn gen_create_table_plan_for_cdc_table( } fn derive_connect_properties( - source_with_properties: &BTreeMap, + source_with_properties: &WithOptionsSecResolved, external_table_name: String, -) -> Result> { +) -> Result { use source::cdc::{MYSQL_CDC_CONNECTOR, POSTGRES_CDC_CONNECTOR}; // we should remove the prefix from `full_table_name` let mut connect_properties = source_with_properties.clone(); @@ -1091,12 +1100,13 @@ fn sanity_check_for_cdc_table( async fn derive_schema_for_cdc_table( column_defs: &Vec, constraints: &Vec, - connect_properties: BTreeMap, + connect_properties: WithOptionsSecResolved, need_auto_schema_map: bool, ) -> Result<(Vec, Vec)> { // read cdc table schema from external db or parsing the schema from SQL definitions if need_auto_schema_map { - let config = ExternalTableConfig::try_from_btreemap(connect_properties) + let (options, secret_refs) = connect_properties.into_parts(); + let config = ExternalTableConfig::try_from_btreemap(options, secret_refs) .context("failed to extract external table config")?; let table = ExternalTableImpl::connect(config) @@ -1205,7 +1215,7 @@ pub fn check_create_table_with_source( if cdc_table_info.is_some() { return Ok(source_schema); } - let defined_source = with_options.inner().contains_key(UPSTREAM_SOURCE_KEY); + let defined_source = with_options.contains_key(UPSTREAM_SOURCE_KEY); if !include_column_options.is_empty() && !defined_source { return Err(ErrorCode::InvalidInputSyntax( "INCLUDE should be used with a connector".to_owned(), diff --git a/src/frontend/src/handler/create_table_as.rs b/src/frontend/src/handler/create_table_as.rs index 9a01d2919086..bb00be2dfa48 100644 --- a/src/frontend/src/handler/create_table_as.rs +++ b/src/frontend/src/handler/create_table_as.rs @@ -90,6 +90,13 @@ pub async fn handle_create_as( let (graph, source, table) = { let context = OptimizerContext::from_handler_args(handler_args.clone()); + let (_, secret_refs) = context.with_options().clone().into_parts(); + if !secret_refs.is_empty() { + return Err(crate::error::ErrorCode::InvalidParameterValue( + "Secret reference is not allowed in options for CREATE TABLE AS".to_string(), + ) + .into()); + } let (plan, table) = gen_create_table_plan_without_source( context, table_name.clone(), diff --git a/src/frontend/src/handler/create_view.rs b/src/frontend/src/handler/create_view.rs index 673fc149dd8c..5ad0e8956b96 100644 --- a/src/frontend/src/handler/create_view.rs +++ b/src/frontend/src/handler/create_view.rs @@ -87,12 +87,20 @@ pub async fn handle_create_view( .collect() }; + let (properties, secret_refs) = properties.into_parts(); + if !secret_refs.is_empty() { + return Err(crate::error::ErrorCode::InvalidParameterValue( + "Secret reference is not allowed in create view options".to_string(), + ) + .into()); + } + let view = PbView { id: 0, schema_id, database_id, name: view_name, - properties: properties.inner().clone().into_iter().collect(), + properties, owner: session.user_id(), dependent_relations: dependent_relations .into_iter() diff --git a/src/frontend/src/handler/drop_secret.rs b/src/frontend/src/handler/drop_secret.rs index 37fbd2cedd40..4001cc99cfd8 100644 --- a/src/frontend/src/handler/drop_secret.rs +++ b/src/frontend/src/handler/drop_secret.rs @@ -52,13 +52,11 @@ pub async fn handle_drop_secret( }; session.check_privilege_for_drop_alter(schema_name, &**secret)?; - secret.secret_id + secret.id }; let catalog_writer = session.catalog_writer()?; catalog_writer.drop_secret(secret_id).await?; - Ok(RwPgResponse::builder(StatementType::DROP_SECRET) - .notice(format!("dropped secret \"{}\"", secret_name)) - .into()) + Ok(RwPgResponse::empty_result(StatementType::DROP_SECRET)) } diff --git a/src/frontend/src/handler/explain.rs b/src/frontend/src/handler/explain.rs index db124b373181..ed22462a6688 100644 --- a/src/frontend/src/handler/explain.rs +++ b/src/frontend/src/handler/explain.rs @@ -21,7 +21,7 @@ use thiserror_ext::AsReport; use super::create_index::{gen_create_index_plan, resolve_index_schema}; use super::create_mv::gen_create_mv_plan; -use super::create_sink::{gen_sink_plan, get_partition_compute_info}; +use super::create_sink::gen_sink_plan; use super::query::gen_batch_plan_by_statement; use super::util::SourceSchemaCompatExt; use super::{RwPgResponse, RwPgResponseBuilderExt}; @@ -87,9 +87,8 @@ async fn do_handle_explain( (Ok(plan), context) } Statement::CreateSink { stmt } => { - let partition_info = get_partition_compute_info(&handler_args.with_options).await?; - let context = OptimizerContext::new(handler_args, explain_options); - let plan = gen_sink_plan(&session, context.into(), stmt, partition_info) + let plan = gen_sink_plan(handler_args, stmt, Some(explain_options)) + .await .map(|plan| plan.sink_plan)?; let context = plan.ctx(); (Ok(plan), context) diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index bb27c50053ad..185d65cb567a 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -65,7 +65,7 @@ use risingwave_common::util::meta_addr::MetaAddressStrategy; use risingwave_common::util::tokio_util::sync::CancellationToken; pub use stream_fragmenter::build_graph; mod utils; -pub use utils::{explain_stream_graph, WithOptions}; +pub use utils::{explain_stream_graph, WithOptions, WithOptionsSecResolved}; pub(crate) mod error; mod meta_client; pub mod test_utils; @@ -143,6 +143,15 @@ pub struct FrontendOpts { #[clap(long, hide = true, env = "RW_ENABLE_BARRIER_READ")] #[override_opts(path = batch.enable_barrier_read)] pub enable_barrier_read: Option, + + /// The path of the temp secret file directory. + #[clap( + long, + hide = true, + env = "RW_TEMP_SECRET_FILE_DIR", + default_value = "./secrets" + )] + pub temp_secret_file_dir: String, } impl risingwave_common::opts::Opts for FrontendOpts { diff --git a/src/frontend/src/observer/observer_manager.rs b/src/frontend/src/observer/observer_manager.rs index 169c6bff1b95..80f6316f9ca9 100644 --- a/src/frontend/src/observer/observer_manager.rs +++ b/src/frontend/src/observer/observer_manager.rs @@ -20,6 +20,7 @@ use parking_lot::RwLock; use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeManagerRef; use risingwave_common::catalog::CatalogVersion; use risingwave_common::hash::WorkerSlotMapping; +use risingwave_common::secret::LocalSecretManager; use risingwave_common::session_config::SessionConfig; use risingwave_common::system_param::local_manager::LocalSystemParamsManagerRef; use risingwave_common_service::ObserverState; @@ -65,10 +66,13 @@ impl ObserverState for FrontendObserverNode { | Info::Schema(_) | Info::RelationGroup(_) | Info::Function(_) - | Info::Secret(_) | Info::Connection(_) => { self.handle_catalog_notification(resp); } + Info::Secret(_) => { + self.handle_catalog_notification(resp.clone()); + self.handle_secret_notification(resp); + } Info::Node(node) => { self.update_worker_node_manager(resp.operation(), node); } @@ -178,8 +182,8 @@ impl ObserverState for FrontendObserverNode { for connection in connections { catalog_guard.create_connection(&connection) } - for secret in secrets { - catalog_guard.create_secret(&secret) + for secret in &secrets { + catalog_guard.create_secret(secret) } for user in users { user_guard.create_user(user) @@ -204,6 +208,7 @@ impl ObserverState for FrontendObserverNode { .unwrap(); *self.session_params.write() = serde_json::from_str(&session_params.unwrap().params).unwrap(); + LocalSecretManager::global().init_secrets(secrets); } } @@ -353,16 +358,21 @@ impl FrontendObserverNode { Operation::Update => catalog_guard.update_connection(connection), _ => panic!("receive an unsupported notify {:?}", resp), }, - Info::Secret(secret) => match resp.operation() { - Operation::Add => catalog_guard.create_secret(secret), - Operation::Delete => catalog_guard.drop_secret( - secret.database_id, - secret.schema_id, - SecretId::new(secret.id), - ), - Operation::Update => catalog_guard.update_secret(secret), - _ => panic!("receive an unsupported notify {:?}", resp), - }, + Info::Secret(secret) => { + let mut secret = secret.clone(); + // The secret value should not be revealed to users. So mask it in the frontend catalog. + secret.value = "SECRET VALUE SHOULD NOT BE REVEALED".as_bytes().to_vec(); + match resp.operation() { + Operation::Add => catalog_guard.create_secret(&secret), + Operation::Delete => catalog_guard.drop_secret( + secret.database_id, + secret.schema_id, + SecretId::new(secret.id), + ), + Operation::Update => catalog_guard.update_secret(&secret), + _ => panic!("receive an unsupported notify {:?}", resp), + } + } _ => unreachable!(), } assert!( @@ -471,6 +481,24 @@ impl FrontendObserverNode { } } + fn handle_secret_notification(&mut self, resp: SubscribeResponse) { + let resp_op = resp.operation(); + let Some(Info::Secret(secret)) = resp.info else { + unreachable!(); + }; + match resp_op { + Operation::Add => { + LocalSecretManager::global().add_secret(secret.id, secret.value); + } + Operation::Delete => { + LocalSecretManager::global().remove_secret(secret.id); + } + _ => { + panic!("error type notification"); + } + } + } + /// `update_worker_node_manager` is called in `start` method. /// It calls `add_worker_node` and `remove_worker_node` of `WorkerNodeManager`. fn update_worker_node_manager(&self, operation: Operation, node: WorkerNode) { diff --git a/src/frontend/src/optimizer/mod.rs b/src/frontend/src/optimizer/mod.rs index c33625491082..491f3b71875b 100644 --- a/src/frontend/src/optimizer/mod.rs +++ b/src/frontend/src/optimizer/mod.rs @@ -81,8 +81,7 @@ use crate::optimizer::plan_node::{ }; use crate::optimizer::plan_visitor::TemporalJoinValidator; use crate::optimizer::property::Distribution; -use crate::utils::ColIndexMappingRewriteExt; -use crate::WithOptions; +use crate::utils::{ColIndexMappingRewriteExt, WithOptionsSecResolved}; /// `PlanRoot` is used to describe a plan. planner will construct a `PlanRoot` with `LogicalNode`. /// and required distribution and order. And `PlanRoot` can generate corresponding streaming or @@ -926,7 +925,7 @@ impl PlanRoot { &mut self, sink_name: String, definition: String, - properties: WithOptions, + properties: WithOptionsSecResolved, emit_on_window_close: bool, db_name: String, sink_from_table_name: String, diff --git a/src/frontend/src/optimizer/plan_node/batch_iceberg_scan.rs b/src/frontend/src/optimizer/plan_node/batch_iceberg_scan.rs index 0ac245175954..3433feb8d210 100644 --- a/src/frontend/src/optimizer/plan_node/batch_iceberg_scan.rs +++ b/src/frontend/src/optimizer/plan_node/batch_iceberg_scan.rs @@ -98,6 +98,7 @@ impl ToDistributedBatch for BatchIcebergScan { impl ToBatchPb for BatchIcebergScan { fn to_batch_prost_body(&self) -> NodeBody { let source_catalog = self.source_catalog().unwrap(); + let (with_properties, secret_refs) = source_catalog.with_properties.clone().into_parts(); NodeBody::Source(SourceNode { source_id: source_catalog.id, info: Some(source_catalog.info.clone()), @@ -107,9 +108,9 @@ impl ToBatchPb for BatchIcebergScan { .iter() .map(|c| c.to_protobuf()) .collect(), - with_properties: source_catalog.with_properties.clone().into_iter().collect(), + with_properties, split: vec![], - secret_refs: Default::default(), + secret_refs, }) } } diff --git a/src/frontend/src/optimizer/plan_node/batch_kafka_scan.rs b/src/frontend/src/optimizer/plan_node/batch_kafka_scan.rs index 8dca30512139..0883f3aa697c 100644 --- a/src/frontend/src/optimizer/plan_node/batch_kafka_scan.rs +++ b/src/frontend/src/optimizer/plan_node/batch_kafka_scan.rs @@ -120,6 +120,7 @@ impl ToDistributedBatch for BatchKafkaScan { impl ToBatchPb for BatchKafkaScan { fn to_batch_prost_body(&self) -> NodeBody { let source_catalog = self.source_catalog().unwrap(); + let (with_properties, secret_refs) = source_catalog.with_properties.clone().into_parts(); NodeBody::Source(SourceNode { source_id: source_catalog.id, info: Some(source_catalog.info.clone()), @@ -129,9 +130,9 @@ impl ToBatchPb for BatchKafkaScan { .iter() .map(|c| c.to_protobuf()) .collect(), - with_properties: source_catalog.with_properties.clone().into_iter().collect(), + with_properties, split: vec![], - secret_refs: Default::default(), + secret_refs, }) } } diff --git a/src/frontend/src/optimizer/plan_node/batch_source.rs b/src/frontend/src/optimizer/plan_node/batch_source.rs index ac5e25d70363..fd33d2dba003 100644 --- a/src/frontend/src/optimizer/plan_node/batch_source.rs +++ b/src/frontend/src/optimizer/plan_node/batch_source.rs @@ -102,6 +102,7 @@ impl ToDistributedBatch for BatchSource { impl ToBatchPb for BatchSource { fn to_batch_prost_body(&self) -> NodeBody { let source_catalog = self.source_catalog().unwrap(); + let (with_properties, secret_refs) = source_catalog.with_properties.clone().into_parts(); NodeBody::Source(SourceNode { source_id: source_catalog.id, info: Some(source_catalog.info.clone()), @@ -111,9 +112,9 @@ impl ToBatchPb for BatchSource { .iter() .map(|c| c.to_protobuf()) .collect(), - with_properties: source_catalog.with_properties.clone().into_iter().collect(), + with_properties, split: vec![], - secret_refs: Default::default(), + secret_refs, }) } } diff --git a/src/frontend/src/optimizer/plan_node/stream_fs_fetch.rs b/src/frontend/src/optimizer/plan_node/stream_fs_fetch.rs index 43c14686e2a2..c57494123672 100644 --- a/src/frontend/src/optimizer/plan_node/stream_fs_fetch.rs +++ b/src/frontend/src/optimizer/plan_node/stream_fs_fetch.rs @@ -97,25 +97,30 @@ impl StreamNode for StreamFsFetch { fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> NodeBody { // `StreamFsFetch` is same as source in proto def, so the following code is the same as `StreamSource` let source_catalog = self.source_catalog(); - let source_inner = source_catalog.map(|source_catalog| PbStreamFsFetch { - source_id: source_catalog.id, - source_name: source_catalog.name.clone(), - state_table: Some( - generic::Source::infer_internal_table_catalog(true) - .with_id(state.gen_table_id_wrapped()) - .to_internal_table_prost(), - ), - info: Some(source_catalog.info.clone()), - row_id_index: self.core.row_id_index.map(|index| index as _), - columns: self - .core - .column_catalog - .iter() - .map(|c| c.to_protobuf()) - .collect_vec(), - with_properties: source_catalog.with_properties.clone().into_iter().collect(), - rate_limit: self.base.ctx().overwrite_options().streaming_rate_limit, - secret_refs: Default::default(), + + let source_inner = source_catalog.map(|source_catalog| { + let (with_properties, secret_refs) = + source_catalog.with_properties.clone().into_parts(); + PbStreamFsFetch { + source_id: source_catalog.id, + source_name: source_catalog.name.clone(), + state_table: Some( + generic::Source::infer_internal_table_catalog(true) + .with_id(state.gen_table_id_wrapped()) + .to_internal_table_prost(), + ), + info: Some(source_catalog.info.clone()), + row_id_index: self.core.row_id_index.map(|index| index as _), + columns: self + .core + .column_catalog + .iter() + .map(|c| c.to_protobuf()) + .collect_vec(), + with_properties, + rate_limit: self.base.ctx().overwrite_options().streaming_rate_limit, + secret_refs, + } }); NodeBody::StreamFsFetch(StreamFsFetchNode { node_inner: source_inner, diff --git a/src/frontend/src/optimizer/plan_node/stream_sink.rs b/src/frontend/src/optimizer/plan_node/stream_sink.rs index 49378ab0a53e..dcdd15b067a0 100644 --- a/src/frontend/src/optimizer/plan_node/stream_sink.rs +++ b/src/frontend/src/optimizer/plan_node/stream_sink.rs @@ -49,7 +49,8 @@ use crate::optimizer::plan_node::utils::plan_has_backfill_leaf_nodes; use crate::optimizer::plan_node::PlanTreeNodeUnary; use crate::optimizer::property::{Distribution, Order, RequiredDist}; use crate::stream_fragmenter::BuildFragmentGraphState; -use crate::{TableCatalog, WithOptions}; +use crate::utils::WithOptionsSecResolved; +use crate::TableCatalog; const DOWNSTREAM_PK_KEY: &str = "primary_key"; @@ -205,7 +206,7 @@ impl StreamSink { user_cols: FixedBitSet, out_names: Vec, definition: String, - properties: WithOptions, + properties: WithOptionsSecResolved, format_desc: Option, partition_info: Option, ) -> Result { @@ -308,7 +309,7 @@ impl StreamSink { user_order_by: Order, columns: Vec, definition: String, - properties: WithOptions, + properties: WithOptionsSecResolved, format_desc: Option, partition_info: Option, ) -> Result<(PlanRef, SinkDesc)> { @@ -381,6 +382,7 @@ impl StreamSink { } else { CreateType::Foreground }; + let (properties, secret_refs) = properties.into_parts(); let sink_desc = SinkDesc { id: SinkId::placeholder(), name, @@ -391,7 +393,8 @@ impl StreamSink { plan_pk: pk, downstream_pk, distribution_key, - properties: properties.into_inner(), + properties, + secret_refs, sink_type, format_desc, target_table, @@ -401,7 +404,7 @@ impl StreamSink { Ok((input, sink_desc)) } - fn is_user_defined_append_only(properties: &WithOptions) -> Result { + fn is_user_defined_append_only(properties: &WithOptionsSecResolved) -> Result { if let Some(sink_type) = properties.get(SINK_TYPE_OPTION) { if sink_type != SINK_TYPE_APPEND_ONLY && sink_type != SINK_TYPE_DEBEZIUM @@ -423,7 +426,7 @@ impl StreamSink { Ok(properties.value_eq_ignore_case(SINK_TYPE_OPTION, SINK_TYPE_APPEND_ONLY)) } - fn is_user_force_append_only(properties: &WithOptions) -> Result { + fn is_user_force_append_only(properties: &WithOptionsSecResolved) -> Result { if properties.contains_key(SINK_USER_FORCE_APPEND_ONLY_OPTION) && !properties.value_eq_ignore_case(SINK_USER_FORCE_APPEND_ONLY_OPTION, "true") && !properties.value_eq_ignore_case(SINK_USER_FORCE_APPEND_ONLY_OPTION, "false") @@ -442,14 +445,16 @@ impl StreamSink { fn derive_sink_type( input_append_only: bool, - properties: &WithOptions, + properties: &WithOptionsSecResolved, format_desc: Option<&SinkFormatDesc>, ) -> Result { let frontend_derived_append_only = input_append_only; let (user_defined_append_only, user_force_append_only, syntax_legacy) = match format_desc { Some(f) => ( f.format == SinkFormat::AppendOnly, - Self::is_user_force_append_only(&WithOptions::new(f.options.clone()))?, + Self::is_user_force_append_only(&WithOptionsSecResolved::without_secrets( + f.options.clone(), + ))?, false, ), None => ( diff --git a/src/frontend/src/optimizer/plan_node/stream_source.rs b/src/frontend/src/optimizer/plan_node/stream_source.rs index 83ad14886872..7b0703aa8436 100644 --- a/src/frontend/src/optimizer/plan_node/stream_source.rs +++ b/src/frontend/src/optimizer/plan_node/stream_source.rs @@ -91,25 +91,29 @@ impl Distill for StreamSource { impl StreamNode for StreamSource { fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> PbNodeBody { let source_catalog = self.source_catalog(); - let source_inner = source_catalog.map(|source_catalog| PbStreamSource { - source_id: source_catalog.id, - source_name: source_catalog.name.clone(), - state_table: Some( - generic::Source::infer_internal_table_catalog(false) - .with_id(state.gen_table_id_wrapped()) - .to_internal_table_prost(), - ), - info: Some(source_catalog.info.clone()), - row_id_index: self.core.row_id_index.map(|index| index as _), - columns: self - .core - .column_catalog - .iter() - .map(|c| c.to_protobuf()) - .collect_vec(), - with_properties: source_catalog.with_properties.clone().into_iter().collect(), - rate_limit: self.base.ctx().overwrite_options().streaming_rate_limit, - secret_refs: Default::default(), + let source_inner = source_catalog.map(|source_catalog| { + let (with_properties, secret_refs) = + source_catalog.with_properties.clone().into_parts(); + PbStreamSource { + source_id: source_catalog.id, + source_name: source_catalog.name.clone(), + state_table: Some( + generic::Source::infer_internal_table_catalog(false) + .with_id(state.gen_table_id_wrapped()) + .to_internal_table_prost(), + ), + info: Some(source_catalog.info.clone()), + row_id_index: self.core.row_id_index.map(|index| index as _), + columns: self + .core + .column_catalog + .iter() + .map(|c| c.to_protobuf()) + .collect_vec(), + with_properties, + rate_limit: self.base.ctx().overwrite_options().streaming_rate_limit, + secret_refs, + } }); PbNodeBody::Source(SourceNode { source_inner }) } diff --git a/src/frontend/src/optimizer/plan_node/stream_source_scan.rs b/src/frontend/src/optimizer/plan_node/stream_source_scan.rs index 0449648798ea..b947cee641d4 100644 --- a/src/frontend/src/optimizer/plan_node/stream_source_scan.rs +++ b/src/frontend/src/optimizer/plan_node/stream_source_scan.rs @@ -139,6 +139,7 @@ impl StreamSourceScan { .collect_vec(); let source_catalog = self.source_catalog(); + let (with_properties, secret_refs) = source_catalog.with_properties.clone().into_parts(); let backfill = SourceBackfillNode { upstream_source_id: source_catalog.id, source_name: source_catalog.name.clone(), @@ -155,9 +156,9 @@ impl StreamSourceScan { .iter() .map(|c| c.to_protobuf()) .collect_vec(), - with_properties: source_catalog.with_properties.clone().into_iter().collect(), + with_properties, rate_limit: self.base.ctx().overwrite_options().streaming_rate_limit, - secret_refs: Default::default(), + secret_refs, }; let fields = self.schema().to_prost(); diff --git a/src/frontend/src/optimizer/plan_visitor/distributed_dml_visitor.rs b/src/frontend/src/optimizer/plan_visitor/distributed_dml_visitor.rs index f79bf5978996..672e886862d8 100644 --- a/src/frontend/src/optimizer/plan_visitor/distributed_dml_visitor.rs +++ b/src/frontend/src/optimizer/plan_visitor/distributed_dml_visitor.rs @@ -40,10 +40,7 @@ impl DistributedDmlVisitor { } fn is_iceberg_source(source_catalog: &Rc) -> bool { - let property = ConnectorProperties::extract( - source_catalog.with_properties.clone().into_iter().collect(), - false, - ); + let property = ConnectorProperties::extract(source_catalog.with_properties.clone(), false); if let Ok(property) = property { matches!(property, ConnectorProperties::Iceberg(_)) } else { diff --git a/src/frontend/src/scheduler/plan_fragmenter.rs b/src/frontend/src/scheduler/plan_fragmenter.rs index 4dfa5f5cd915..ee6f49486c1b 100644 --- a/src/frontend/src/scheduler/plan_fragmenter.rs +++ b/src/frontend/src/scheduler/plan_fragmenter.rs @@ -1011,10 +1011,8 @@ impl BatchPlanFragmenter { let batch_kafka_scan: &BatchKafkaScan = batch_kafka_node; let source_catalog = batch_kafka_scan.source_catalog(); if let Some(source_catalog) = source_catalog { - let property = ConnectorProperties::extract( - source_catalog.with_properties.clone().into_iter().collect(), - false, - )?; + let property = + ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?; let timestamp_bound = batch_kafka_scan.kafka_timestamp_range_value(); return Ok(Some(SourceScanInfo::new(SourceFetchInfo { connector: property, @@ -1026,10 +1024,8 @@ impl BatchPlanFragmenter { let batch_iceberg_scan: &BatchIcebergScan = batch_iceberg_scan; let source_catalog = batch_iceberg_scan.source_catalog(); if let Some(source_catalog) = source_catalog { - let property = ConnectorProperties::extract( - source_catalog.with_properties.clone().into_iter().collect(), - false, - )?; + let property = + ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?; let as_of = batch_iceberg_scan.as_of(); return Ok(Some(SourceScanInfo::new(SourceFetchInfo { connector: property, @@ -1042,10 +1038,8 @@ impl BatchPlanFragmenter { let source_node: &BatchSource = source_node; let source_catalog = source_node.source_catalog(); if let Some(source_catalog) = source_catalog { - let property = ConnectorProperties::extract( - source_catalog.with_properties.clone().into_iter().collect(), - false, - )?; + let property = + ConnectorProperties::extract(source_catalog.with_properties.clone(), false)?; let as_of = source_node.as_of(); return Ok(Some(SourceScanInfo::new(SourceFetchInfo { connector: property, diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 5f0ad6d62b75..7c694ba00266 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -51,6 +51,7 @@ use risingwave_common::config::{ load_config, BatchConfig, MetaConfig, MetricLevel, StreamingConfig, }; use risingwave_common::memory::MemoryContext; +use risingwave_common::secret::LocalSecretManager; use risingwave_common::session_config::{ConfigReporter, SessionConfig, VisibilityMode}; use risingwave_common::system_param::local_manager::{ LocalSystemParamsManager, LocalSystemParamsManagerRef, @@ -85,6 +86,7 @@ use crate::binder::{Binder, BoundStatement, ResolveQualifiedNameError}; use crate::catalog::catalog_service::{CatalogReader, CatalogWriter, CatalogWriterImpl}; use crate::catalog::connection_catalog::ConnectionCatalog; use crate::catalog::root_catalog::Catalog; +use crate::catalog::secret_catalog::SecretCatalog; use crate::catalog::subscription_catalog::SubscriptionCatalog; use crate::catalog::{ check_schema_writable, CatalogError, DatabaseId, OwnedByUserCatalog, SchemaId, TableId, @@ -317,6 +319,12 @@ impl FrontendEnv { let system_params_manager = Arc::new(LocalSystemParamsManager::new(system_params_reader.clone())); + LocalSecretManager::init( + opts.temp_secret_file_dir, + meta_client.cluster_id().to_string(), + worker_id, + ); + // This `session_params` should be initialized during the initial notification in `observer_manager` let session_params = Arc::new(RwLock::new(SessionConfig::default())); let frontend_observer_node = FrontendObserverNode::new( @@ -992,6 +1000,30 @@ impl SessionImpl { Ok(table.clone()) } + pub fn get_secret_by_name( + &self, + schema_name: Option, + secret_name: &str, + ) -> Result> { + let db_name = self.database(); + let search_path = self.config().search_path(); + let user_name = &self.auth_context().user_name; + + let catalog_reader = self.env().catalog_reader().read_guard(); + let schema = match schema_name { + Some(schema_name) => catalog_reader.get_schema_by_name(db_name, &schema_name)?, + None => catalog_reader.first_valid_schema(db_name, &search_path, user_name)?, + }; + let schema = catalog_reader.get_schema_by_name(db_name, schema.name().as_str())?; + let secret = schema.get_secret_by_name(secret_name).ok_or_else(|| { + RwError::from(ErrorCode::ItemNotFound(format!( + "secret {} not found", + secret_name + ))) + })?; + Ok(secret.clone()) + } + pub async fn list_change_log_epochs( &self, table_id: u32, diff --git a/src/frontend/src/user/user_authentication.rs b/src/frontend/src/user/user_authentication.rs index b0cefc1faedc..a36c09e987a7 100644 --- a/src/frontend/src/user/user_authentication.rs +++ b/src/frontend/src/user/user_authentication.rs @@ -35,9 +35,8 @@ pub const OAUTH_ISSUER_KEY: &str = "issuer"; /// Build `AuthInfo` for `OAuth`. #[inline(always)] pub fn build_oauth_info(options: &Vec) -> Option { - let metadata: HashMap = WithOptions::try_from(options.as_slice()) + let metadata: HashMap = WithOptions::oauth_options_to_map(options.as_slice()) .ok()? - .into_inner() .into_iter() .collect(); if !metadata.contains_key(OAUTH_JWKS_URL_KEY) || !metadata.contains_key(OAUTH_ISSUER_KEY) { diff --git a/src/frontend/src/utils/overwrite_options.rs b/src/frontend/src/utils/overwrite_options.rs index e1719c348107..904518331234 100644 --- a/src/frontend/src/utils/overwrite_options.rs +++ b/src/frontend/src/utils/overwrite_options.rs @@ -25,11 +25,7 @@ impl OverwriteOptions { pub fn new(args: &mut HandlerArgs) -> Self { let streaming_rate_limit = { // CREATE MATERIALIZED VIEW m1 WITH (rate_limit = N) ... - if let Some(x) = args - .with_options - .inner_mut() - .remove(Self::STREAMING_RATE_LIMIT_KEY) - { + if let Some(x) = args.with_options.remove(Self::STREAMING_RATE_LIMIT_KEY) { // FIXME(tabVersion): validate the value Some(x.parse::().unwrap()) } else { diff --git a/src/frontend/src/utils/with_options.rs b/src/frontend/src/utils/with_options.rs index a6984d687bc7..2456dceb9fb1 100644 --- a/src/frontend/src/utils/with_options.rs +++ b/src/frontend/src/utils/with_options.rs @@ -18,11 +18,13 @@ use std::num::NonZeroU32; use risingwave_connector::source::kafka::private_link::{ insert_privatelink_broker_rewrite_map, CONNECTION_NAME_KEY, PRIVATELINK_ENDPOINT_KEY, }; +pub use risingwave_connector::WithOptionsSecResolved; use risingwave_connector::WithPropertiesExt; +use risingwave_pb::secret::secret_ref::PbRefAsType; use risingwave_pb::secret::PbSecretRef; use risingwave_sqlparser::ast::{ CreateConnectionStatement, CreateSinkStatement, CreateSourceStatement, - CreateSubscriptionStatement, SqlOption, Statement, Value, + CreateSubscriptionStatement, SecretRef, SecretRefAsType, SqlOption, Statement, Value, }; use super::OverwriteOptions; @@ -30,16 +32,18 @@ use crate::catalog::connection_catalog::resolve_private_link_connection; use crate::catalog::ConnectionId; use crate::error::{ErrorCode, Result as RwResult, RwError}; use crate::session::SessionImpl; +use crate::Binder; mod options { pub const RETENTION_SECONDS: &str = "retention_seconds"; } -/// Options or properties extracted fro m the `WITH` clause of DDLs. +/// Options or properties extracted from the `WITH` clause of DDLs. #[derive(Default, Clone, Debug, PartialEq, Eq, Hash)] pub struct WithOptions { inner: BTreeMap, + secret_ref: BTreeMap, } impl std::ops::Deref for WithOptions { @@ -58,35 +62,42 @@ impl std::ops::DerefMut for WithOptions { impl WithOptions { /// Create a new [`WithOptions`] from a [`BTreeMap`]. - pub fn new(inner: BTreeMap) -> Self { + pub fn new_with_options(inner: BTreeMap) -> Self { Self { - inner: inner.into_iter().collect(), + inner, + secret_ref: Default::default(), } } - /// Get the reference of the inner map. - pub fn inner(&self) -> &BTreeMap { - &self.inner + /// Create a new [`WithOptions`] from a option [`BTreeMap`] and secret ref. + pub fn new(inner: BTreeMap, secret_ref: BTreeMap) -> Self { + Self { inner, secret_ref } } pub fn inner_mut(&mut self) -> &mut BTreeMap { &mut self.inner } - /// Take the value of the inner map. - pub fn into_inner(self) -> BTreeMap { - self.inner + /// Take the value of the option map and secret refs. + pub fn into_parts(self) -> (BTreeMap, BTreeMap) { + (self.inner, self.secret_ref) } /// Convert to connector props, remove the key-value pairs used in the top-level. - pub fn into_connector_props(self) -> BTreeMap { - self.inner + pub fn into_connector_props(self) -> WithOptions { + let inner = self + .inner .into_iter() .filter(|(key, _)| { key != OverwriteOptions::STREAMING_RATE_LIMIT_KEY && key != options::RETENTION_SECONDS }) - .collect() + .collect(); + + Self { + inner, + secret_ref: self.secret_ref, + } } /// Parse the retention seconds from the options. @@ -107,7 +118,10 @@ impl WithOptions { }) .collect(); - Self { inner } + Self { + inner, + secret_ref: self.secret_ref.clone(), + } } pub fn value_eq_ignore_case(&self, key: &str, val: &str) -> bool { @@ -118,15 +132,57 @@ impl WithOptions { } false } -} -pub(crate) fn resolve_secret_in_with_options( - _with_options: &mut WithOptions, - _session: &SessionImpl, -) -> RwResult> { - // todo: implement the function and take `resolve_privatelink_in_with_option` as reference + pub fn secret_ref(&self) -> &BTreeMap { + &self.secret_ref + } + + pub fn encode_options_to_map(sql_options: &[SqlOption]) -> RwResult> { + let WithOptions { inner, secret_ref } = WithOptions::try_from(sql_options)?; + if secret_ref.is_empty() { + Ok(inner) + } else { + Err(RwError::from(ErrorCode::InvalidParameterValue( + "Secret reference is not allowed in encode options".to_string(), + ))) + } + } - Ok(BTreeMap::new()) + pub fn oauth_options_to_map(sql_options: &[SqlOption]) -> RwResult> { + let WithOptions { inner, secret_ref } = WithOptions::try_from(sql_options)?; + if secret_ref.is_empty() { + Ok(inner) + } else { + Err(RwError::from(ErrorCode::InvalidParameterValue( + "Secret reference is not allowed in OAuth options".to_string(), + ))) + } + } +} + +/// Get the secret id from the name. +pub(crate) fn resolve_secret_ref_in_with_options( + with_options: WithOptions, + session: &SessionImpl, +) -> RwResult { + let (options, secret_refs) = with_options.into_parts(); + let mut resolved_secret_refs = BTreeMap::new(); + let db_name: &str = session.database(); + for (key, secret_ref) in secret_refs { + let (schema_name, secret_name) = + Binder::resolve_schema_qualified_name(db_name, secret_ref.secret_name.clone())?; + let secret_catalog = session.get_secret_by_name(schema_name, &secret_name)?; + let ref_as = match secret_ref.ref_as { + SecretRefAsType::Text => PbRefAsType::Text, + SecretRefAsType::File => PbRefAsType::File, + }; + let pb_secret_ref = PbSecretRef { + secret_id: secret_catalog.id.secret_id(), + ref_as: ref_as.into(), + }; + resolved_secret_refs.insert(key.clone(), pb_secret_ref); + } + Ok(WithOptionsSecResolved::new(options, resolved_secret_refs)) } pub(crate) fn resolve_privatelink_in_with_option( @@ -175,8 +231,18 @@ impl TryFrom<&[SqlOption]> for WithOptions { fn try_from(options: &[SqlOption]) -> Result { let mut inner: BTreeMap = BTreeMap::new(); + let mut secret_ref: BTreeMap = BTreeMap::new(); for option in options { let key = option.name.real_value(); + if let Value::Ref(r) = &option.value { + if secret_ref.insert(key.clone(), r.clone()).is_some() || inner.contains_key(&key) { + return Err(RwError::from(ErrorCode::InvalidParameterValue(format!( + "Duplicated option: {}", + key + )))); + } + continue; + } let value: String = match option.value.clone() { Value::CstyleEscapedString(s) => s.value, Value::SingleQuotedString(s) => s, @@ -189,7 +255,7 @@ impl TryFrom<&[SqlOption]> for WithOptions { ))) } }; - if inner.insert(key.clone(), value).is_some() { + if inner.insert(key.clone(), value).is_some() || secret_ref.contains_key(&key) { return Err(RwError::from(ErrorCode::InvalidParameterValue(format!( "Duplicated option: {}", key @@ -197,7 +263,7 @@ impl TryFrom<&[SqlOption]> for WithOptions { } } - Ok(Self { inner }) + Ok(Self { inner, secret_ref }) } } diff --git a/src/meta/Cargo.toml b/src/meta/Cargo.toml index ddae0c9c2462..c0afe22177e5 100644 --- a/src/meta/Cargo.toml +++ b/src/meta/Cargo.toml @@ -14,7 +14,6 @@ ignored = ["workspace-hack"] normal = ["workspace-hack"] [dependencies] -aes-siv = "0.7" anyhow = "1" arc-swap = "1" assert_matches = "1" diff --git a/src/meta/node/Cargo.toml b/src/meta/node/Cargo.toml index d0ab59edee99..9300a08727b4 100644 --- a/src/meta/node/Cargo.toml +++ b/src/meta/node/Cargo.toml @@ -16,9 +16,11 @@ normal = ["workspace-hack"] [dependencies] anyhow = "1" clap = { workspace = true } +educe = "0.6" either = "1" etcd-client = { workspace = true } futures = { version = "0.3", default-features = false, features = ["alloc"] } +hex = "0.4" itertools = { workspace = true } otlp-embedded = { workspace = true } prometheus-http-query = "0.8" diff --git a/src/meta/node/src/lib.rs b/src/meta/node/src/lib.rs index cb712914f25e..0fd7ee94d9fc 100644 --- a/src/meta/node/src/lib.rs +++ b/src/meta/node/src/lib.rs @@ -21,6 +21,7 @@ mod server; use std::time::Duration; use clap::Parser; +use educe::Educe; pub use error::{MetaError, MetaResult}; use redact::Secret; use risingwave_common::config::OverrideConfig; @@ -37,7 +38,8 @@ pub use server::started::get as is_server_started; use crate::manager::MetaOpts; -#[derive(Debug, Clone, Parser, OverrideConfig)] +#[derive(Educe, Clone, Parser, OverrideConfig)] +#[educe(Debug)] #[command(version, about = "The central metadata management service")] pub struct MetaNodeOpts { // TODO: use `SocketAddr` @@ -184,6 +186,20 @@ pub struct MetaNodeOpts { #[deprecated = "connector node has been deprecated."] #[clap(long, hide = true, env = "RW_CONNECTOR_RPC_ENDPOINT")] pub connector_rpc_endpoint: Option, + + /// 128-bit AES key for secret store in HEX format. + #[educe(Debug(ignore))] + #[clap(long, hide = true, env = "RW_SECRET_STORE_PRIVATE_KEY_HEX")] + pub secret_store_private_key_hex: Option, + + /// The path of the temp secret file directory. + #[clap( + long, + hide = true, + env = "RW_TEMP_SECRET_FILE_DIR", + default_value = "./secrets" + )] + pub temp_secret_file_dir: String, } impl risingwave_common::opts::Opts for MetaNodeOpts { @@ -283,6 +299,9 @@ pub fn start( // Run a background heap profiler heap_profiler.start(); + let secret_store_private_key = opts + .secret_store_private_key_hex + .map(|key| hex::decode(key).unwrap()); let max_heartbeat_interval = Duration::from_secs(config.meta.max_heartbeat_interval_secs as u64); let max_idle_ms = config.meta.dangerous_max_idle_secs.unwrap_or(0) * 1000; @@ -430,7 +449,8 @@ pub fn start( .developer .max_trivial_move_task_count_per_loop, max_get_task_probe_times: config.meta.developer.max_get_task_probe_times, - secret_store_private_key: config.meta.secret_store_private_key, + secret_store_private_key, + temp_secret_file_dir: opts.temp_secret_file_dir, table_info_statistic_history_times: config .storage .table_info_statistic_history_times, diff --git a/src/meta/node/src/server.rs b/src/meta/node/src/server.rs index 74310c75374e..bb3ef15aff91 100644 --- a/src/meta/node/src/server.rs +++ b/src/meta/node/src/server.rs @@ -20,6 +20,7 @@ use etcd_client::ConnectOptions; use otlp_embedded::TraceServiceServer; use regex::Regex; use risingwave_common::monitor::{RouterExt, TcpConfig}; +use risingwave_common::secret::LocalSecretManager; use risingwave_common::session_config::SessionConfig; use risingwave_common::system_param::reader::SystemParamsRead; use risingwave_common::telemetry::manager::TelemetryManager; @@ -29,7 +30,9 @@ use risingwave_common_service::{MetricsManager, TracingExtractLayer}; use risingwave_meta::barrier::StreamRpcManager; use risingwave_meta::controller::catalog::CatalogController; use risingwave_meta::controller::cluster::ClusterController; -use risingwave_meta::manager::{MetaStoreImpl, MetadataManager, SystemParamsManagerImpl}; +use risingwave_meta::manager::{ + MetaStoreImpl, MetadataManager, SystemParamsManagerImpl, META_NODE_ID, +}; use risingwave_meta::rpc::election::dummy::DummyElectionClient; use risingwave_meta::rpc::intercept::MetricsMiddlewareLayer; use risingwave_meta::rpc::ElectionClientRef; @@ -508,6 +511,31 @@ pub async fn start_service_as_election_leader( system_params_reader.checkpoint_frequency() as usize, ); + // Initialize services. + let backup_manager = BackupManager::new( + env.clone(), + hummock_manager.clone(), + meta_metrics.clone(), + system_params_reader.backup_storage_url(), + system_params_reader.backup_storage_directory(), + ) + .await?; + + LocalSecretManager::init( + opts.temp_secret_file_dir, + env.cluster_id().to_string(), + META_NODE_ID, + ); + + let notification_srv = NotificationServiceImpl::new( + env.clone(), + metadata_manager.clone(), + hummock_manager.clone(), + backup_manager.clone(), + serving_vnode_mapping.clone(), + ) + .await?; + let source_manager = Arc::new( SourceManager::new( barrier_scheduler.clone(), @@ -567,15 +595,6 @@ pub async fn start_service_as_election_leader( .await .unwrap(); - // Initialize services. - let backup_manager = BackupManager::new( - env.clone(), - hummock_manager.clone(), - meta_metrics.clone(), - system_params_reader.backup_storage_url(), - system_params_reader.backup_storage_directory(), - ) - .await?; let vacuum_manager = Arc::new(hummock::VacuumManager::new( env.clone(), hummock_manager.clone(), @@ -626,13 +645,7 @@ pub async fn start_service_as_election_leader( vacuum_manager.clone(), metadata_manager.clone(), ); - let notification_srv = NotificationServiceImpl::new( - env.clone(), - metadata_manager.clone(), - hummock_manager.clone(), - backup_manager.clone(), - serving_vnode_mapping.clone(), - ); + let health_srv = HealthServiceImpl::new(); let backup_srv = BackupServiceImpl::new(backup_manager); let telemetry_srv = TelemetryInfoServiceImpl::new(env.meta_store()); diff --git a/src/meta/service/Cargo.toml b/src/meta/service/Cargo.toml index 6e00e6eb4318..10f53b9162e2 100644 --- a/src/meta/service/Cargo.toml +++ b/src/meta/service/Cargo.toml @@ -19,6 +19,7 @@ async-trait = "0.1" either = "1" futures = { version = "0.3", default-features = false, features = ["alloc"] } itertools = { workspace = true } +prost ={ workspace = true } rand = { workspace = true } regex = "1" risingwave_common = { workspace = true } diff --git a/src/meta/service/src/cloud_service.rs b/src/meta/service/src/cloud_service.rs index 5777f1c40500..739d7dc2d197 100644 --- a/src/meta/service/src/cloud_service.rs +++ b/src/meta/service/src/cloud_service.rs @@ -17,12 +17,12 @@ use std::sync::LazyLock; use async_trait::async_trait; use regex::Regex; -use risingwave_connector::dispatch_source_prop; use risingwave_connector::error::ConnectorResult; use risingwave_connector::source::kafka::private_link::insert_privatelink_broker_rewrite_map; use risingwave_connector::source::{ ConnectorProperties, SourceEnumeratorContext, SourceProperties, SplitEnumerator, }; +use risingwave_connector::{dispatch_source_prop, WithOptionsSecResolved}; use risingwave_meta::manager::{ConnectionId, MetadataManager}; use risingwave_pb::catalog::connection::Info::PrivateLinkService; use risingwave_pb::cloud_service::cloud_service_server::CloudService; @@ -146,6 +146,10 @@ impl CloudService for CloudServiceImpl { )); } } + + // XXX: We can't use secret in cloud validate source. + let source_cfg = WithOptionsSecResolved::without_secrets(source_cfg); + // try fetch kafka metadata, return error message on failure let props = ConnectorProperties::extract(source_cfg, false); if let Err(e) = props { diff --git a/src/meta/service/src/cluster_service.rs b/src/meta/service/src/cluster_service.rs index 842dde71efc4..39cd40ed3740 100644 --- a/src/meta/service/src/cluster_service.rs +++ b/src/meta/service/src/cluster_service.rs @@ -62,10 +62,12 @@ impl ClusterService for ClusterServiceImpl { .metadata_manager .add_worker_node(worker_type, host, property, resource) .await; + let cluster_id = self.metadata_manager.cluster_id().to_string(); match result { Ok(worker_id) => Ok(Response::new(AddWorkerNodeResponse { status: None, node_id: Some(worker_id), + cluster_id, })), Err(e) => { if e.is_invalid_worker() { @@ -75,6 +77,7 @@ impl ClusterService for ClusterServiceImpl { message: e.to_report_string(), }), node_id: None, + cluster_id, })); } Err(e.into()) diff --git a/src/meta/service/src/notification_service.rs b/src/meta/service/src/notification_service.rs index 5d68211a71d5..b2b78ba28f19 100644 --- a/src/meta/service/src/notification_service.rs +++ b/src/meta/service/src/notification_service.rs @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::Context; +use anyhow::{anyhow, Context}; use itertools::Itertools; +use risingwave_common::secret::{LocalSecretManager, SecretEncryption}; use risingwave_meta::manager::{MetadataManager, SessionParamsManagerImpl}; use risingwave_meta::MetaResult; use risingwave_pb::backup_service::MetaBackupManifestId; -use risingwave_pb::catalog::Table; +use risingwave_pb::catalog::{Secret, Table}; use risingwave_pb::common::worker_node::State::Running; use risingwave_pb::common::{WorkerNode, WorkerType}; use risingwave_pb::hummock::WriteLimits; @@ -47,20 +48,23 @@ pub struct NotificationServiceImpl { } impl NotificationServiceImpl { - pub fn new( + pub async fn new( env: MetaSrvEnv, metadata_manager: MetadataManager, hummock_manager: HummockManagerRef, backup_manager: BackupManagerRef, serving_vnode_mapping: ServingVnodeMappingRef, - ) -> Self { - Self { + ) -> MetaResult { + let service = Self { env, metadata_manager, hummock_manager, backup_manager, serving_vnode_mapping, - } + }; + let (secrets, _catalog_version) = service.get_decrypted_secret_snapshot().await?; + LocalSecretManager::global().init_secrets(secrets); + Ok(service) } async fn get_catalog_snapshot( @@ -142,6 +146,47 @@ impl NotificationServiceImpl { } } + /// Get decrypted secret snapshot + async fn get_decrypted_secret_snapshot( + &self, + ) -> MetaResult<(Vec, NotificationVersion)> { + let secrets = match &self.metadata_manager { + MetadataManager::V1(mgr) => { + let catalog_guard = mgr.catalog_manager.get_catalog_core_guard().await; + catalog_guard.database.list_secrets() + } + MetadataManager::V2(mgr) => { + let catalog_guard = mgr.catalog_controller.get_inner_read_guard().await; + catalog_guard.list_secrets().await? + } + }; + let notification_version = self.env.notification_manager().current_version().await; + + let decrypted_secrets = self.decrypt_secrets(secrets)?; + + Ok((decrypted_secrets, notification_version)) + } + + fn decrypt_secrets(&self, secrets: Vec) -> MetaResult> { + let secret_store_private_key = self + .env + .opts + .secret_store_private_key + .clone() + .ok_or_else(|| anyhow!("secret_store_private_key is not configured"))?; + let mut decrypted_secrets = Vec::with_capacity(secrets.len()); + for mut secret in secrets { + let encrypted_secret = SecretEncryption::deserialize(secret.get_value()) + .context(format!("failed to deserialize secret {}", secret.name))?; + let decrypted_secret = encrypted_secret + .decrypt(secret_store_private_key.as_slice()) + .context(format!("failed to decrypt secret {}", secret.name))?; + secret.value = decrypted_secret; + decrypted_secrets.push(secret); + } + Ok(decrypted_secrets) + } + async fn get_worker_slot_mapping_snapshot( &self, ) -> MetaResult<(Vec, NotificationVersion)> { @@ -247,6 +292,9 @@ impl NotificationServiceImpl { catalog_version, ) = self.get_catalog_snapshot().await?; + // Use the plain text secret value for frontend. The secret value will be masked in frontend handle. + let decrypted_secrets = self.decrypt_secrets(secrets)?; + let (streaming_worker_slot_mappings, streaming_worker_slot_mapping_version) = self.get_worker_slot_mapping_snapshot().await?; let serving_worker_slot_mappings = self.get_serving_vnode_mappings(); @@ -276,7 +324,7 @@ impl NotificationServiceImpl { subscriptions, functions, connections, - secrets, + secrets: decrypted_secrets, users, nodes, hummock_snapshot, @@ -315,8 +363,16 @@ impl NotificationServiceImpl { }) } - fn compute_subscribe(&self) -> MetaSnapshot { - MetaSnapshot::default() + async fn compute_subscribe(&self) -> MetaResult { + let (secrets, catalog_version) = self.get_decrypted_secret_snapshot().await?; + Ok(MetaSnapshot { + secrets, + version: Some(SnapshotVersion { + catalog_version, + ..Default::default() + }), + ..Default::default() + }) } } @@ -355,7 +411,7 @@ impl NotificationService for NotificationServiceImpl { .await?; self.hummock_subscribe().await? } - SubscribeType::Compute => self.compute_subscribe(), + SubscribeType::Compute => self.compute_subscribe().await?, SubscribeType::Unspecified => unreachable!(), }; diff --git a/src/meta/src/controller/catalog.rs b/src/meta/src/controller/catalog.rs index 6bfaaa2c1e79..6fc372197285 100644 --- a/src/meta/src/controller/catalog.rs +++ b/src/meta/src/controller/catalog.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use anyhow::anyhow; use itertools::Itertools; use risingwave_common::catalog::{TableOption, DEFAULT_SCHEMA_NAME, SYSTEM_SCHEMAS}; +use risingwave_common::secret::LocalSecretManager; use risingwave_common::util::stream_graph_visitor::visit_stream_node_cont_mut; use risingwave_common::{bail, current_cluster_version}; use risingwave_connector::source::UPSTREAM_SOURCE_KEY; @@ -1086,7 +1087,11 @@ impl CatalogController { Ok(version) } - pub async fn create_secret(&self, mut pb_secret: PbSecret) -> MetaResult { + pub async fn create_secret( + &self, + mut pb_secret: PbSecret, + secret_plain_payload: Vec, + ) -> MetaResult { let inner = self.inner.write().await; let owner_id = pb_secret.owner as _; let txn = inner.db.begin().await?; @@ -1109,12 +1114,22 @@ impl CatalogController { txn.commit().await?; + // Notify the compute and frontend node plain secret + let mut secret_plain = pb_secret; + secret_plain.value.clone_from(&secret_plain_payload); + + LocalSecretManager::global().add_secret(secret_plain.id, secret_plain_payload); + self.env + .notification_manager() + .notify_compute_without_version(Operation::Add, Info::Secret(secret_plain.clone())); + let version = self .notify_frontend( NotificationOperation::Add, - NotificationInfo::Secret(pb_secret), + NotificationInfo::Secret(secret_plain), ) .await; + Ok(version) } @@ -1159,6 +1174,11 @@ impl CatalogController { let pb_secret: PbSecret = ObjectModel(secret, secret_obj.unwrap()).into(); self.notify_users_update(user_infos).await; + + LocalSecretManager::global().remove_secret(pb_secret.id); + self.env + .notification_manager() + .notify_compute_without_version(Operation::Delete, Info::Secret(pb_secret.clone())); let version = self .notify_frontend( NotificationOperation::Delete, @@ -3118,7 +3138,7 @@ impl CatalogControllerInner { .collect()) } - async fn list_secrets(&self) -> MetaResult> { + pub async fn list_secrets(&self) -> MetaResult> { let secret_objs = Secret::find() .find_also_related(Object) .all(&self.db) diff --git a/src/meta/src/controller/cluster.rs b/src/meta/src/controller/cluster.rs index 50300f0ec282..4749bafe370f 100644 --- a/src/meta/src/controller/cluster.rs +++ b/src/meta/src/controller/cluster.rs @@ -52,6 +52,7 @@ use tokio::task::JoinHandle; use crate::manager::{ LocalNotification, MetaSrvEnv, StreamingClusterInfo, WorkerKey, META_NODE_ID, }; +use crate::model::ClusterId; use crate::{MetaError, MetaResult}; pub type ClusterControllerRef = Arc; @@ -392,6 +393,10 @@ impl ClusterController { .await .get_worker_extra_info_by_id(worker_id) } + + pub fn cluster_id(&self) -> &ClusterId { + self.env.cluster_id() + } } #[derive(Default, Clone)] diff --git a/src/meta/src/controller/streaming_job.rs b/src/meta/src/controller/streaming_job.rs index 2db28e153e59..802e6dabbb90 100644 --- a/src/meta/src/controller/streaming_job.rs +++ b/src/meta/src/controller/streaming_job.rs @@ -288,14 +288,14 @@ impl CatalogController { } } - // get dependent secret ref. - let dependent_secret_refs = streaming_job.dependent_secret_refs()?; + // get dependent secrets. + let dependent_secret_ids = streaming_job.dependent_secret_ids()?; let dependent_objs = dependent_relations .iter() - .chain(dependent_secret_refs.iter()); + .chain(dependent_secret_ids.iter()); // record object dependency. - if !dependent_secret_refs.is_empty() || !dependent_relations.is_empty() { + if !dependent_secret_ids.is_empty() || !dependent_relations.is_empty() { ObjectDependency::insert_many(dependent_objs.map(|id| { object_dependency::ActiveModel { oid: Set(*id as _), diff --git a/src/meta/src/manager/catalog/mod.rs b/src/meta/src/manager/catalog/mod.rs index 3f815f8f9c8d..b1afc612bcea 100644 --- a/src/meta/src/manager/catalog/mod.rs +++ b/src/meta/src/manager/catalog/mod.rs @@ -30,6 +30,7 @@ use risingwave_common::catalog::{ DEFAULT_SCHEMA_NAME, DEFAULT_SUPER_USER, DEFAULT_SUPER_USER_FOR_PG, DEFAULT_SUPER_USER_FOR_PG_ID, DEFAULT_SUPER_USER_ID, SYSTEM_SCHEMAS, }; +use risingwave_common::secret::LocalSecretManager; use risingwave_common::{bail, current_cluster_version, ensure}; use risingwave_connector::source::{should_copy_to_format_encode_options, UPSTREAM_SOURCE_KEY}; use risingwave_pb::catalog::subscription::PbSubscriptionState; @@ -487,7 +488,11 @@ impl CatalogManager { } } - pub async fn create_secret(&self, secret: Secret) -> MetaResult { + pub async fn create_secret( + &self, + secret: Secret, + secret_plain_payload: Vec, + ) -> MetaResult { let core = &mut *self.core.lock().await; let database_core = &mut core.database; let user_core = &mut core.user; @@ -504,14 +509,25 @@ impl CatalogManager { let secret_id = secret.id; let mut secret_entry = BTreeMapTransaction::new(&mut database_core.secrets); + secret_entry.insert(secret_id, secret.to_owned()); commit_meta!(self, secret_entry)?; user_core.increase_ref(secret.owner); + // Notify the compute and frontend node plain secret + let mut secret_plain = secret; + secret_plain.value.clone_from(&secret_plain_payload); + + LocalSecretManager::global().add_secret(secret_id, secret_plain_payload); + self.env + .notification_manager() + .notify_compute_without_version(Operation::Add, Info::Secret(secret_plain.clone())); + let version = self - .notify_frontend(Operation::Add, Info::Secret(secret)) + .notify_frontend(Operation::Add, Info::Secret(secret_plain)) .await; + Ok(version) } @@ -521,21 +537,40 @@ impl CatalogManager { let user_core = &mut core.user; let mut secrets = BTreeMapTransaction::new(&mut database_core.secrets); - // todo: impl a ref count check for secret - // if secret is used by other relations, not found in the catalog or do not have the privilege to drop, return error - // else: commit the change and notify frontend - - let secret = secrets - .remove(secret_id) - .ok_or_else(|| anyhow!("secret not found"))?; - - commit_meta!(self, secrets)?; - user_core.decrease_ref(secret.owner); + match database_core.secret_ref_count.get(&secret_id) { + Some(ref_count) => { + let secret_name = secrets + .get(&secret_id) + .ok_or_else(|| MetaError::catalog_id_not_found("connection", secret_id))? + .name + .clone(); + Err(MetaError::permission_denied(format!( + "Fail to delete secret {} because {} other relation(s) depend on it", + secret_name, ref_count + ))) + } + None => { + let secret = secrets + .remove(secret_id) + .ok_or_else(|| anyhow!("secret not found"))?; + + commit_meta!(self, secrets)?; + user_core.decrease_ref(secret.owner); + + LocalSecretManager::global().remove_secret(secret.id); + self.env + .notification_manager() + .notify_compute_without_version( + Operation::Delete, + Info::Secret(secret.clone()), + ); - let version = self - .notify_frontend(Operation::Delete, Info::Secret(secret)) - .await; - Ok(version) + let version = self + .notify_frontend(Operation::Delete, Info::Secret(secret)) + .await; + Ok(version) + } + } } pub async fn create_connection( diff --git a/src/meta/src/manager/cluster.rs b/src/meta/src/manager/cluster.rs index a5e2c8175ea6..e5e3dfa47049 100644 --- a/src/meta/src/manager/cluster.rs +++ b/src/meta/src/manager/cluster.rs @@ -39,7 +39,8 @@ use tokio::task::JoinHandle; use crate::manager::{IdCategory, LocalNotification, MetaSrvEnv}; use crate::model::{ - InMemValTransaction, MetadataModel, ValTransaction, VarTransaction, Worker, INVALID_EXPIRE_AT, + ClusterId, InMemValTransaction, MetadataModel, ValTransaction, VarTransaction, Worker, + INVALID_EXPIRE_AT, }; use crate::storage::{MetaStore, Transaction}; use crate::{MetaError, MetaResult}; @@ -539,6 +540,10 @@ impl ClusterManager { pub async fn get_worker_by_id(&self, worker_id: WorkerId) -> Option { self.core.read().await.get_worker_by_id(worker_id) } + + pub fn cluster_id(&self) -> &ClusterId { + self.env.cluster_id() + } } /// The cluster info used for scheduling a streaming job. diff --git a/src/meta/src/manager/env.rs b/src/meta/src/manager/env.rs index 78b5f3989935..3c6a18adef74 100644 --- a/src/meta/src/manager/env.rs +++ b/src/meta/src/manager/env.rs @@ -286,7 +286,9 @@ pub struct MetaOpts { pub compact_task_table_size_partition_threshold_high: u64, // The private key for the secret store, used when the secret is stored in the meta. - pub secret_store_private_key: Vec, + pub secret_store_private_key: Option>, + /// The path of the temp secret file directory. + pub temp_secret_file_dir: String, pub table_info_statistic_history_times: usize, } @@ -352,7 +354,8 @@ impl MetaOpts { object_store_config: ObjectStoreConfig::default(), max_trivial_move_task_count_per_loop: 256, max_get_task_probe_times: 5, - secret_store_private_key: "demo-secret-private-key".as_bytes().to_vec(), + secret_store_private_key: Some("0123456789abcdef".as_bytes().to_vec()), + temp_secret_file_dir: "./secrets".to_string(), table_info_statistic_history_times: 240, } } diff --git a/src/meta/src/manager/metadata.rs b/src/meta/src/manager/metadata.rs index 1caf04fa65a8..76e266e2ddd2 100644 --- a/src/meta/src/manager/metadata.rs +++ b/src/meta/src/manager/metadata.rs @@ -37,7 +37,9 @@ use crate::manager::{ CatalogManagerRef, ClusterManagerRef, FragmentManagerRef, LocalNotification, StreamingClusterInfo, WorkerId, }; -use crate::model::{ActorId, FragmentId, MetadataModel, TableFragments, TableParallelism}; +use crate::model::{ + ActorId, ClusterId, FragmentId, MetadataModel, TableFragments, TableParallelism, +}; use crate::stream::{to_build_actor_info, SplitAssignment}; use crate::telemetry::MetaTelemetryJobDesc; use crate::MetaResult; @@ -832,4 +834,11 @@ impl MetadataManager { } } } + + pub fn cluster_id(&self) -> &ClusterId { + match self { + MetadataManager::V1(mgr) => mgr.cluster_manager.cluster_id(), + MetadataManager::V2(mgr) => mgr.cluster_controller.cluster_id(), + } + } } diff --git a/src/meta/src/manager/streaming_job.rs b/src/meta/src/manager/streaming_job.rs index 90d0781174b5..67150ce8351d 100644 --- a/src/meta/src/manager/streaming_job.rs +++ b/src/meta/src/manager/streaming_job.rs @@ -289,7 +289,8 @@ impl StreamingJob { } } - pub fn dependent_secret_refs(&self) -> MetaResult> { + // Get the secret ids that are referenced by this job. + pub fn dependent_secret_ids(&self) -> MetaResult> { match self { StreamingJob::Sink(sink, _) => Ok(get_refed_secret_ids_from_sink(sink)), StreamingJob::Table(source, _, _) => { diff --git a/src/meta/src/rpc/ddl_controller.rs b/src/meta/src/rpc/ddl_controller.rs index 6bcc59fee553..4c75e751b848 100644 --- a/src/meta/src/rpc/ddl_controller.rs +++ b/src/meta/src/rpc/ddl_controller.rs @@ -18,14 +18,12 @@ use std::num::NonZeroUsize; use std::sync::Arc; use std::time::Duration; -use aes_siv::aead::generic_array::GenericArray; -use aes_siv::aead::Aead; -use aes_siv::{Aes128SivAead, KeyInit}; -use anyhow::Context; +use anyhow::{anyhow, Context}; use itertools::Itertools; -use rand::{Rng, RngCore}; +use rand::Rng; use risingwave_common::config::DefaultParallelism; use risingwave_common::hash::{ParallelUnitMapping, VirtualNode}; +use risingwave_common::secret::SecretEncryption; use risingwave_common::system_param::reader::SystemParamsRead; use risingwave_common::util::column_index_mapping::ColIndexMapping; use risingwave_common::util::epoch::Epoch; @@ -33,13 +31,13 @@ use risingwave_common::util::stream_graph_visitor::{ visit_fragment, visit_stream_node, visit_stream_node_cont_mut, }; use risingwave_common::{bail, current_cluster_version}; -use risingwave_connector::dispatch_source_prop; use risingwave_connector::error::ConnectorError; use risingwave_connector::source::cdc::CdcSourceType; use risingwave_connector::source::{ ConnectorProperties, SourceEnumeratorContext, SourceProperties, SplitEnumerator, UPSTREAM_SOURCE_KEY, }; +use risingwave_connector::{dispatch_source_prop, WithOptionsSecResolved}; use risingwave_meta_model_v2::object::ObjectType; use risingwave_meta_model_v2::ObjectId; use risingwave_pb::catalog::connection::private_link_service::PbPrivateLinkProvider; @@ -61,7 +59,6 @@ use risingwave_pb::stream_plan::{ Dispatcher, DispatcherType, FragmentTypeFlag, MergeNode, PbStreamFragmentGraph, StreamFragmentGraph as StreamFragmentGraphProto, }; -use serde::{Deserialize, Serialize}; use thiserror_ext::AsReport; use tokio::sync::Semaphore; use tokio::time::sleep; @@ -69,7 +66,6 @@ use tracing::log::warn; use tracing::Instrument; use crate::barrier::BarrierManagerRef; -use crate::error::MetaErrorInner; use crate::manager::{ CatalogManagerRef, ConnectionId, DatabaseId, DdlType, FragmentManagerRef, FunctionId, IdCategory, IdCategoryType, IndexId, LocalNotification, MetaSrvEnv, MetadataManager, @@ -160,12 +156,6 @@ pub enum DdlCommand { DropSubscription(SubscriptionId, DropMode), } -#[derive(Deserialize, Serialize)] -struct SecretEncryption { - nonce: [u8; 16], - ciphertext: Vec, -} - impl DdlCommand { fn allow_in_recovery(&self) -> bool { match self { @@ -629,44 +619,38 @@ impl DdlController { async fn create_secret(&self, mut secret: Secret) -> MetaResult { // The 'secret' part of the request we receive from the frontend is in plaintext; // here, we need to encrypt it before storing it in the catalog. + let secret_plain_payload = secret.value.clone(); + let secret_store_private_key = self + .env + .opts + .secret_store_private_key + .clone() + .ok_or_else(|| anyhow!("secret_store_private_key is not configured"))?; let encrypted_payload = { - let data = secret.get_value().as_slice(); - let key = self.env.opts.secret_store_private_key.as_slice(); - let encrypt_key = { - let mut k = key[..(std::cmp::min(key.len(), 32))].to_vec(); - k.resize_with(32, || 0); - k - }; - - let mut rng = rand::thread_rng(); - let mut nonce: [u8; 16] = [0; 16]; - rng.fill_bytes(&mut nonce); - let nonce_array = GenericArray::from_slice(&nonce); - let cipher = Aes128SivAead::new(encrypt_key.as_slice().into()); - - let ciphertext = cipher.encrypt(nonce_array, data).map_err(|e| { - MetaError::from(MetaErrorInner::InvalidParameter(format!( - "failed to encrypt secret {}: {:?}", - secret.name, e - ))) - })?; - bincode::serialize(&SecretEncryption { nonce, ciphertext }).map_err(|e| { - MetaError::from(MetaErrorInner::InvalidParameter(format!( - "failed to serialize secret {}: {:?}", - secret.name, - e.as_report() - ))) - })? + let encrypted_secret = SecretEncryption::encrypt( + secret_store_private_key.as_slice(), + secret.get_value().as_slice(), + ) + .context(format!("failed to encrypt secret {}", secret.name))?; + encrypted_secret + .serialize() + .context(format!("failed to serialize secret {}", secret.name))? }; secret.value = encrypted_payload; match &self.metadata_manager { MetadataManager::V1(mgr) => { secret.id = self.gen_unique_id::<{ IdCategory::Secret }>().await?; - mgr.catalog_manager.create_secret(secret).await + mgr.catalog_manager + .create_secret(secret, secret_plain_payload) + .await + } + MetadataManager::V2(mgr) => { + mgr.catalog_controller + .create_secret(secret, secret_plain_payload) + .await } - MetadataManager::V2(mgr) => mgr.catalog_controller.create_secret(secret).await, } } @@ -1053,8 +1037,11 @@ impl DdlController { actor.nodes.as_ref().unwrap().node_body && let Some(ref cdc_table_desc) = stream_cdc_scan.cdc_table_desc { - let properties = cdc_table_desc.connect_properties.clone(); - let mut props = ConnectorProperties::extract(properties, true)?; + let options_with_secret = WithOptionsSecResolved::new( + cdc_table_desc.connect_properties.clone(), + cdc_table_desc.secret_refs.clone(), + ); + let mut props = ConnectorProperties::extract(options_with_secret, true)?; props.init_from_pb_cdc_table_desc(cdc_table_desc); dispatch_source_prop!(props, props, { diff --git a/src/meta/src/stream/sink.rs b/src/meta/src/stream/sink.rs index 6b91c52c85f4..90a496823cc4 100644 --- a/src/meta/src/stream/sink.rs +++ b/src/meta/src/stream/sink.rs @@ -22,7 +22,7 @@ use crate::MetaResult; pub async fn validate_sink(prost_sink_catalog: &PbSink) -> MetaResult<()> { let sink_catalog = SinkCatalog::from(prost_sink_catalog); - let param = SinkParam::from(sink_catalog); + let param = SinkParam::try_from_sink_catalog(sink_catalog)?; let sink = build_sink(param)?; diff --git a/src/meta/src/stream/source_manager.rs b/src/meta/src/stream/source_manager.rs index 0fe9d4a96142..b56b9e582f94 100644 --- a/src/meta/src/stream/source_manager.rs +++ b/src/meta/src/stream/source_manager.rs @@ -23,12 +23,12 @@ use std::time::Duration; use anyhow::Context; use risingwave_common::catalog::TableId; use risingwave_common::metrics::LabelGuardedIntGauge; -use risingwave_connector::dispatch_source_prop; use risingwave_connector::error::ConnectorResult; use risingwave_connector::source::{ ConnectorProperties, SourceEnumeratorContext, SourceEnumeratorInfo, SourceProperties, SplitEnumerator, SplitId, SplitImpl, SplitMetaData, }; +use risingwave_connector::{dispatch_source_prop, WithOptionsSecResolved}; use risingwave_pb::catalog::Source; use risingwave_pb::source::{ConnectorSplit, ConnectorSplits}; use risingwave_pb::stream_plan::Dispatcher; @@ -81,12 +81,16 @@ struct ConnectorSourceWorker { } fn extract_prop_from_existing_source(source: &Source) -> ConnectorResult { - let mut properties = ConnectorProperties::extract(source.with_properties.clone(), false)?; + let options_with_secret = + WithOptionsSecResolved::new(source.with_properties.clone(), source.secret_refs.clone()); + let mut properties = ConnectorProperties::extract(options_with_secret, false)?; properties.init_from_pb_source(source); Ok(properties) } fn extract_prop_from_new_source(source: &Source) -> ConnectorResult { - let mut properties = ConnectorProperties::extract(source.with_properties.clone(), true)?; + let options_with_secret = + WithOptionsSecResolved::new(source.with_properties.clone(), source.secret_refs.clone()); + let mut properties = ConnectorProperties::extract(options_with_secret, true)?; properties.init_from_pb_source(source); Ok(properties) } diff --git a/src/prost/build.rs b/src/prost/build.rs index 6d31201fa473..729651b71d69 100644 --- a/src/prost/build.rs +++ b/src/prost/build.rs @@ -66,10 +66,11 @@ fn main() -> Result<(), Box> { ".plan_common.ExternalTableDesc", ".hummock.CompactTask", ".catalog.StreamSourceInfo", - ".catalog.SecretRef", + ".secret.SecretRef", ".catalog.Source", ".catalog.Sink", ".catalog.View", + ".catalog.SinkFormatDesc", ".connector_service.ValidateSourceRequest", ".connector_service.GetEventStreamRequest", ".connector_service.SinkParam", diff --git a/src/risedevtool/common.toml b/src/risedevtool/common.toml index 391d52f399cf..d58c09970981 100644 --- a/src/risedevtool/common.toml +++ b/src/risedevtool/common.toml @@ -15,6 +15,7 @@ PREFIX_LOG = "${PREFIX}/log" PREFIX_TMP = "${PREFIX}/tmp" PREFIX_DOCKER = "${PREFIX}/rw-docker" PREFIX_PROFILING = "${PREFIX}/profiling" +PREFIX_SECRET = "${PREFIX}/secrets" BUILD_MODE_DIR = { source = "${ENABLE_RELEASE_PROFILE}", default_value = "debug", mapping = { true = "release" } } RISINGWAVE_BUILD_PROFILE = { source = "${ENABLE_RELEASE_PROFILE}", default_value = "dev", mapping = { true = "release" } } diff --git a/src/rpc_client/src/meta_client.rs b/src/rpc_client/src/meta_client.rs index 7982c7b20999..db24711ba237 100644 --- a/src/rpc_client/src/meta_client.rs +++ b/src/rpc_client/src/meta_client.rs @@ -114,6 +114,7 @@ pub struct MetaClient { host_addr: HostAddr, inner: GrpcMetaClient, meta_config: MetaConfig, + cluster_id: String, } impl MetaClient { @@ -129,6 +130,10 @@ impl MetaClient { self.worker_type } + pub fn cluster_id(&self) -> &str { + &self.cluster_id + } + /// Subscribe to notification from meta. pub async fn subscribe( &self, @@ -270,6 +275,7 @@ impl MetaClient { host_addr: addr.clone(), inner: grpc_meta_client, meta_config: meta_config.to_owned(), + cluster_id: add_worker_resp.cluster_id, }; static REPORT_PANIC: std::sync::Once = std::sync::Once::new(); diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index ddab0e030678..448f448628e4 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -52,8 +52,8 @@ pub use self::query::{ }; pub use self::statement::*; pub use self::value::{ - CstyleEscapedString, DateTimeField, DollarQuotedString, JsonPredicateType, TrimWhereField, - Value, + CstyleEscapedString, DateTimeField, DollarQuotedString, JsonPredicateType, SecretRef, + SecretRefAsType, TrimWhereField, Value, }; pub use crate::ast::ddl::{ AlterIndexOperation, AlterSinkOperation, AlterSourceOperation, AlterSubscriptionOperation, diff --git a/src/sqlparser/src/ast/value.rs b/src/sqlparser/src/ast/value.rs index 79f2a6ebd99c..9cae715e0927 100644 --- a/src/sqlparser/src/ast/value.rs +++ b/src/sqlparser/src/ast/value.rs @@ -60,7 +60,7 @@ pub enum Value { /// `NULL` value Null, /// name of the reference to secret - Ref(ObjectName), + Ref(SecretRef), } impl fmt::Display for Value { @@ -115,7 +115,7 @@ impl fmt::Display for Value { Ok(()) } Value::Null => write!(f, "NULL"), - Value::Ref(v) => write!(f, "ref secret {}", v), + Value::Ref(v) => write!(f, "secret {}", v), } } } @@ -238,3 +238,25 @@ impl fmt::Display for JsonPredicateType { }) } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct SecretRef { + pub secret_name: ObjectName, + pub ref_as: SecretRefAsType, +} + +impl fmt::Display for SecretRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.ref_as { + SecretRefAsType::Text => write!(f, "{}", self.secret_name), + SecretRefAsType::File => write!(f, "{} AS FILE", self.secret_name), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum SecretRefAsType { + Text, + File, +} diff --git a/src/sqlparser/src/keywords.rs b/src/sqlparser/src/keywords.rs index fc5971bf4640..561b602298f1 100644 --- a/src/sqlparser/src/keywords.rs +++ b/src/sqlparser/src/keywords.rs @@ -230,6 +230,7 @@ define_keywords!( EXTRACT, FALSE, FETCH, + FILE, FILTER, FIRST, FIRST_VALUE, diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 3077d8b9a6c0..cd073b6bc261 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -3495,6 +3495,18 @@ impl Parser<'_> { Some('\'') => Ok(Value::SingleQuotedString(w.value)), _ => self.expected_at(checkpoint, "A value")?, }, + Keyword::SECRET => { + let secret_name = self.parse_object_name()?; + let ref_as = if self.parse_keywords(&[Keyword::AS, Keyword::FILE]) { + SecretRefAsType::File + } else { + SecretRefAsType::Text + }; + Ok(Value::Ref(SecretRef { + secret_name, + ref_as, + })) + } _ => self.expected_at(checkpoint, "a concrete value"), }, Token::Number(ref n) => Ok(Value::Number(n.clone())), diff --git a/src/stream/src/error.rs b/src/stream/src/error.rs index defefa7b474a..42b3e92e4e04 100644 --- a/src/stream/src/error.rs +++ b/src/stream/src/error.rs @@ -13,6 +13,7 @@ // limitations under the License. use risingwave_common::array::ArrayError; +use risingwave_common::secret::SecretError; use risingwave_connector::error::ConnectorError; use risingwave_connector::sink::SinkError; use risingwave_expr::ExprError; @@ -85,7 +86,12 @@ pub enum ErrorKind { actor_id: ActorId, reason: &'static str, }, - + #[error("Secret error: {0}")] + Secret( + #[from] + #[backtrace] + SecretError, + ), #[error(transparent)] Uncategorized( #[from] diff --git a/src/stream/src/executor/sink.rs b/src/stream/src/executor/sink.rs index 9b6c84f674cf..726fc3ccad55 100644 --- a/src/stream/src/executor/sink.rs +++ b/src/stream/src/executor/sink.rs @@ -563,6 +563,7 @@ mod test { sink_id: 0.into(), sink_name: "test".into(), properties, + columns: columns .iter() .filter(|col| !col.is_hidden) @@ -691,6 +692,7 @@ mod test { sink_id: 0.into(), sink_name: "test".into(), properties, + columns: columns .iter() .filter(|col| !col.is_hidden) @@ -792,6 +794,7 @@ mod test { sink_id: 0.into(), sink_name: "test".into(), properties, + columns: columns .iter() .filter(|col| !col.is_hidden) diff --git a/src/stream/src/from_proto/sink.rs b/src/stream/src/from_proto/sink.rs index 5e77be7beb7a..7ed4fd3802d5 100644 --- a/src/stream/src/from_proto/sink.rs +++ b/src/stream/src/from_proto/sink.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use anyhow::anyhow; use risingwave_common::catalog::{ColumnCatalog, Schema}; +use risingwave_common::secret::LocalSecretManager; use risingwave_common::types::DataType; use risingwave_connector::match_sink_name_str; use risingwave_connector::sink::catalog::{SinkFormatDesc, SinkType}; @@ -115,6 +116,7 @@ impl ExecutorBuilder for SinkExecutorBuilder { let db_name = sink_desc.get_db_name().into(); let sink_from_name = sink_desc.get_sink_from_name().into(); let properties = sink_desc.get_properties().clone(); + let secret_refs = sink_desc.get_secret_refs().clone(); let downstream_pk = sink_desc .downstream_pk .iter() @@ -155,10 +157,15 @@ impl ExecutorBuilder for SinkExecutorBuilder { }, }; + let properties_with_secret = + LocalSecretManager::global().fill_secrets(properties, secret_refs)?; + + let format_desc_with_secret = SinkParam::fill_secret_for_format_desc(format_desc)?; + let sink_param = SinkParam { sink_id, sink_name, - properties, + properties: properties_with_secret, columns: columns .iter() .filter(|col| !col.is_hidden) @@ -166,7 +173,7 @@ impl ExecutorBuilder for SinkExecutorBuilder { .collect(), downstream_pk, sink_type, - format_desc, + format_desc: format_desc_with_secret, db_name, sink_from_name, }; diff --git a/src/stream/src/from_proto/source/fs_fetch.rs b/src/stream/src/from_proto/source/fs_fetch.rs index 1951365a47ee..7a08c2d2f512 100644 --- a/src/stream/src/from_proto/source/fs_fetch.rs +++ b/src/stream/src/from_proto/source/fs_fetch.rs @@ -20,6 +20,7 @@ use risingwave_connector::source::filesystem::opendal_source::{ }; use risingwave_connector::source::reader::desc::SourceDescBuilder; use risingwave_connector::source::ConnectorProperties; +use risingwave_connector::WithOptionsSecResolved; use risingwave_pb::stream_plan::StreamFsFetchNode; use risingwave_storage::StateStore; @@ -46,12 +47,14 @@ impl ExecutorBuilder for FsFetchExecutorBuilder { let source_id = TableId::new(source.source_id); let source_name = source.source_name.clone(); let source_info = source.get_info()?; - let properties = ConnectorProperties::extract(source.with_properties.clone(), false)?; + let source_options_with_secret = + WithOptionsSecResolved::new(source.with_properties.clone(), source.secret_refs.clone()); + let properties = ConnectorProperties::extract(source_options_with_secret.clone(), false)?; let source_desc_builder = SourceDescBuilder::new( source.columns.clone(), params.env.source_metrics(), source.row_id_index.map(|x| x as _), - source.with_properties.clone(), + source_options_with_secret, source_info.clone(), params.env.config().developer.connector_message_buffer_size, params.info.pk_indices.clone(), diff --git a/src/stream/src/from_proto/source/trad_source.rs b/src/stream/src/from_proto/source/trad_source.rs index 667f2c5d49bc..25bc01d817fb 100644 --- a/src/stream/src/from_proto/source/trad_source.rs +++ b/src/stream/src/from_proto/source/trad_source.rs @@ -12,14 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; - use risingwave_common::catalog::{ default_key_column_name_version_mapping, TableId, KAFKA_TIMESTAMP_COLUMN_NAME, }; use risingwave_connector::source::reader::desc::SourceDescBuilder; use risingwave_connector::source::{should_copy_to_format_encode_options, UPSTREAM_SOURCE_KEY}; -use risingwave_connector::WithPropertiesExt; +use risingwave_connector::{WithOptionsSecResolved, WithPropertiesExt}; use risingwave_pb::catalog::PbStreamSourceInfo; use risingwave_pb::data::data_type::TypeName as PbTypeName; use risingwave_pb::plan_common::additional_column::ColumnType as AdditionalColumnType; @@ -46,7 +44,7 @@ pub fn create_source_desc_builder( params: &ExecutorParams, source_info: PbStreamSourceInfo, row_id_index: Option, - with_properties: BTreeMap, + with_properties: WithOptionsSecResolved, ) -> SourceDescBuilder { { // compatible code: introduced in https://github.com/risingwavelabs/risingwave/pull/13707 @@ -168,12 +166,17 @@ impl ExecutorBuilder for SourceExecutorBuilder { ); } + let with_properties = WithOptionsSecResolved::new( + source.with_properties.clone(), + source.secret_refs.clone(), + ); + let source_desc_builder = create_source_desc_builder( source.columns.clone(), ¶ms, source_info, source.row_id_index, - source.with_properties.clone(), + with_properties, ); let source_column_ids: Vec<_> = source_desc_builder diff --git a/src/stream/src/from_proto/source_backfill.rs b/src/stream/src/from_proto/source_backfill.rs index 84f4bf7adab8..5f3ef58e0fb4 100644 --- a/src/stream/src/from_proto/source_backfill.rs +++ b/src/stream/src/from_proto/source_backfill.rs @@ -13,6 +13,7 @@ // limitations under the License. use risingwave_common::catalog::TableId; +use risingwave_connector::WithOptionsSecResolved; use risingwave_pb::stream_plan::SourceBackfillNode; use super::*; @@ -35,12 +36,14 @@ impl ExecutorBuilder for SourceBackfillExecutorBuilder { let source_name = node.source_name.clone(); let source_info = node.get_info()?; + let options_with_secret = + WithOptionsSecResolved::new(node.with_properties.clone(), node.secret_refs.clone()); let source_desc_builder = super::source::create_source_desc_builder( node.columns.clone(), ¶ms, source_info.clone(), node.row_id_index, - node.with_properties.clone(), + options_with_secret, ); let source_column_ids: Vec<_> = source_desc_builder diff --git a/src/stream/src/from_proto/stream_cdc_scan.rs b/src/stream/src/from_proto/stream_cdc_scan.rs index 150812c57a1c..26794bc6153b 100644 --- a/src/stream/src/from_proto/stream_cdc_scan.rs +++ b/src/stream/src/from_proto/stream_cdc_scan.rs @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; use std::sync::Arc; -use anyhow::anyhow; +use anyhow::Context; use risingwave_common::catalog::{Schema, TableId}; use risingwave_common::util::sort_util::OrderType; use risingwave_connector::source::cdc::external::{ @@ -52,11 +51,7 @@ impl ExecutorBuilder for StreamCdcScanExecutorBuilder { assert_eq!(output_indices, (0..output_schema.len()).collect_vec()); assert_eq!(output_schema.data_types(), params.info.schema.data_types()); - let properties: HashMap = table_desc - .connect_properties - .iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect(); + let properties = table_desc.connect_properties.clone(); let table_pk_order_types = table_desc .pk @@ -92,10 +87,12 @@ impl ExecutorBuilder for StreamCdcScanExecutorBuilder { .collect(); let schema_table_name = SchemaTableName::from_properties(&properties); - let table_config = serde_json::from_value::( - serde_json::to_value(properties).unwrap(), + let table_config = ExternalTableConfig::try_from_btreemap( + properties.clone(), + table_desc.secret_refs.clone(), ) - .map_err(|e| anyhow!("failed to parse external table config").context(e))?; + .context("failed to parse external table config")?; + let database_name = table_config.database.clone(); let table_reader = table_type .create_table_reader( diff --git a/src/stream/src/task/barrier_manager.rs b/src/stream/src/task/barrier_manager.rs index a1108d9e4627..8d6b4e5056ad 100644 --- a/src/stream/src/task/barrier_manager.rs +++ b/src/stream/src/task/barrier_manager.rs @@ -960,7 +960,8 @@ pub fn try_find_root_actor_failure<'a>( | ErrorKind::Storage(_) | ErrorKind::Expression(_) | ErrorKind::Array(_) - | ErrorKind::Sink(_) => 1000, + | ErrorKind::Sink(_) + | ErrorKind::Secret(_) => 1000, } } diff --git a/src/tests/simulation/src/cluster.rs b/src/tests/simulation/src/cluster.rs index 8e2ffece2fed..6c9db8c48170 100644 --- a/src/tests/simulation/src/cluster.rs +++ b/src/tests/simulation/src/cluster.rs @@ -450,6 +450,8 @@ impl Cluster { "hummock+sim://hummockadmin:hummockadmin@192.168.12.1:9301/hummock001", "--data-directory", "hummock_001", + "--temp-secret-file-dir", + &format!("./secrets/meta-{i}"), ]); handle .create_node() @@ -477,6 +479,8 @@ impl Cluster { "0.0.0.0:4566", "--advertise-addr", &format!("192.168.2.{i}:4566"), + "--temp-secret-file-dir", + &format!("./secrets/frontend-{i}"), ]); handle .create_node() @@ -505,6 +509,8 @@ impl Cluster { "6979321856", "--parallelism", &conf.compute_node_cores.to_string(), + "--temp-secret-file-dir", + &format!("./secrets/compute-{i}"), ]); handle .create_node() From 46b4ccd73238a9ab5ee7187569ee3337b18212aa Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 16 Jul 2024 10:59:37 +0800 Subject: [PATCH 13/70] refactor: graceful shutdown in standalone mode (#17633) Signed-off-by: Bugen Zhao --- src/cmd_all/src/bin/risingwave.rs | 6 +- src/cmd_all/src/standalone.rs | 168 ++++++++++++++++++++++-------- 2 files changed, 128 insertions(+), 46 deletions(-) diff --git a/src/cmd_all/src/bin/risingwave.rs b/src/cmd_all/src/bin/risingwave.rs index 13c73217b77c..c6ca4675bb81 100644 --- a/src/cmd_all/src/bin/risingwave.rs +++ b/src/cmd_all/src/bin/risingwave.rs @@ -230,8 +230,7 @@ fn standalone(opts: StandaloneOpts) -> ! { .with_target("risingwave_storage", Level::WARN) .with_thread_name(true); risingwave_rt::init_risingwave_logger(settings); - // TODO(shutdown): pass the shutdown token - risingwave_rt::main_okk(|_| risingwave_cmd_all::standalone(opts)); + risingwave_rt::main_okk(|shutdown| risingwave_cmd_all::standalone(opts, shutdown)); } /// For single node, the internals are just a config mapping from its @@ -246,8 +245,7 @@ fn single_node(opts: SingleNodeOpts) -> ! { .with_target("risingwave_storage", Level::WARN) .with_thread_name(true); risingwave_rt::init_risingwave_logger(settings); - // TODO(shutdown): pass the shutdown token - risingwave_rt::main_okk(|_| risingwave_cmd_all::standalone(opts)); + risingwave_rt::main_okk(|shutdown| risingwave_cmd_all::standalone(opts, shutdown)); } #[cfg(test)] diff --git a/src/cmd_all/src/standalone.rs b/src/cmd_all/src/standalone.rs index 26d4aefeb56d..aac4bdff35d4 100644 --- a/src/cmd_all/src/standalone.rs +++ b/src/cmd_all/src/standalone.rs @@ -12,16 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::future::Future; + use clap::Parser; use risingwave_common::config::MetaBackend; use risingwave_common::util::meta_addr::MetaAddressStrategy; +use risingwave_common::util::runtime::BackgroundShutdownRuntime; use risingwave_common::util::tokio_util::sync::CancellationToken; use risingwave_compactor::CompactorOpts; use risingwave_compute::ComputeNodeOpts; use risingwave_frontend::FrontendOpts; use risingwave_meta_node::MetaNodeOpts; use shell_words::split; -use tokio::signal; use crate::common::osstrs; @@ -173,9 +175,65 @@ pub fn parse_standalone_opt_args(opts: &StandaloneOpts) -> ParsedStandaloneOpts } } +/// A service under standalone mode. +struct Service { + name: &'static str, + runtime: BackgroundShutdownRuntime, + main_task: tokio::task::JoinHandle<()>, + shutdown: CancellationToken, +} + +impl Service { + /// Spawn a new tokio runtime and start a service in it. + /// + /// By using a separate runtime, we get better isolation between services. For example, + /// + /// - The logs in the main runtime of each service can be distinguished by the thread name. + /// - Each service can be shutdown cleanly by shutting down its runtime. + fn spawn(name: &'static str, f: F) -> Self + where + F: FnOnce(CancellationToken) -> Fut, + Fut: Future + Send + 'static, + { + let runtime = tokio::runtime::Builder::new_multi_thread() + .thread_name(format!("rw-standalone-{name}")) + .enable_all() + .build() + .unwrap(); + let shutdown = CancellationToken::new(); + let main_task = runtime.spawn(f(shutdown.clone())); + + Self { + name, + runtime: runtime.into(), + main_task, + shutdown, + } + } + + /// Shutdown the service and the runtime gracefully. + /// + /// As long as the main task of the service is resolved after signaling `shutdown`, + /// the service is considered stopped and the runtime will be shutdown. This follows + /// the same convention as described in `risingwave_rt::main_okk`. + async fn shutdown(self) { + tracing::info!("stopping {} service...", self.name); + + self.shutdown.cancel(); + let _ = self.main_task.await; + drop(self.runtime); // shutdown in background + + tracing::info!("{} service stopped", self.name); + } +} + /// For `standalone` mode, we can configure and start multiple services in one process. /// `standalone` mode is meant to be used by our cloud service and docker, /// where we can configure and start multiple services in one process. +/// +/// Services are started in the order of `meta`, `compute`, `frontend`, then `compactor`. +/// When the `shutdown` token is signaled, all services will be stopped gracefully in the +/// reverse order. pub async fn standalone( ParsedStandaloneOpts { meta_opts, @@ -183,61 +241,69 @@ pub async fn standalone( frontend_opts, compactor_opts, }: ParsedStandaloneOpts, + shutdown: CancellationToken, ) { tracing::info!("launching Risingwave in standalone mode"); - // TODO(shutdown): use the real one passed-in - let shutdown = CancellationToken::new(); - - let mut is_in_memory = false; - if let Some(opts) = meta_opts { - is_in_memory = matches!(opts.backend, Some(MetaBackend::Mem)); + let (meta, is_in_memory) = if let Some(opts) = meta_opts { + let is_in_memory = matches!(opts.backend, Some(MetaBackend::Mem)); tracing::info!("starting meta-node thread with cli args: {:?}", opts); - - let shutdown = shutdown.clone(); - let _meta_handle = tokio::spawn(async move { - let dangerous_max_idle_secs = opts.dangerous_max_idle_secs; - risingwave_meta_node::start(opts, shutdown).await; - tracing::warn!("meta is stopped, shutdown all nodes"); - if let Some(idle_exit_secs) = dangerous_max_idle_secs { - eprintln!("{}", - console::style(format_args!( - "RisingWave playground exited after being idle for {idle_exit_secs} seconds. Bye!" - )).bold()); - std::process::exit(0); - } + let service = Service::spawn("meta", |shutdown| { + risingwave_meta_node::start(opts, shutdown) }); + // wait for the service to be ready let mut tries = 0; while !risingwave_meta_node::is_server_started() { if tries % 50 == 0 { tracing::info!("waiting for meta service to be ready..."); } + if service.main_task.is_finished() { + tracing::error!("meta service failed to start, exiting..."); + return; + } tries += 1; tokio::time::sleep(std::time::Duration::from_millis(100)).await; } - } - if let Some(opts) = compute_opts { + + (Some(service), is_in_memory) + } else { + (None, false) + }; + + let compute = if let Some(opts) = compute_opts { tracing::info!("starting compute-node thread with cli args: {:?}", opts); - let shutdown = shutdown.clone(); - let _compute_handle = - tokio::spawn(async move { risingwave_compute::start(opts, shutdown).await }); - } - if let Some(opts) = frontend_opts.clone() { + let service = Service::spawn("compute", |shutdown| { + risingwave_compute::start(opts, shutdown) + }); + Some(service) + } else { + None + }; + + let frontend = if let Some(opts) = frontend_opts.clone() { tracing::info!("starting frontend-node thread with cli args: {:?}", opts); - let shutdown = shutdown.clone(); - let _frontend_handle = - tokio::spawn(async move { risingwave_frontend::start(opts, shutdown).await }); - } - if let Some(opts) = compactor_opts { + let service = Service::spawn("frontend", |shutdown| { + risingwave_frontend::start(opts, shutdown) + }); + Some(service) + } else { + None + }; + + let compactor = if let Some(opts) = compactor_opts { tracing::info!("starting compactor-node thread with cli args: {:?}", opts); - let shutdown = shutdown.clone(); - let _compactor_handle = - tokio::spawn(async move { risingwave_compactor::start(opts, shutdown).await }); - } + let service = Service::spawn("compactor", |shutdown| { + risingwave_compactor::start(opts, shutdown) + }); + Some(service) + } else { + None + }; // wait for log messages to be flushed tokio::time::sleep(std::time::Duration::from_millis(5000)).await; + eprintln!("----------------------------------------"); eprintln!("| RisingWave standalone mode is ready. |"); eprintln!("----------------------------------------"); @@ -252,6 +318,7 @@ It SHOULD NEVER be used in benchmarks and production environment!!!" .bold() ); } + if let Some(opts) = frontend_opts { let host = opts.listen_addr.split(':').next().unwrap_or("localhost"); let port = opts.listen_addr.split(':').last().unwrap_or("4566"); @@ -268,12 +335,29 @@ It SHOULD NEVER be used in benchmarks and production environment!!!" ); } - // TODO: should we join all handles? - // Currently, not all services can be shutdown gracefully, just quit on Ctrl-C now. - // TODO(kwannoel): Why can't be shutdown gracefully? Is it that the service just does not - // support it? - signal::ctrl_c().await.unwrap(); - tracing::info!("Ctrl+C received, now exiting"); + let meta_stopped = meta + .as_ref() + .map(|m| m.shutdown.clone()) + // If there's no meta service, use a dummy token which will never resolve. + .unwrap_or_else(CancellationToken::new) + .cancelled_owned(); + + // Wait for shutdown signals. + tokio::select! { + // Meta service stopped itself, typically due to leadership loss of idleness. + // Directly exit in this case. + _ = meta_stopped => { + tracing::info!("meta service is stopped, terminating..."); + } + + // Shutdown requested by the user. + _ = shutdown.cancelled() => { + for service in [compactor, frontend, compute, meta].into_iter().flatten() { + service.shutdown().await; + } + tracing::info!("all services stopped, bye"); + } + } } #[cfg(test)] From 5aabf542875465e591d65f4b1ad8a4cd925a5800 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 16 Jul 2024 11:11:43 +0800 Subject: [PATCH 14/70] fix(streaming): find and return root actor failure when injection failed (#17672) --- src/meta/src/rpc/ddl_controller_v2.rs | 2 +- src/stream/src/task/barrier_manager.rs | 77 +++++++++++++++++--------- 2 files changed, 51 insertions(+), 28 deletions(-) diff --git a/src/meta/src/rpc/ddl_controller_v2.rs b/src/meta/src/rpc/ddl_controller_v2.rs index 0dabc9b19022..518d6e7b3eaf 100644 --- a/src/meta/src/rpc/ddl_controller_v2.rs +++ b/src/meta/src/rpc/ddl_controller_v2.rs @@ -94,7 +94,7 @@ impl DdlController { { Ok(version) => Ok(version), Err(err) => { - tracing::error!(id = job_id, error = ?err.as_report(), "failed to create streaming job"); + tracing::error!(id = job_id, error = %err.as_report(), "failed to create streaming job"); let event = risingwave_pb::meta::event_log::EventCreateStreamJobFail { id: streaming_job.id(), name: streaming_job.name(), diff --git a/src/stream/src/task/barrier_manager.rs b/src/stream/src/task/barrier_manager.rs index 8d6b4e5056ad..b0ce6ad30540 100644 --- a/src/stream/src/task/barrier_manager.rs +++ b/src/stream/src/task/barrier_manager.rs @@ -120,12 +120,6 @@ impl ControlStreamHandle { } } - fn inspect_result(&mut self, result: StreamResult<()>) { - if let Err(e) = result { - self.reset_stream_with_err(e.to_status_unnamed(Code::Internal)); - } - } - fn send_response(&mut self, response: StreamingControlStreamResponse) { if let Some((sender, _)) = self.pair.as_ref() { if sender.send(Ok(response)).is_err() { @@ -374,7 +368,8 @@ pub(super) struct LocalBarrierWorker { actor_failure_rx: UnboundedReceiver<(ActorId, StreamError)>, - root_failure: Option, + /// Cached result of [`Self::try_find_root_failure`]. + cached_root_failure: Option, } impl LocalBarrierWorker { @@ -403,7 +398,7 @@ impl LocalBarrierWorker { current_shared_context: shared_context, barrier_event_rx: event_rx, actor_failure_rx: failure_rx, - root_failure: None, + cached_root_failure: None, } } @@ -431,14 +426,16 @@ impl LocalBarrierWorker { } completed_epoch = self.state.next_completed_epoch() => { let result = self.on_epoch_completed(completed_epoch); - self.control_stream_handle.inspect_result(result); + if let Err(err) = result { + self.notify_other_failure(err, "failed to complete epoch").await; + } }, event = self.barrier_event_rx.recv() => { self.handle_barrier_event(event.expect("should not be none")); }, failure = self.actor_failure_rx.recv() => { let (actor_id, err) = failure.unwrap(); - self.notify_failure(actor_id, err).await; + self.notify_actor_failure(actor_id, err).await; }, actor_op = actor_op_rx.recv() => { if let Some(actor_op) = actor_op { @@ -462,7 +459,9 @@ impl LocalBarrierWorker { }, request = self.control_stream_handle.next_request() => { let result = self.handle_streaming_control_request(request); - self.control_stream_handle.inspect_result(result); + if let Err(err) = result { + self.notify_other_failure(err, "failed to inject barrier").await; + } }, } } @@ -661,6 +660,10 @@ impl LocalBarrierWorker { /// Broadcast a barrier to all senders. Save a receiver which will get notified when this /// barrier is finished, in managed mode. + /// + /// Note that the error returned here is typically a [`StreamError::barrier_send`], which is not + /// the root cause of the failure. The caller should then call [`Self::try_find_root_failure`] + /// to find the root cause. fn send_barrier( &mut self, barrier: &Barrier, @@ -668,8 +671,7 @@ impl LocalBarrierWorker { to_collect: HashSet, table_ids: HashSet, ) -> StreamResult<()> { - #[cfg(not(test))] - { + if !cfg!(test) { // The barrier might be outdated and been injected after recovery in some certain extreme // scenarios. So some newly creating actors in the barrier are possibly not rebuilt during // recovery. Check it here and return an error here if some actors are not found to @@ -702,12 +704,15 @@ impl LocalBarrierWorker { ); for actor_id in &to_collect { - if let Some(e) = self.failure_actors.get(actor_id) { + if self.failure_actors.contains_key(actor_id) { // The failure actors could exit before the barrier is issued, while their // up-downstream actors could be stuck somehow. Return error directly to trigger the // recovery. - // try_find_root_failure is not used merely because it requires async. - return Err(self.root_failure.clone().unwrap_or(e.clone())); + return Err(StreamError::barrier_send( + barrier.clone(), + *actor_id, + "actor has already failed", + )); } } @@ -763,11 +768,11 @@ impl LocalBarrierWorker { self.state.collect(actor_id, barrier) } - /// When a actor exit unexpectedly, it should report this event using this function, so meta - /// will notice actor's exit while collecting. - async fn notify_failure(&mut self, actor_id: ActorId, err: StreamError) { + /// When a actor exit unexpectedly, the error is reported using this function. The control stream + /// will be reset and the meta service will then trigger recovery. + async fn notify_actor_failure(&mut self, actor_id: ActorId, err: StreamError) { self.add_failure(actor_id, err.clone()); - let root_err = self.try_find_root_failure().await; + let root_err = self.try_find_root_failure().await.unwrap(); // always `Some` because we just added one let failed_epochs = self.state.epochs_await_on_actor(actor_id).collect_vec(); if !failed_epochs.is_empty() { @@ -782,6 +787,21 @@ impl LocalBarrierWorker { } } + /// When some other failure happens (like failed to send barrier), the error is reported using + /// this function. The control stream will be reset and the meta service will then trigger recovery. + /// + /// This is similar to [`Self::notify_actor_failure`], but since there's not always an actor failure, + /// the given `err` will be used if there's no root failure found. + async fn notify_other_failure(&mut self, err: StreamError, message: impl Into) { + let root_err = self.try_find_root_failure().await.unwrap_or(err); + + self.control_stream_handle.reset_stream_with_err( + anyhow!(root_err) + .context(message.into()) + .to_status_unnamed(Code::Internal), + ); + } + fn add_failure(&mut self, actor_id: ActorId, err: StreamError) { if let Some(prev_err) = self.failure_actors.insert(actor_id, err) { warn!( @@ -792,9 +812,12 @@ impl LocalBarrierWorker { } } - async fn try_find_root_failure(&mut self) -> StreamError { - if let Some(root_failure) = &self.root_failure { - return root_failure.clone(); + /// Collect actor errors for a while and find the one that might be the root cause. + /// + /// Returns `None` if there's no actor error received. + async fn try_find_root_failure(&mut self) -> Option { + if self.cached_root_failure.is_some() { + return self.cached_root_failure.clone(); } // fetch more actor errors within a timeout let _ = tokio::time::timeout(Duration::from_secs(3), async { @@ -803,11 +826,9 @@ impl LocalBarrierWorker { } }) .await; - self.root_failure = try_find_root_actor_failure(self.failure_actors.values()); + self.cached_root_failure = try_find_root_actor_failure(self.failure_actors.values()); - self.root_failure - .clone() - .expect("failure actors should not be empty") + self.cached_root_failure.clone() } } @@ -915,6 +936,8 @@ impl LocalBarrierManager { } /// Tries to find the root cause of actor failures, based on hard-coded rules. +/// +/// Returns `None` if the input is empty. pub fn try_find_root_actor_failure<'a>( actor_errors: impl IntoIterator, ) -> Option { From fe209424c208cb7343913a028773b6274d38e0a3 Mon Sep 17 00:00:00 2001 From: zwang28 <70626450+zwang28@users.noreply.github.com> Date: Tue, 16 Jul 2024 12:17:14 +0800 Subject: [PATCH 15/70] fix(storage): license time travel feature (#17683) --- e2e_test/time_travel/license.slt | 37 ++++++++++++ e2e_test/time_travel/syntax.slt | 58 +++++++++++++++++++ src/batch/src/executor/row_seq_scan.rs | 2 +- src/frontend/src/optimizer/plan_node/utils.rs | 4 ++ src/license/src/feature.rs | 1 + 5 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 e2e_test/time_travel/license.slt create mode 100644 e2e_test/time_travel/syntax.slt diff --git a/e2e_test/time_travel/license.slt b/e2e_test/time_travel/license.slt new file mode 100644 index 000000000000..c5d200c05116 --- /dev/null +++ b/e2e_test/time_travel/license.slt @@ -0,0 +1,37 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +SET QUERY_MODE TO local; + +statement ok +ALTER SYSTEM SET license_key TO ''; + +statement ok +CREATE TABLE t (k INT); + +query error +SELECT * FROM t FOR SYSTEM_TIME AS OF now(); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Scheduler error + 2: feature TimeTravel is only available for tier Paid and above, while the current tier is Free + +Hint: You may want to set a license key with `ALTER SYSTEM SET license_key = '...';` command. + + + +statement ok +ALTER SYSTEM SET license_key TO DEFAULT; + +query I +SELECT * FROM t FOR SYSTEM_TIME AS OF now(); +---- + +statement ok +DROP TABLE t; + +statement ok +SET QUERY_MODE TO auto; \ No newline at end of file diff --git a/e2e_test/time_travel/syntax.slt b/e2e_test/time_travel/syntax.slt new file mode 100644 index 000000000000..6c3408a27676 --- /dev/null +++ b/e2e_test/time_travel/syntax.slt @@ -0,0 +1,58 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +SET QUERY_MODE TO local; + +statement ok +CREATE TABLE t (k INT); + +query error +SELECT * FROM t FOR SYSTEM_TIME AS OF 963716300; +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: gRPC request to batch service failed: Internal error + 2: Storage error + 3: Hummock error + 4: Meta error: gRPC request to hummock service failed: Internal error: time travel: version not found for epoch 0 + + +query error +SELECT * FROM t FOR SYSTEM_TIME AS OF '2000-02-20T12:13:14-08:30'; +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: gRPC request to batch service failed: Internal error + 2: Storage error + 3: Hummock error + 4: Meta error: gRPC request to hummock service failed: Internal error: time travel: version not found for epoch 0 + + +query error +SELECT * FROM t FOR SYSTEM_TIME AS OF NOW() - '100' YEAR; +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: gRPC request to batch service failed: Internal error + 2: Storage error + 3: Hummock error + 4: Meta error: gRPC request to hummock service failed: Internal error: time travel: version not found for epoch 0 + + +query I +SELECT * FROM t FOR SYSTEM_TIME AS OF '2130-02-10T12:13:14-08:30'; +---- + +query I +SELECT * FROM t FOR SYSTEM_TIME AS OF 5066075780; +---- + +statement ok +DROP TABLE t; + +statement ok +SET QUERY_MODE TO auto; \ No newline at end of file diff --git a/src/batch/src/executor/row_seq_scan.rs b/src/batch/src/executor/row_seq_scan.rs index 9e3710a7040b..b8287147c675 100644 --- a/src/batch/src/executor/row_seq_scan.rs +++ b/src/batch/src/executor/row_seq_scan.rs @@ -500,6 +500,6 @@ impl RowSeqScanExecutor { pub fn unix_timestamp_sec_to_epoch(ts: i64) -> risingwave_common::util::epoch::Epoch { let ts = ts.checked_add(1).unwrap(); risingwave_common::util::epoch::Epoch::from_unix_millis_or_earliest( - u64::try_from(ts).unwrap().checked_mul(1000).unwrap(), + u64::try_from(ts).unwrap_or(0).checked_mul(1000).unwrap(), ) } diff --git a/src/frontend/src/optimizer/plan_node/utils.rs b/src/frontend/src/optimizer/plan_node/utils.rs index afae9cf64ca0..288b0957db19 100644 --- a/src/frontend/src/optimizer/plan_node/utils.rs +++ b/src/frontend/src/optimizer/plan_node/utils.rs @@ -325,6 +325,7 @@ macro_rules! plan_node_name { }; } pub(crate) use plan_node_name; +use risingwave_common::license::Feature; use risingwave_common::types::{DataType, Interval}; use risingwave_expr::aggregate::AggKind; use risingwave_pb::plan_common::as_of::AsOfType; @@ -397,6 +398,9 @@ pub fn to_pb_time_travel_as_of(a: &Option) -> Result> { let Some(ref a) = a else { return Ok(None); }; + Feature::TimeTravel + .check_available() + .map_err(|e| anyhow::anyhow!(e))?; let as_of_type = match a { AsOf::ProcessTime => { return Err(ErrorCode::NotSupported( diff --git a/src/license/src/feature.rs b/src/license/src/feature.rs index 302538cc3ecc..144eaacc580e 100644 --- a/src/license/src/feature.rs +++ b/src/license/src/feature.rs @@ -44,6 +44,7 @@ macro_rules! for_all_features { $macro! { // name min tier doc { TestPaid, Paid, "A dummy feature that's only available on paid tier for testing purposes." }, + { TimeTravel, Paid, "Query historical data within the retention period."}, } }; } From e0be76dbb7fb38615d11eba9297d543ca5829262 Mon Sep 17 00:00:00 2001 From: xxchan Date: Tue, 16 Jul 2024 15:00:27 +0800 Subject: [PATCH 16/70] refactor(types): remove dead code and minor (#17691) --- src/common/src/array/arrow/arrow_impl.rs | 5 +- src/common/src/array/proto_reader.rs | 2 +- src/common/src/types/macros.rs | 1 + src/common/src/types/mod.rs | 202 ++++++------------ .../system_catalog/pg_catalog/pg_cast.rs | 4 +- 5 files changed, 76 insertions(+), 138 deletions(-) diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index 1d5e5816efe0..2ecef10e7aa3 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -54,6 +54,9 @@ use crate::types::*; use crate::util::iter_util::ZipEqFast; /// Defines how to convert RisingWave arrays to Arrow arrays. +/// +/// This trait allows for customized conversion logic for different external systems using Arrow. +/// The default implementation is based on the `From` implemented in this mod. pub trait ToArrow { /// Converts RisingWave `DataChunk` to Arrow `RecordBatch` with specified schema. /// @@ -777,7 +780,7 @@ converts!(IntervalArray, arrow_array::IntervalMonthDayNanoArray, @map); converts!(SerialArray, arrow_array::Int64Array, @map); /// Converts RisingWave value from and into Arrow value. -pub trait FromIntoArrow { +trait FromIntoArrow { /// The corresponding element type in the Arrow array. type ArrowType; fn from_arrow(value: Self::ArrowType) -> Self; diff --git a/src/common/src/array/proto_reader.rs b/src/common/src/array/proto_reader.rs index 073ad0b3de7b..7c3b05437770 100644 --- a/src/common/src/array/proto_reader.rs +++ b/src/common/src/array/proto_reader.rs @@ -26,6 +26,7 @@ impl ArrayImpl { pub fn from_protobuf(array: &PbArray, cardinality: usize) -> ArrayResult { use crate::array::value_reader::*; let array = match array.array_type() { + PbArrayType::Unspecified => unreachable!(), PbArrayType::Int16 => read_numeric_array::(array, cardinality)?, PbArrayType::Int32 => read_numeric_array::(array, cardinality)?, PbArrayType::Int64 => read_numeric_array::(array, cardinality)?, @@ -49,7 +50,6 @@ impl ArrayImpl { PbArrayType::Jsonb => JsonbArray::from_protobuf(array)?, PbArrayType::Struct => StructArray::from_protobuf(array)?, PbArrayType::List => ListArray::from_protobuf(array)?, - PbArrayType::Unspecified => unreachable!(), PbArrayType::Bytea => { read_string_array::(array, cardinality)? } diff --git a/src/common/src/types/macros.rs b/src/common/src/types/macros.rs index 35f106aafdff..520e4ab8f45e 100644 --- a/src/common/src/types/macros.rs +++ b/src/common/src/types/macros.rs @@ -39,6 +39,7 @@ macro_rules! for_all_variants { ($macro:ident $(, $x:tt)*) => { $macro! { $($x, )* + //data_type variant_name suffix_name scalar scalar_ref array builder { Int16, Int16, int16, i16, i16, $crate::array::I16Array, $crate::array::I16ArrayBuilder }, { Int32, Int32, int32, i32, i32, $crate::array::I32Array, $crate::array::I32ArrayBuilder }, { Int64, Int64, int64, i64, i64, $crate::array::I64Array, $crate::array::I64ArrayBuilder }, diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 3b02b8c38d02..91bebde846f0 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -179,67 +179,39 @@ impl std::str::FromStr for Box { impl ZeroHeapSize for DataType {} -impl DataTypeName { - pub fn is_scalar(&self) -> bool { - match self { - DataTypeName::Boolean - | DataTypeName::Int16 - | DataTypeName::Int32 - | DataTypeName::Int64 - | DataTypeName::Int256 - | DataTypeName::Serial - | DataTypeName::Decimal - | DataTypeName::Float32 - | DataTypeName::Float64 - | DataTypeName::Varchar - | DataTypeName::Date - | DataTypeName::Timestamp - | DataTypeName::Timestamptz - | DataTypeName::Time - | DataTypeName::Bytea - | DataTypeName::Jsonb - | DataTypeName::Interval => true, - - DataTypeName::Struct | DataTypeName::List => false, - } - } +impl TryFrom for DataType { + type Error = &'static str; - pub fn to_type(self) -> Option { - let t = match self { - DataTypeName::Boolean => DataType::Boolean, - DataTypeName::Int16 => DataType::Int16, - DataTypeName::Int32 => DataType::Int32, - DataTypeName::Int64 => DataType::Int64, - DataTypeName::Int256 => DataType::Int256, - DataTypeName::Serial => DataType::Serial, - DataTypeName::Decimal => DataType::Decimal, - DataTypeName::Float32 => DataType::Float32, - DataTypeName::Float64 => DataType::Float64, - DataTypeName::Varchar => DataType::Varchar, - DataTypeName::Bytea => DataType::Bytea, - DataTypeName::Date => DataType::Date, - DataTypeName::Timestamp => DataType::Timestamp, - DataTypeName::Timestamptz => DataType::Timestamptz, - DataTypeName::Time => DataType::Time, - DataTypeName::Interval => DataType::Interval, - DataTypeName::Jsonb => DataType::Jsonb, + fn try_from(type_name: DataTypeName) -> Result { + match type_name { + DataTypeName::Boolean => Ok(DataType::Boolean), + DataTypeName::Int16 => Ok(DataType::Int16), + DataTypeName::Int32 => Ok(DataType::Int32), + DataTypeName::Int64 => Ok(DataType::Int64), + DataTypeName::Int256 => Ok(DataType::Int256), + DataTypeName::Serial => Ok(DataType::Serial), + DataTypeName::Decimal => Ok(DataType::Decimal), + DataTypeName::Float32 => Ok(DataType::Float32), + DataTypeName::Float64 => Ok(DataType::Float64), + DataTypeName::Varchar => Ok(DataType::Varchar), + DataTypeName::Bytea => Ok(DataType::Bytea), + DataTypeName::Date => Ok(DataType::Date), + DataTypeName::Timestamp => Ok(DataType::Timestamp), + DataTypeName::Timestamptz => Ok(DataType::Timestamptz), + DataTypeName::Time => Ok(DataType::Time), + DataTypeName::Interval => Ok(DataType::Interval), + DataTypeName::Jsonb => Ok(DataType::Jsonb), DataTypeName::Struct | DataTypeName::List => { - return None; + Err("Functions returning struct or list can not be inferred. Please use `FunctionCall::new_unchecked`.") } - }; - Some(t) - } -} - -impl From for DataType { - fn from(type_name: DataTypeName) -> Self { - type_name.to_type().unwrap_or_else(|| panic!("Functions returning struct or list can not be inferred. Please use `FunctionCall::new_unchecked`.")) + } } } impl From<&PbDataType> for DataType { fn from(proto: &PbDataType) -> DataType { match proto.get_type_name().expect("missing type field") { + PbTypeName::TypeUnspecified => unreachable!(), PbTypeName::Int16 => DataType::Int16, PbTypeName::Int32 => DataType::Int32, PbTypeName::Int64 => DataType::Int64, @@ -265,7 +237,6 @@ impl From<&PbDataType> for DataType { // The first (and only) item is the list element type. Box::new((&proto.field_type[0]).into()), ), - PbTypeName::TypeUnspecified => unreachable!(), PbTypeName::Int256 => DataType::Int256, } } @@ -337,27 +308,7 @@ impl DataType { } pub fn prost_type_name(&self) -> PbTypeName { - match self { - DataType::Int16 => PbTypeName::Int16, - DataType::Int32 => PbTypeName::Int32, - DataType::Int64 => PbTypeName::Int64, - DataType::Int256 => PbTypeName::Int256, - DataType::Serial => PbTypeName::Serial, - DataType::Float32 => PbTypeName::Float, - DataType::Float64 => PbTypeName::Double, - DataType::Boolean => PbTypeName::Boolean, - DataType::Varchar => PbTypeName::Varchar, - DataType::Date => PbTypeName::Date, - DataType::Time => PbTypeName::Time, - DataType::Timestamp => PbTypeName::Timestamp, - DataType::Timestamptz => PbTypeName::Timestamptz, - DataType::Decimal => PbTypeName::Decimal, - DataType::Interval => PbTypeName::Interval, - DataType::Jsonb => PbTypeName::Jsonb, - DataType::Struct { .. } => PbTypeName::Struct, - DataType::List { .. } => PbTypeName::List, - DataType::Bytea => PbTypeName::Bytea, - } + DataTypeName::from(self).into() } pub fn to_protobuf(&self) -> PbDataType { @@ -374,7 +325,23 @@ impl DataType { DataType::List(datatype) => { pb.field_type = vec![datatype.to_protobuf()]; } - _ => {} + DataType::Boolean + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal + | DataType::Date + | DataType::Varchar + | DataType::Time + | DataType::Timestamp + | DataType::Timestamptz + | DataType::Interval + | DataType::Bytea + | DataType::Jsonb + | DataType::Serial + | DataType::Int256 => (), } pb } @@ -392,10 +359,6 @@ impl DataType { ) } - pub fn is_scalar(&self) -> bool { - DataTypeName::from(self).is_scalar() - } - pub fn is_array(&self) -> bool { matches!(self, DataType::List(_)) } @@ -440,37 +403,6 @@ impl DataType { } } - /// WARNING: Currently this should only be used in `WatermarkFilterExecutor`. Please be careful - /// if you want to use this. - pub fn min_value(&self) -> ScalarImpl { - match self { - DataType::Int16 => ScalarImpl::Int16(i16::MIN), - DataType::Int32 => ScalarImpl::Int32(i32::MIN), - DataType::Int64 => ScalarImpl::Int64(i64::MIN), - DataType::Int256 => ScalarImpl::Int256(Int256::min_value()), - DataType::Serial => ScalarImpl::Serial(Serial::from(i64::MIN)), - DataType::Float32 => ScalarImpl::Float32(F32::neg_infinity()), - DataType::Float64 => ScalarImpl::Float64(F64::neg_infinity()), - DataType::Boolean => ScalarImpl::Bool(false), - DataType::Varchar => ScalarImpl::Utf8("".into()), - DataType::Bytea => ScalarImpl::Bytea("".to_string().into_bytes().into()), - DataType::Date => ScalarImpl::Date(Date::MIN), - DataType::Time => ScalarImpl::Time(Time::from_hms_uncheck(0, 0, 0)), - DataType::Timestamp => ScalarImpl::Timestamp(Timestamp::MIN), - DataType::Timestamptz => ScalarImpl::Timestamptz(Timestamptz::MIN), - DataType::Decimal => ScalarImpl::Decimal(Decimal::NegativeInf), - DataType::Interval => ScalarImpl::Interval(Interval::MIN), - DataType::Jsonb => ScalarImpl::Jsonb(JsonbVal::null()), // NOT `min` #7981 - DataType::Struct(data_types) => ScalarImpl::Struct(StructValue::new( - data_types - .types() - .map(|data_type| Some(data_type.min_value())) - .collect_vec(), - )), - DataType::List(data_type) => ScalarImpl::List(ListValue::empty(data_type)), - } - } - /// Return a new type that removes the outer list. /// /// ``` @@ -513,28 +445,32 @@ impl From for PbDataType { } } -/// Common trait bounds of scalar and scalar reference types. -/// -/// NOTE(rc): `Hash` is not in the trait bound list, it's implemented as [`ScalarRef::hash_scalar`]. -pub trait ScalarBounds = Debug - + Send - + Sync - + Clone - + PartialEq - + Eq - // in default ascending order - + PartialOrd - + Ord - + TryFrom - // `ScalarImpl`/`ScalarRefImpl` - + Into; +mod private { + use super::*; + + /// Common trait bounds of scalar and scalar reference types. + /// + /// NOTE(rc): `Hash` is not in the trait bound list, it's implemented as [`ScalarRef::hash_scalar`]. + pub trait ScalarBounds = Debug + + Send + + Sync + + Clone + + PartialEq + + Eq + // in default ascending order + + PartialOrd + + Ord + + TryFrom + // `ScalarImpl`/`ScalarRefImpl` + + Into; +} /// `Scalar` is a trait over all possible owned types in the evaluation /// framework. /// /// `Scalar` is reciprocal to `ScalarRef`. Use `as_scalar_ref` to get a /// reference which has the same lifetime as `self`. -pub trait Scalar: ScalarBounds + 'static { +pub trait Scalar: private::ScalarBounds + 'static { /// Type for reference of `Scalar` type ScalarRefType<'a>: ScalarRef<'a, ScalarType = Self> + 'a where @@ -548,17 +484,12 @@ pub trait Scalar: ScalarBounds + 'static { } } -/// Convert an `Option` to corresponding `Option`. -pub fn option_as_scalar_ref(scalar: &Option) -> Option> { - scalar.as_ref().map(|x| x.as_scalar_ref()) -} - /// `ScalarRef` is a trait over all possible references in the evaluation /// framework. /// /// `ScalarRef` is reciprocal to `Scalar`. Use `to_owned_scalar` to get an /// owned scalar. -pub trait ScalarRef<'a>: ScalarBounds> + 'a + Copy { +pub trait ScalarRef<'a>: private::ScalarBounds> + 'a + Copy { /// `ScalarType` is the owned type of current `ScalarRef`. type ScalarType: Scalar = Self>; @@ -653,6 +584,9 @@ impl ToDatumRef for DatumRef<'_> { } /// To make sure there is `as_scalar_ref` for all scalar ref types. +/// See +/// +/// This is used by the expr macro. pub trait SelfAsScalarRef { fn as_scalar_ref(&self) -> Self; } @@ -1021,7 +955,7 @@ impl ScalarRefImpl<'_> { } impl ScalarImpl { - /// Serialize the scalar. + /// Serialize the scalar into the `memcomparable` format. pub fn serialize( &self, ser: &mut memcomparable::Serializer, @@ -1029,7 +963,7 @@ impl ScalarImpl { self.as_scalar_ref_impl().serialize(ser) } - /// Deserialize the scalar. + /// Deserialize the scalar from the `memcomparable` format. pub fn deserialize( ty: &DataType, de: &mut memcomparable::Deserializer, diff --git a/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs b/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs index 11bcabcde0f6..d5b1332c25b3 100644 --- a/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs +++ b/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs @@ -38,8 +38,8 @@ fn read_pg_cast(_: &SysCatalogReaderImpl) -> Vec { .enumerate() .map(|(idx, (src, target, ctx))| PgCast { oid: idx as i32, - castsource: DataType::from(*src).to_oid(), - casttarget: DataType::from(*target).to_oid(), + castsource: DataType::try_from(*src).unwrap().to_oid(), + casttarget: DataType::try_from(*target).unwrap().to_oid(), castcontext: ctx.to_string(), }) .collect() From 0ea5f056d160771c7573d177e436f225b442225e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 16 Jul 2024 17:09:55 +0800 Subject: [PATCH 17/70] build(deps): bump curve25519-dalek from 4.1.2 to 4.1.3 (#17688) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 68a9060f529d..08d753895393 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3436,16 +3436,15 @@ dependencies = [ [[package]] name = "curve25519-dalek" -version = "4.1.2" +version = "4.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a677b8922c94e01bdbb12126b0bc852f00447528dee1782229af9c720c3f348" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" dependencies = [ "cfg-if", "cpufeatures", "curve25519-dalek-derive", "digest", "fiat-crypto", - "platforms", "rustc_version 0.4.0", "subtle", "zeroize", @@ -9268,12 +9267,6 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" -[[package]] -name = "platforms" -version = "3.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4503fa043bf02cee09a9582e9554b4c6403b2ef55e4612e96561d294419429f8" - [[package]] name = "plotlib" version = "0.5.1" From dc3a3fc153fa2ad38a63ae0f289895ad051148c8 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Tue, 16 Jul 2024 18:15:06 +0800 Subject: [PATCH 18/70] refactor(optimizer): some clean up for stream nodes' constructors (#17696) Signed-off-by: Richard Chien --- .../optimizer/plan_node/stream_changelog.rs | 5 ++-- .../optimizer/plan_node/stream_delta_join.rs | 1 + .../src/optimizer/plan_node/stream_expand.rs | 8 ++--- .../optimizer/plan_node/stream_group_topn.rs | 2 ++ .../optimizer/plan_node/stream_hop_window.rs | 20 ++++++------- .../src/optimizer/plan_node/stream_project.rs | 26 +++++++++-------- .../optimizer/plan_node/stream_project_set.rs | 29 +++++++++++-------- src/frontend/src/optimizer/plan_node/utils.rs | 28 ++++++++---------- 8 files changed, 60 insertions(+), 59 deletions(-) diff --git a/src/frontend/src/optimizer/plan_node/stream_changelog.rs b/src/frontend/src/optimizer/plan_node/stream_changelog.rs index 0ee696c58067..b02c5eeb0c35 100644 --- a/src/frontend/src/optimizer/plan_node/stream_changelog.rs +++ b/src/frontend/src/optimizer/plan_node/stream_changelog.rs @@ -33,12 +33,13 @@ impl StreamChangeLog { pub fn new(core: generic::ChangeLog) -> Self { let input = core.input.clone(); let dist = input.distribution().clone(); + let input_len = input.schema().len(); // Filter executor won't change the append-only behavior of the stream. let mut watermark_columns = input.watermark_columns().clone(); if core.need_op { - watermark_columns.grow(input.watermark_columns().len() + 2); + watermark_columns.grow(input_len + 2); } else { - watermark_columns.grow(input.watermark_columns().len() + 1); + watermark_columns.grow(input_len + 1); } let base = PlanBase::new_stream_with_core( &core, diff --git a/src/frontend/src/optimizer/plan_node/stream_delta_join.rs b/src/frontend/src/optimizer/plan_node/stream_delta_join.rs index 49257676bc00..7a99c8f7955b 100644 --- a/src/frontend/src/optimizer/plan_node/stream_delta_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_delta_join.rs @@ -68,6 +68,7 @@ impl StreamDeltaJoin { let watermark_columns = from_left.bitand(&from_right); core.i2o_col_mapping().rewrite_bitset(&watermark_columns) }; + // TODO: derive from input let base = PlanBase::new_stream_with_core( &core, diff --git a/src/frontend/src/optimizer/plan_node/stream_expand.rs b/src/frontend/src/optimizer/plan_node/stream_expand.rs index 5eefede3469c..4f38e95cdfea 100644 --- a/src/frontend/src/optimizer/plan_node/stream_expand.rs +++ b/src/frontend/src/optimizer/plan_node/stream_expand.rs @@ -33,6 +33,7 @@ pub struct StreamExpand { impl StreamExpand { pub fn new(core: generic::Expand) -> Self { let input = core.input.clone(); + let input_len = input.schema().len(); let dist = match input.distribution() { Distribution::Single => Distribution::Single, @@ -43,12 +44,7 @@ impl StreamExpand { }; let mut watermark_columns = FixedBitSet::with_capacity(core.output_len()); - watermark_columns.extend( - input - .watermark_columns() - .ones() - .map(|idx| idx + input.schema().len()), - ); + watermark_columns.extend(input.watermark_columns().ones().map(|idx| idx + input_len)); let base = PlanBase::new_stream_with_core( &core, diff --git a/src/frontend/src/optimizer/plan_node/stream_group_topn.rs b/src/frontend/src/optimizer/plan_node/stream_group_topn.rs index 8500e24b0fd9..0cd8edc996c8 100644 --- a/src/frontend/src/optimizer/plan_node/stream_group_topn.rs +++ b/src/frontend/src/optimizer/plan_node/stream_group_topn.rs @@ -42,6 +42,8 @@ impl StreamGroupTopN { let input = &core.input; let schema = input.schema().clone(); + // FIXME(rc): Actually only watermark messages on the first group-by column are propagated + // acccoring to the current GroupTopN implementation. This should be fixed. let watermark_columns = if input.append_only() { input.watermark_columns().clone() } else { diff --git a/src/frontend/src/optimizer/plan_node/stream_hop_window.rs b/src/frontend/src/optimizer/plan_node/stream_hop_window.rs index 0cdddf77ed0e..a94dfbe788f8 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hop_window.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hop_window.rs @@ -13,7 +13,6 @@ // limitations under the License. use pretty_xmlish::XmlNode; -use risingwave_common::util::column_index_mapping::ColIndexMapping; use risingwave_pb::stream_plan::stream_node::PbNodeBody; use risingwave_pb::stream_plan::HopWindowNode; @@ -41,29 +40,28 @@ impl StreamHopWindow { window_end_exprs: Vec, ) -> Self { let input = core.input.clone(); - let i2o = core.i2o_col_mapping(); - let dist = i2o.rewrite_provided_distribution(input.distribution()); + let dist = core + .i2o_col_mapping() + .rewrite_provided_distribution(input.distribution()); - let mut watermark_columns = input.watermark_columns().clone(); + let input2internal = core.input2internal_col_mapping(); + let internal2output = core.internal2output_col_mapping(); + + let mut watermark_columns = input2internal.rewrite_bitset(input.watermark_columns()); watermark_columns.grow(core.internal_column_num()); - if watermark_columns.contains(core.time_col.index) { + if input.watermark_columns().contains(core.time_col.index) { // Watermark on `time_col` indicates watermark on both `window_start` and `window_end`. watermark_columns.insert(core.internal_window_start_col_idx()); watermark_columns.insert(core.internal_window_end_col_idx()); } - let watermark_columns = ColIndexMapping::with_remaining_columns( - &core.output_indices, - core.internal_column_num(), - ) - .rewrite_bitset(&watermark_columns); let base = PlanBase::new_stream_with_core( &core, dist, input.append_only(), input.emit_on_window_close(), - watermark_columns, + internal2output.rewrite_bitset(&watermark_columns), ); Self { base, diff --git a/src/frontend/src/optimizer/plan_node/stream_project.rs b/src/frontend/src/optimizer/plan_node/stream_project.rs index eae1bd5a34d5..e5828a326706 100644 --- a/src/frontend/src/optimizer/plan_node/stream_project.rs +++ b/src/frontend/src/optimizer/plan_node/stream_project.rs @@ -81,24 +81,26 @@ impl StreamProject { let mut watermark_derivations = vec![]; let mut nondecreasing_exprs = vec![]; - let mut watermark_columns = FixedBitSet::with_capacity(core.exprs.len()); + let mut out_watermark_columns = FixedBitSet::with_capacity(core.exprs.len()); for (expr_idx, expr) in core.exprs.iter().enumerate() { use monotonicity_variants::*; match analyze_monotonicity(expr) { + Inherent(Constant) => { + // XXX(rc): we can produce one watermark on each recovery for this case. + } + Inherent(monotonicity) => { + if monotonicity.is_non_decreasing() { + nondecreasing_exprs.push(expr_idx); // to produce watermarks + out_watermark_columns.insert(expr_idx); + } + } FollowingInput(input_idx) => { if input.watermark_columns().contains(input_idx) { - watermark_derivations.push((input_idx, expr_idx)); - watermark_columns.insert(expr_idx); + watermark_derivations.push((input_idx, expr_idx)); // to propagate watermarks + out_watermark_columns.insert(expr_idx); } } - Inherent(NonDecreasing) => { - nondecreasing_exprs.push(expr_idx); - watermark_columns.insert(expr_idx); - } - Inherent(Constant) => { - // XXX(rc): we can produce one watermark on each recovery for this case. - } - Inherent(_) | _FollowingInputInversely(_) => {} + _FollowingInputInversely(_) => {} } } // Project executor won't change the append-only behavior of the stream, so it depends on @@ -108,7 +110,7 @@ impl StreamProject { distribution, input.append_only(), input.emit_on_window_close(), - watermark_columns, + out_watermark_columns, ); StreamProject { diff --git a/src/frontend/src/optimizer/plan_node/stream_project_set.rs b/src/frontend/src/optimizer/plan_node/stream_project_set.rs index b65d4e8da0b5..4630e1c62c83 100644 --- a/src/frontend/src/optimizer/plan_node/stream_project_set.rs +++ b/src/frontend/src/optimizer/plan_node/stream_project_set.rs @@ -47,24 +47,29 @@ impl StreamProjectSet { let mut watermark_derivations = vec![]; let mut nondecreasing_exprs = vec![]; - let mut watermark_columns = FixedBitSet::with_capacity(core.output_len()); + let mut out_watermark_columns = FixedBitSet::with_capacity(core.output_len()); for (expr_idx, expr) in core.select_list.iter().enumerate() { + let out_expr_idx = expr_idx + 1; + use monotonicity_variants::*; match analyze_monotonicity(expr) { + Inherent(Constant) => { + // XXX(rc): we can produce one watermark on each recovery for this case. + } + Inherent(monotonicity) => { + if monotonicity.is_non_decreasing() { + // FIXME(rc): we need to check expr is not table function + nondecreasing_exprs.push(expr_idx); // to produce watermarks + out_watermark_columns.insert(out_expr_idx); + } + } FollowingInput(input_idx) => { if input.watermark_columns().contains(input_idx) { - watermark_derivations.push((input_idx, expr_idx)); - watermark_columns.insert(expr_idx + 1); + watermark_derivations.push((input_idx, expr_idx)); // to propagate watermarks + out_watermark_columns.insert(out_expr_idx); } } - Inherent(NonDecreasing) => { - nondecreasing_exprs.push(expr_idx); - watermark_columns.insert(expr_idx + 1); - } - Inherent(Constant) => { - // XXX(rc): we can produce one watermark on each recovery for this case. - } - Inherent(_) | _FollowingInputInversely(_) => {} + _FollowingInputInversely(_) => {} } } @@ -75,7 +80,7 @@ impl StreamProjectSet { distribution, input.append_only(), input.emit_on_window_close(), - watermark_columns, + out_watermark_columns, ); StreamProjectSet { base, diff --git a/src/frontend/src/optimizer/plan_node/utils.rs b/src/frontend/src/optimizer/plan_node/utils.rs index 288b0957db19..155381ab4310 100644 --- a/src/frontend/src/optimizer/plan_node/utils.rs +++ b/src/frontend/src/optimizer/plan_node/utils.rs @@ -106,11 +106,6 @@ impl TableCatalogBuilder { self.value_indices = Some(value_indices); } - #[allow(dead_code)] - pub fn set_watermark_columns(&mut self, watermark_columns: FixedBitSet) { - self.watermark_columns = Some(watermark_columns); - } - pub fn set_dist_key_in_pk(&mut self, dist_key_in_pk: Vec) { self.dist_key_in_pk = Some(dist_key_in_pk); } @@ -236,21 +231,22 @@ pub(crate) fn watermark_pretty<'a>( watermark_columns: &FixedBitSet, schema: &Schema, ) -> Option> { - if watermark_columns.count_ones(..) > 0 { - Some(watermark_fields_pretty(watermark_columns.ones(), schema)) - } else { - None - } + iter_fields_pretty(watermark_columns.ones(), schema) } -pub(crate) fn watermark_fields_pretty<'a>( - watermark_columns: impl Iterator, + +pub(crate) fn iter_fields_pretty<'a>( + columns: impl Iterator, schema: &Schema, -) -> Pretty<'a> { - let arr = watermark_columns +) -> Option> { + let arr = columns .map(|idx| FieldDisplay(schema.fields.get(idx).unwrap())) .map(|d| Pretty::display(&d)) - .collect(); - Pretty::Array(arr) + .collect::>(); + if arr.is_empty() { + None + } else { + Some(Pretty::Array(arr)) + } } #[derive(Clone, Copy)] From 10220edcf6fcf4411b4b77215207374b7838cb82 Mon Sep 17 00:00:00 2001 From: Noel Kwan <47273164+kwannoel@users.noreply.github.com> Date: Tue, 16 Jul 2024 18:17:26 +0800 Subject: [PATCH 19/70] refactor(meta): commit finish catalog in barrier manager (#17428) Co-authored-by: William Wen Co-authored-by: William Wen <44139337+wenym1@users.noreply.github.com> --- ci/scripts/deterministic-recovery-test.sh | 4 +- src/meta/src/barrier/command.rs | 13 +- src/meta/src/barrier/mod.rs | 35 ++- src/meta/src/barrier/notifier.rs | 28 +-- src/meta/src/barrier/progress.rs | 276 +++++++++++++--------- src/meta/src/barrier/recovery.rs | 121 +--------- src/meta/src/barrier/schedule.rs | 13 +- src/meta/src/controller/catalog.rs | 55 ++++- src/meta/src/controller/streaming_job.rs | 43 +++- src/meta/src/manager/catalog/database.rs | 21 +- src/meta/src/manager/catalog/mod.rs | 141 ++++++++++- src/meta/src/manager/metadata.rs | 64 ++++- src/meta/src/manager/streaming_job.rs | 10 + src/meta/src/rpc/ddl_controller.rs | 117 +++++---- src/meta/src/rpc/ddl_controller_v2.rs | 70 ++---- src/meta/src/stream/stream_manager.rs | 194 ++++++++------- 16 files changed, 708 insertions(+), 497 deletions(-) diff --git a/ci/scripts/deterministic-recovery-test.sh b/ci/scripts/deterministic-recovery-test.sh index 4dd2c1ec8893..1a400d4ade9e 100755 --- a/ci/scripts/deterministic-recovery-test.sh +++ b/ci/scripts/deterministic-recovery-test.sh @@ -13,7 +13,9 @@ export RUST_LOG="risingwave_meta::barrier::recovery=debug,\ risingwave_meta::manager::catalog=debug,\ risingwave_meta::rpc::ddl_controller=debug,\ risingwave_meta::barrier::mod=debug,\ -risingwave_simulation=debug" +risingwave_simulation=debug,\ +risingwave_meta::stream::stream_manager=debug,\ +risingwave_meta::barrier::progress=debug" # Extra logs you can enable if the existing trace does not give enough info. #risingwave_stream::executor::backfill=trace, diff --git a/src/meta/src/barrier/command.rs b/src/meta/src/barrier/command.rs index b5c538758529..486e314ae046 100644 --- a/src/meta/src/barrier/command.rs +++ b/src/meta/src/barrier/command.rs @@ -24,7 +24,7 @@ use risingwave_common::types::Timestamptz; use risingwave_common::util::epoch::Epoch; use risingwave_connector::source::SplitImpl; use risingwave_hummock_sdk::HummockEpoch; -use risingwave_pb::catalog::CreateType; +use risingwave_pb::catalog::{CreateType, Table}; use risingwave_pb::meta::table_fragments::PbActorStatus; use risingwave_pb::meta::PausedReason; use risingwave_pb::source::{ConnectorSplit, ConnectorSplits}; @@ -44,7 +44,7 @@ use tracing::warn; use super::info::{CommandActorChanges, CommandFragmentChanges, InflightActorInfo}; use super::trace::TracedEpoch; use crate::barrier::GlobalBarrierManagerContext; -use crate::manager::{DdlType, MetadataManager, WorkerId}; +use crate::manager::{DdlType, MetadataManager, StreamingJob, WorkerId}; use crate::model::{ActorId, DispatcherId, FragmentId, TableFragments, TableParallelism}; use crate::stream::{build_actor_connector_splits, SplitAssignment, ThrottleConfig}; use crate::MetaResult; @@ -97,6 +97,10 @@ pub struct ReplaceTablePlan { /// Note that there's no `SourceBackfillExecutor` involved for table with connector, so we don't need to worry about /// `backfill_splits`. pub init_split_assignment: SplitAssignment, + /// The `StreamingJob` info of the table to be replaced. Must be `StreamingJob::Table` + pub streaming_job: StreamingJob, + /// The temporary dummy table fragments id of new table fragment + pub dummy_id: u32, } impl ReplaceTablePlan { @@ -183,6 +187,8 @@ pub enum Command { /// for a while** until the `finish` channel is signaled, then the state of `TableFragments` /// will be set to `Created`. CreateStreamingJob { + streaming_job: StreamingJob, + internal_tables: Vec
, table_fragments: TableFragments, /// Refer to the doc on [`MetadataManager::get_upstream_root_fragments`] for the meaning of "root". upstream_root_actors: HashMap>, @@ -551,6 +557,7 @@ impl CommandContext { merge_updates, dispatchers, init_split_assignment, + .. }) = replace_table { // TODO: support in v2. @@ -1019,6 +1026,7 @@ impl CommandContext { merge_updates, dispatchers, init_split_assignment, + .. }) = replace_table { self.clean_up(old_table_fragments.actor_ids()).await?; @@ -1104,6 +1112,7 @@ impl CommandContext { merge_updates, dispatchers, init_split_assignment, + .. }) => { self.clean_up(old_table_fragments.actor_ids()).await?; diff --git a/src/meta/src/barrier/mod.rs b/src/meta/src/barrier/mod.rs index 92fc4dc31a2e..df1a34544e6c 100644 --- a/src/meta/src/barrier/mod.rs +++ b/src/meta/src/barrier/mod.rs @@ -63,7 +63,6 @@ use crate::manager::{ ActiveStreamingWorkerChange, ActiveStreamingWorkerNodes, LocalNotification, MetaSrvEnv, MetadataManager, SystemParamsManagerImpl, WorkerId, }; -use crate::model::{ActorId, TableFragments}; use crate::rpc::metrics::MetaMetrics; use crate::stream::{ScaleControllerRef, SourceManagerRef}; use crate::{MetaError, MetaResult}; @@ -88,12 +87,6 @@ pub(crate) struct TableMap { inner: HashMap, } -impl TableMap { - pub fn remove(&mut self, table_id: &TableId) -> Option { - self.inner.remove(table_id) - } -} - impl From> for TableMap { fn from(inner: HashMap) -> Self { Self { inner } @@ -106,12 +99,6 @@ impl From> for HashMap { } } -pub(crate) type TableActorMap = TableMap>; -pub(crate) type TableUpstreamMvCountMap = TableMap>; -pub(crate) type TableDefinitionMap = TableMap; -pub(crate) type TableNotifierMap = TableMap; -pub(crate) type TableFragmentMap = TableMap; - /// The reason why the cluster is recovering. enum RecoveryReason { /// After bootstrap. @@ -802,7 +789,12 @@ impl GlobalBarrierManager { } async fn failure_recovery(&mut self, err: MetaError) { - self.context.tracker.lock().await.abort_all(&err); + self.context + .tracker + .lock() + .await + .abort_all(&err, &self.context) + .await; self.checkpoint_control.clear_on_err(&err).await; self.pending_non_checkpoint_barriers.clear(); @@ -830,7 +822,12 @@ impl GlobalBarrierManager { async fn adhoc_recovery(&mut self) { let err = MetaErrorInner::AdhocRecovery.into(); - self.context.tracker.lock().await.abort_all(&err); + self.context + .tracker + .lock() + .await + .abort_all(&err, &self.context) + .await; self.checkpoint_control.clear_on_err(&err).await; if self.enable_recovery { @@ -859,7 +856,7 @@ impl GlobalBarrierManagerContext { async fn complete_barrier(self, node: EpochNode) -> MetaResult { let EpochNode { command_ctx, - mut notifiers, + notifiers, enqueue_time, state, .. @@ -877,11 +874,11 @@ impl GlobalBarrierManagerContext { } return Err(e); }; - notifiers.iter_mut().for_each(|notifier| { + notifiers.into_iter().for_each(|notifier| { notifier.notify_collected(); }); let has_remaining = self - .update_tracking_jobs(notifiers, command_ctx.clone(), create_mview_progress) + .update_tracking_jobs(command_ctx.clone(), create_mview_progress) .await?; let duration_sec = enqueue_time.stop_and_record(); self.report_complete_event(duration_sec, &command_ctx); @@ -943,7 +940,6 @@ impl GlobalBarrierManagerContext { async fn update_tracking_jobs( &self, - notifiers: Vec, command_ctx: Arc, create_mview_progress: Vec, ) -> MetaResult { @@ -960,7 +956,6 @@ impl GlobalBarrierManagerContext { if let Some(command) = tracker.add( TrackingCommand { context: command_ctx.clone(), - notifiers, }, &version_stats, ) { diff --git a/src/meta/src/barrier/notifier.rs b/src/meta/src/barrier/notifier.rs index d8df005600e4..24f13050b030 100644 --- a/src/meta/src/barrier/notifier.rs +++ b/src/meta/src/barrier/notifier.rs @@ -36,9 +36,6 @@ pub(crate) struct Notifier { /// Get notified when scheduled barrier is collected or failed. pub collected: Option>>, - - /// Get notified when scheduled barrier is finished. - pub finished: Option>>, } impl Notifier { @@ -50,8 +47,8 @@ impl Notifier { } /// Notify when we have collected a barrier from all actors. - pub fn notify_collected(&mut self) { - if let Some(tx) = self.collected.take() { + pub fn notify_collected(self) { + if let Some(tx) = self.collected { tx.send(Ok(())).ok(); } } @@ -63,31 +60,10 @@ impl Notifier { } } - /// Notify when we have finished a barrier from all actors. This function consumes `self`. - /// - /// Generally when a barrier is collected, it's also finished since it does not require further - /// report of finishing from actors. - /// However for creating MV, this is only called when all `BackfillExecutor` report it finished. - pub fn notify_finished(self) { - if let Some(tx) = self.finished { - tx.send(Ok(())).ok(); - } - } - - /// Notify when we failed to finish a barrier. This function consumes `self`. - pub fn notify_finish_failed(self, err: MetaError) { - if let Some(tx) = self.finished { - tx.send(Err(err)).ok(); - } - } - /// Notify when we failed to collect or finish a barrier. This function consumes `self`. pub fn notify_failed(self, err: MetaError) { if let Some(tx) = self.collected { tx.send(Err(err.clone())).ok(); } - if let Some(tx) = self.finished { - tx.send(Err(err)).ok(); - } } } diff --git a/src/meta/src/barrier/progress.rs b/src/meta/src/barrier/progress.rs index 746a263b0631..5fdf875486fd 100644 --- a/src/meta/src/barrier/progress.rs +++ b/src/meta/src/barrier/progress.rs @@ -13,23 +13,22 @@ // limitations under the License. use std::collections::hash_map::Entry; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::sync::Arc; use risingwave_common::catalog::TableId; use risingwave_common::util::epoch::Epoch; -use risingwave_pb::catalog::CreateType; +use risingwave_meta_model_v2::ObjectId; +use risingwave_pb::catalog::{CreateType, Table}; use risingwave_pb::ddl_service::DdlProgress; use risingwave_pb::hummock::HummockVersionStats; use risingwave_pb::stream_service::barrier_complete_response::CreateMviewProgress; use super::command::CommandContext; -use super::notifier::Notifier; -use crate::barrier::{ - Command, TableActorMap, TableDefinitionMap, TableFragmentMap, TableNotifierMap, - TableUpstreamMvCountMap, +use crate::barrier::{Command, GlobalBarrierManagerContext}; +use crate::manager::{ + DdlType, MetadataManager, MetadataManagerV1, MetadataManagerV2, StreamingJob, }; -use crate::manager::{DdlType, MetadataManager}; use crate::model::{ActorId, TableFragments}; use crate::{MetaError, MetaResult}; @@ -92,6 +91,7 @@ impl Progress { fn update(&mut self, actor: ActorId, new_state: BackfillState, upstream_total_key_count: u64) { self.upstream_total_key_count = upstream_total_key_count; let total_actors = self.states.len(); + tracing::debug!(?actor, states = ?self.states, "update progress for actor"); match self.states.remove(&actor).unwrap() { BackfillState::Init => {} BackfillState::ConsumingUpstream(_, old_consumed_rows) => { @@ -155,23 +155,17 @@ impl Progress { /// On recovery, the barrier manager will recover and start managing the job. pub enum TrackingJob { New(TrackingCommand), - Recovered(RecoveredTrackingJob), + RecoveredV1(RecoveredTrackingJobV1), + RecoveredV2(RecoveredTrackingJobV2), } impl TrackingJob { - fn metadata_manager(&self) -> &MetadataManager { - match self { - TrackingJob::New(command) => command.context.metadata_manager(), - TrackingJob::Recovered(recovered) => &recovered.metadata_manager, - } - } - /// Returns whether the `TrackingJob` requires a checkpoint to complete. pub(crate) fn is_checkpoint_required(&self) -> bool { match self { // Recovered tracking job is always a streaming job, // It requires a checkpoint to complete. - TrackingJob::Recovered(_) => true, + TrackingJob::RecoveredV1(_) | TrackingJob::RecoveredV2(_) => true, TrackingJob::New(command) => { command.context.kind.is_initial() || command.context.kind.is_checkpoint() } @@ -179,54 +173,55 @@ impl TrackingJob { } pub(crate) async fn pre_finish(&self) -> MetaResult<()> { - let table_fragments = match &self { + match &self { TrackingJob::New(command) => match &command.context.command { Command::CreateStreamingJob { - table_fragments, .. - } => Some(table_fragments), - _ => None, + table_fragments, + streaming_job, + internal_tables, + replace_table, + .. + } => match command.context.metadata_manager() { + MetadataManager::V1(mgr) => { + mgr.fragment_manager + .mark_table_fragments_created(table_fragments.table_id()) + .await?; + mgr.catalog_manager + .finish_stream_job(streaming_job.clone(), internal_tables.clone()) + .await?; + Ok(()) + } + MetadataManager::V2(mgr) => { + mgr.catalog_controller + .finish_streaming_job(streaming_job.id() as i32, replace_table.clone()) + .await?; + Ok(()) + } + }, + _ => Ok(()), }, - TrackingJob::Recovered(recovered) => Some(&recovered.fragments), - }; - // Update the state of the table fragments from `Creating` to `Created`, so that the - // fragments can be scaled. - if let Some(table_fragments) = table_fragments { - match self.metadata_manager() { - MetadataManager::V1(mgr) => { - mgr.fragment_manager - .mark_table_fragments_created(table_fragments.table_id()) - .await?; - } - MetadataManager::V2(_) => {} - } - } - Ok(()) - } - - pub(crate) fn notify_finished(self) { - match self { - TrackingJob::New(command) => { - command - .notifiers - .into_iter() - .for_each(Notifier::notify_finished); - } - TrackingJob::Recovered(recovered) => { - recovered.finished.notify_finished(); - } - } - } - - pub(crate) fn notify_finish_failed(self, err: MetaError) { - match self { - TrackingJob::New(command) => { - command - .notifiers - .into_iter() - .for_each(|n| n.notify_finish_failed(err.clone())); + TrackingJob::RecoveredV1(recovered) => { + let manager = &recovered.metadata_manager; + manager + .fragment_manager + .mark_table_fragments_created(recovered.fragments.table_id()) + .await?; + manager + .catalog_manager + .finish_stream_job( + recovered.streaming_job.clone(), + recovered.internal_tables.clone(), + ) + .await?; + Ok(()) } - TrackingJob::Recovered(recovered) => { - recovered.finished.notify_finish_failed(err); + TrackingJob::RecoveredV2(recovered) => { + recovered + .metadata_manager + .catalog_controller + .finish_streaming_job(recovered.id, None) + .await?; + Ok(()) } } } @@ -234,7 +229,8 @@ impl TrackingJob { pub(crate) fn table_to_create(&self) -> Option { match self { TrackingJob::New(command) => command.context.table_to_create(), - TrackingJob::Recovered(recovered) => Some(recovered.fragments.table_id()), + TrackingJob::RecoveredV1(recovered) => Some(recovered.fragments.table_id()), + TrackingJob::RecoveredV2(recovered) => Some((recovered.id as u32).into()), } } } @@ -247,35 +243,38 @@ impl std::fmt::Debug for TrackingJob { "TrackingJob::New({:?})", command.context.table_to_create() ), - TrackingJob::Recovered(recovered) => { + TrackingJob::RecoveredV1(recovered) => { write!( f, - "TrackingJob::Recovered({:?})", + "TrackingJob::RecoveredV1({:?})", recovered.fragments.table_id() ) } + TrackingJob::RecoveredV2(recovered) => { + write!(f, "TrackingJob::RecoveredV2({:?})", recovered.id) + } } } } -pub struct RecoveredTrackingJob { +pub struct RecoveredTrackingJobV1 { pub fragments: TableFragments, - pub finished: Notifier, - pub metadata_manager: MetadataManager, + pub streaming_job: StreamingJob, + pub internal_tables: Vec
, + pub metadata_manager: MetadataManagerV1, +} + +pub struct RecoveredTrackingJobV2 { + pub id: ObjectId, + pub metadata_manager: MetadataManagerV2, } /// The command tracking by the [`CreateMviewProgressTracker`]. pub(super) struct TrackingCommand { /// The context of the command. pub context: Arc, - - /// Should be called when the command is finished. - pub notifiers: Vec, } -/// Track the progress of all creating mviews. When creation is done, `notify_finished` will be -/// called on registered notifiers. -/// /// Tracking is done as follows: /// 1. We identify a `StreamJob` by its `TableId` of its `Materialized` table. /// 2. For each stream job, there are several actors which run its tasks. @@ -300,48 +299,73 @@ impl CreateMviewProgressTracker { /// Other state are persisted by the `BackfillExecutor`, such as: /// 1. `CreateMviewProgress`. /// 2. `Backfill` position. - pub fn recover( - table_map: TableActorMap, - mut upstream_mv_counts: TableUpstreamMvCountMap, - mut definitions: TableDefinitionMap, + pub fn recover_v1( version_stats: HummockVersionStats, - mut finished_notifiers: TableNotifierMap, - mut table_fragment_map: TableFragmentMap, - metadata_manager: MetadataManager, + mviews: HashMap< + TableId, + ( + TableFragments, + Table, // mview table + Vec
, // internal tables + ), + >, + metadata_manager: MetadataManagerV1, ) -> Self { let mut actor_map = HashMap::new(); let mut progress_map = HashMap::new(); - let table_map: HashMap<_, HashSet> = table_map.into(); - for (creating_table_id, actors) in table_map { - // 1. Recover `BackfillState` in the tracker. + for (creating_table_id, (table_fragments, mview, internal_tables)) in mviews { + let actors = table_fragments.backfill_actor_ids(); let mut states = HashMap::new(); + tracing::debug!(?actors, ?creating_table_id, "recover progress for actors"); for actor in actors { actor_map.insert(actor, creating_table_id); states.insert(actor, BackfillState::ConsumingUpstream(Epoch(0), 0)); } - let upstream_mv_count = upstream_mv_counts.remove(&creating_table_id).unwrap(); - let upstream_total_key_count = upstream_mv_count - .iter() - .map(|(upstream_mv, count)| { - *count as u64 - * version_stats - .table_stats - .get(&upstream_mv.table_id) - .map_or(0, |stat| stat.total_key_count as u64) - }) - .sum(); - let definition = definitions.remove(&creating_table_id).unwrap(); - let progress = Progress { + + let progress = Self::recover_progress( states, - done_count: 0, // Fill only after first barrier pass - upstream_mv_count, - upstream_total_key_count, - consumed_rows: 0, // Fill only after first barrier pass + table_fragments.dependent_table_ids(), + mview.definition.clone(), + &version_stats, + ); + let tracking_job = TrackingJob::RecoveredV1(RecoveredTrackingJobV1 { + fragments: table_fragments, + metadata_manager: metadata_manager.clone(), + internal_tables, + streaming_job: StreamingJob::MaterializedView(mview), + }); + progress_map.insert(creating_table_id, (progress, tracking_job)); + } + Self { + progress_map, + actor_map, + finished_jobs: Vec::new(), + } + } + + pub fn recover_v2( + mview_map: HashMap, + version_stats: HummockVersionStats, + metadata_manager: MetadataManagerV2, + ) -> Self { + let mut actor_map = HashMap::new(); + let mut progress_map = HashMap::new(); + for (creating_table_id, (definition, table_fragments)) in mview_map { + let mut states = HashMap::new(); + let actors = table_fragments.backfill_actor_ids(); + for actor in actors { + actor_map.insert(actor, creating_table_id); + states.insert(actor, BackfillState::ConsumingUpstream(Epoch(0), 0)); + } + + let progress = Self::recover_progress( + states, + table_fragments.dependent_table_ids(), definition, - }; - let tracking_job = TrackingJob::Recovered(RecoveredTrackingJob { - fragments: table_fragment_map.remove(&creating_table_id).unwrap(), - finished: finished_notifiers.remove(&creating_table_id).unwrap(), + &version_stats, + ); + let tracking_job = TrackingJob::RecoveredV2(RecoveredTrackingJobV2 { + id: creating_table_id.table_id as i32, metadata_manager: metadata_manager.clone(), }); progress_map.insert(creating_table_id, (progress, tracking_job)); @@ -353,6 +377,32 @@ impl CreateMviewProgressTracker { } } + fn recover_progress( + states: HashMap, + upstream_mv_count: HashMap, + definition: String, + version_stats: &HummockVersionStats, + ) -> Progress { + let upstream_total_key_count = upstream_mv_count + .iter() + .map(|(upstream_mv, count)| { + *count as u64 + * version_stats + .table_stats + .get(&upstream_mv.table_id) + .map_or(0, |stat| stat.total_key_count as u64) + }) + .sum(); + Progress { + states, + done_count: 0, // Fill only after first barrier pass + upstream_mv_count, + upstream_total_key_count, + consumed_rows: 0, // Fill only after first barrier pass + definition, + } + } + pub fn new() -> Self { Self { progress_map: Default::default(), @@ -394,7 +444,6 @@ impl CreateMviewProgressTracker { { // The command is ready to finish. We can now call `pre_finish`. job.pre_finish().await?; - job.notify_finished(); } Ok(!self.finished_jobs.is_empty()) } @@ -407,14 +456,18 @@ impl CreateMviewProgressTracker { } /// Notify all tracked commands that error encountered and clear them. - pub fn abort_all(&mut self, err: &MetaError) { + pub async fn abort_all(&mut self, err: &MetaError, context: &GlobalBarrierManagerContext) { self.actor_map.clear(); - self.finished_jobs.drain(..).for_each(|job| { - job.notify_finish_failed(err.clone()); - }); - self.progress_map - .drain() - .for_each(|(_, (_, job))| job.notify_finish_failed(err.clone())); + self.finished_jobs.clear(); + self.progress_map.clear(); + match &context.metadata_manager { + MetadataManager::V1(mgr) => { + mgr.notify_finish_failed(err).await; + } + MetadataManager::V2(mgr) => { + mgr.notify_finish_failed(err).await; + } + } } /// Add a new create-mview DDL command to track. @@ -554,6 +607,7 @@ impl CreateMviewProgressTracker { }) .sum(); + tracing::debug!(?table_id, "updating progress for table"); progress.update(actor, new_state, upstream_total_key_count); if progress.is_done() { diff --git a/src/meta/src/barrier/recovery.rs b/src/meta/src/barrier/recovery.rs index 0ead9779e914..4bb9d2f669c0 100644 --- a/src/meta/src/barrier/recovery.rs +++ b/src/meta/src/barrier/recovery.rs @@ -28,14 +28,12 @@ use risingwave_pb::meta::{PausedReason, Recovery}; use risingwave_pb::stream_plan::barrier_mutation::Mutation; use risingwave_pb::stream_plan::AddMutation; use thiserror_ext::AsReport; -use tokio::sync::oneshot; use tokio_retry::strategy::{jitter, ExponentialBackoff}; use tracing::{debug, warn, Instrument}; use super::TracedEpoch; use crate::barrier::command::CommandContext; use crate::barrier::info::InflightActorInfo; -use crate::barrier::notifier::Notifier; use crate::barrier::progress::CreateMviewProgressTracker; use crate::barrier::rpc::ControlStreamManager; use crate::barrier::schedule::ScheduledBarriers; @@ -131,12 +129,7 @@ impl GlobalBarrierManagerContext { let mgr = self.metadata_manager.as_v1_ref(); let mviews = mgr.catalog_manager.list_creating_background_mvs().await; - let mut mview_definitions = HashMap::new(); - let mut table_map = HashMap::new(); - let mut table_fragment_map = HashMap::new(); - let mut upstream_mv_counts = HashMap::new(); - let mut senders = HashMap::new(); - let mut receivers = Vec::new(); + let mut table_mview_map = HashMap::new(); for mview in mviews { let table_id = TableId::new(mview.id); let fragments = mgr @@ -152,19 +145,7 @@ impl GlobalBarrierManagerContext { .await?; tracing::debug!("notified frontend for stream job {}", table_id.table_id); } else { - table_map.insert(table_id, fragments.backfill_actor_ids()); - mview_definitions.insert(table_id, mview.definition.clone()); - upstream_mv_counts.insert(table_id, fragments.dependent_table_ids()); - table_fragment_map.insert(table_id, fragments); - let (finished_tx, finished_rx) = oneshot::channel(); - senders.insert( - table_id, - Notifier { - finished: Some(finished_tx), - ..Default::default() - }, - ); - receivers.push((mview, internal_tables, finished_rx)); + table_mview_map.insert(table_id, (fragments, mview, internal_tables)); } } @@ -172,45 +153,8 @@ impl GlobalBarrierManagerContext { // If failed, enter recovery mode. { let mut tracker = self.tracker.lock().await; - *tracker = CreateMviewProgressTracker::recover( - table_map.into(), - upstream_mv_counts.into(), - mview_definitions.into(), - version_stats, - senders.into(), - table_fragment_map.into(), - self.metadata_manager.clone(), - ); - } - for (table, internal_tables, finished) in receivers { - let catalog_manager = mgr.catalog_manager.clone(); - tokio::spawn(async move { - let res: MetaResult<()> = try { - tracing::debug!("recovering stream job {}", table.id); - finished.await.ok().context("failed to finish command")??; - - tracing::debug!("finished stream job {}", table.id); - // Once notified that job is finished we need to notify frontend. - // and mark catalog as created and commit to meta. - // both of these are done by catalog manager. - catalog_manager - .finish_create_materialized_view_procedure(internal_tables, table.clone()) - .await?; - tracing::debug!("notified frontend for stream job {}", table.id); - }; - if let Err(e) = res.as_ref() { - tracing::error!( - id = table.id, - error = %e.as_report(), - "stream job interrupted, will retry after recovery", - ); - // NOTE(kwannoel): We should not cleanup stream jobs, - // we don't know if it's just due to CN killed, - // or the job has actually failed. - // Users have to manually cancel the stream jobs, - // if they want to clean it. - } - }); + *tracker = + CreateMviewProgressTracker::recover_v1(version_stats, table_mview_map, mgr.clone()); } Ok(()) } @@ -222,73 +166,24 @@ impl GlobalBarrierManagerContext { .list_background_creating_mviews() .await?; - let mut senders = HashMap::new(); - let mut receivers = Vec::new(); - let mut table_fragment_map = HashMap::new(); - let mut mview_definitions = HashMap::new(); - let mut table_map = HashMap::new(); - let mut upstream_mv_counts = HashMap::new(); + let mut mview_map = HashMap::new(); for mview in &mviews { - let (finished_tx, finished_rx) = oneshot::channel(); let table_id = TableId::new(mview.table_id as _); - senders.insert( - table_id, - Notifier { - finished: Some(finished_tx), - ..Default::default() - }, - ); - let table_fragments = mgr .catalog_controller .get_job_fragments_by_id(mview.table_id) .await?; let table_fragments = TableFragments::from_protobuf(table_fragments); - upstream_mv_counts.insert(table_id, table_fragments.dependent_table_ids()); - table_map.insert(table_id, table_fragments.backfill_actor_ids()); - table_fragment_map.insert(table_id, table_fragments); - mview_definitions.insert(table_id, mview.definition.clone()); - receivers.push((mview.table_id, finished_rx)); + mview_map.insert(table_id, (mview.definition.clone(), table_fragments)); } let version_stats = self.hummock_manager.get_version_stats().await; // If failed, enter recovery mode. { let mut tracker = self.tracker.lock().await; - *tracker = CreateMviewProgressTracker::recover( - table_map.into(), - upstream_mv_counts.into(), - mview_definitions.into(), - version_stats, - senders.into(), - table_fragment_map.into(), - self.metadata_manager.clone(), - ); - } - for (id, finished) in receivers { - let catalog_controller = mgr.catalog_controller.clone(); - tokio::spawn(async move { - let res: MetaResult<()> = try { - tracing::debug!("recovering stream job {}", id); - finished.await.ok().context("failed to finish command")??; - tracing::debug!(id, "finished stream job"); - catalog_controller.finish_streaming_job(id, None).await?; - }; - if let Err(e) = &res { - tracing::error!( - id, - error = %e.as_report(), - "stream job interrupted, will retry after recovery", - ); - // NOTE(kwannoel): We should not cleanup stream jobs, - // we don't know if it's just due to CN killed, - // or the job has actually failed. - // Users have to manually cancel the stream jobs, - // if they want to clean it. - } - }); + *tracker = + CreateMviewProgressTracker::recover_v2(mview_map, version_stats, mgr.clone()); } - Ok(()) } diff --git a/src/meta/src/barrier/schedule.rs b/src/meta/src/barrier/schedule.rs index 50034811a339..7af683c18fc7 100644 --- a/src/meta/src/barrier/schedule.rs +++ b/src/meta/src/barrier/schedule.rs @@ -259,16 +259,14 @@ impl BarrierScheduler { for command in commands { let (started_tx, started_rx) = oneshot::channel(); let (collect_tx, collect_rx) = oneshot::channel(); - let (finish_tx, finish_rx) = oneshot::channel(); - contexts.push((started_rx, collect_rx, finish_rx)); + contexts.push((started_rx, collect_rx)); scheduleds.push(self.inner.new_scheduled( command.need_checkpoint(), command, once(Notifier { started: Some(started_tx), collected: Some(collect_tx), - finished: Some(finish_tx), }), )); } @@ -277,7 +275,7 @@ impl BarrierScheduler { let mut infos = Vec::with_capacity(contexts.len()); - for (injected_rx, collect_rx, finish_rx) in contexts { + for (injected_rx, collect_rx) in contexts { // Wait for this command to be injected, and record the result. tracing::trace!("waiting for injected_rx"); let info = injected_rx.await.ok().context("failed to inject barrier")?; @@ -289,10 +287,6 @@ impl BarrierScheduler { .await .ok() .context("failed to collect barrier")??; - - tracing::trace!("waiting for finish_rx"); - // Wait for this command to be finished. - finish_rx.await.ok().context("failed to finish command")??; } Ok(infos) @@ -448,9 +442,8 @@ impl ScheduledBarriers { unreachable!("only drop and cancel streaming jobs should be buffered"); } } - notifiers.into_iter().for_each(|mut notify| { + notifiers.into_iter().for_each(|notify| { notify.notify_collected(); - notify.notify_finished(); }); } (dropped_actors, cancel_table_ids) diff --git a/src/meta/src/controller/catalog.rs b/src/meta/src/controller/catalog.rs index 6fc372197285..6fbf7cb0ac04 100644 --- a/src/meta/src/controller/catalog.rs +++ b/src/meta/src/controller/catalog.rs @@ -14,6 +14,7 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use std::iter; +use std::mem::take; use std::sync::Arc; use anyhow::anyhow; @@ -57,7 +58,8 @@ use sea_orm::{ IntoActiveModel, JoinType, PaginatorTrait, QueryFilter, QuerySelect, RelationTrait, TransactionTrait, Value, }; -use tokio::sync::{RwLock, RwLockReadGuard}; +use tokio::sync::oneshot::Sender; +use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use super::utils::{check_subscription_name_duplicate, get_fragment_ids_by_jobs}; use crate::controller::rename::{alter_relation_rename, alter_relation_rename_refs}; @@ -108,6 +110,7 @@ impl CatalogController { env, inner: RwLock::new(CatalogControllerInner { db: meta_store.conn, + creating_table_finish_notifier: HashMap::new(), }), } } @@ -117,10 +120,20 @@ impl CatalogController { pub async fn get_inner_read_guard(&self) -> RwLockReadGuard<'_, CatalogControllerInner> { self.inner.read().await } + + pub async fn get_inner_write_guard(&self) -> RwLockWriteGuard<'_, CatalogControllerInner> { + self.inner.write().await + } } pub struct CatalogControllerInner { pub(crate) db: DatabaseConnection, + /// Registered finish notifiers for creating tables. + /// + /// `DdlController` will update this map, and pass the `tx` side to `CatalogController`. + /// On notifying, we can remove the entry from this map. + pub creating_table_finish_notifier: + HashMap>>>, } impl CatalogController { @@ -145,6 +158,10 @@ impl CatalogController { .notify_frontend_relation_info(operation, relation_info) .await } + + pub(crate) async fn current_notification_version(&self) -> NotificationVersion { + self.env.notification_manager().current_version().await + } } impl CatalogController { @@ -3160,6 +3177,42 @@ impl CatalogControllerInner { .map(|(func, obj)| ObjectModel(func, obj.unwrap()).into()) .collect()) } + + pub(crate) fn register_finish_notifier( + &mut self, + id: i32, + sender: Sender>, + ) { + self.creating_table_finish_notifier + .entry(id) + .or_default() + .push(sender); + } + + pub(crate) async fn streaming_job_is_finished(&mut self, id: i32) -> MetaResult { + let status = StreamingJob::find() + .select_only() + .column(streaming_job::Column::JobStatus) + .filter(streaming_job::Column::JobId.eq(id)) + .into_tuple::() + .one(&self.db) + .await?; + + status + .map(|status| status == JobStatus::Created) + .ok_or_else(|| { + MetaError::catalog_id_not_found("streaming job", "may have been cancelled/dropped") + }) + } + + pub(crate) fn notify_finish_failed(&mut self, err: &MetaError) { + for tx in take(&mut self.creating_table_finish_notifier) + .into_values() + .flatten() + { + let _ = tx.send(Err(err.clone())); + } + } } #[cfg(test)] diff --git a/src/meta/src/controller/streaming_job.rs b/src/meta/src/controller/streaming_job.rs index 802e6dabbb90..2c1332fe65cc 100644 --- a/src/meta/src/controller/streaming_job.rs +++ b/src/meta/src/controller/streaming_job.rs @@ -49,7 +49,7 @@ use risingwave_pb::meta::{ use risingwave_pb::source::{PbConnectorSplit, PbConnectorSplits}; use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism; use risingwave_pb::stream_plan::stream_node::PbNodeBody; -use risingwave_pb::stream_plan::update_mutation::{MergeUpdate, PbMergeUpdate}; +use risingwave_pb::stream_plan::update_mutation::PbMergeUpdate; use risingwave_pb::stream_plan::{ PbDispatcher, PbDispatcherType, PbFragmentTypeFlag, PbStreamActor, }; @@ -61,7 +61,7 @@ use sea_orm::{ TransactionTrait, }; -use crate::barrier::Reschedule; +use crate::barrier::{ReplaceTablePlan, Reschedule}; use crate::controller::catalog::CatalogController; use crate::controller::rename::ReplaceTableExprRewriter; use crate::controller::utils::{ @@ -418,7 +418,7 @@ impl CatalogController { job_id: ObjectId, is_cancelled: bool, ) -> MetaResult { - let inner = self.inner.write().await; + let mut inner = self.inner.write().await; let txn = inner.db.begin().await?; let cnt = Object::find_by_id(job_id).count(&txn).await?; @@ -473,6 +473,23 @@ impl CatalogController { if let Some(source_id) = associated_source_id { Object::delete_by_id(source_id).exec(&txn).await?; } + + for tx in inner + .creating_table_finish_notifier + .remove(&job_id) + .into_iter() + .flatten() + { + let err = if is_cancelled { + MetaError::cancelled(format!("stremaing job {job_id} is cancelled")) + } else { + MetaError::catalog_id_not_found( + "stream job", + format!("streaming job {job_id} failed"), + ) + }; + let _ = tx.send(Err(err)); + } txn.commit().await?; Ok(true) @@ -625,9 +642,9 @@ impl CatalogController { pub async fn finish_streaming_job( &self, job_id: ObjectId, - replace_table_job_info: Option<(crate::manager::StreamingJob, Vec, u32)>, - ) -> MetaResult { - let inner = self.inner.write().await; + replace_table_job_info: Option, + ) -> MetaResult<()> { + let mut inner = self.inner.write().await; let txn = inner.db.begin().await?; let job_type = Object::find_by_id(job_id) @@ -756,7 +773,12 @@ impl CatalogController { let fragment_mapping = get_fragment_mappings(&txn, job_id).await?; let replace_table_mapping_update = match replace_table_job_info { - Some((streaming_job, merge_updates, dummy_id)) => { + Some(ReplaceTablePlan { + streaming_job, + merge_updates, + dummy_id, + .. + }) => { let incoming_sink_id = job_id; let (relations, fragment_mapping) = Self::finish_replace_streaming_job_inner( @@ -797,8 +819,13 @@ impl CatalogController { ) .await; } + if let Some(txs) = inner.creating_table_finish_notifier.remove(&job_id) { + for tx in txs { + let _ = tx.send(Ok(version)); + } + } - Ok(version) + Ok(()) } pub async fn finish_replace_streaming_job( diff --git a/src/meta/src/manager/catalog/database.rs b/src/meta/src/manager/catalog/database.rs index 299065fe9586..e3b2e0d02cae 100644 --- a/src/meta/src/manager/catalog/database.rs +++ b/src/meta/src/manager/catalog/database.rs @@ -26,13 +26,14 @@ use risingwave_pb::catalog::{ }; use risingwave_pb::data::DataType; use risingwave_pb::user::grant_privilege::PbObject; +use tokio::sync::oneshot::Sender; use super::utils::{get_refed_secret_ids_from_sink, get_refed_secret_ids_from_source}; use super::{ ConnectionId, DatabaseId, FunctionId, RelationId, SchemaId, SecretId, SinkId, SourceId, SubscriptionId, ViewId, }; -use crate::manager::{IndexId, MetaSrvEnv, TableId, UserId}; +use crate::manager::{IndexId, MetaSrvEnv, NotificationVersion, TableId, UserId}; use crate::model::MetadataModel; use crate::{MetaError, MetaResult}; @@ -95,6 +96,13 @@ pub struct DatabaseManager { pub(super) in_progress_creation_streaming_job: HashMap, // In-progress creating tables, including internal tables. pub(super) in_progress_creating_tables: HashMap, + + /// Registered finish notifiers for creating tables. + /// + /// `DdlController` will update this map, and pass the `tx` side to `CatalogController`. + /// On notifying, we can remove the entry from this map. + pub creating_table_finish_notifier: + HashMap>>>, } impl DatabaseManager { @@ -187,6 +195,7 @@ impl DatabaseManager { in_progress_creation_tracker: HashSet::default(), in_progress_creation_streaming_job: HashMap::default(), in_progress_creating_tables: HashMap::default(), + creating_table_finish_notifier: Default::default(), }) } @@ -567,6 +576,16 @@ impl DatabaseManager { pub fn unmark_creating_streaming_job(&mut self, table_id: TableId) { self.in_progress_creation_streaming_job.remove(&table_id); + for tx in self + .creating_table_finish_notifier + .remove(&table_id) + .into_iter() + .flatten() + { + let _ = tx.send(Err(MetaError::cancelled(format!( + "streaing_job {table_id} has been cancelled" + )))); + } } pub fn find_creating_streaming_job_id(&self, key: &RelationKey) -> Option { diff --git a/src/meta/src/manager/catalog/mod.rs b/src/meta/src/manager/catalog/mod.rs index b1afc612bcea..02cc9ee8de0b 100644 --- a/src/meta/src/manager/catalog/mod.rs +++ b/src/meta/src/manager/catalog/mod.rs @@ -19,6 +19,7 @@ mod utils; use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; use std::iter; +use std::mem::take; use std::sync::Arc; use anyhow::{anyhow, Context}; @@ -44,6 +45,7 @@ use risingwave_pb::meta::subscribe_response::{Info, Operation}; use risingwave_pb::user::grant_privilege::{Action, ActionWithGrantOption, Object}; use risingwave_pb::user::update_user_request::UpdateField; use risingwave_pb::user::{GrantPrivilege, UserInfo}; +use tokio::sync::oneshot::Sender; use tokio::sync::{Mutex, MutexGuard}; use user::*; @@ -165,6 +167,89 @@ impl CatalogManagerCore { let user = UserManager::new(env.clone(), &database).await?; Ok(Self { database, user }) } + + pub(crate) fn register_finish_notifier( + &mut self, + id: TableId, + sender: Sender>, + ) { + self.database + .creating_table_finish_notifier + .entry(id) + .or_default() + .push(sender); + } + + pub(crate) fn streaming_job_is_finished(&mut self, job: &StreamingJob) -> MetaResult { + fn gen_err(job: &StreamingJob, name: &String) -> MetaError { + MetaError::catalog_id_not_found( + job.job_type_str(), + format!("{} may have been dropped/cancelled", name), + ) + } + let (job_status, name) = match job { + StreamingJob::MaterializedView(table) | StreamingJob::Table(_, table, _) => ( + self.database + .tables + .get(&table.id) + .map(|table| table.stream_job_status), + &table.name, + ), + StreamingJob::Sink(sink, _) => ( + self.database + .sinks + .get(&sink.id) + .map(|sink| sink.stream_job_status), + &sink.name, + ), + StreamingJob::Index(index, _) => ( + self.database + .indexes + .get(&index.id) + .map(|index| index.stream_job_status), + &index.name, + ), + StreamingJob::Source(source) => { + return Ok(self.database.sources.contains_key(&source.id)); + } + }; + + job_status + .map(|status| status == StreamJobStatus::Created as i32) + .or_else(|| { + if self + .database + .in_progress_creation_streaming_job + .contains_key(&job.id()) + { + Some(false) + } else { + None + } + }) + .ok_or_else(|| gen_err(job, name)) + } + + pub(crate) fn notify_finish(&mut self, id: TableId, version: NotificationVersion) { + for tx in self + .database + .creating_table_finish_notifier + .remove(&id) + .into_iter() + .flatten() + { + let _ = tx.send(Ok(version)); + } + } + + pub(crate) fn notify_finish_failed(&mut self, err: &MetaError) { + for tx in take(&mut self.database.creating_table_finish_notifier) + .into_values() + .flatten() + { + let _ = tx.send(Err(err.clone())); + } + } } impl CatalogManager { @@ -182,6 +267,10 @@ impl CatalogManager { Ok(()) } + pub async fn current_notification_version(&self) -> NotificationVersion { + self.env.notification_manager().current_version().await + } + pub async fn get_catalog_core_guard(&self) -> MutexGuard<'_, CatalogManagerCore> { self.core.lock().await } @@ -1173,18 +1262,21 @@ impl CatalogManager { &self, mut stream_job: StreamingJob, internal_tables: Vec
, - ) -> MetaResult { + ) -> MetaResult<()> { // 1. finish procedure. let mut creating_internal_table_ids = internal_tables.iter().map(|t| t.id).collect_vec(); // Update the corresponding 'created_at' field. stream_job.mark_created(); - let version = match stream_job { + let (version, table_id) = match stream_job { StreamingJob::MaterializedView(table) => { creating_internal_table_ids.push(table.id); - self.finish_create_materialized_view_procedure(internal_tables, table) - .await? + let table_id = table.id; + let version = self + .finish_create_materialized_view_procedure(internal_tables, table) + .await?; + (version, table_id) } StreamingJob::Sink(sink, target_table) => { let sink_id = sink.id; @@ -1199,26 +1291,34 @@ impl CatalogManager { .await?; } - version + (version, sink_id) } StreamingJob::Table(source, table, ..) => { creating_internal_table_ids.push(table.id); - if let Some(source) = source { + let table_id = table.id; + let version = if let Some(source) = source { self.finish_create_table_procedure_with_source(source, table, internal_tables) .await? } else { self.finish_create_table_procedure(internal_tables, table) .await? - } + }; + (version, table_id) } StreamingJob::Index(index, table) => { creating_internal_table_ids.push(table.id); - self.finish_create_index_procedure(internal_tables, index, table) - .await? + let table_id = table.id; + let version = self + .finish_create_index_procedure(internal_tables, index, table) + .await?; + (version, table_id) } StreamingJob::Source(source) => { - self.finish_create_source_procedure(source, internal_tables) - .await? + let table_id = source.id; + let version = self + .finish_create_source_procedure(source, internal_tables) + .await?; + (version, table_id) } }; @@ -1226,7 +1326,10 @@ impl CatalogManager { self.unmark_creating_tables(&creating_internal_table_ids, false) .await; - Ok(version) + // 3. notify create streaming job finish + self.core.lock().await.notify_finish(table_id, version); + + Ok(()) } /// This is used for `CREATE TABLE`. @@ -1384,6 +1487,18 @@ impl CatalogManager { } } + for tx in core + .database + .creating_table_finish_notifier + .remove(&table_id) + .into_iter() + .flatten() + { + let _ = tx.send(Err(MetaError::cancelled(format!( + "materialized view {table_id} has been cancelled" + )))); + } + // FIXME(kwannoel): Propagate version to fe let _version = self .notify_frontend( @@ -3627,7 +3742,7 @@ impl CatalogManager { } } - /// This is used for `ALTER TABLE ADD/DROP COLUMN`. + /// This is used for `ALTER TABLE ADD/DROP COLUMN` and `SINK INTO TABLE`. pub async fn finish_replace_table_procedure( &self, source: &Option, diff --git a/src/meta/src/manager/metadata.rs b/src/meta/src/manager/metadata.rs index 76e266e2ddd2..ecd9d4971d2b 100644 --- a/src/meta/src/manager/metadata.rs +++ b/src/meta/src/manager/metadata.rs @@ -16,9 +16,10 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::pin::pin; use std::time::Duration; +use anyhow::anyhow; use futures::future::{select, Either}; use risingwave_common::catalog::{TableId, TableOption}; -use risingwave_meta_model_v2::SourceId; +use risingwave_meta_model_v2::{ObjectId, SourceId}; use risingwave_pb::catalog::{PbSource, PbTable}; use risingwave_pb::common::worker_node::{PbResource, State}; use risingwave_pb::common::{HostAddress, PbWorkerNode, PbWorkerType, WorkerNode, WorkerType}; @@ -27,6 +28,7 @@ use risingwave_pb::meta::table_fragments::{ActorStatus, Fragment, PbFragment}; use risingwave_pb::stream_plan::{PbDispatchStrategy, StreamActor}; use risingwave_pb::stream_service::BuildActorInfo; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; +use tokio::sync::oneshot; use tokio::time::sleep; use tracing::warn; @@ -35,14 +37,14 @@ use crate::controller::catalog::CatalogControllerRef; use crate::controller::cluster::{ClusterControllerRef, WorkerExtraInfo}; use crate::manager::{ CatalogManagerRef, ClusterManagerRef, FragmentManagerRef, LocalNotification, - StreamingClusterInfo, WorkerId, + NotificationVersion, StreamingClusterInfo, StreamingJob, WorkerId, }; use crate::model::{ ActorId, ClusterId, FragmentId, MetadataModel, TableFragments, TableParallelism, }; use crate::stream::{to_build_actor_info, SplitAssignment}; use crate::telemetry::MetaTelemetryJobDesc; -use crate::MetaResult; +use crate::{MetaError, MetaResult}; #[derive(Clone)] pub enum MetadataManager { @@ -842,3 +844,59 @@ impl MetadataManager { } } } + +impl MetadataManager { + pub(crate) async fn wait_streaming_job_finished( + &self, + job: &StreamingJob, + ) -> MetaResult { + match self { + MetadataManager::V1(mgr) => mgr.wait_streaming_job_finished(job).await, + MetadataManager::V2(mgr) => mgr.wait_streaming_job_finished(job.id() as _).await, + } + } +} + +impl MetadataManagerV2 { + pub(crate) async fn wait_streaming_job_finished( + &self, + id: ObjectId, + ) -> MetaResult { + let mut mgr = self.catalog_controller.get_inner_write_guard().await; + if mgr.streaming_job_is_finished(id).await? { + return Ok(self.catalog_controller.current_notification_version().await); + } + let (tx, rx) = oneshot::channel(); + + mgr.register_finish_notifier(id, tx); + drop(mgr); + rx.await.map_err(|e| anyhow!(e))? + } + + pub(crate) async fn notify_finish_failed(&self, err: &MetaError) { + let mut mgr = self.catalog_controller.get_inner_write_guard().await; + mgr.notify_finish_failed(err); + } +} + +impl MetadataManagerV1 { + pub(crate) async fn wait_streaming_job_finished( + &self, + job: &StreamingJob, + ) -> MetaResult { + let mut mgr = self.catalog_manager.get_catalog_core_guard().await; + if mgr.streaming_job_is_finished(job)? { + return Ok(self.catalog_manager.current_notification_version().await); + } + let (tx, rx) = oneshot::channel(); + + mgr.register_finish_notifier(job.id(), tx); + drop(mgr); + rx.await.map_err(|e| anyhow!(e))? + } + + pub(crate) async fn notify_finish_failed(&self, err: &MetaError) { + let mut mgr = self.catalog_manager.get_catalog_core_guard().await; + mgr.notify_finish_failed(err); + } +} diff --git a/src/meta/src/manager/streaming_job.rs b/src/meta/src/manager/streaming_job.rs index 67150ce8351d..5b6826de5421 100644 --- a/src/meta/src/manager/streaming_job.rs +++ b/src/meta/src/manager/streaming_job.rs @@ -241,6 +241,16 @@ impl StreamingJob { } } + pub fn job_type_str(&self) -> &'static str { + match self { + StreamingJob::MaterializedView(_) => "materialized view", + StreamingJob::Sink(_, _) => "sink", + StreamingJob::Table(_, _, _) => "table", + StreamingJob::Index(_, _) => "index", + StreamingJob::Source(_) => "source", + } + } + pub fn definition(&self) -> String { match self { Self::MaterializedView(table) => table.definition.clone(), diff --git a/src/meta/src/rpc/ddl_controller.rs b/src/meta/src/rpc/ddl_controller.rs index 4c75e751b848..2d9a1acc4fea 100644 --- a/src/meta/src/rpc/ddl_controller.rs +++ b/src/meta/src/rpc/ddl_controller.rs @@ -30,7 +30,7 @@ use risingwave_common::util::epoch::Epoch; use risingwave_common::util::stream_graph_visitor::{ visit_fragment, visit_stream_node, visit_stream_node_cont_mut, }; -use risingwave_common::{bail, current_cluster_version}; +use risingwave_common::{bail, current_cluster_version, must_match}; use risingwave_connector::error::ConnectorError; use risingwave_connector::source::cdc::CdcSourceType; use risingwave_connector::source::{ @@ -905,20 +905,22 @@ impl DdlController { None => None, }; + let stream_job_clone_for_err_handle = stream_job.clone(); + // 4. Build and persist stream job. let result: MetaResult<_> = try { tracing::debug!(id = stream_job.id(), "building stream job"); let (ctx, table_fragments) = self .build_stream_job( stream_ctx, - &stream_job, + stream_job, fragment_graph, affected_table_replace_info, ) .await?; // Do some type-specific work for each type of stream job. - match stream_job { + match &ctx.streaming_job { StreamingJob::Table(None, ref table, TableJobType::SharedCdcSource) => { Self::validate_cdc_table(table, &table_fragments) .await @@ -928,16 +930,7 @@ impl DdlController { // Register the source on the connector node. self.source_manager.register_source(source).await?; } - StreamingJob::Sink(ref sink, ref mut target_table) => { - // When sinking into table occurs, some variables of the target table may be modified, - // such as `fragment_id` being altered by `prepare_replace_table`. - // At this point, it’s necessary to update the table info carried with the sink. - if let Some((StreamingJob::Table(source, table, _), ..)) = - &ctx.replace_table_job_info - { - *target_table = Some((table.clone(), source.clone())); - } - + StreamingJob::Sink(ref sink, _) => { // Validate the sink on the connector node. validate_sink(sink).await?; } @@ -954,6 +947,7 @@ impl DdlController { let (ctx, table_fragments) = match result { Ok(r) => r, Err(e) => { + let stream_job = stream_job_clone_for_err_handle; tracing::error!(error = %e.as_report(), id = stream_job.id(), "failed to create streaming job"); self.cancel_stream_job(&stream_job, internal_tables, Some(&e)) .await?; @@ -961,14 +955,13 @@ impl DdlController { } }; - match (create_type, &stream_job) { + match (create_type, &ctx.streaming_job) { (CreateType::Foreground, _) | (CreateType::Unspecified, _) // FIXME(kwannoel): Unify background stream's creation path with MV below. | (CreateType::Background, &StreamingJob::Sink(_, _)) => { self.create_streaming_job_inner( mgr, - stream_job, table_fragments, ctx, internal_tables, @@ -978,12 +971,11 @@ impl DdlController { (CreateType::Background, &StreamingJob::MaterializedView(_)) => { let ctrl = self.clone(); let mgr = mgr.clone(); - let stream_job_id = stream_job.id(); + let stream_job_id = ctx.streaming_job.id(); let fut = async move { let result = ctrl .create_streaming_job_inner( &mgr, - stream_job, table_fragments, ctx, internal_tables, @@ -1002,7 +994,7 @@ impl DdlController { Ok(IGNORED_NOTIFICATION_VERSION) } (CreateType::Background, _) => { - let d: StreamingJobDiscriminants = stream_job.into(); + let d: StreamingJobDiscriminants = ctx.streaming_job.into(); bail!("background_ddl not supported for: {:?}", d) } } @@ -1267,15 +1259,15 @@ impl DdlController { async fn create_streaming_job_inner( &self, mgr: &MetadataManagerV1, - stream_job: StreamingJob, table_fragments: TableFragments, ctx: CreateStreamingJobContext, internal_tables: Vec
, ) -> MetaResult { + let stream_job = ctx.streaming_job.clone(); let job_id = stream_job.id(); tracing::debug!(id = job_id, "creating stream job"); - let result: MetaResult<()> = try { + let result: MetaResult = try { // Add table fragments to meta store with state: `State::Initial`. mgr.fragment_manager .start_create_table_fragments(table_fragments.clone()) @@ -1286,44 +1278,41 @@ impl DdlController { .await? }; - if let Err(e) = result { - match stream_job.create_type() { - CreateType::Background => { - tracing::error!(id = job_id, error = %e.as_report(), "finish stream job failed"); - let should_cancel = match mgr - .fragment_manager - .select_table_fragments_by_table_id(&job_id.into()) - .await - { - Err(err) => err.is_fragment_not_found(), - Ok(table_fragments) => table_fragments.is_initial(), - }; - if should_cancel { - // If the table fragments are not found or in initial state, it means that the stream job has not been created. - // We need to cancel the stream job. + match result { + Err(e) => { + match stream_job.create_type() { + CreateType::Background => { + tracing::error!(id = job_id, error = %e.as_report(), "finish stream job failed"); + let should_cancel = match mgr + .fragment_manager + .select_table_fragments_by_table_id(&job_id.into()) + .await + { + Err(err) => err.is_fragment_not_found(), + Ok(table_fragments) => table_fragments.is_initial(), + }; + if should_cancel { + // If the table fragments are not found or in initial state, it means that the stream job has not been created. + // We need to cancel the stream job. + self.cancel_stream_job(&stream_job, internal_tables, Some(&e)) + .await?; + } else { + // NOTE: This assumes that we will trigger recovery, + // and recover stream job progress. + } + } + _ => { self.cancel_stream_job(&stream_job, internal_tables, Some(&e)) .await?; - } else { - // NOTE: This assumes that we will trigger recovery, - // and recover stream job progress. } } - _ => { - self.cancel_stream_job(&stream_job, internal_tables, Some(&e)) - .await?; - } + Err(e) } - return Err(e); - }; - - tracing::debug!(id = job_id, "finishing stream job"); - let version = mgr - .catalog_manager - .finish_stream_job(stream_job, internal_tables) - .await?; - tracing::debug!(id = job_id, "finished stream job"); - - Ok(version) + Ok(version) => { + tracing::info!(id = job_id, "finish stream job succeeded"); + Ok(version) + } + } } async fn drop_streaming_job( @@ -1527,7 +1516,7 @@ impl DdlController { pub(crate) async fn build_stream_job( &self, stream_ctx: StreamContext, - stream_job: &StreamingJob, + mut stream_job: StreamingJob, fragment_graph: StreamFragmentGraph, affected_table_replace_info: Option<(StreamingJob, StreamFragmentGraph)>, ) -> MetaResult<(CreateStreamingJobContext, TableFragments)> { @@ -1557,7 +1546,7 @@ impl DdlController { let complete_graph = CompleteStreamFragmentGraph::with_upstreams( fragment_graph, upstream_root_fragments, - stream_job.into(), + (&stream_job).into(), )?; // 2. Build the actor graph. @@ -1575,7 +1564,7 @@ impl DdlController { dispatchers, merge_updates, } = actor_graph_builder - .generate_graph(&self.env, stream_job, expr_context) + .generate_graph(&self.env, &stream_job, expr_context) .await?; assert!(merge_updates.is_empty()); @@ -1600,7 +1589,7 @@ impl DdlController { let replace_table_job_info = match affected_table_replace_info { Some((streaming_job, fragment_graph)) => { - let StreamingJob::Sink(s, _) = stream_job else { + let StreamingJob::Sink(s, target_table) = &mut stream_job else { bail!("additional replace table event only occurs when sinking into table"); }; @@ -1637,6 +1626,13 @@ impl DdlController { fragment_graph, ) .await?; + // When sinking into table occurs, some variables of the target table may be modified, + // such as `fragment_id` being altered by `prepare_replace_table`. + // At this point, it’s necessary to update the table info carried with the sink. + must_match!(&streaming_job, StreamingJob::Table(source, table, _) => { + // The StreamingJob in ReplaceTableInfo must be StreamingJob::Table + *target_table = Some((table.clone(), source.clone())); + }); Some((streaming_job, context, table_fragments)) } @@ -1652,7 +1648,8 @@ impl DdlController { definition: stream_job.definition(), mv_table_id: stream_job.mv_table(), create_type: stream_job.create_type(), - ddl_type: stream_job.into(), + ddl_type: (&stream_job).into(), + streaming_job: stream_job, replace_table_job_info, option: CreateStreamingJobOption {}, }; @@ -1661,7 +1658,7 @@ impl DdlController { let creating_tables = ctx .internal_tables() .into_iter() - .chain(stream_job.table().cloned()) + .chain(ctx.streaming_job.table().cloned()) .collect_vec(); if let MetadataManager::V1(mgr) = &self.metadata_manager { @@ -2002,6 +1999,8 @@ impl DdlController { dispatchers, building_locations, existing_locations, + streaming_job: stream_job.clone(), + dummy_id: dummy_table_id, }; Ok((ctx, table_fragments)) diff --git a/src/meta/src/rpc/ddl_controller_v2.rs b/src/meta/src/rpc/ddl_controller_v2.rs index 518d6e7b3eaf..d8308703ddd4 100644 --- a/src/meta/src/rpc/ddl_controller_v2.rs +++ b/src/meta/src/rpc/ddl_controller_v2.rs @@ -81,12 +81,20 @@ impl DdlController { .unwrap(); let _reschedule_job_lock = self.stream_manager.reschedule_lock_read_guard().await; + let id = streaming_job.id(); + let name = streaming_job.name(); + let definition = streaming_job.definition(); + let source_id = match &streaming_job { + StreamingJob::Table(Some(src), _, _) | StreamingJob::Source(src) => Some(src.id), + _ => None, + }; + // create streaming job. match self .create_streaming_job_inner_v2( mgr, ctx, - &mut streaming_job, + streaming_job, fragment_graph, affected_table_replace_info, ) @@ -96,9 +104,9 @@ impl DdlController { Err(err) => { tracing::error!(id = job_id, error = %err.as_report(), "failed to create streaming job"); let event = risingwave_pb::meta::event_log::EventCreateStreamJobFail { - id: streaming_job.id(), - name: streaming_job.name(), - definition: streaming_job.definition(), + id, + name, + definition, error: err.as_report().to_string(), }; self.env.event_log_manager_ref().add_event_logs(vec![ @@ -110,11 +118,10 @@ impl DdlController { .await?; if aborted { tracing::warn!(id = job_id, "aborted streaming job"); - match &streaming_job { - StreamingJob::Table(Some(src), _, _) | StreamingJob::Source(src) => { - self.source_manager.unregister_sources(vec![src.id]).await; - } - _ => {} + if let Some(source_id) = source_id { + self.source_manager + .unregister_sources(vec![source_id]) + .await; } } Err(err) @@ -126,12 +133,12 @@ impl DdlController { &self, mgr: &MetadataManagerV2, ctx: StreamContext, - streaming_job: &mut StreamingJob, + mut streaming_job: StreamingJob, fragment_graph: StreamFragmentGraphProto, affected_table_replace_info: Option, ) -> MetaResult { let mut fragment_graph = - StreamFragmentGraph::new(&self.env, fragment_graph, streaming_job).await?; + StreamFragmentGraph::new(&self.env, fragment_graph, &streaming_job).await?; streaming_job.set_table_fragment_id(fragment_graph.table_fragment_id()); streaming_job.set_dml_fragment_id(fragment_graph.dml_fragment_id()); @@ -173,6 +180,8 @@ impl DdlController { ) .await?; + let streaming_job = &ctx.streaming_job; + match streaming_job { StreamingJob::Table(None, table, TableJobType::SharedCdcSource) => { Self::validate_cdc_table(table, &table_fragments).await?; @@ -181,13 +190,7 @@ impl DdlController { // Register the source on the connector node. self.source_manager.register_source(source).await?; } - StreamingJob::Sink(sink, target_table) => { - if let Some((StreamingJob::Table(source, table, _), ..)) = - &ctx.replace_table_job_info - { - *target_table = Some((table.clone(), source.clone())); - } - + StreamingJob::Sink(sink, _) => { // Validate the sink on the connector node. validate_sink(sink).await?; } @@ -204,50 +207,25 @@ impl DdlController { // create streaming jobs. let stream_job_id = streaming_job.id(); - match (streaming_job.create_type(), streaming_job) { + match (streaming_job.create_type(), &streaming_job) { (CreateType::Unspecified, _) | (CreateType::Foreground, _) // FIXME(kwannoel): Unify background stream's creation path with MV below. | (CreateType::Background, StreamingJob::Sink(_, _)) => { - let replace_table_job_info = ctx.replace_table_job_info.as_ref().map( - |(streaming_job, ctx, table_fragments)| { - ( - streaming_job.clone(), - ctx.merge_updates.clone(), - table_fragments.table_id().table_id(), - ) - }, - ); - - self.stream_manager + let version = self.stream_manager .create_streaming_job(table_fragments, ctx) .await?; - - let version = mgr - .catalog_controller - .finish_streaming_job(stream_job_id as _, replace_table_job_info) - .await?; - Ok(version) } (CreateType::Background, _) => { let ctrl = self.clone(); - let mgr = mgr.clone(); let fut = async move { - let result = ctrl + let _ = ctrl .stream_manager .create_streaming_job(table_fragments, ctx) .await.inspect_err(|err| { tracing::error!(id = stream_job_id, error = ?err.as_report(), "failed to create background streaming job"); }); - if result.is_ok() { - let _ = mgr - .catalog_controller - .finish_streaming_job(stream_job_id as _, None) - .await.inspect_err(|err| { - tracing::error!(id = stream_job_id, error = ?err.as_report(), "failed to finish background streaming job"); - }); - } }; tokio::spawn(fut); Ok(IGNORED_NOTIFICATION_VERSION) diff --git a/src/meta/src/stream/stream_manager.rs b/src/meta/src/stream/stream_manager.rs index c2301a9cf3db..2756d71a8a6c 100644 --- a/src/meta/src/stream/stream_manager.rs +++ b/src/meta/src/stream/stream_manager.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use futures::future::join_all; use itertools::Itertools; +use risingwave_common::bail; use risingwave_common::catalog::TableId; use risingwave_meta_model_v2::ObjectId; use risingwave_pb::catalog::{CreateType, Subscription, Table}; @@ -29,7 +30,7 @@ use tracing::Instrument; use super::{Locations, RescheduleOptions, ScaleControllerRef, TableResizePolicy}; use crate::barrier::{BarrierScheduler, Command, ReplaceTablePlan, StreamRpcManager}; -use crate::manager::{DdlType, MetaSrvEnv, MetadataManager, StreamingJob}; +use crate::manager::{DdlType, MetaSrvEnv, MetadataManager, NotificationVersion, StreamingJob}; use crate::model::{ActorId, FragmentId, MetadataModel, TableFragments, TableParallelism}; use crate::stream::{to_build_actor_info, SourceManagerRef}; use crate::{MetaError, MetaResult}; @@ -44,7 +45,6 @@ pub struct CreateStreamingJobOption { /// [`CreateStreamingJobContext`] carries one-time infos for creating a streaming job. /// /// Note: for better readability, keep this struct complete and immutable once created. -#[cfg_attr(test, derive(Default))] pub struct CreateStreamingJobContext { /// New dispatchers to add from upstream actors to downstream actors. pub dispatchers: HashMap>, @@ -76,6 +76,8 @@ pub struct CreateStreamingJobContext { pub replace_table_job_info: Option<(StreamingJob, ReplaceTableContext, TableFragments)>, pub option: CreateStreamingJobOption, + + pub streaming_job: StreamingJob, } impl CreateStreamingJobContext { @@ -88,7 +90,7 @@ pub enum CreatingState { Failed { reason: MetaError }, // sender is used to notify the canceling result. Canceling { finish_tx: oneshot::Sender<()> }, - Created, + Created { version: NotificationVersion }, } struct StreamingJobExecution { @@ -174,6 +176,10 @@ pub struct ReplaceTableContext { /// The locations of the existing actors, essentially the downstream chain actors to update. pub existing_locations: Locations, + + pub streaming_job: StreamingJob, + + pub dummy_id: u32, } /// `GlobalStreamManager` manages all the streams in the system. @@ -228,7 +234,7 @@ impl GlobalStreamManager { self: &Arc, table_fragments: TableFragments, ctx: CreateStreamingJobContext, - ) -> MetaResult<()> { + ) -> MetaResult { let table_id = table_fragments.table_id(); let (sender, mut receiver) = tokio::sync::mpsc::channel(10); let execution = StreamingJobExecution::new(table_id, sender.clone()); @@ -237,12 +243,12 @@ impl GlobalStreamManager { let stream_manager = self.clone(); let fut = async move { let res = stream_manager - .create_streaming_job_impl( table_fragments, ctx) + .create_streaming_job_impl(table_fragments, ctx) .await; match res { - Ok(_) => { + Ok(version) => { let _ = sender - .send(CreatingState::Created) + .send(CreatingState::Created { version }) .await .inspect_err(|_| tracing::warn!("failed to notify created: {table_id}")); } @@ -261,72 +267,66 @@ impl GlobalStreamManager { .in_current_span(); tokio::spawn(fut); - let res = try { - while let Some(state) = receiver.recv().await { - match state { - CreatingState::Failed { reason } => { - tracing::debug!(id=?table_id, "stream job failed"); - self.creating_job_info.delete_job(table_id).await; - return Err(reason); - } - CreatingState::Canceling { finish_tx } => { - tracing::debug!(id=?table_id, "cancelling streaming job"); - if let Ok(table_fragments) = self - .metadata_manager - .get_job_fragments_by_id(&table_id) - .await - { - // try to cancel buffered creating command. - if self.barrier_scheduler.try_cancel_scheduled_create(table_id) { - tracing::debug!( - "cancelling streaming job {table_id} in buffer queue." - ); - let node_actors = table_fragments.worker_actor_ids(); - let cluster_info = - self.metadata_manager.get_streaming_cluster_info().await?; - self.stream_rpc_manager - .drop_actors( - &cluster_info.worker_nodes, - node_actors.into_iter(), - ) - .await?; - - if let MetadataManager::V1(mgr) = &self.metadata_manager { - mgr.fragment_manager - .drop_table_fragments_vec(&HashSet::from_iter( - std::iter::once(table_id), - )) - .await?; - } - } else if !table_fragments.is_created() { - tracing::debug!( - "cancelling streaming job {table_id} by issue cancel command." - ); - - self.barrier_scheduler - .run_command(Command::CancelStreamingJob(table_fragments)) + while let Some(state) = receiver.recv().await { + match state { + CreatingState::Failed { reason } => { + tracing::debug!(id=?table_id, "stream job failed"); + // FIXME(kwannoel): For creating stream jobs + // we need to clean up the resources in the stream manager. + self.creating_job_info.delete_job(table_id).await; + return Err(reason); + } + CreatingState::Canceling { finish_tx } => { + tracing::debug!(id=?table_id, "cancelling streaming job"); + if let Ok(table_fragments) = self + .metadata_manager + .get_job_fragments_by_id(&table_id) + .await + { + // try to cancel buffered creating command. + if self.barrier_scheduler.try_cancel_scheduled_create(table_id) { + tracing::debug!("cancelling streaming job {table_id} in buffer queue."); + let node_actors = table_fragments.worker_actor_ids(); + let cluster_info = + self.metadata_manager.get_streaming_cluster_info().await?; + self.stream_rpc_manager + .drop_actors(&cluster_info.worker_nodes, node_actors.into_iter()) + .await?; + + if let MetadataManager::V1(mgr) = &self.metadata_manager { + mgr.fragment_manager + .drop_table_fragments_vec(&HashSet::from_iter(std::iter::once( + table_id, + ))) .await?; - } else { - // streaming job is already completed. - continue; } - let _ = finish_tx.send(()).inspect_err(|_| { - tracing::warn!("failed to notify cancelled: {table_id}") - }); - self.creating_job_info.delete_job(table_id).await; - return Err(MetaError::cancelled("create")); + } else if !table_fragments.is_created() { + tracing::debug!( + "cancelling streaming job {table_id} by issue cancel command." + ); + + self.barrier_scheduler + .run_command(Command::CancelStreamingJob(table_fragments)) + .await?; + } else { + // streaming job is already completed. + continue; } - } - CreatingState::Created => { + let _ = finish_tx.send(()).inspect_err(|_| { + tracing::warn!("failed to notify cancelled: {table_id}") + }); self.creating_job_info.delete_job(table_id).await; - return Ok(()); + return Err(MetaError::cancelled("create")); } } + CreatingState::Created { version } => { + self.creating_job_info.delete_job(table_id).await; + return Ok(version); + } } - }; - + } self.creating_job_info.delete_job(table_id).await; - res + bail!("receiver failed to get notification version for finished stream job") } async fn build_actors( @@ -393,6 +393,7 @@ impl GlobalStreamManager { &self, table_fragments: TableFragments, CreateStreamingJobContext { + streaming_job, dispatchers, upstream_root_actors, building_locations, @@ -401,9 +402,10 @@ impl GlobalStreamManager { create_type, ddl_type, replace_table_job_info, + internal_tables, .. }: CreateStreamingJobContext, - ) -> MetaResult<()> { + ) -> MetaResult { let mut replace_table_command = None; let mut replace_table_id = None; @@ -446,15 +448,17 @@ impl GlobalStreamManager { let init_split_assignment = self.source_manager.allocate_splits(&dummy_table_id).await?; + replace_table_id = Some(dummy_table_id); + replace_table_command = Some(ReplaceTablePlan { old_table_fragments: context.old_table_fragments, new_table_fragments: table_fragments, merge_updates: context.merge_updates, dispatchers: context.dispatchers, init_split_assignment, + streaming_job, + dummy_id: dummy_table_id.table_id, }); - - replace_table_id = Some(dummy_table_id); } let table_id = table_fragments.table_id(); @@ -476,28 +480,38 @@ impl GlobalStreamManager { dispatchers, init_split_assignment, definition: definition.to_string(), + streaming_job: streaming_job.clone(), + internal_tables: internal_tables.into_values().collect_vec(), ddl_type, replace_table: replace_table_command, create_type, }; tracing::debug!("sending Command::CreateStreamingJob"); - if let Err(err) = self.barrier_scheduler.run_command(command).await { - if create_type == CreateType::Foreground || err.is_cancelled() { - let mut table_ids = HashSet::from_iter(std::iter::once(table_id)); - if let Some(dummy_table_id) = replace_table_id { - table_ids.insert(dummy_table_id); - } - if let MetadataManager::V1(mgr) = &self.metadata_manager { - mgr.fragment_manager - .drop_table_fragments_vec(&table_ids) - .await?; + let result: MetaResult = try { + self.barrier_scheduler.run_command(command).await?; + tracing::debug!("first barrier collected for stream job"); + self.metadata_manager + .wait_streaming_job_finished(&streaming_job) + .await? + }; + match result { + Err(err) => { + if create_type == CreateType::Foreground || err.is_cancelled() { + let mut table_ids = HashSet::from_iter(std::iter::once(table_id)); + if let Some(dummy_table_id) = replace_table_id { + table_ids.insert(dummy_table_id); + } + if let MetadataManager::V1(mgr) = &self.metadata_manager { + mgr.fragment_manager + .drop_table_fragments_vec(&table_ids) + .await?; + } } - } - return Err(err); + Err(err) + } + Ok(version) => Ok(version), } - - Ok(()) } pub async fn replace_table( @@ -509,6 +523,8 @@ impl GlobalStreamManager { dispatchers, building_locations, existing_locations, + dummy_id, + streaming_job, }: ReplaceTableContext, ) -> MetaResult<()> { self.build_actors( @@ -530,6 +546,8 @@ impl GlobalStreamManager { merge_updates, dispatchers, init_split_assignment, + dummy_id, + streaming_job, })) .await && let MetadataManager::V1(mgr) = &self.metadata_manager @@ -1136,7 +1154,17 @@ mod tests { ); let ctx = CreateStreamingJobContext { building_locations: locations, - ..Default::default() + streaming_job: StreamingJob::MaterializedView(table.clone()), + mv_table_id: Some(table_fragments.table_id().table_id), + dispatchers: Default::default(), + upstream_root_actors: Default::default(), + internal_tables: Default::default(), + existing_locations: Default::default(), + definition: "".to_string(), + create_type: Default::default(), + ddl_type: Default::default(), + replace_table_job_info: None, + option: Default::default(), }; self.catalog_manager From cd35372b0844a42060b0ddb098fc028768a5e083 Mon Sep 17 00:00:00 2001 From: Noel Kwan <47273164+kwannoel@users.noreply.github.com> Date: Tue, 16 Jul 2024 19:19:47 +0800 Subject: [PATCH 20/70] chore(ci): bump backfill test timeout (#17701) --- ci/workflows/main-cron.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/workflows/main-cron.yml b/ci/workflows/main-cron.yml index 50cd3cc7f45d..21f26827b4f0 100644 --- a/ci/workflows/main-cron.yml +++ b/ci/workflows/main-cron.yml @@ -714,7 +714,7 @@ steps: config: ci/docker-compose.yml mount-buildkite-agent: true - ./ci/plugins/upload-failure-logs - timeout_in_minutes: 22 + timeout_in_minutes: 24 retry: *auto-retry - label: "e2e standalone binary test" From 5a56574cd9a602d3ade78b53a5f61c96684b1334 Mon Sep 17 00:00:00 2001 From: Noel Kwan <47273164+kwannoel@users.noreply.github.com> Date: Tue, 16 Jul 2024 19:45:03 +0800 Subject: [PATCH 21/70] feat(meta): support drop creating materialized views for v2 backend (#17503) --- Makefile.toml | 2 +- ci/scripts/run-e2e-test.sh | 2 +- ci/scripts/single-node-utils.sh | 6 + e2e_test/ddl/drop/drop_creating_mv.slt | 76 ++++++ .../src/worker_manager/worker_node_manager.rs | 19 +- src/frontend/src/catalog/schema_catalog.rs | 5 +- src/meta/src/barrier/mod.rs | 30 +-- src/meta/src/controller/catalog.rs | 242 ++++++++---------- src/meta/src/controller/streaming_job.rs | 142 ++++++++-- src/meta/src/controller/utils.rs | 79 +++++- src/meta/src/manager/streaming_job.rs | 4 +- 11 files changed, 412 insertions(+), 195 deletions(-) create mode 100644 e2e_test/ddl/drop/drop_creating_mv.slt diff --git a/Makefile.toml b/Makefile.toml index aa3d38baf59b..16b5e6393e9a 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -581,7 +581,7 @@ set -euo pipefail RC_ENV_FILE="${PREFIX_CONFIG}/risedev-env" if [ ! -f "${RC_ENV_FILE}" ]; then - echo "risedev-env file not found. Did you start cluster using $(tput setaf 4)\`./risedev d\`$(tput sgr0) or $(tput setaf 4)\`./risedev p\`$(tput sgr0)?" + echo "risedev-env file not found at ${RC_ENV_FILE}. Did you start cluster using $(tput setaf 4)\`./risedev d\`$(tput sgr0) or $(tput setaf 4)\`./risedev p\`$(tput sgr0)?" exit 1 fi ''' diff --git a/ci/scripts/run-e2e-test.sh b/ci/scripts/run-e2e-test.sh index 4736f4aa53a8..c1c912723267 100755 --- a/ci/scripts/run-e2e-test.sh +++ b/ci/scripts/run-e2e-test.sh @@ -88,7 +88,7 @@ cluster_stop echo "--- e2e, $mode, batch" RUST_LOG="info,risingwave_stream=info,risingwave_batch=info,risingwave_storage=info" \ cluster_start -sqllogictest -p 4566 -d dev './e2e_test/ddl/**/*.slt' --junit "batch-ddl-${profile}" +sqllogictest -p 4566 -d dev './e2e_test/ddl/**/*.slt' --junit "batch-ddl-${profile}" --label "can-use-recover" if [[ "$mode" != "single-node" ]]; then sqllogictest -p 4566 -d dev './e2e_test/background_ddl/basic.slt' --junit "batch-ddl-${profile}" fi diff --git a/ci/scripts/single-node-utils.sh b/ci/scripts/single-node-utils.sh index 8f7240d3d906..4b19b2444f84 100755 --- a/ci/scripts/single-node-utils.sh +++ b/ci/scripts/single-node-utils.sh @@ -12,6 +12,12 @@ export PREFIX_LOG=$RW_PREFIX/log start_single_node() { mkdir -p "$HOME/.risingwave/state_store" mkdir -p "$HOME/.risingwave/meta_store" + mkdir -p .risingwave/config + cat < .risingwave/config/risedev-env +RW_META_ADDR="http://127.0.0.1:5690" +RISEDEV_RW_FRONTEND_LISTEN_ADDRESS="127.0.0.1" +RISEDEV_RW_FRONTEND_PORT="4566" +EOF RUST_BACKTRACE=1 "$PREFIX_BIN"/risingwave >"$1" 2>&1 } diff --git a/e2e_test/ddl/drop/drop_creating_mv.slt b/e2e_test/ddl/drop/drop_creating_mv.slt new file mode 100644 index 000000000000..621ac216d4ec --- /dev/null +++ b/e2e_test/ddl/drop/drop_creating_mv.slt @@ -0,0 +1,76 @@ +statement ok +create table t(v1 int); + +statement ok +insert into t select * from generate_series(1, 10000); + +statement ok +flush; + +statement ok +set streaming_rate_limit=1; + +############## Test drop foreground mv +onlyif can-use-recover +system ok +risedev psql -c 'create materialized view m1 as select * from t;' & + +onlyif can-use-recover +sleep 5s + +onlyif can-use-recover +statement ok +drop materialized view m1; + +############## Test drop background mv BEFORE recovery +statement ok +set background_ddl=true; + +onlyif can-use-recover +statement ok +create materialized view m1 as select * from t; + +onlyif can-use-recover +sleep 5s + +onlyif can-use-recover +statement ok +drop materialized view m1; + +############## Test drop background mv AFTER recovery +statement ok +set background_ddl=true; + +onlyif can-use-recover +statement ok +create materialized view m1 as select * from t; + +onlyif can-use-recover +sleep 5s + +onlyif can-use-recover +statement ok +recover; + +onlyif can-use-recover +sleep 10s + +onlyif can-use-recover +statement ok +drop materialized view m1; + +############## Make sure the mv can still be successfully created later. +statement ok +set streaming_rate_limit=default; + +statement ok +set background_ddl=false; + +statement ok +create materialized view m1 as select * from t; + +statement ok +drop materialized view m1; + +statement ok +drop table t; \ No newline at end of file diff --git a/src/batch/src/worker_manager/worker_node_manager.rs b/src/batch/src/worker_manager/worker_node_manager.rs index b1e08517954d..1e7b96b5b8bb 100644 --- a/src/batch/src/worker_manager/worker_node_manager.rs +++ b/src/batch/src/worker_manager/worker_node_manager.rs @@ -18,6 +18,7 @@ use std::time::Duration; use rand::seq::SliceRandom; use risingwave_common::bail; +use risingwave_common::catalog::OBJECT_ID_PLACEHOLDER; use risingwave_common::hash::{WorkerSlotId, WorkerSlotMapping}; use risingwave_common::vnode_mapping::vnode_placement::place_vnode; use risingwave_pb::common::{WorkerNode, WorkerType}; @@ -213,10 +214,20 @@ impl WorkerNodeManager { pub fn remove_streaming_fragment_mapping(&self, fragment_id: &FragmentId) { let mut guard = self.inner.write().unwrap(); - guard - .streaming_fragment_vnode_mapping - .remove(fragment_id) - .unwrap(); + + let res = guard.streaming_fragment_vnode_mapping.remove(fragment_id); + match &res { + Some(_) => {} + None if OBJECT_ID_PLACEHOLDER == *fragment_id => { + // Do nothing for placeholder fragment. + } + None => { + panic!( + "Streaming vnode mapping not found for fragment_id: {}", + fragment_id + ) + } + }; } /// Returns fragment's vnode mapping for serving. diff --git a/src/frontend/src/catalog/schema_catalog.rs b/src/frontend/src/catalog/schema_catalog.rs index 61ec11e144dc..0394da2a70f8 100644 --- a/src/frontend/src/catalog/schema_catalog.rs +++ b/src/frontend/src/catalog/schema_catalog.rs @@ -168,10 +168,7 @@ impl SchemaCatalog { pub fn create_index(&mut self, prost: &PbIndex) { let name = prost.name.clone(); let id = prost.id.into(); - - let index_table = self - .get_created_table_by_id(&prost.index_table_id.into()) - .unwrap(); + let index_table = self.get_table_by_id(&prost.index_table_id.into()).unwrap(); let primary_table = self .get_created_table_by_id(&prost.primary_table_id.into()) .unwrap(); diff --git a/src/meta/src/barrier/mod.rs b/src/meta/src/barrier/mod.rs index df1a34544e6c..1a792b5ebfab 100644 --- a/src/meta/src/barrier/mod.rs +++ b/src/meta/src/barrier/mod.rs @@ -830,24 +830,20 @@ impl GlobalBarrierManager { .await; self.checkpoint_control.clear_on_err(&err).await; - if self.enable_recovery { - self.context - .set_status(BarrierManagerStatus::Recovering(RecoveryReason::Adhoc)); - let latest_snapshot = self.context.hummock_manager.latest_snapshot(); - let prev_epoch = TracedEpoch::new(latest_snapshot.committed_epoch.into()); // we can only recover from the committed epoch - let span = tracing::info_span!( - "adhoc_recovery", - error = %err.as_report(), - prev_epoch = prev_epoch.value().0 - ); + self.context + .set_status(BarrierManagerStatus::Recovering(RecoveryReason::Adhoc)); + let latest_snapshot = self.context.hummock_manager.latest_snapshot(); + let prev_epoch = TracedEpoch::new(latest_snapshot.committed_epoch.into()); // we can only recover from the committed epoch + let span = tracing::info_span!( + "adhoc_recovery", + error = %err.as_report(), + prev_epoch = prev_epoch.value().0 + ); - // No need to clean dirty tables for barrier recovery, - // The foreground stream job should cleanup their own tables. - self.recovery(None).instrument(span).await; - self.context.set_status(BarrierManagerStatus::Running); - } else { - panic!("failed to execute barrier: {}", err.as_report()); - } + // No need to clean dirty tables for barrier recovery, + // The foreground stream job should cleanup their own tables. + self.recovery(None).instrument(span).await; + self.context.set_status(BarrierManagerStatus::Running); } } diff --git a/src/meta/src/controller/catalog.rs b/src/meta/src/controller/catalog.rs index 6fbf7cb0ac04..cb5dc2bf41b7 100644 --- a/src/meta/src/controller/catalog.rs +++ b/src/meta/src/controller/catalog.rs @@ -64,7 +64,7 @@ use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use super::utils::{check_subscription_name_duplicate, get_fragment_ids_by_jobs}; use crate::controller::rename::{alter_relation_rename, alter_relation_rename_refs}; use crate::controller::utils::{ - check_connection_name_duplicate, check_database_name_duplicate, + build_relation_group, check_connection_name_duplicate, check_database_name_duplicate, check_function_signature_duplicate, check_relation_name_duplicate, check_schema_name_duplicate, check_secret_name_duplicate, ensure_object_id, ensure_object_not_refer, ensure_schema_empty, ensure_user_id, get_referring_objects, get_referring_objects_cascade, get_user_privilege, @@ -648,10 +648,15 @@ impl CatalogController { let inner = self.inner.write().await; let txn = inner.db.begin().await?; - let creating_jobs: Vec<(ObjectId, ObjectType)> = streaming_job::Entity::find() + let mut dirty_objs: Vec = streaming_job::Entity::find() .select_only() .column(streaming_job::Column::JobId) - .column(object::Column::ObjType) + .columns([ + object::Column::Oid, + object::Column::ObjType, + object::Column::SchemaId, + object::Column::DatabaseId, + ]) .join(JoinType::InnerJoin, streaming_job::Relation::Object.def()) .filter( streaming_job::Column::JobStatus.eq(JobStatus::Initial).or( @@ -660,13 +665,13 @@ impl CatalogController { .and(streaming_job::Column::CreateType.eq(CreateType::Foreground)), ), ) - .into_tuple() + .into_partial_model() .all(&txn) .await?; let changed = Self::clean_dirty_sink_downstreams(&txn).await?; - if creating_jobs.is_empty() { + if dirty_objs.is_empty() { if changed { txn.commit().await?; } @@ -674,22 +679,95 @@ impl CatalogController { return Ok(ReleaseContext::default()); } + self.log_cleaned_dirty_jobs(&dirty_objs, &txn).await?; + + let dirty_job_ids = dirty_objs.iter().map(|obj| obj.oid).collect::>(); + + let associated_source_ids: Vec = Table::find() + .select_only() + .column(table::Column::OptionalAssociatedSourceId) + .filter( + table::Column::TableId + .is_in(dirty_job_ids.clone()) + .and(table::Column::OptionalAssociatedSourceId.is_not_null()), + ) + .into_tuple() + .all(&txn) + .await?; + let dirty_source_objs: Vec = Object::find() + .filter(object::Column::Oid.is_in(associated_source_ids.clone())) + .into_partial_model() + .all(&txn) + .await?; + dirty_objs.extend(dirty_source_objs); + + let mut dirty_state_table_ids = vec![]; + let to_drop_internal_table_objs: Vec = Object::find() + .select_only() + .columns([ + object::Column::Oid, + object::Column::ObjType, + object::Column::SchemaId, + object::Column::DatabaseId, + ]) + .join(JoinType::InnerJoin, object::Relation::Table.def()) + .filter(table::Column::BelongsToJobId.is_in(dirty_job_ids.clone())) + .into_partial_model() + .all(&txn) + .await?; + dirty_state_table_ids.extend(to_drop_internal_table_objs.iter().map(|obj| obj.oid)); + dirty_objs.extend(to_drop_internal_table_objs); + + let to_delete_objs: HashSet = dirty_job_ids + .clone() + .into_iter() + .chain(dirty_state_table_ids.clone().into_iter()) + .chain(associated_source_ids.clone().into_iter()) + .collect(); + + let res = Object::delete_many() + .filter(object::Column::Oid.is_in(to_delete_objs)) + .exec(&txn) + .await?; + assert!(res.rows_affected > 0); + + txn.commit().await?; + + let relation_group = build_relation_group(dirty_objs); + + let _version = self + .notify_frontend(NotificationOperation::Delete, relation_group) + .await; + + Ok(ReleaseContext { + state_table_ids: dirty_state_table_ids, + source_ids: associated_source_ids, + ..Default::default() + }) + } + + async fn log_cleaned_dirty_jobs( + &self, + dirty_objs: &[PartialObject], + txn: &DatabaseTransaction, + ) -> MetaResult<()> { // Record cleaned streaming jobs in event logs. - let mut creating_table_ids = vec![]; - let mut creating_source_ids = vec![]; - let mut creating_sink_ids = vec![]; - let mut creating_job_ids = vec![]; - for (job_id, job_type) in creating_jobs { - creating_job_ids.push(job_id); + let mut dirty_table_ids = vec![]; + let mut dirty_source_ids = vec![]; + let mut dirty_sink_ids = vec![]; + for dirty_job_obj in dirty_objs { + let job_id = dirty_job_obj.oid; + let job_type = dirty_job_obj.obj_type; match job_type { - ObjectType::Table | ObjectType::Index => creating_table_ids.push(job_id), - ObjectType::Source => creating_source_ids.push(job_id), - ObjectType::Sink => creating_sink_ids.push(job_id), + ObjectType::Table | ObjectType::Index => dirty_table_ids.push(job_id), + ObjectType::Source => dirty_source_ids.push(job_id), + ObjectType::Sink => dirty_sink_ids.push(job_id), _ => unreachable!("unexpected streaming job type"), } } + let mut event_logs = vec![]; - if !creating_table_ids.is_empty() { + if !dirty_table_ids.is_empty() { let table_info: Vec<(TableId, String, String)> = Table::find() .select_only() .columns([ @@ -697,9 +775,9 @@ impl CatalogController { table::Column::Name, table::Column::Definition, ]) - .filter(table::Column::TableId.is_in(creating_table_ids)) + .filter(table::Column::TableId.is_in(dirty_table_ids)) .into_tuple() - .all(&txn) + .all(txn) .await?; for (table_id, name, definition) in table_info { let event = risingwave_pb::meta::event_log::EventDirtyStreamJobClear { @@ -713,7 +791,7 @@ impl CatalogController { )); } } - if !creating_source_ids.is_empty() { + if !dirty_source_ids.is_empty() { let source_info: Vec<(SourceId, String, String)> = Source::find() .select_only() .columns([ @@ -721,9 +799,9 @@ impl CatalogController { source::Column::Name, source::Column::Definition, ]) - .filter(source::Column::SourceId.is_in(creating_source_ids)) + .filter(source::Column::SourceId.is_in(dirty_source_ids)) .into_tuple() - .all(&txn) + .all(txn) .await?; for (source_id, name, definition) in source_info { let event = risingwave_pb::meta::event_log::EventDirtyStreamJobClear { @@ -737,7 +815,7 @@ impl CatalogController { )); } } - if !creating_sink_ids.is_empty() { + if !dirty_sink_ids.is_empty() { let sink_info: Vec<(SinkId, String, String)> = Sink::find() .select_only() .columns([ @@ -745,9 +823,9 @@ impl CatalogController { sink::Column::Name, sink::Column::Definition, ]) - .filter(sink::Column::SinkId.is_in(creating_sink_ids)) + .filter(sink::Column::SinkId.is_in(dirty_sink_ids)) .into_tuple() - .all(&txn) + .all(txn) .await?; for (sink_id, name, definition) in sink_info { let event = risingwave_pb::meta::event_log::EventDirtyStreamJobClear { @@ -761,53 +839,8 @@ impl CatalogController { )); } } - - let associated_source_ids: Vec = Table::find() - .select_only() - .column(table::Column::OptionalAssociatedSourceId) - .filter( - table::Column::TableId - .is_in(creating_job_ids.clone()) - .and(table::Column::OptionalAssociatedSourceId.is_not_null()), - ) - .into_tuple() - .all(&txn) - .await?; - - let state_table_ids: Vec = Table::find() - .select_only() - .column(table::Column::TableId) - .filter( - table::Column::BelongsToJobId - .is_in(creating_job_ids.clone()) - .or(table::Column::TableId.is_in(creating_job_ids.clone())), - ) - .into_tuple() - .all(&txn) - .await?; - - let to_delete_objs: HashSet = creating_job_ids - .clone() - .into_iter() - .chain(state_table_ids.clone().into_iter()) - .chain(associated_source_ids.clone().into_iter()) - .collect(); - - let res = Object::delete_many() - .filter(object::Column::Oid.is_in(to_delete_objs)) - .exec(&txn) - .await?; - assert!(res.rows_affected > 0); - - txn.commit().await?; - self.env.event_log_manager_ref().add_event_logs(event_logs); - - Ok(ReleaseContext { - state_table_ids, - source_ids: associated_source_ids, - ..Default::default() - }) + Ok(()) } async fn clean_dirty_sink_downstreams(txn: &DatabaseTransaction) -> MetaResult { @@ -2134,75 +2167,10 @@ impl CatalogController { // notify about them. self.notify_users_update(user_infos).await; - let mut relations = vec![]; - for obj in to_drop_objects { - match obj.obj_type { - ObjectType::Table => relations.push(PbRelation { - relation_info: Some(PbRelationInfo::Table(PbTable { - id: obj.oid as _, - schema_id: obj.schema_id.unwrap() as _, - database_id: obj.database_id.unwrap() as _, - ..Default::default() - })), - }), - ObjectType::Source => relations.push(PbRelation { - relation_info: Some(PbRelationInfo::Source(PbSource { - id: obj.oid as _, - schema_id: obj.schema_id.unwrap() as _, - database_id: obj.database_id.unwrap() as _, - ..Default::default() - })), - }), - ObjectType::Sink => relations.push(PbRelation { - relation_info: Some(PbRelationInfo::Sink(PbSink { - id: obj.oid as _, - schema_id: obj.schema_id.unwrap() as _, - database_id: obj.database_id.unwrap() as _, - ..Default::default() - })), - }), - ObjectType::Subscription => relations.push(PbRelation { - relation_info: Some(PbRelationInfo::Subscription(PbSubscription { - id: obj.oid as _, - schema_id: obj.schema_id.unwrap() as _, - database_id: obj.database_id.unwrap() as _, - ..Default::default() - })), - }), - ObjectType::View => relations.push(PbRelation { - relation_info: Some(PbRelationInfo::View(PbView { - id: obj.oid as _, - schema_id: obj.schema_id.unwrap() as _, - database_id: obj.database_id.unwrap() as _, - ..Default::default() - })), - }), - ObjectType::Index => { - relations.push(PbRelation { - relation_info: Some(PbRelationInfo::Index(PbIndex { - id: obj.oid as _, - schema_id: obj.schema_id.unwrap() as _, - database_id: obj.database_id.unwrap() as _, - ..Default::default() - })), - }); - relations.push(PbRelation { - relation_info: Some(PbRelationInfo::Table(PbTable { - id: obj.oid as _, - schema_id: obj.schema_id.unwrap() as _, - database_id: obj.database_id.unwrap() as _, - ..Default::default() - })), - }); - } - _ => unreachable!("only relations will be dropped."), - } - } + let relation_group = build_relation_group(to_drop_objects); + let version = self - .notify_frontend( - NotificationOperation::Delete, - NotificationInfo::RelationGroup(PbRelationGroup { relations }), - ) + .notify_frontend(NotificationOperation::Delete, relation_group) .await; let fragment_mappings = fragment_ids diff --git a/src/meta/src/controller/streaming_job.rs b/src/meta/src/controller/streaming_job.rs index 2c1332fe65cc..199861220bfe 100644 --- a/src/meta/src/controller/streaming_job.rs +++ b/src/meta/src/controller/streaming_job.rs @@ -37,14 +37,14 @@ use risingwave_meta_model_v2::{ use risingwave_pb::catalog::source::PbOptionalAssociatedTableId; use risingwave_pb::catalog::table::{PbOptionalAssociatedSourceId, PbTableVersion}; use risingwave_pb::catalog::{PbCreateType, PbTable}; -use risingwave_pb::meta::relation::PbRelationInfo; +use risingwave_pb::meta::relation::{PbRelationInfo, RelationInfo}; use risingwave_pb::meta::subscribe_response::{ - Info as NotificationInfo, Operation as NotificationOperation, Operation, + Info as NotificationInfo, Info, Operation as NotificationOperation, Operation, }; use risingwave_pb::meta::table_fragments::PbActorStatus; use risingwave_pb::meta::{ FragmentWorkerSlotMapping, PbFragmentWorkerSlotMapping, PbRelation, PbRelationGroup, - PbTableFragments, Relation, + PbTableFragments, Relation, RelationGroup, }; use risingwave_pb::source::{PbConnectorSplit, PbConnectorSplits}; use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism; @@ -65,8 +65,9 @@ use crate::barrier::{ReplaceTablePlan, Reschedule}; use crate::controller::catalog::CatalogController; use crate::controller::rename::ReplaceTableExprRewriter; use crate::controller::utils::{ - check_relation_name_duplicate, check_sink_into_table_cycle, ensure_object_id, ensure_user_id, - get_fragment_actor_ids, get_fragment_mappings, get_parallel_unit_to_worker_map, + build_relation_group, check_relation_name_duplicate, check_sink_into_table_cycle, + ensure_object_id, ensure_user_id, get_fragment_actor_ids, get_fragment_mappings, + get_parallel_unit_to_worker_map, PartialObject, }; use crate::controller::ObjectModel; use crate::manager::{NotificationVersion, SinkId, StreamingJob}; @@ -157,6 +158,8 @@ impl CatalogController { } } + let mut relations = vec![]; + match streaming_job { StreamingJob::MaterializedView(table) => { let job_id = Self::create_streaming_job_obj( @@ -171,8 +174,12 @@ impl CatalogController { ) .await?; table.id = job_id as _; - let table: table::ActiveModel = table.clone().into(); - Table::insert(table).exec(&txn).await?; + let table_model: table::ActiveModel = table.clone().into(); + Table::insert(table_model).exec(&txn).await?; + + relations.push(Relation { + relation_info: Some(RelationInfo::Table(table.to_owned())), + }); } StreamingJob::Sink(sink, _) => { if let Some(target_table_id) = sink.target_table { @@ -202,8 +209,11 @@ impl CatalogController { ) .await?; sink.id = job_id as _; - let sink: sink::ActiveModel = sink.clone().into(); - Sink::insert(sink).exec(&txn).await?; + let sink_model: sink::ActiveModel = sink.clone().into(); + Sink::insert(sink_model).exec(&txn).await?; + relations.push(Relation { + relation_info: Some(RelationInfo::Sink(sink.to_owned())), + }); } StreamingJob::Table(src, table, _) => { let job_id = Self::create_streaming_job_obj( @@ -235,9 +245,15 @@ impl CatalogController { ); let source: source::ActiveModel = src.clone().into(); Source::insert(source).exec(&txn).await?; + relations.push(Relation { + relation_info: Some(RelationInfo::Source(src.to_owned())), + }); } - let table: table::ActiveModel = table.clone().into(); - Table::insert(table).exec(&txn).await?; + let table_model: table::ActiveModel = table.clone().into(); + Table::insert(table_model).exec(&txn).await?; + relations.push(Relation { + relation_info: Some(RelationInfo::Table(table.to_owned())), + }); } StreamingJob::Index(index, table) => { ensure_object_id(ObjectType::Table, index.primary_table_id as _, &txn).await?; @@ -265,10 +281,16 @@ impl CatalogController { .exec(&txn) .await?; - let table: table::ActiveModel = table.clone().into(); - Table::insert(table).exec(&txn).await?; - let index: index::ActiveModel = index.clone().into(); - Index::insert(index).exec(&txn).await?; + let table_model: table::ActiveModel = table.clone().into(); + Table::insert(table_model).exec(&txn).await?; + let index_model: index::ActiveModel = index.clone().into(); + Index::insert(index_model).exec(&txn).await?; + relations.push(Relation { + relation_info: Some(RelationInfo::Table(table.to_owned())), + }); + relations.push(Relation { + relation_info: Some(RelationInfo::Index(index.to_owned())), + }); } StreamingJob::Source(src) => { let job_id = Self::create_streaming_job_obj( @@ -283,8 +305,11 @@ impl CatalogController { ) .await?; src.id = job_id as _; - let source: source::ActiveModel = src.clone().into(); - Source::insert(source).exec(&txn).await?; + let source_model: source::ActiveModel = src.clone().into(); + Source::insert(source_model).exec(&txn).await?; + relations.push(Relation { + relation_info: Some(RelationInfo::Source(src.to_owned())), + }); } } @@ -309,18 +334,25 @@ impl CatalogController { txn.commit().await?; + let _version = self + .notify_frontend( + Operation::Add, + Info::RelationGroup(RelationGroup { relations }), + ) + .await; + Ok(()) } pub async fn create_internal_table_catalog( &self, job_id: ObjectId, - internal_tables: Vec, + mut internal_tables: Vec, ) -> MetaResult> { let inner = self.inner.write().await; let txn = inner.db.begin().await?; let mut table_id_map = HashMap::new(); - for table in internal_tables { + for table in &mut internal_tables { let table_id = Self::create_object( &txn, ObjectType::Table, @@ -331,14 +363,27 @@ impl CatalogController { .await? .oid; table_id_map.insert(table.id, table_id as u32); - let mut table: table::ActiveModel = table.into(); - table.table_id = Set(table_id as _); - table.belongs_to_job_id = Set(Some(job_id as _)); - table.fragment_id = NotSet; - Table::insert(table).exec(&txn).await?; + table.id = table_id as _; + let mut table_model: table::ActiveModel = table.clone().into(); + table_model.table_id = Set(table_id as _); + table_model.belongs_to_job_id = Set(Some(job_id as _)); + table_model.fragment_id = NotSet; + Table::insert(table_model).exec(&txn).await?; } txn.commit().await?; - + let _version = self + .notify_frontend( + Operation::Add, + Info::RelationGroup(RelationGroup { + relations: internal_tables + .iter() + .map(|table| Relation { + relation_info: Some(RelationInfo::Table(table.clone())), + }) + .collect(), + }), + ) + .await; Ok(table_id_map) } @@ -463,6 +508,48 @@ impl CatalogController { .one(&txn) .await?; + // Get notification info + let mut objs = vec![]; + let obj: Option = Object::find_by_id(job_id) + .select_only() + .columns([ + object::Column::Oid, + object::Column::ObjType, + object::Column::SchemaId, + object::Column::DatabaseId, + ]) + .into_partial_model() + .one(&txn) + .await?; + let obj = obj.ok_or_else(|| MetaError::catalog_id_not_found("streaming job", job_id))?; + objs.push(obj); + let internal_table_objs: Vec = Object::find() + .select_only() + .columns([ + object::Column::Oid, + object::Column::ObjType, + object::Column::SchemaId, + object::Column::DatabaseId, + ]) + .join(JoinType::InnerJoin, object::Relation::Table.def()) + .filter(table::Column::BelongsToJobId.is_in(internal_table_ids.clone())) + .into_partial_model() + .all(&txn) + .await?; + objs.extend(internal_table_objs); + if let Some(source_id) = associated_source_id { + let source_obj = Object::find_by_id(source_id) + .select_only() + .column(object::Column::ObjType) + .into_partial_model() + .one(&txn) + .await? + .ok_or_else(|| MetaError::catalog_id_not_found("source", source_id))?; + objs.push(source_obj); + } + let relation_group = build_relation_group(objs); + + // Can delete objects after queried notification info Object::delete_by_id(job_id).exec(&txn).await?; if !internal_table_ids.is_empty() { Object::delete_many() @@ -492,6 +579,9 @@ impl CatalogController { } txn.commit().await?; + let _version = self + .notify_frontend(Operation::Delete, relation_group) + .await; Ok(true) } @@ -804,7 +894,7 @@ impl CatalogController { let mut version = self .notify_frontend( - NotificationOperation::Add, + NotificationOperation::Update, NotificationInfo::RelationGroup(PbRelationGroup { relations }), ) .await; diff --git a/src/meta/src/controller/utils.rs b/src/meta/src/controller/utils.rs index 835805766331..7b4fae3c0eca 100644 --- a/src/meta/src/controller/utils.rs +++ b/src/meta/src/controller/utils.rs @@ -27,8 +27,14 @@ use risingwave_meta_model_v2::{ view, worker_property, ActorId, DataTypeArray, DatabaseId, FragmentId, FragmentVnodeMapping, I32Array, ObjectId, PrivilegeId, SchemaId, SourceId, StreamNode, UserId, WorkerId, }; -use risingwave_pb::catalog::{PbConnection, PbFunction, PbSecret, PbSubscription}; -use risingwave_pb::meta::{PbFragmentParallelUnitMapping, PbFragmentWorkerSlotMapping}; +use risingwave_pb::catalog::{ + PbConnection, PbFunction, PbIndex, PbSecret, PbSink, PbSource, PbSubscription, PbTable, PbView, +}; +use risingwave_pb::meta::relation::PbRelationInfo; +use risingwave_pb::meta::subscribe_response::Info as NotificationInfo; +use risingwave_pb::meta::{ + PbFragmentParallelUnitMapping, PbFragmentWorkerSlotMapping, PbRelation, PbRelationGroup, +}; use risingwave_pb::stream_plan::stream_node::NodeBody; use risingwave_pb::stream_plan::{PbFragmentTypeFlag, PbStreamNode, StreamSource}; use risingwave_pb::user::grant_privilege::{PbAction, PbActionWithGrantOption, PbObject}; @@ -44,7 +50,6 @@ use sea_orm::{ use crate::controller::catalog::CatalogController; use crate::{MetaError, MetaResult}; - /// This function will construct a query using recursive cte to find all objects[(id, `obj_type`)] that are used by the given object. /// /// # Examples @@ -1026,3 +1031,71 @@ where Ok(parallel_unit_to_worker) } + +pub(crate) fn build_relation_group(relation_objects: Vec) -> NotificationInfo { + let mut relations = vec![]; + for obj in relation_objects { + match obj.obj_type { + ObjectType::Table => relations.push(PbRelation { + relation_info: Some(PbRelationInfo::Table(PbTable { + id: obj.oid as _, + schema_id: obj.schema_id.unwrap() as _, + database_id: obj.database_id.unwrap() as _, + ..Default::default() + })), + }), + ObjectType::Source => relations.push(PbRelation { + relation_info: Some(PbRelationInfo::Source(PbSource { + id: obj.oid as _, + schema_id: obj.schema_id.unwrap() as _, + database_id: obj.database_id.unwrap() as _, + ..Default::default() + })), + }), + ObjectType::Sink => relations.push(PbRelation { + relation_info: Some(PbRelationInfo::Sink(PbSink { + id: obj.oid as _, + schema_id: obj.schema_id.unwrap() as _, + database_id: obj.database_id.unwrap() as _, + ..Default::default() + })), + }), + ObjectType::Subscription => relations.push(PbRelation { + relation_info: Some(PbRelationInfo::Subscription(PbSubscription { + id: obj.oid as _, + schema_id: obj.schema_id.unwrap() as _, + database_id: obj.database_id.unwrap() as _, + ..Default::default() + })), + }), + ObjectType::View => relations.push(PbRelation { + relation_info: Some(PbRelationInfo::View(PbView { + id: obj.oid as _, + schema_id: obj.schema_id.unwrap() as _, + database_id: obj.database_id.unwrap() as _, + ..Default::default() + })), + }), + ObjectType::Index => { + relations.push(PbRelation { + relation_info: Some(PbRelationInfo::Index(PbIndex { + id: obj.oid as _, + schema_id: obj.schema_id.unwrap() as _, + database_id: obj.database_id.unwrap() as _, + ..Default::default() + })), + }); + relations.push(PbRelation { + relation_info: Some(PbRelationInfo::Table(PbTable { + id: obj.oid as _, + schema_id: obj.schema_id.unwrap() as _, + database_id: obj.database_id.unwrap() as _, + ..Default::default() + })), + }); + } + _ => unreachable!("only relations will be dropped."), + } + } + NotificationInfo::RelationGroup(PbRelationGroup { relations }) +} diff --git a/src/meta/src/manager/streaming_job.rs b/src/meta/src/manager/streaming_job.rs index 5b6826de5421..0aca9eccf897 100644 --- a/src/meta/src/manager/streaming_job.rs +++ b/src/meta/src/manager/streaming_job.rs @@ -19,7 +19,7 @@ use risingwave_common::current_cluster_version; use risingwave_common::util::epoch::Epoch; use risingwave_pb::catalog::{CreateType, Index, PbSource, Sink, Table}; use risingwave_pb::ddl_service::TableJobType; -use strum::EnumDiscriminants; +use strum::{EnumDiscriminants, EnumIs}; use super::{get_refed_secret_ids_from_sink, get_refed_secret_ids_from_source}; use crate::model::FragmentId; @@ -27,7 +27,7 @@ use crate::MetaResult; // This enum is used in order to re-use code in `DdlServiceImpl` for creating MaterializedView and // Sink. -#[derive(Debug, Clone, EnumDiscriminants)] +#[derive(Debug, Clone, EnumDiscriminants, EnumIs)] pub enum StreamingJob { MaterializedView(Table), Sink(Sink, Option<(Table, Option)>), From 45f27dc8ed098b2882a838c0bc1f7922479a383f Mon Sep 17 00:00:00 2001 From: Bohan Zhang Date: Tue, 16 Jul 2024 23:03:34 +0800 Subject: [PATCH 22/70] fix: s3_v2 connector cannot read incremental files (#17702) Signed-off-by: tabVersion --- ci/workflows/main-cron.yml | 28 +++++- e2e_test/s3/fs_source_v2.py | 15 ++-- e2e_test/s3/fs_source_v2_new_file.py | 86 +++++++++++++++++++ .../src/executor/source/list_executor.rs | 59 +++++++------ 4 files changed, 153 insertions(+), 35 deletions(-) create mode 100644 e2e_test/s3/fs_source_v2_new_file.py diff --git a/ci/workflows/main-cron.yml b/ci/workflows/main-cron.yml index 21f26827b4f0..0060335d8504 100644 --- a/ci/workflows/main-cron.yml +++ b/ci/workflows/main-cron.yml @@ -78,9 +78,9 @@ steps: key: "slow-e2e-test-release" command: "ci/scripts/slow-e2e-test.sh -p ci-release -m ci-3streaming-2serving-3fe" if: | - !(build.pull_request.labels includes "ci/main-cron/run-selected") && build.env("CI_STEPS") == null - || build.pull_request.labels includes "ci/run-slow-e2e-tests" - || build.env("CI_STEPS") =~ /(^|,)slow-e2e-tests?(,|$$)/ + !(build.pull_request.labels includes "ci/main-cron/run-selected") && build.env("CI_STEPS") == null + || build.pull_request.labels includes "ci/run-slow-e2e-tests" + || build.env("CI_STEPS") =~ /(^|,)slow-e2e-tests?(,|$$)/ depends_on: - "build" - "build-other" @@ -478,6 +478,28 @@ steps: timeout_in_minutes: 25 retry: *auto-retry + - label: "S3_v2 source new file check on AWS (json)" + key: "s3-v2-source-new-file-check-aws" + command: "ci/scripts/s3-source-test.sh -p ci-release -s fs_source_v2_new_file.py" + if: | + !(build.pull_request.labels includes "ci/main-cron/run-selected") && build.env("CI_STEPS") == null + || build.pull_request.labels includes "ci/run-s3-source-tests" + || build.env("CI_STEPS") =~ /(^|,)s3-source-tests?(,|$$)/ + depends_on: build + plugins: + - seek-oss/aws-sm#v2.3.1: + env: + S3_SOURCE_TEST_CONF: ci_s3_source_test_aws + - docker-compose#v5.1.0: + run: rw-build-env + config: ci/docker-compose.yml + mount-buildkite-agent: true + environment: + - S3_SOURCE_TEST_CONF + - ./ci/plugins/upload-failure-logs + timeout_in_minutes: 25 + retry: *auto-retry + - label: "S3_v2 source check on parquet file" key: "s3-v2-source-check-parquet-file" command: "ci/scripts/s3-source-test.sh -p ci-release -s fs_parquet_source.py" diff --git a/e2e_test/s3/fs_source_v2.py b/e2e_test/s3/fs_source_v2.py index 760b8d07a09a..6706d4b6d4a9 100644 --- a/e2e_test/s3/fs_source_v2.py +++ b/e2e_test/s3/fs_source_v2.py @@ -43,7 +43,7 @@ def format_csv(data, with_header): csv_files.append(ostream.getvalue()) return csv_files -def do_test(config, file_num, item_num_per_file, prefix, fmt): +def do_test(config, file_num, item_num_per_file, prefix, fmt, need_drop_table=True): conn = psycopg2.connect( host="localhost", port="4566", @@ -106,10 +106,16 @@ def _assert_eq(field, got, expect): print('Test pass') - cur.execute(f'drop table {_table()}') + if need_drop_table: + cur.execute(f'drop table {_table()}') cur.close() conn.close() +FORMATTER = { + 'json': format_json, + 'csv_with_header': partial(format_csv, with_header=True), + 'csv_without_header': partial(format_csv, with_header=False), + } if __name__ == "__main__": FILE_NUM = 4001 @@ -117,11 +123,6 @@ def _assert_eq(field, got, expect): data = gen_data(FILE_NUM, ITEM_NUM_PER_FILE) fmt = sys.argv[1] - FORMATTER = { - 'json': format_json, - 'csv_with_header': partial(format_csv, with_header=True), - 'csv_without_header': partial(format_csv, with_header=False), - } assert fmt in FORMATTER, f"Unsupported format: {fmt}" formatted_files = FORMATTER[fmt](data) diff --git a/e2e_test/s3/fs_source_v2_new_file.py b/e2e_test/s3/fs_source_v2_new_file.py new file mode 100644 index 000000000000..a7ae53f3a37d --- /dev/null +++ b/e2e_test/s3/fs_source_v2_new_file.py @@ -0,0 +1,86 @@ +from fs_source_v2 import gen_data, FORMATTER, do_test +import json +import os +import random +import psycopg2 +import time +from minio import Minio + + +def upload_to_s3_bucket(config, minio_client, run_id, files, start_bias): + _local = lambda idx, start_bias: f"data_{idx + start_bias}.{fmt}" + _s3 = lambda idx, start_bias: f"{run_id}_data_{idx + start_bias}.{fmt}" + for idx, file_str in enumerate(files): + with open(_local(idx, start_bias), "w") as f: + f.write(file_str) + os.fsync(f.fileno()) + + minio_client.fput_object( + config["S3_BUCKET"], _s3(idx, start_bias), _local(idx, start_bias) + ) + + +def check_for_new_files(file_num, item_num_per_file, fmt): + conn = psycopg2.connect(host="localhost", port="4566", user="root", database="dev") + + # Open a cursor to execute SQL statements + cur = conn.cursor() + + def _table(): + return f"s3_test_{fmt}" + + total_rows = file_num * item_num_per_file + + MAX_RETRIES = 40 + for retry_no in range(MAX_RETRIES): + cur.execute(f"select count(*) from {_table()}") + result = cur.fetchone() + if result[0] == total_rows: + return True + print( + f"[retry {retry_no}] Now got {result[0]} rows in table, {total_rows} expected, wait 10s" + ) + time.sleep(10) + return False + + +if __name__ == "__main__": + FILE_NUM = 101 + ITEM_NUM_PER_FILE = 2 + data = gen_data(FILE_NUM, ITEM_NUM_PER_FILE) + fmt = "json" + + split_idx = 51 + data_batch1 = data[:split_idx] + data_batch2 = data[split_idx:] + + config = json.loads(os.environ["S3_SOURCE_TEST_CONF"]) + client = Minio( + config["S3_ENDPOINT"], + access_key=config["S3_ACCESS_KEY"], + secret_key=config["S3_SECRET_KEY"], + secure=True, + ) + run_id = str(random.randint(1000, 9999)) + print(f"S3 Source New File Test: run ID: {run_id} to bucket {config['S3_BUCKET']}") + + formatted_batch1 = FORMATTER[fmt](data_batch1) + upload_to_s3_bucket(config, client, run_id, formatted_batch1, 0) + + do_test( + config, len(data_batch1), ITEM_NUM_PER_FILE, run_id, fmt, need_drop_table=False + ) + + formatted_batch2 = FORMATTER[fmt](data_batch2) + upload_to_s3_bucket(config, client, run_id, formatted_batch2, split_idx) + + success_flag = check_for_new_files(FILE_NUM, ITEM_NUM_PER_FILE, fmt) + if success_flag: + print("Test(add new file) pass") + else: + print("Test(add new file) fail") + + _s3 = lambda idx, start_bias: f"{run_id}_data_{idx + start_bias}.{fmt}" + # clean up s3 files + for idx, _ in enumerate(data): + client.remove_object(config["S3_BUCKET"], _s3(idx, 0)) diff --git a/src/stream/src/executor/source/list_executor.rs b/src/stream/src/executor/source/list_executor.rs index 9317cfbc9aa4..25b32c0a0e4b 100644 --- a/src/stream/src/executor/source/list_executor.rs +++ b/src/stream/src/executor/source/list_executor.rs @@ -146,36 +146,45 @@ impl FsListExecutor { yield Message::Barrier(barrier); - while let Some(msg) = stream.next().await { - match msg { - Err(e) => { - tracing::warn!(error = %e.as_report(), "encountered an error, recovering"); - // todo: rebuild stream here - } - Ok(msg) => match msg { - // Barrier arrives. - Either::Left(msg) => match &msg { - Message::Barrier(barrier) => { - if let Some(mutation) = barrier.mutation.as_deref() { - match mutation { - Mutation::Pause => stream.pause_stream(), - Mutation::Resume => stream.resume_stream(), - _ => (), + loop { + // a list file stream never ends, keep list to find if there is any new file. + while let Some(msg) = stream.next().await { + match msg { + Err(e) => { + tracing::warn!(error = %e.as_report(), "encountered an error, recovering"); + stream + .replace_data_stream(self.build_chunked_paginate_stream(&source_desc)?); + } + Ok(msg) => match msg { + // Barrier arrives. + Either::Left(msg) => match &msg { + Message::Barrier(barrier) => { + if let Some(mutation) = barrier.mutation.as_deref() { + match mutation { + Mutation::Pause => stream.pause_stream(), + Mutation::Resume => stream.resume_stream(), + _ => (), + } } - } - // Propagate the barrier. - yield msg; + // Propagate the barrier. + yield msg; + } + // Only barrier can be received. + _ => unreachable!(), + }, + // Chunked FsPage arrives. + Either::Right(chunk) => { + yield Message::Chunk(chunk); } - // Only barrier can be received. - _ => unreachable!(), }, - // Chunked FsPage arrives. - Either::Right(chunk) => { - yield Message::Chunk(chunk); - } - }, + } } + + stream.replace_data_stream( + self.build_chunked_paginate_stream(&source_desc) + .map_err(StreamExecutorError::from)?, + ); } } } From 46ee5674401b27ad4a725981f860ff360137faa8 Mon Sep 17 00:00:00 2001 From: Yuhao Su <31772373+yuhao-su@users.noreply.github.com> Date: Tue, 16 Jul 2024 15:32:20 -0500 Subject: [PATCH 23/70] feat(streaming): introduce nested loop temporal join executor (#16445) --- src/stream/src/executor/mod.rs | 3 +- .../src/executor/nested_loop_temporal_join.rs | 276 ++++++++++++++++++ src/stream/src/executor/temporal_join.rs | 55 ++-- 3 files changed, 308 insertions(+), 26 deletions(-) create mode 100644 src/stream/src/executor/nested_loop_temporal_join.rs diff --git a/src/stream/src/executor/mod.rs b/src/stream/src/executor/mod.rs index c111bf4ee99d..dc63e62b1d58 100644 --- a/src/stream/src/executor/mod.rs +++ b/src/stream/src/executor/mod.rs @@ -77,6 +77,7 @@ mod lookup; mod lookup_union; mod merge; mod mview; +mod nested_loop_temporal_join; mod no_op; mod now; mod over_window; @@ -142,7 +143,7 @@ pub use simple_agg::SimpleAggExecutor; pub use sink::SinkExecutor; pub use sort::*; pub use stateless_simple_agg::StatelessSimpleAggExecutor; -pub use temporal_join::*; +pub use temporal_join::TemporalJoinExecutor; pub use top_n::{ AppendOnlyGroupTopNExecutor, AppendOnlyTopNExecutor, GroupTopNExecutor, TopNExecutor, }; diff --git a/src/stream/src/executor/nested_loop_temporal_join.rs b/src/stream/src/executor/nested_loop_temporal_join.rs new file mode 100644 index 000000000000..0888d8981fc8 --- /dev/null +++ b/src/stream/src/executor/nested_loop_temporal_join.rs @@ -0,0 +1,276 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use futures::StreamExt; +use futures_async_stream::try_stream; +use risingwave_common::array::stream_chunk_builder::StreamChunkBuilder; +use risingwave_common::array::StreamChunk; +use risingwave_common::bitmap::BitmapBuilder; +use risingwave_common::types::DataType; +use risingwave_common::util::iter_util::ZipEqDebug; +use risingwave_expr::expr::NonStrictExpression; +use risingwave_hummock_sdk::{HummockEpoch, HummockReadEpoch}; +use risingwave_storage::store::PrefetchOptions; +use risingwave_storage::table::batch_table::storage_table::StorageTable; +use risingwave_storage::StateStore; + +use super::join::{JoinType, JoinTypePrimitive}; +use super::temporal_join::{align_input, apply_indices_map, phase1, InternalMessage}; +use super::{Execute, ExecutorInfo, Message, StreamExecutorError}; +use crate::common::metrics::MetricsInfo; +use crate::executor::join::builder::JoinStreamChunkBuilder; +use crate::executor::monitor::StreamingMetrics; +use crate::executor::{ActorContextRef, Executor}; + +pub struct NestedLoopTemporalJoinExecutor { + ctx: ActorContextRef, + #[allow(dead_code)] + info: ExecutorInfo, + left: Executor, + right: Executor, + right_table: TemporalSide, + condition: Option, + output_indices: Vec, + chunk_size: usize, + // TODO: update metrics + #[allow(dead_code)] + metrics: Arc, +} + +struct TemporalSide { + source: StorageTable, +} + +impl TemporalSide {} + +#[try_stream(ok = StreamChunk, error = StreamExecutorError)] +#[allow(clippy::too_many_arguments)] +async fn phase1_handle_chunk( + chunk_size: usize, + right_size: usize, + full_schema: Vec, + epoch: HummockEpoch, + right_table: &mut TemporalSide, + chunk: StreamChunk, +) { + let mut builder = StreamChunkBuilder::new(chunk_size, full_schema); + + for (op, left_row) in chunk.rows() { + let mut matched = false; + #[for_await] + for keyed_row in right_table + .source + .batch_iter( + HummockReadEpoch::NoWait(epoch), + false, + PrefetchOptions::prefetch_for_large_range_scan(), + ) + .await? + { + let keyed_row = keyed_row?; + let right_row = keyed_row.row(); + matched = true; + if let Some(chunk) = E::append_matched_row(op, &mut builder, left_row, right_row) { + yield chunk; + } + } + if let Some(chunk) = E::match_end(&mut builder, op, left_row, right_size, matched) { + yield chunk; + } + } + if let Some(chunk) = builder.take() { + yield chunk; + } +} + +impl NestedLoopTemporalJoinExecutor { + #[allow(clippy::too_many_arguments)] + #[expect(dead_code)] + pub fn new( + ctx: ActorContextRef, + info: ExecutorInfo, + left: Executor, + right: Executor, + table: StorageTable, + condition: Option, + output_indices: Vec, + metrics: Arc, + chunk_size: usize, + ) -> Self { + let _metrics_info = MetricsInfo::new( + metrics.clone(), + table.table_id().table_id, + ctx.id, + "nested loop temporal join", + ); + + Self { + ctx: ctx.clone(), + info, + left, + right, + right_table: TemporalSide { source: table }, + condition, + output_indices, + chunk_size, + metrics, + } + } + + #[try_stream(ok = Message, error = StreamExecutorError)] + async fn into_stream(mut self) { + let right_size = self.right.schema().len(); + + let (left_map, _right_map) = JoinStreamChunkBuilder::get_i2o_mapping( + &self.output_indices, + self.left.schema().len(), + right_size, + ); + + let left_to_output: HashMap = HashMap::from_iter(left_map.iter().cloned()); + + let mut prev_epoch = None; + + let full_schema: Vec<_> = self + .left + .schema() + .data_types() + .into_iter() + .chain(self.right.schema().data_types().into_iter()) + .collect(); + + #[for_await] + for msg in align_input::(self.left, self.right) { + match msg? { + InternalMessage::WaterMark(watermark) => { + let output_watermark_col_idx = *left_to_output.get(&watermark.col_idx).unwrap(); + yield Message::Watermark(watermark.with_idx(output_watermark_col_idx)); + } + InternalMessage::Chunk(chunk) => { + let epoch = prev_epoch.expect("Chunk data should come after some barrier."); + + let full_schema = full_schema.clone(); + + if T == JoinType::Inner { + let st1 = phase1_handle_chunk::( + self.chunk_size, + right_size, + full_schema, + epoch, + &mut self.right_table, + chunk, + ); + #[for_await] + for chunk in st1 { + let chunk = chunk?; + let new_chunk = if let Some(ref cond) = self.condition { + let (data_chunk, ops) = chunk.into_parts(); + let passed_bitmap = cond.eval_infallible(&data_chunk).await; + let passed_bitmap = + Arc::unwrap_or_clone(passed_bitmap).into_bool().to_bitmap(); + let (columns, vis) = data_chunk.into_parts(); + let new_vis = vis & passed_bitmap; + StreamChunk::with_visibility(ops, columns, new_vis) + } else { + chunk + }; + let new_chunk = apply_indices_map(new_chunk, &self.output_indices); + yield Message::Chunk(new_chunk); + } + } else if let Some(ref cond) = self.condition { + // Joined result without evaluating non-lookup conditions. + let st1 = phase1_handle_chunk::( + self.chunk_size, + right_size, + full_schema, + epoch, + &mut self.right_table, + chunk, + ); + let mut matched_count = 0usize; + #[for_await] + for chunk in st1 { + let chunk = chunk?; + let (data_chunk, ops) = chunk.into_parts(); + let passed_bitmap = cond.eval_infallible(&data_chunk).await; + let passed_bitmap = + Arc::unwrap_or_clone(passed_bitmap).into_bool().to_bitmap(); + let (columns, vis) = data_chunk.into_parts(); + let mut new_vis = BitmapBuilder::with_capacity(vis.len()); + for (passed, not_match_end) in + passed_bitmap.iter().zip_eq_debug(vis.iter()) + { + let is_match_end = !not_match_end; + let vis = if is_match_end && matched_count == 0 { + // Nothing is matched, so the marker row should be visible. + true + } else if is_match_end { + // reset the count + matched_count = 0; + // rows found, so the marker row should be invisible. + false + } else { + if passed { + matched_count += 1; + } + passed + }; + new_vis.append(vis); + } + let new_chunk = apply_indices_map( + StreamChunk::with_visibility(ops, columns, new_vis.finish()), + &self.output_indices, + ); + yield Message::Chunk(new_chunk); + } + // The last row should always be marker row, + assert_eq!(matched_count, 0); + } else { + let st1 = phase1_handle_chunk::( + self.chunk_size, + right_size, + full_schema, + epoch, + &mut self.right_table, + chunk, + ); + #[for_await] + for chunk in st1 { + let chunk = chunk?; + let new_chunk = apply_indices_map(chunk, &self.output_indices); + yield Message::Chunk(new_chunk); + } + } + } + InternalMessage::Barrier(chunk, barrier) => { + assert!(chunk.is_empty()); + if let Some(vnodes) = barrier.as_update_vnode_bitmap(self.ctx.id) { + let _vnodes = self.right_table.source.update_vnode_bitmap(vnodes.clone()); + } + prev_epoch = Some(barrier.epoch.curr); + yield Message::Barrier(barrier) + } + } + } + } +} + +impl Execute for NestedLoopTemporalJoinExecutor { + fn execute(self: Box) -> super::BoxedMessageStream { + self.into_stream().boxed() + } +} diff --git a/src/stream/src/executor/temporal_join.rs b/src/stream/src/executor/temporal_join.rs index db9f8c850ce6..8a994acb5e8d 100644 --- a/src/stream/src/executor/temporal_join.rs +++ b/src/stream/src/executor/temporal_join.rs @@ -200,7 +200,7 @@ impl TemporalSide { } } -enum InternalMessage { +pub(super) enum InternalMessage { Chunk(StreamChunk), Barrier(Vec, Barrier), WaterMark(Watermark), @@ -245,7 +245,7 @@ async fn internal_messages_until_barrier(stream: impl MessageStream, expected_ba // any number of `InternalMessage::Chunk(left_chunk)` and followed by // `InternalMessage::Barrier(right_chunks, barrier)`. #[try_stream(ok = InternalMessage, error = StreamExecutorError)] -async fn align_input(left: Executor, right: Executor) { +pub(super) async fn align_input(left: Executor, right: Executor) { let mut left = pin!(left.execute()); let mut right = pin!(right.execute()); // Keep producing intervals until stream exhaustion or errors. @@ -260,12 +260,18 @@ async fn align_input(left: Executor, right: Executor) { ); match combined.next().await { Some(Either::Left(Ok(Message::Chunk(c)))) => yield InternalMessage::Chunk(c), - Some(Either::Right(Ok(Message::Chunk(c)))) => right_chunks.push(c), + Some(Either::Right(Ok(Message::Chunk(c)))) => { + if YIELD_RIGHT_CHUNKS { + right_chunks.push(c); + } + } Some(Either::Left(Ok(Message::Barrier(b)))) => { let mut remain = chunks_until_barrier(right.by_ref(), b.clone()) .try_collect() .await?; - right_chunks.append(&mut remain); + if YIELD_RIGHT_CHUNKS { + right_chunks.append(&mut remain); + } yield InternalMessage::Barrier(right_chunks, b); break 'inner; } @@ -292,7 +298,18 @@ async fn align_input(left: Executor, right: Executor) { } } -mod phase1 { +pub(super) fn apply_indices_map(chunk: StreamChunk, indices: &[usize]) -> StreamChunk { + let (data_chunk, ops) = chunk.into_parts(); + let (columns, vis) = data_chunk.into_parts(); + let output_columns = indices + .iter() + .cloned() + .map(|idx| columns[idx].clone()) + .collect(); + StreamChunk::with_visibility(ops, output_columns, vis) +} + +pub(super) mod phase1 { use std::ops::Bound; use futures::{pin_mut, StreamExt}; @@ -310,7 +327,7 @@ mod phase1 { use crate::common::table::state_table::StateTable; use crate::executor::monitor::TemporalJoinMetrics; - pub(super) trait Phase1Evaluation { + pub trait Phase1Evaluation { /// Called when a matched row is found. #[must_use = "consume chunk if produced"] fn append_matched_row( @@ -331,9 +348,9 @@ mod phase1 { ) -> Option; } - pub(super) struct Inner; - pub(super) struct LeftOuter; - pub(super) struct LeftOuterWithCond; + pub struct Inner; + pub struct LeftOuter; + pub struct LeftOuterWithCond; impl Phase1Evaluation for Inner { fn append_matched_row( @@ -635,17 +652,6 @@ impl StreamChunk { - let (data_chunk, ops) = chunk.into_parts(); - let (columns, vis) = data_chunk.into_parts(); - let output_columns = indices - .iter() - .cloned() - .map(|idx| columns[idx].clone()) - .collect(); - StreamChunk::with_visibility(ops, output_columns, vis) - } - #[try_stream(ok = Message, error = StreamExecutorError)] async fn into_stream(mut self) { let right_size = self.right.schema().len(); @@ -682,7 +688,7 @@ impl(self.left, self.right) { self.right_table.cache.evict(); self.metrics .temporal_join_cached_entry_count @@ -725,8 +731,7 @@ impl Date: Wed, 17 Jul 2024 05:10:22 +0800 Subject: [PATCH 24/70] refactor(types): doc & refine ToBinary/Text (#17697) --- src/common/src/types/datetime.rs | 42 +--------------------------- src/common/src/types/decimal.rs | 16 +---------- src/common/src/types/interval.rs | 13 --------- src/common/src/types/mod.rs | 22 +++++++++++---- src/common/src/types/serial.rs | 13 +-------- src/common/src/types/timestamptz.rs | 15 +--------- src/common/src/types/to_binary.rs | 32 +++++++++++---------- src/common/src/types/to_text.rs | 43 +++++++++++++++-------------- src/connector/src/parser/mysql.rs | 4 +-- src/expr/impl/src/scalar/cast.rs | 6 ++-- src/utils/pgwire/src/types.rs | 1 + 11 files changed, 66 insertions(+), 141 deletions(-) diff --git a/src/common/src/types/datetime.rs b/src/common/src/types/datetime.rs index bac96b6c1dea..fac104b3f5aa 100644 --- a/src/common/src/types/datetime.rs +++ b/src/common/src/types/datetime.rs @@ -20,7 +20,7 @@ use std::hash::Hash; use std::io::Write; use std::str::FromStr; -use bytes::{Bytes, BytesMut}; +use bytes::BytesMut; use chrono::{ DateTime, Datelike, Days, Duration, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Weekday, }; @@ -28,7 +28,6 @@ use postgres_types::{accepts, to_sql_checked, FromSql, IsNull, ToSql, Type}; use risingwave_common_estimate_size::ZeroHeapSize; use thiserror::Error; -use super::to_binary::ToBinary; use super::to_text::ToText; use super::{CheckedAdd, DataType, Interval}; use crate::array::{ArrayError, ArrayResult}; @@ -427,45 +426,6 @@ impl ToText for Timestamp { } } -impl ToBinary for Date { - fn to_binary_with_type(&self, ty: &DataType) -> super::to_binary::Result> { - match ty { - super::DataType::Date => { - let mut output = BytesMut::new(); - self.0.to_sql(&Type::ANY, &mut output).unwrap(); - Ok(Some(output.freeze())) - } - _ => unreachable!(), - } - } -} - -impl ToBinary for Time { - fn to_binary_with_type(&self, ty: &DataType) -> super::to_binary::Result> { - match ty { - super::DataType::Time => { - let mut output = BytesMut::new(); - self.0.to_sql(&Type::ANY, &mut output).unwrap(); - Ok(Some(output.freeze())) - } - _ => unreachable!(), - } - } -} - -impl ToBinary for Timestamp { - fn to_binary_with_type(&self, ty: &DataType) -> super::to_binary::Result> { - match ty { - super::DataType::Timestamp => { - let mut output = BytesMut::new(); - self.0.to_sql(&Type::ANY, &mut output).unwrap(); - Ok(Some(output.freeze())) - } - _ => unreachable!(), - } - } -} - impl Date { pub fn with_days(days: i32) -> Result { Ok(Date::new( diff --git a/src/common/src/types/decimal.rs b/src/common/src/types/decimal.rs index b38dbb4822e6..9523157239e2 100644 --- a/src/common/src/types/decimal.rs +++ b/src/common/src/types/decimal.rs @@ -17,7 +17,7 @@ use std::io::{Cursor, Read, Write}; use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; use byteorder::{BigEndian, ReadBytesExt}; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{BufMut, BytesMut}; use num_traits::{ CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Zero, }; @@ -26,7 +26,6 @@ use risingwave_common_estimate_size::ZeroHeapSize; use rust_decimal::prelude::FromStr; use rust_decimal::{Decimal as RustDecimal, Error, MathematicalOps as _, RoundingStrategy}; -use super::to_binary::ToBinary; use super::to_text::ToText; use super::DataType; use crate::array::ArrayResult; @@ -90,19 +89,6 @@ impl Decimal { } } -impl ToBinary for Decimal { - fn to_binary_with_type(&self, ty: &DataType) -> super::to_binary::Result> { - match ty { - DataType::Decimal => { - let mut output = BytesMut::new(); - self.to_sql(&Type::NUMERIC, &mut output).unwrap(); - Ok(Some(output.freeze())) - } - _ => unreachable!(), - } - } -} - impl ToSql for Decimal { accepts!(NUMERIC); diff --git a/src/common/src/types/interval.rs b/src/common/src/types/interval.rs index 495561561e93..d446669d4a94 100644 --- a/src/common/src/types/interval.rs +++ b/src/common/src/types/interval.rs @@ -1181,19 +1181,6 @@ impl<'a> FromSql<'a> for Interval { } } -impl ToBinary for Interval { - fn to_binary_with_type(&self, ty: &DataType) -> super::to_binary::Result> { - match ty { - DataType::Interval => { - let mut output = BytesMut::new(); - self.to_sql(&Type::ANY, &mut output).unwrap(); - Ok(Some(output.freeze())) - } - _ => unreachable!(), - } - } -} - #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum DateTimeField { Year, diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 91bebde846f0..ed424b4f4956 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -89,7 +89,6 @@ pub use self::serial::Serial; pub use self::struct_type::StructType; pub use self::successor::Successor; pub use self::timestamptz::*; -pub use self::to_binary::ToBinary; pub use self::to_text::ToText; pub use self::with_data_type::WithDataType; @@ -505,6 +504,10 @@ macro_rules! scalar_impl_enum { ($( { $variant_name:ident, $suffix_name:ident, $scalar:ty, $scalar_ref:ty } ),*) => { /// `ScalarImpl` embeds all possible scalars in the evaluation framework. /// + /// Note: `ScalarImpl` doesn't contain all information of its `DataType`, + /// so sometimes they need to be used together. + /// e.g., for `Struct`, we don't have the field names in the value. + /// /// See `for_all_variants` for the definition. #[derive(Debug, Clone, PartialEq, Eq, EstimateSize)] pub enum ScalarImpl { @@ -513,6 +516,12 @@ macro_rules! scalar_impl_enum { /// `ScalarRefImpl` embeds all possible scalar references in the evaluation /// framework. + /// + /// Note: `ScalarRefImpl` doesn't contain all information of its `DataType`, + /// so sometimes they need to be used together. + /// e.g., for `Struct`, we don't have the field names in the value. + /// + /// See `for_all_variants` for the definition. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum ScalarRefImpl<'scalar> { $( $variant_name($scalar_ref) ),* @@ -791,7 +800,9 @@ impl From for ScalarImpl { } impl ScalarImpl { - /// Creates a scalar from binary. + /// Creates a scalar from pgwire "BINARY" format. + /// + /// The counterpart of [`to_binary::ToBinary`]. pub fn from_binary(bytes: &Bytes, data_type: &DataType) -> Result { let res = match data_type { DataType::Varchar => Self::Utf8(String::from_sql(&Type::VARCHAR, bytes)?.into()), @@ -827,7 +838,9 @@ impl ScalarImpl { Ok(res) } - /// Creates a scalar from text. + /// Creates a scalar from pgwire "TEXT" format. + /// + /// The counterpart of [`ToText`]. pub fn from_text(s: &str, data_type: &DataType) -> Result { Ok(match data_type { DataType::Boolean => str_to_bool(s)?.into(), @@ -908,9 +921,8 @@ pub fn hash_datum(datum: impl ToDatumRef, state: &mut impl std::hash::Hasher) { } impl ScalarRefImpl<'_> { - /// Encode the scalar to postgresql binary format. - /// The encoder implements encoding using pub fn binary_format(&self, data_type: &DataType) -> to_binary::Result { + use self::to_binary::ToBinary; self.to_binary_with_type(data_type).transpose().unwrap() } diff --git a/src/common/src/types/serial.rs b/src/common/src/types/serial.rs index 5f4ba237ee30..c36fb1eb11ff 100644 --- a/src/common/src/types/serial.rs +++ b/src/common/src/types/serial.rs @@ -24,7 +24,7 @@ use crate::util::row_id::RowId; // Serial is an alias for i64 #[derive(Debug, Copy, Clone, PartialEq, Eq, Ord, PartialOrd, Default, Hash)] -pub struct Serial(i64); +pub struct Serial(pub(crate) i64); impl From for i64 { fn from(value: Serial) -> i64 { @@ -75,17 +75,6 @@ impl crate::types::to_text::ToText for Serial { } } -impl crate::types::to_binary::ToBinary for Serial { - fn to_binary_with_type( - &self, - _ty: &crate::types::DataType, - ) -> super::to_binary::Result> { - let mut output = bytes::BytesMut::new(); - self.0.to_sql(&Type::ANY, &mut output).unwrap(); - Ok(Some(output.freeze())) - } -} - impl ToSql for Serial { accepts!(INT8); diff --git a/src/common/src/types/timestamptz.rs b/src/common/src/types/timestamptz.rs index feafee6e212b..e5e7cb3c71fd 100644 --- a/src/common/src/types/timestamptz.rs +++ b/src/common/src/types/timestamptz.rs @@ -16,14 +16,13 @@ use std::error::Error; use std::io::Write; use std::str::FromStr; -use bytes::{Bytes, BytesMut}; +use bytes::BytesMut; use chrono::{DateTime, Datelike, TimeZone, Utc}; use chrono_tz::Tz; use postgres_types::{accepts, to_sql_checked, FromSql, IsNull, ToSql, Type}; use risingwave_common_estimate_size::ZeroHeapSize; use serde::{Deserialize, Serialize}; -use super::to_binary::ToBinary; use super::to_text::ToText; use super::DataType; use crate::array::ArrayResult; @@ -65,18 +64,6 @@ impl<'a> FromSql<'a> for Timestamptz { } } -impl ToBinary for Timestamptz { - fn to_binary_with_type(&self, _ty: &DataType) -> super::to_binary::Result> { - let instant = self.to_datetime_utc(); - let mut out = BytesMut::new(); - // postgres_types::Type::ANY is only used as a placeholder. - instant - .to_sql(&postgres_types::Type::ANY, &mut out) - .unwrap(); - Ok(Some(out.freeze())) - } -} - impl ToText for Timestamptz { fn write(&self, f: &mut W) -> std::fmt::Result { // Just a meaningful representation as placeholder. The real implementation depends diff --git a/src/common/src/types/to_binary.rs b/src/common/src/types/to_binary.rs index 5ab9fd316dca..56eea301f3f6 100644 --- a/src/common/src/types/to_binary.rs +++ b/src/common/src/types/to_binary.rs @@ -15,7 +15,10 @@ use bytes::{Bytes, BytesMut}; use postgres_types::{ToSql, Type}; -use super::{DataType, DatumRef, ScalarRefImpl, F32, F64}; +use super::{ + DataType, Date, Decimal, Interval, ScalarRefImpl, Serial, Time, Timestamp, Timestamptz, F32, + F64, +}; use crate::error::NotImplemented; /// Error type for [`ToBinary`] trait. @@ -30,14 +33,15 @@ pub enum ToBinaryError { pub type Result = std::result::Result; -// Used to convert ScalarRef to text format +/// Converts `ScalarRef` to pgwire "BINARY" format. +/// +/// [`postgres_types::ToSql`] has similar functionality, and most of our types implement +/// that trait and forward `ToBinary` to it directly. pub trait ToBinary { fn to_binary_with_type(&self, ty: &DataType) -> Result>; } - -// implement use to_sql macro_rules! implement_using_to_sql { - ($({ $scalar_type:ty, $data_type:ident, $accessor:expr } ),*) => { + ($({ $scalar_type:ty, $data_type:ident, $accessor:expr } ),* $(,)?) => { $( impl ToBinary for $scalar_type { fn to_binary_with_type(&self, ty: &DataType) -> Result> { @@ -64,7 +68,14 @@ implement_using_to_sql! { { F32, Float32, |x: &F32| x.0 }, { F64, Float64, |x: &F64| x.0 }, { bool, Boolean, |x| x }, - { &[u8], Bytea, |x| x } + { &[u8], Bytea, |x| x }, + { Time, Time, |x: &Time| x.0 }, + { Date, Date, |x: &Date| x.0 }, + { Timestamp, Timestamp, |x: &Timestamp| x.0 }, + { Decimal, Decimal, |x| x }, + { Interval, Interval, |x| x }, + { Serial, Serial, |x: &Serial| x.0 }, + { Timestamptz, Timestamptz, |x: &Timestamptz| x.to_datetime_utc() } } impl ToBinary for ScalarRefImpl<'_> { @@ -94,12 +105,3 @@ impl ToBinary for ScalarRefImpl<'_> { } } } - -impl ToBinary for DatumRef<'_> { - fn to_binary_with_type(&self, ty: &DataType) -> Result> { - match self { - Some(scalar) => scalar.to_binary_with_type(ty), - None => Ok(None), - } - } -} diff --git a/src/common/src/types/to_text.rs b/src/common/src/types/to_text.rs index ca140b93c37b..166356f1977e 100644 --- a/src/common/src/types/to_text.rs +++ b/src/common/src/types/to_text.rs @@ -18,7 +18,24 @@ use std::num::FpCategory; use super::{DataType, DatumRef, ScalarRefImpl}; use crate::dispatch_scalar_ref_variants; -// Used to convert ScalarRef to text format +/// Converts `ScalarRef` to pgwire "TEXT" format. +/// +/// ## Relationship with casting to varchar +/// +/// For most types, this is also the implementation for casting to varchar, but there are exceptions. +/// e.g., The TEXT format for boolean is `t` / `f` while they cast to varchar `true` / `false`. +/// - +/// - +/// +/// ## Relationship with `ToString`/`Display` +/// +/// For some types, the implementation diverge from Rust's standard `ToString`/`Display`, +/// to match PostgreSQL's representation. +/// +/// --- +/// +/// FIXME: `ToText` should depend on a lot of other stuff +/// but we have not implemented them yet: timezone, date style, interval style, bytea output, etc pub trait ToText { /// Write the text to the writer *regardless* of its data type /// @@ -39,26 +56,10 @@ pub trait ToText { /// text. E.g. for Int64, it will convert to text as a Int64 type. /// We should prefer to use `to_text_with_type` because it's more clear and readable. /// - /// Following is the relationship between scalar and default type: - /// - `ScalarRefImpl::Int16` -> `DataType::Int16` - /// - `ScalarRefImpl::Int32` -> `DataType::Int32` - /// - `ScalarRefImpl::Int64` -> `DataType::Int64` - /// - `ScalarRefImpl::Int256` -> `DataType::Int256` - /// - `ScalarRefImpl::Float32` -> `DataType::Float32` - /// - `ScalarRefImpl::Float64` -> `DataType::Float64` - /// - `ScalarRefImpl::Decimal` -> `DataType::Decimal` - /// - `ScalarRefImpl::Bool` -> `DataType::Boolean` - /// - `ScalarRefImpl::Utf8` -> `DataType::Varchar` - /// - `ScalarRefImpl::Bytea` -> `DataType::Bytea` - /// - `ScalarRefImpl::Date` -> `DataType::Date` - /// - `ScalarRefImpl::Time` -> `DataType::Time` - /// - `ScalarRefImpl::Timestamp` -> `DataType::Timestamp` - /// - `ScalarRefImpl::Timestamptz` -> `DataType::Timestamptz` - /// - `ScalarRefImpl::Interval` -> `DataType::Interval` - /// - `ScalarRefImpl::Jsonb` -> `DataType::Jsonb` - /// - `ScalarRefImpl::List` -> `DataType::List` - /// - `ScalarRefImpl::Struct` -> `DataType::Struct` - /// - `ScalarRefImpl::Serial` -> `DataType::Serial` + /// Note: currently the `DataType` param is actually unnecessary. + /// Previously, Timestamptz is also represented as int64, and we need the data type to distinguish them. + /// Now we have 1-1 mapping, and it happens to be the case that PostgreSQL default `ToText` format does + /// not need additional metadata like field names contained in `DataType`. fn to_text(&self) -> String { let mut s = String::new(); self.write(&mut s).unwrap(); diff --git a/src/connector/src/parser/mysql.rs b/src/connector/src/parser/mysql.rs index d1df27263e80..a28dddc9aa65 100644 --- a/src/connector/src/parser/mysql.rs +++ b/src/connector/src/parser/mysql.rs @@ -149,7 +149,7 @@ mod tests { use mysql_async::Row as MySqlRow; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::row::Row; - use risingwave_common::types::{DataType, ToText}; + use risingwave_common::types::DataType; use tokio_stream::StreamExt; use crate::parser::mysql_row_to_owned_row; @@ -187,7 +187,7 @@ mod tests { let d = owned_row.datum_at(2); if let Some(scalar) = d { let v = scalar.into_timestamptz(); - println!("timestamp: {}", v.to_text()); + println!("timestamp: {:?}", v); } } } diff --git a/src/expr/impl/src/scalar/cast.rs b/src/expr/impl/src/scalar/cast.rs index 48218d720080..0c93c0ed15dd 100644 --- a/src/expr/impl/src/scalar/cast.rs +++ b/src/expr/impl/src/scalar/cast.rs @@ -147,8 +147,8 @@ pub fn int_to_bool(input: i32) -> bool { input != 0 } -// For most of the types, cast them to varchar is similar to return their text format. -// So we use this function to cast type to varchar. +/// For most of the types, cast them to varchar is the same as their pgwire "TEXT" format. +/// So we use `ToText` to cast type to varchar. #[function("cast(*int) -> varchar")] #[function("cast(decimal) -> varchar")] #[function("cast(*float) -> varchar")] @@ -177,7 +177,7 @@ pub fn bool_to_varchar(input: bool, writer: &mut impl Write) { .unwrap(); } -/// `bool_out` is different from `general_to_string` to produce a single char. `PostgreSQL` +/// `bool_out` is different from `cast(boolean) -> varchar` to produce a single char. `PostgreSQL` /// uses different variants of bool-to-string in different situations. #[function("bool_out(boolean) -> varchar")] pub fn bool_out(input: bool, writer: &mut impl Write) { diff --git a/src/utils/pgwire/src/types.rs b/src/utils/pgwire/src/types.rs index d4d37e1168ea..c76aa20aac4c 100644 --- a/src/utils/pgwire/src/types.rs +++ b/src/utils/pgwire/src/types.rs @@ -59,6 +59,7 @@ impl Index for Row { } } +/// #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Format { Binary, From e929c29078e9117158d21f4c67de83d6190fe728 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Wed, 17 Jul 2024 09:28:56 +0800 Subject: [PATCH 25/70] refactor(dyn-filter): refactor dynamic filter for better readability (#17699) Signed-off-by: Richard Chien --- .../plan_node/generic/dynamic_filter.rs | 18 +-- .../plan_node/stream_dynamic_filter.rs | 48 ++++--- src/stream/src/executor/dynamic_filter.rs | 136 ++++++++++-------- 3 files changed, 106 insertions(+), 96 deletions(-) diff --git a/src/frontend/src/optimizer/plan_node/generic/dynamic_filter.rs b/src/frontend/src/optimizer/plan_node/generic/dynamic_filter.rs index 1f6ab5be98b9..4112bd0a60d8 100644 --- a/src/frontend/src/optimizer/plan_node/generic/dynamic_filter.rs +++ b/src/frontend/src/optimizer/plan_node/generic/dynamic_filter.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use fixedbitset::FixedBitSet; use pretty_xmlish::Pretty; use risingwave_common::catalog::Schema; use risingwave_common::util::sort_util::OrderType; @@ -31,9 +30,9 @@ pub struct DynamicFilter { /// The predicate (formed with exactly one of < , <=, >, >=) comparator: ExprType, left_index: usize, - pub left: PlanRef, + left: PlanRef, /// The right input can only have one column. - pub right: PlanRef, + right: PlanRef, } impl DynamicFilter { pub fn comparator(&self) -> ExprType { @@ -108,19 +107,6 @@ impl DynamicFilter { } } - pub fn watermark_columns(&self, right_watermark: bool) -> FixedBitSet { - let mut watermark_columns = FixedBitSet::with_capacity(self.left.schema().len()); - if right_watermark { - match self.comparator { - ExprType::Equal | ExprType::GreaterThan | ExprType::GreaterThanOrEqual => { - watermark_columns.set(self.left_index, true) - } - _ => {} - } - } - watermark_columns - } - fn condition_display(&self) -> (Condition, Schema) { let mut concat_schema = self.left.schema().fields.clone(); concat_schema.extend(self.right.schema().fields.clone()); diff --git a/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs b/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs index a90cd4b77c66..a6c31b9197eb 100644 --- a/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs +++ b/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use fixedbitset::FixedBitSet; use pretty_xmlish::{Pretty, XmlNode}; pub use risingwave_pb::expr::expr_node::Type as ExprType; use risingwave_pb::stream_plan::stream_node::NodeBody; @@ -23,7 +24,7 @@ use super::utils::{ childless_record, column_names_pretty, plan_node_name, watermark_pretty, Distill, }; use super::{generic, ExprRewritable, PlanTreeNodeUnary}; -use crate::expr::{Expr, ExprImpl}; +use crate::expr::Expr; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::{PlanBase, PlanTreeNodeBinary, StreamNode}; use crate::optimizer::property::Distribution; @@ -40,8 +41,6 @@ pub struct StreamDynamicFilter { impl StreamDynamicFilter { pub fn new(core: DynamicFilter) -> Self { - let watermark_columns = core.watermark_columns(core.right().watermark_columns()[0]); - // TODO(st1page): here we just check if RHS // is a `StreamNow`. It will be generalized to more cases // by introducing monotonically increasing property of the node in https://github.com/risingwavelabs/risingwave/pull/13984. @@ -66,17 +65,18 @@ impl StreamDynamicFilter { ExprType::LessThan | ExprType::LessThanOrEqual ); - let append_only = if condition_always_relax { + let out_append_only = if condition_always_relax { core.left().append_only() } else { false }; + let base = PlanBase::new_stream_with_core( &core, core.left().distribution().clone(), - append_only, + out_append_only, false, // TODO(rc): decide EOWC property - watermark_columns, + Self::derive_watermark_columns(&core), ); let cleaned_by_watermark = Self::cleaned_by_watermark(&core); Self { @@ -87,20 +87,34 @@ impl StreamDynamicFilter { } } - pub fn left_index(&self) -> usize { - self.core.left_index() + fn derive_watermark_columns(core: &DynamicFilter) -> FixedBitSet { + let mut res = FixedBitSet::with_capacity(core.left().schema().len()); + let rhs_watermark_columns = core.right().watermark_columns(); + if rhs_watermark_columns.contains(0) { + match core.comparator() { + // We can derive output watermark only if the output is supposed to be always >= rhs. + // While we have to keep in mind that, the propagation of watermark messages from + // the right input must be delayed until `Update`/`Delete`s are sent to downstream, + // otherwise, we will have watermark messages sent before the `Delete` of old rows. + ExprType::GreaterThan | ExprType::GreaterThanOrEqual => { + res.set(core.left_index(), true) + } + _ => {} + } + } + res } - /// 1. Check the comparator. - /// 2. RHS input should only have 1 columns, which is the watermark column. - /// We check that the watermark should be set. - pub fn cleaned_by_watermark(core: &DynamicFilter) -> bool { - let expr = core.predicate(); - if let Some(ExprImpl::FunctionCall(function_call)) = expr.as_expr_unless_true() { - match function_call.func_type() { + fn cleaned_by_watermark(core: &DynamicFilter) -> bool { + let rhs_watermark_columns = core.right().watermark_columns(); + if rhs_watermark_columns.contains(0) { + match core.comparator() { ExprType::GreaterThan | ExprType::GreaterThanOrEqual => { - let rhs_input = core.right(); - rhs_input.watermark_columns().contains(0) + // For >= and >, watermark on rhs means there's no change that rows older than the watermark will + // ever be `Insert`ed again. So, we can clean up the state table. In this case, future lhs inputs + // that are less than the watermark can be safely ignored, and hence watermark can be propagated to + // downstream. See `derive_watermark_columns`. + true } _ => false, } diff --git a/src/stream/src/executor/dynamic_filter.rs b/src/stream/src/executor/dynamic_filter.rs index 541ddeb1a22a..65530f816295 100644 --- a/src/stream/src/executor/dynamic_filter.rs +++ b/src/stream/src/executor/dynamic_filter.rs @@ -25,7 +25,7 @@ use risingwave_common::util::iter_util::ZipEqDebug; use risingwave_expr::expr::{ build_func_non_strict, InputRefExpression, LiteralExpression, NonStrictExpression, }; -use risingwave_pb::expr::expr_node::Type as ExprNodeType; +use risingwave_pb::expr::expr_node::Type as PbExprNodeType; use risingwave_pb::expr::expr_node::Type::{ GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, }; @@ -46,7 +46,7 @@ pub struct DynamicFilterExecutor source_l: Option, source_r: Option, key_l: usize, - comparator: ExprNodeType, + comparator: PbExprNodeType, left_table: WatermarkCacheParameterizedStateTable, right_table: StateTable, metrics: Arc, @@ -68,7 +68,7 @@ impl DynamicFilterExecutor, state_table_r: StateTable, metrics: Arc, @@ -96,13 +96,13 @@ impl DynamicFilterExecutor, + filter_condition: Option, ) -> Result<(Vec, Bitmap), StreamExecutorError> { let mut new_ops = Vec::with_capacity(chunk.capacity()); let mut new_visibility = BitmapBuilder::with_capacity(chunk.capacity()); let mut last_res = false; - let eval_results = if let Some(cond) = condition { + let filter_results = if let Some(cond) = filter_condition { Some(cond.eval_infallible(chunk).await) } else { None @@ -111,11 +111,11 @@ impl DynamicFilterExecutor DynamicFilterExecutor { self.left_table.insert(row); @@ -266,7 +272,8 @@ impl DynamicFilterExecutor DynamicFilterExecutor = recovered_value.clone(); - let mut current_epoch_value: Option = recovered_value; + let recovered_rhs = self.recover_rhs().await?; + let recovered_rhs_value = recovered_rhs.as_ref().map(|r| r[0].clone()); + // At the beginning of an epoch, the `committed_rhs_value` == `staging_rhs_value` + let mut committed_rhs_value: Option = recovered_rhs_value.clone(); + let mut staging_rhs_value: Option = recovered_rhs_value; // This is only required to be some if the row arrived during this epoch. - let mut current_epoch_row = recovered_row.clone(); - let mut last_committed_epoch_row = recovered_row; + let mut committed_rhs_row = recovered_rhs.clone(); + let mut staging_rhs_row = recovered_rhs; // The first barrier message should be propagated. yield Message::Barrier(barrier); @@ -313,22 +320,24 @@ impl DynamicFilterExecutor { // Reuse the logic from `FilterExecutor` let chunk = chunk.compact(); // Is this unnecessary work? - let right_val = prev_epoch_value.clone().flatten(); // The condition is `None` if it is always false by virtue of a NULL right // input, so we save evaluating it on the datachunk - let condition = dynamic_cond(right_val).transpose()?; + let filter_condition = + build_cond(committed_rhs_value.clone().flatten()).transpose()?; - let (new_ops, new_visibility) = self.apply_batch(&chunk, condition).await?; + let (new_ops, new_visibility) = + self.apply_batch(&chunk, filter_condition).await?; let columns = chunk.into_parts().0.into_parts().0; @@ -346,24 +355,24 @@ impl DynamicFilterExecutor { - current_epoch_value = Some(row.datum_at(0).to_owned_datum()); - current_epoch_row = Some(row.into_owned_row()); + staging_rhs_value = Some(row.datum_at(0).to_owned_datum()); + staging_rhs_row = Some(row.into_owned_row()); } - _ => { + Op::UpdateDelete | Op::Delete => { // To be consistent, there must be an existing `current_epoch_value` // equivalent to row indicated for // deletion. if Some(row.datum_at(0)) - != current_epoch_value.as_ref().map(ToDatumRef::to_datum_ref) + != staging_rhs_value.as_ref().map(ToDatumRef::to_datum_ref) { consistency_panic!( - current = ?current_epoch_value, + current = ?staging_rhs_value, to_delete = ?row, "inconsistent delete", ); } - current_epoch_value = None; - current_epoch_row = None; + staging_rhs_value = None; + staging_rhs_row = None; } } } @@ -372,25 +381,27 @@ impl DynamicFilterExecutor { - if watermark_can_clean_state { - unused_clean_hint = Some(watermark.val.clone()); - buffered_right_watermark = Some(watermark); + if self.cleaned_by_watermark { + staging_state_watermark = Some(watermark.val.clone()); + } + if can_propagate_watermark { + watermark_to_propagate = Some(watermark); } } AlignedMessage::Barrier(barrier) => { - // Flush the difference between the `prev_value` and `current_value` + // Commit the staging RHS value. // // This block is guaranteed to be idempotent even if we may encounter multiple - // barriers since `prev_epoch_value` is always be reset to - // the equivalent of `current_epoch_value` at the end of - // this block. Likewise, `last_committed_epoch_row` will always be equal to - // `current_epoch_row`. - // It is thus guaranteed not to commit state or produce chunks as long as - // no new chunks have arrived since the previous barrier. - let curr: Datum = current_epoch_value.clone().flatten(); - let prev: Datum = prev_epoch_value.flatten(); + // barriers since `committed_rhs_value` is always be reset to the equivalent of + // `staging_rhs_value` at the end of this block. Likewise, `committed_rhs_row` + // will always be equal to `staging_rhs_row`. It is thus guaranteed not to commit + // state or produce chunks as long as no new chunks have arrived since the previous + // barrier. + let curr: Datum = staging_rhs_value.clone().flatten(); + let prev: Datum = committed_rhs_value.flatten(); if prev != curr { let (range, _latest_is_lower, is_insert) = self.get_range(&curr, prev); + if !is_insert && self.condition_always_relax { bail!("The optimizer inferred that the right side's change always make the condition more relaxed.\ But the right changes make the conditions stricter."); @@ -415,9 +426,11 @@ impl DynamicFilterExecutor DynamicFilterExecutor (MessageSender, MessageSender, BoxedMessageStream) { let mem_state = MemoryStateStore::new(); create_executor_inner(comparator, mem_state, false).await } async fn create_executor_inner( - comparator: ExprNodeType, + comparator: PbExprNodeType, mem_state: MemoryStateStore, always_relax: bool, ) -> (MessageSender, MessageSender, BoxedMessageStream) { @@ -600,7 +610,7 @@ mod tests { ); let mem_state = MemoryStateStore::new(); let (mut tx_l, mut tx_r, mut dynamic_filter) = - create_executor_inner(ExprNodeType::GreaterThan, mem_state.clone(), false).await; + create_executor_inner(PbExprNodeType::GreaterThan, mem_state.clone(), false).await; // push the init barrier for left and right tx_l.push_barrier(test_epoch(1), false); @@ -623,7 +633,7 @@ mod tests { // Recover executor from state store let (mut tx_l, mut tx_r, mut dynamic_filter) = - create_executor_inner(ExprNodeType::GreaterThan, mem_state.clone(), false).await; + create_executor_inner(PbExprNodeType::GreaterThan, mem_state.clone(), false).await; // push the recovery barrier for left and right tx_l.push_barrier(test_epoch(2), false); @@ -670,7 +680,7 @@ mod tests { // Recover executor from state store let (mut tx_l, mut tx_r, mut dynamic_filter) = - create_executor_inner(ExprNodeType::GreaterThan, mem_state.clone(), false).await; + create_executor_inner(PbExprNodeType::GreaterThan, mem_state.clone(), false).await; // push recovery barrier tx_l.push_barrier(test_epoch(3), false); @@ -756,7 +766,7 @@ mod tests { + 4", ); let (mut tx_l, mut tx_r, mut dynamic_filter) = - create_executor(ExprNodeType::GreaterThan).await; + create_executor(PbExprNodeType::GreaterThan).await; // push the init barrier for left and right tx_l.push_barrier(test_epoch(1), false); @@ -862,7 +872,7 @@ mod tests { + 5", ); let (mut tx_l, mut tx_r, mut dynamic_filter) = - create_executor(ExprNodeType::GreaterThanOrEqual).await; + create_executor(PbExprNodeType::GreaterThanOrEqual).await; // push the init barrier for left and right tx_l.push_barrier(test_epoch(1), false); @@ -968,7 +978,7 @@ mod tests { + 1", ); let (mut tx_l, mut tx_r, mut dynamic_filter) = - create_executor(ExprNodeType::LessThan).await; + create_executor(PbExprNodeType::LessThan).await; // push the init barrier for left and right tx_l.push_barrier(test_epoch(1), false); @@ -1074,7 +1084,7 @@ mod tests { + 0", ); let (mut tx_l, mut tx_r, mut dynamic_filter) = - create_executor(ExprNodeType::LessThanOrEqual).await; + create_executor(PbExprNodeType::LessThanOrEqual).await; // push the init barrier for left and right tx_l.push_barrier(test_epoch(1), false); @@ -1191,7 +1201,7 @@ mod tests { let mem_state = MemoryStateStore::new(); let (mut tx_l, mut tx_r, mut dynamic_filter) = - create_executor_inner(ExprNodeType::LessThanOrEqual, mem_state.clone(), true).await; + create_executor_inner(PbExprNodeType::LessThanOrEqual, mem_state.clone(), true).await; let column_descs = ColumnDesc::unnamed(ColumnId::new(0), DataType::Int64); let table = StorageTable::for_test( mem_state.clone(), From 5037601c6fba922c5bf9e34e1b094cdf7d1bb624 Mon Sep 17 00:00:00 2001 From: xxchan Date: Wed, 17 Jul 2024 11:09:41 +0800 Subject: [PATCH 26/70] chore: bump typos (#17710) Signed-off-by: xxchan --- .github/workflows/typo.yml | 2 +- .typos.toml | 4 ++++ Makefile.toml | 2 +- src/connector/src/sink/clickhouse.rs | 2 +- src/storage/src/hummock/sstable/block.rs | 2 +- 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/typo.yml b/.github/workflows/typo.yml index bf2163cdb029..50f0ce89ac06 100644 --- a/.github/workflows/typo.yml +++ b/.github/workflows/typo.yml @@ -10,4 +10,4 @@ jobs: uses: actions/checkout@v3 - name: Check spelling of the entire repository - uses: crate-ci/typos@v1.20.4 + uses: crate-ci/typos@v1.23.2 diff --git a/.typos.toml b/.typos.toml index a7b5570bb766..225e69a6997b 100644 --- a/.typos.toml +++ b/.typos.toml @@ -30,4 +30,8 @@ extend-exclude = [ "**/Cargo.toml", "**/go.mod", "**/go.sum", + # https://github.com/risingwavelabs/risingwave/blob/0ce6228df6a4da183ae91146f2cdfff1ca9cc6a7/src/common/src/cast/mod.rs#L30 + # We don't want to fix "fals" here, but may want in other places. + # Ideally, we should just ignore that line: https://github.com/crate-ci/typos/issues/316 + "src/common/src/cast/mod.rs", ] diff --git a/Makefile.toml b/Makefile.toml index 16b5e6393e9a..6c392384f518 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -1138,7 +1138,7 @@ fi private = true category = "RiseDev - Check" description = "Run cargo typos-cli check" -install_crate = { min_version = "1.20.4", crate_name = "typos-cli", binary = "typos", test_arg = [ +install_crate = { min_version = "1.23.2", crate_name = "typos-cli", binary = "typos", test_arg = [ "--help", ], install_command = "binstall" } script = """ diff --git a/src/connector/src/sink/clickhouse.rs b/src/connector/src/sink/clickhouse.rs index bbeb27aa4514..ac4930460ece 100644 --- a/src/connector/src/sink/clickhouse.rs +++ b/src/connector/src/sink/clickhouse.rs @@ -1011,7 +1011,7 @@ pub fn build_fields_name_type_from_schema(schema: &Schema) -> Result Date: Tue, 16 Jul 2024 22:34:46 -0500 Subject: [PATCH 27/70] fix(secret): avoid trying to get the secret_store_private_key when there is no secret (#17712) --- src/meta/service/src/notification_service.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/meta/service/src/notification_service.rs b/src/meta/service/src/notification_service.rs index b2b78ba28f19..5f0abbc8fe0e 100644 --- a/src/meta/service/src/notification_service.rs +++ b/src/meta/service/src/notification_service.rs @@ -168,6 +168,10 @@ impl NotificationServiceImpl { } fn decrypt_secrets(&self, secrets: Vec) -> MetaResult> { + // Skip getting `secret_store_private_key` if there is no secret + if secrets.is_empty() { + return Ok(vec![]); + } let secret_store_private_key = self .env .opts From 89148356dab7941207c56375b34d3a011518306b Mon Sep 17 00:00:00 2001 From: Yuhao Su <31772373+yuhao-su@users.noreply.github.com> Date: Tue, 16 Jul 2024 23:47:59 -0500 Subject: [PATCH 28/70] refactor(sink): use `JsonEncoderConfig` for json encoder (#17706) --- src/connector/src/sink/encoder/json.rs | 278 +++++++++++-------------- 1 file changed, 124 insertions(+), 154 deletions(-) diff --git a/src/connector/src/sink/encoder/json.rs b/src/connector/src/sink/encoder/json.rs index 474811660970..d657b8b9c585 100644 --- a/src/connector/src/sink/encoder/json.rs +++ b/src/connector/src/sink/encoder/json.rs @@ -35,15 +35,19 @@ use super::{ }; use crate::sink::SinkError; -pub struct JsonEncoder { - schema: Schema, - col_indices: Option>, +pub struct JsonEncoderConfig { time_handling_mode: TimeHandlingMode, date_handling_mode: DateHandlingMode, timestamp_handling_mode: TimestampHandlingMode, timestamptz_handling_mode: TimestamptzHandlingMode, custom_json_type: CustomJsonType, +} + +pub struct JsonEncoder { + schema: Schema, + col_indices: Option>, kafka_connect: Option, + config: JsonEncoderConfig, } impl JsonEncoder { @@ -55,28 +59,34 @@ impl JsonEncoder { timestamptz_handling_mode: TimestamptzHandlingMode, time_handling_mode: TimeHandlingMode, ) -> Self { - Self { - schema, - col_indices, + let config = JsonEncoderConfig { time_handling_mode, date_handling_mode, timestamp_handling_mode, timestamptz_handling_mode, custom_json_type: CustomJsonType::None, + }; + Self { + schema, + col_indices, kafka_connect: None, + config, } } pub fn new_with_es(schema: Schema, col_indices: Option>) -> Self { - Self { - schema, - col_indices, + let config = JsonEncoderConfig { time_handling_mode: TimeHandlingMode::String, date_handling_mode: DateHandlingMode::String, timestamp_handling_mode: TimestampHandlingMode::String, timestamptz_handling_mode: TimestamptzHandlingMode::UtcWithoutSuffix, custom_json_type: CustomJsonType::Es, + }; + Self { + schema, + col_indices, kafka_connect: None, + config, } } @@ -85,28 +95,34 @@ impl JsonEncoder { col_indices: Option>, map: HashMap, ) -> Self { - Self { - schema, - col_indices, + let config = JsonEncoderConfig { time_handling_mode: TimeHandlingMode::Milli, date_handling_mode: DateHandlingMode::String, timestamp_handling_mode: TimestampHandlingMode::String, timestamptz_handling_mode: TimestamptzHandlingMode::UtcWithoutSuffix, custom_json_type: CustomJsonType::Doris(map), + }; + Self { + schema, + col_indices, kafka_connect: None, + config, } } pub fn new_with_starrocks(schema: Schema, col_indices: Option>) -> Self { - Self { - schema, - col_indices, + let config = JsonEncoderConfig { time_handling_mode: TimeHandlingMode::Milli, date_handling_mode: DateHandlingMode::String, timestamp_handling_mode: TimestampHandlingMode::String, timestamptz_handling_mode: TimestamptzHandlingMode::UtcWithoutSuffix, custom_json_type: CustomJsonType::StarRocks, + }; + Self { + schema, + col_indices, kafka_connect: None, + config, } } @@ -139,16 +155,8 @@ impl RowEncoder for JsonEncoder { for idx in &col_indices { let field = &self.schema[*idx]; let key = field.name.clone(); - let value = datum_to_json_object( - field, - row.datum_at(*idx), - self.date_handling_mode, - self.timestamp_handling_mode, - self.timestamptz_handling_mode, - self.time_handling_mode, - &self.custom_json_type, - ) - .map_err(|e| SinkError::Encode(e.to_report_string()))?; + let value = datum_to_json_object(field, row.datum_at(*idx), &self.config) + .map_err(|e| SinkError::Encode(e.to_report_string()))?; mappings.insert(key, value); } @@ -179,11 +187,7 @@ impl SerTo for Value { fn datum_to_json_object( field: &Field, datum: DatumRef<'_>, - date_handling_mode: DateHandlingMode, - timestamp_handling_mode: TimestampHandlingMode, - timestamptz_handling_mode: TimestamptzHandlingMode, - time_handling_mode: TimeHandlingMode, - custom_json_type: &CustomJsonType, + config: &JsonEncoderConfig, ) -> ArrayResult { let scalar_ref = match datum { None => { @@ -223,7 +227,7 @@ fn datum_to_json_object( json!(v) } // Doris/Starrocks will convert out-of-bounds decimal and -INF, INF, NAN to NULL - (DataType::Decimal, ScalarRefImpl::Decimal(mut v)) => match custom_json_type { + (DataType::Decimal, ScalarRefImpl::Decimal(mut v)) => match &config.custom_json_type { CustomJsonType::Doris(map) => { let s = map.get(&field.name).unwrap(); v.rescale(*s as u32); @@ -233,21 +237,23 @@ fn datum_to_json_object( json!(v.to_text()) } }, - (DataType::Timestamptz, ScalarRefImpl::Timestamptz(v)) => match timestamptz_handling_mode { - TimestamptzHandlingMode::UtcString => { - let parsed = v.to_datetime_utc(); - let v = parsed.to_rfc3339_opts(chrono::SecondsFormat::Micros, true); - json!(v) - } - TimestamptzHandlingMode::UtcWithoutSuffix => { - let parsed = v.to_datetime_utc().naive_utc(); - let v = parsed.format("%Y-%m-%d %H:%M:%S%.6f").to_string(); - json!(v) + (DataType::Timestamptz, ScalarRefImpl::Timestamptz(v)) => { + match config.timestamptz_handling_mode { + TimestamptzHandlingMode::UtcString => { + let parsed = v.to_datetime_utc(); + let v = parsed.to_rfc3339_opts(chrono::SecondsFormat::Micros, true); + json!(v) + } + TimestamptzHandlingMode::UtcWithoutSuffix => { + let parsed = v.to_datetime_utc().naive_utc(); + let v = parsed.format("%Y-%m-%d %H:%M:%S%.6f").to_string(); + json!(v) + } + TimestamptzHandlingMode::Micro => json!(v.timestamp_micros()), + TimestamptzHandlingMode::Milli => json!(v.timestamp_millis()), } - TimestamptzHandlingMode::Micro => json!(v.timestamp_micros()), - TimestamptzHandlingMode::Milli => json!(v.timestamp_millis()), - }, - (DataType::Time, ScalarRefImpl::Time(v)) => match time_handling_mode { + } + (DataType::Time, ScalarRefImpl::Time(v)) => match config.time_handling_mode { TimeHandlingMode::Milli => { // todo: just ignore the nanos part to avoid leap second complex json!(v.0.num_seconds_from_midnight() as i64 * 1000) @@ -257,7 +263,7 @@ fn datum_to_json_object( json!(a) } }, - (DataType::Date, ScalarRefImpl::Date(v)) => match date_handling_mode { + (DataType::Date, ScalarRefImpl::Date(v)) => match config.date_handling_mode { DateHandlingMode::FromCe => json!(v.0.num_days_from_ce()), DateHandlingMode::FromEpoch => { let duration = v.0 - NaiveDateTime::UNIX_EPOCH.date(); @@ -268,10 +274,14 @@ fn datum_to_json_object( json!(a) } }, - (DataType::Timestamp, ScalarRefImpl::Timestamp(v)) => match timestamp_handling_mode { - TimestampHandlingMode::Milli => json!(v.0.and_utc().timestamp_millis()), - TimestampHandlingMode::String => json!(v.0.format("%Y-%m-%d %H:%M:%S%.6f").to_string()), - }, + (DataType::Timestamp, ScalarRefImpl::Timestamp(v)) => { + match config.timestamp_handling_mode { + TimestampHandlingMode::Milli => json!(v.0.and_utc().timestamp_millis()), + TimestampHandlingMode::String => { + json!(v.0.format("%Y-%m-%d %H:%M:%S%.6f").to_string()) + } + } + } (DataType::Bytea, ScalarRefImpl::Bytea(v)) => { json!(general_purpose::STANDARD.encode(v)) } @@ -279,7 +289,7 @@ fn datum_to_json_object( (DataType::Interval, ScalarRefImpl::Interval(v)) => { json!(v.as_iso_8601()) } - (DataType::Jsonb, ScalarRefImpl::Jsonb(jsonb_ref)) => match custom_json_type { + (DataType::Jsonb, ScalarRefImpl::Jsonb(jsonb_ref)) => match &config.custom_json_type { CustomJsonType::Es | CustomJsonType::StarRocks => JsonbVal::from(jsonb_ref).take(), CustomJsonType::Doris(_) | CustomJsonType::None => { json!(jsonb_ref.to_string()) @@ -290,21 +300,13 @@ fn datum_to_json_object( let mut vec = Vec::with_capacity(elems.len()); let inner_field = Field::unnamed(Box::::into_inner(datatype)); for sub_datum_ref in elems { - let value = datum_to_json_object( - &inner_field, - sub_datum_ref, - date_handling_mode, - timestamp_handling_mode, - timestamptz_handling_mode, - time_handling_mode, - custom_json_type, - )?; + let value = datum_to_json_object(&inner_field, sub_datum_ref, config)?; vec.push(value); } json!(vec) } (DataType::Struct(st), ScalarRefImpl::Struct(struct_ref)) => { - match custom_json_type { + match config.custom_json_type { CustomJsonType::Doris(_) => { // We need to ensure that the order of elements in the json matches the insertion order. let mut map = IndexMap::with_capacity(st.len()); @@ -312,15 +314,7 @@ fn datum_to_json_object( st.iter() .map(|(name, dt)| Field::with_name(dt.clone(), name)), ) { - let value = datum_to_json_object( - &sub_field, - sub_datum_ref, - date_handling_mode, - timestamp_handling_mode, - timestamptz_handling_mode, - time_handling_mode, - custom_json_type, - )?; + let value = datum_to_json_object(&sub_field, sub_datum_ref, config)?; map.insert(sub_field.name.clone(), value); } Value::String( @@ -338,15 +332,7 @@ fn datum_to_json_object( st.iter() .map(|(name, dt)| Field::with_name(dt.clone(), name)), ) { - let value = datum_to_json_object( - &sub_field, - sub_datum_ref, - date_handling_mode, - timestamp_handling_mode, - timestamptz_handling_mode, - time_handling_mode, - custom_json_type, - )?; + let value = datum_to_json_object(&sub_field, sub_datum_ref, config)?; map.insert(sub_field.name.clone(), value); } json!(map) @@ -454,17 +440,21 @@ mod tests { type_name: Default::default(), }; + let config = JsonEncoderConfig { + time_handling_mode: TimeHandlingMode::Milli, + date_handling_mode: DateHandlingMode::FromCe, + timestamp_handling_mode: TimestampHandlingMode::String, + timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, + custom_json_type: CustomJsonType::None, + }; + let boolean_value = datum_to_json_object( &Field { data_type: DataType::Boolean, ..mock_field.clone() }, Some(ScalarImpl::Bool(false).as_scalar_ref_impl()), - DateHandlingMode::FromCe, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::None, + &config, ) .unwrap(); assert_eq!(boolean_value, json!(false)); @@ -475,11 +465,7 @@ mod tests { ..mock_field.clone() }, Some(ScalarImpl::Int16(16).as_scalar_ref_impl()), - DateHandlingMode::FromCe, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::None, + &config, ) .unwrap(); assert_eq!(int16_value, json!(16)); @@ -490,11 +476,7 @@ mod tests { ..mock_field.clone() }, Some(ScalarImpl::Int64(i64::MAX).as_scalar_ref_impl()), - DateHandlingMode::FromCe, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::None, + &config, ) .unwrap(); assert_eq!( @@ -508,11 +490,7 @@ mod tests { ..mock_field.clone() }, Some(ScalarImpl::Serial(i64::MAX.into()).as_scalar_ref_impl()), - DateHandlingMode::FromCe, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::None, + &config, ) .unwrap(); assert_eq!( @@ -528,15 +506,19 @@ mod tests { ..mock_field.clone() }, Some(ScalarImpl::Timestamptz(tstz_inner).as_scalar_ref_impl()), - DateHandlingMode::FromCe, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::None, + &config, ) .unwrap(); assert_eq!(tstz_value, "2018-01-26T18:30:09.453000Z"); + let unix_wo_suffix_config = JsonEncoderConfig { + time_handling_mode: TimeHandlingMode::Milli, + date_handling_mode: DateHandlingMode::FromCe, + timestamp_handling_mode: TimestampHandlingMode::String, + timestamptz_handling_mode: TimestamptzHandlingMode::UtcWithoutSuffix, + custom_json_type: CustomJsonType::None, + }; + let tstz_inner = "2018-01-26T18:30:09.453Z".parse().unwrap(); let tstz_value = datum_to_json_object( &Field { @@ -544,15 +526,18 @@ mod tests { ..mock_field.clone() }, Some(ScalarImpl::Timestamptz(tstz_inner).as_scalar_ref_impl()), - DateHandlingMode::FromCe, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcWithoutSuffix, - TimeHandlingMode::Milli, - &CustomJsonType::None, + &unix_wo_suffix_config, ) .unwrap(); assert_eq!(tstz_value, "2018-01-26 18:30:09.453000"); + let timestamp_milli_config = JsonEncoderConfig { + time_handling_mode: TimeHandlingMode::String, + date_handling_mode: DateHandlingMode::FromCe, + timestamp_handling_mode: TimestampHandlingMode::Milli, + timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, + custom_json_type: CustomJsonType::None, + }; let ts_value = datum_to_json_object( &Field { data_type: DataType::Timestamp, @@ -562,11 +547,7 @@ mod tests { ScalarImpl::Timestamp(Timestamp::from_timestamp_uncheck(1000, 0)) .as_scalar_ref_impl(), ), - DateHandlingMode::FromCe, - TimestampHandlingMode::Milli, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::None, + ×tamp_milli_config, ) .unwrap(); assert_eq!(ts_value, json!(1000 * 1000)); @@ -580,11 +561,7 @@ mod tests { ScalarImpl::Timestamp(Timestamp::from_timestamp_uncheck(1000, 0)) .as_scalar_ref_impl(), ), - DateHandlingMode::FromCe, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::None, + &config, ) .unwrap(); assert_eq!(ts_value, json!("1970-01-01 00:16:40.000000".to_string())); @@ -599,11 +576,7 @@ mod tests { ScalarImpl::Time(Time::from_num_seconds_from_midnight_uncheck(1000, 0)) .as_scalar_ref_impl(), ), - DateHandlingMode::FromCe, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::None, + &config, ) .unwrap(); assert_eq!(time_value, json!(1000 * 1000)); @@ -617,17 +590,20 @@ mod tests { ScalarImpl::Interval(Interval::from_month_day_usec(13, 2, 1000000)) .as_scalar_ref_impl(), ), - DateHandlingMode::FromCe, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::None, + &config, ) .unwrap(); assert_eq!(interval_value, json!("P1Y1M2DT0H0M1S")); let mut map = HashMap::default(); map.insert("aaa".to_string(), 5_u8); + let doris_config = JsonEncoderConfig { + time_handling_mode: TimeHandlingMode::String, + date_handling_mode: DateHandlingMode::String, + timestamp_handling_mode: TimestampHandlingMode::String, + timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, + custom_json_type: CustomJsonType::Doris(map), + }; let decimal = datum_to_json_object( &Field { data_type: DataType::Decimal, @@ -635,11 +611,7 @@ mod tests { ..mock_field.clone() }, Some(ScalarImpl::Decimal(Decimal::try_from(1.1111111).unwrap()).as_scalar_ref_impl()), - DateHandlingMode::String, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::Doris(map), + &doris_config, ) .unwrap(); assert_eq!(decimal, json!("1.11111")); @@ -650,41 +622,43 @@ mod tests { ..mock_field.clone() }, Some(ScalarImpl::Date(Date::from_ymd_uncheck(1970, 1, 1)).as_scalar_ref_impl()), - DateHandlingMode::FromCe, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::None, + &config, ) .unwrap(); assert_eq!(date_value, json!(719163)); + let from_epoch_config = JsonEncoderConfig { + time_handling_mode: TimeHandlingMode::Milli, + date_handling_mode: DateHandlingMode::FromEpoch, + timestamp_handling_mode: TimestampHandlingMode::String, + timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, + custom_json_type: CustomJsonType::None, + }; let date_value = datum_to_json_object( &Field { data_type: DataType::Date, ..mock_field.clone() }, Some(ScalarImpl::Date(Date::from_ymd_uncheck(1970, 1, 1)).as_scalar_ref_impl()), - DateHandlingMode::FromEpoch, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::None, + &from_epoch_config, ) .unwrap(); assert_eq!(date_value, json!(0)); + let doris_config = JsonEncoderConfig { + time_handling_mode: TimeHandlingMode::String, + date_handling_mode: DateHandlingMode::String, + timestamp_handling_mode: TimestampHandlingMode::String, + timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, + custom_json_type: CustomJsonType::Doris(HashMap::default()), + }; let date_value = datum_to_json_object( &Field { data_type: DataType::Date, ..mock_field.clone() }, Some(ScalarImpl::Date(Date::from_ymd_uncheck(2010, 10, 10)).as_scalar_ref_impl()), - DateHandlingMode::String, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::Doris(HashMap::default()), + &doris_config, ) .unwrap(); assert_eq!(date_value, json!("2010-10-10")); @@ -705,11 +679,7 @@ mod tests { ..mock_field.clone() }, Some(ScalarRefImpl::Struct(StructRef::ValueRef { val: &value })), - DateHandlingMode::String, - TimestampHandlingMode::String, - TimestamptzHandlingMode::UtcString, - TimeHandlingMode::Milli, - &CustomJsonType::Doris(HashMap::default()), + &doris_config, ) .unwrap(); assert_eq!(interval_value, json!("{\"v3\":3,\"v2\":2,\"v1\":1}")); From 05330f62873a2127ea98ca153b885c66d226c368 Mon Sep 17 00:00:00 2001 From: Bohan Zhang Date: Wed, 17 Jul 2024 14:12:19 +0800 Subject: [PATCH 29/70] fix: Kinesis: NextToken and StreamName cannot be provided together (#17687) Signed-off-by: tabVersion Signed-off-by: tabVersion Co-authored-by: tabVersion --- .../src/source/kinesis/enumerator/client.rs | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/connector/src/source/kinesis/enumerator/client.rs b/src/connector/src/source/kinesis/enumerator/client.rs index 840def08f685..423516fa5bd4 100644 --- a/src/connector/src/source/kinesis/enumerator/client.rs +++ b/src/connector/src/source/kinesis/enumerator/client.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::Context as _; +use anyhow::anyhow; use async_trait::async_trait; use aws_sdk_kinesis::types::Shard; use aws_sdk_kinesis::Client as kinesis_client; @@ -52,14 +52,26 @@ impl SplitEnumerator for KinesisSplitEnumerator { let mut shard_collect: Vec = Vec::new(); loop { - let list_shard_output = self - .client - .list_shards() - .set_next_token(next_token) - .stream_name(&self.stream_name) - .send() - .await - .context("failed to list kinesis shards")?; + let mut req = self.client.list_shards(); + if let Some(token) = next_token.take() { + req = req.next_token(token); + } else { + req = req.stream_name(&self.stream_name); + } + + let list_shard_output = match req.send().await { + Ok(output) => output, + Err(e) => { + if let Some(e_inner) = e.as_service_error() + && e_inner.is_expired_next_token_exception() + { + tracing::info!("Kinesis ListShard token expired, retrying..."); + next_token = None; + continue; + } + return Err(anyhow!(e).context("failed to list kinesis shards").into()); + } + }; match list_shard_output.shards { Some(shard) => shard_collect.extend(shard), None => bail!("no shards in stream {}", &self.stream_name), From 42d51538453d3eb751619ec98874022d3239bdf5 Mon Sep 17 00:00:00 2001 From: Noel Kwan <47273164+kwannoel@users.noreply.github.com> Date: Wed, 17 Jul 2024 14:22:27 +0800 Subject: [PATCH 30/70] feat(ci): introduce main-cron bisect (#17596) --- ci/scripts/find-regression.py | 262 ++++++++++++++++++++++++++++++ ci/workflows/main-cron-bisect.yml | 12 ++ ci/workflows/main-cron.yml | 19 ++- docs/dev/src/ci.md | 20 +++ 4 files changed, 307 insertions(+), 6 deletions(-) create mode 100755 ci/scripts/find-regression.py create mode 100644 ci/workflows/main-cron-bisect.yml diff --git a/ci/scripts/find-regression.py b/ci/scripts/find-regression.py new file mode 100755 index 000000000000..e5106c9d892e --- /dev/null +++ b/ci/scripts/find-regression.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 + +import subprocess +import unittest +import os +import sys + +''' +@kwannoel +This script is used to find the commit that introduced a regression in the codebase. +It uses binary search to find the regressed commit. +It works as follows: +1. Use the start (inclusive) and end (exclusive) bounds, find the middle commit. + e.g. given commit 0->1(start)->2->3->4(bad), start will be 1, end will be 4. Then the middle commit is (1+4)//2 = 2 + given commit 0->1(start)->2->3(bad)->4, start will be 1, end will be 3. Then the middle commit is (1+3)//2 = 2 + given commit 0->1(start)->2(bad), start will be 1, end will be 2. Then the middle commit is (1+2)//2 = 1. + given commit 0->1(start,bad), start will be 1, end will be 1. We just return the bad commit (1) immediately. +2. Run the pipeline on the middle commit. +3. If the pipeline fails, the regression is in the first half of the commits. Recurse (start, mid) +4. If the pipeline passes, the regression is in the second half of the commits. Recurse (mid+1, end) +5. If start>=end, return start as the regressed commit. + +We won't run the entire pipeline, only steps specified by the CI_STEPS environment variable. + +For step (2), we need to check its outcome and only run the next step, if the outcome is successful. +''' + + +def format_step(env): + commit = get_bisect_commit(env["GOOD_COMMIT"], env["BAD_COMMIT"]) + step = f''' +cat <<- YAML | buildkite-agent pipeline upload +steps: + - label: "run-{commit}" + key: "run-{commit}" + trigger: "main-cron" + soft_fail: true + build: + branch: {env["BISECT_BRANCH"]} + commit: {commit} + env: + CI_STEPS: {env['CI_STEPS']} + - wait + - label: 'check' + command: | + GOOD_COMMIT={env['GOOD_COMMIT']} BAD_COMMIT={env['BAD_COMMIT']} BISECT_BRANCH={env['BISECT_BRANCH']} CI_STEPS=\'{env['CI_STEPS']}\' ci/scripts/find-regression.py check +YAML''' + return step + + +def report_step(commit): + step = f''' +cat <<- YAML | buildkite-agent pipeline upload +steps: + - label: "Regressed Commit: {commit}" + command: "echo 'Regressed Commit: {commit}'" +YAML''' + print(f"--- reporting regression commit: {commit}") + result = subprocess.run(step, shell=True) + if result.returncode != 0: + print(f"stderr: {result.stderr}") + print(f"stdout: {result.stdout}") + sys.exit(1) + + +# Triggers a buildkite job to run the pipeline on the given commit, with the specified tests. +def run_pipeline(env): + step = format_step(env) + print(f"--- running upload pipeline for step\n{step}") + result = subprocess.run(step, shell=True) + if result.returncode != 0: + print(f"stderr: {result.stderr}") + print(f"stdout: {result.stdout}") + sys.exit(1) + + +# Number of commits for [start, end) +def get_number_of_commits(start, end): + cmd = f"git rev-list --count {start}..{end}" + result = subprocess.run([cmd], shell=True, capture_output=True, text=True) + if result.returncode != 0: + print(f"stderr: {result.stderr}") + print(f"stdout: {result.stdout}") + sys.exit(1) + return int(result.stdout) + + +def get_bisect_commit(start, end): + number_of_commits = get_number_of_commits(start, end) + commit_offset = number_of_commits // 2 + if commit_offset == 0: + return start + + cmd = f"git rev-list --reverse {start}..{end} | head -n {commit_offset} | tail -n 1" + result = subprocess.run([cmd], shell=True, capture_output=True, text=True) + if result.returncode != 0: + print(f"stderr: {result.stderr}") + print(f"stdout: {result.stdout}") + sys.exit(1) + return result.stdout.strip() + + +def get_commit_after(branch, commit): + cmd = f"git log --reverse --ancestry-path {commit}..origin/{branch} --format=%H | head -n 1" + result = subprocess.run([cmd], shell=True, capture_output=True, text=True) + + if result.returncode != 0: + print(f"stderr: {result.stderr}") + print(f"stdout: {result.stdout}") + sys.exit(1) + + return result.stdout.strip() + + +def get_env(): + env = { + "GOOD_COMMIT": os.environ['GOOD_COMMIT'], + "BAD_COMMIT": os.environ['BAD_COMMIT'], + "BISECT_BRANCH": os.environ['BISECT_BRANCH'], + "CI_STEPS": os.environ['CI_STEPS'], + } + + print(f''' +GOOD_COMMIT={env["GOOD_COMMIT"]} +BAD_COMMIT={env["BAD_COMMIT"]} +BISECT_BRANCH={env["BISECT_BRANCH"]} +CI_STEPS={env["CI_STEPS"]} + ''') + + return env + + +def fetch_branch_commits(branch): + cmd = f"git fetch -q origin {branch}" + result = subprocess.run([cmd], shell=True) + if result.returncode != 0: + print(f"stderr: {result.stderr}") + print(f"stdout: {result.stdout}") + sys.exit(1) + + +def main(): + cmd = sys.argv[1] + + if cmd == "start": + print("--- start bisecting") + env = get_env() + fetch_branch_commits(env["BISECT_BRANCH"]) + run_pipeline(env) + elif cmd == "check": + print("--- check pipeline outcome") + env = get_env() + fetch_branch_commits(env["BISECT_BRANCH"]) + commit = get_bisect_commit(env["GOOD_COMMIT"], env["BAD_COMMIT"]) + step = f"run-{commit}" + cmd = f"buildkite-agent step get outcome --step {step}" + outcome = subprocess.run(cmd, shell=True, capture_output=True, text=True) + + if outcome.returncode != 0: + print(f"stderr: {outcome.stderr}") + print(f"stdout: {outcome.stdout}") + sys.exit(1) + + outcome = outcome.stdout.strip() + if outcome == "soft_failed": + print(f"commit failed: {commit}") + env["BAD_COMMIT"] = commit + elif outcome == "passed": + print(f"commit passed: {commit}") + env["GOOD_COMMIT"] = get_commit_after(env["BISECT_BRANCH"], commit) + else: + print(f"invalid outcome: {outcome}") + sys.exit(1) + + if env["GOOD_COMMIT"] == env["BAD_COMMIT"]: + report_step(env["GOOD_COMMIT"]) + return + else: + print(f"run next iteration, start: {env['GOOD_COMMIT']}, end: {env['BAD_COMMIT']}") + run_pipeline(env) + else: + print(f"invalid cmd: {cmd}") + sys.exit(1) + + +# For the tests, we use RisingWave's sequence of commits, from earliest to latest: +# 617d23ddcac88ced87b96a2454c9217da0fe7915 +# 72f70960226680e841a8fbdd09c79d74609f27a2 +# 5c7b556ea60d136c5bccf1b1f7e313d2f9c79ef0 +# 9ca415a9998a5e04e021c899fb66d93a17931d4f +class Test(unittest.TestCase): + def test_get_commit_after(self): + fetch_branch_commits("kwannoel/find-regress") + commit = get_commit_after("kwannoel/find-regress", "72f70960226680e841a8fbdd09c79d74609f27a2") + self.assertEqual(commit, "5c7b556ea60d136c5bccf1b1f7e313d2f9c79ef0") + commit2 = get_commit_after("kwannoel/find-regress", "617d23ddcac88ced87b96a2454c9217da0fe7915") + self.assertEqual(commit2, "72f70960226680e841a8fbdd09c79d74609f27a2") + commit3 = get_commit_after("kwannoel/find-regress", "5c7b556ea60d136c5bccf1b1f7e313d2f9c79ef0") + self.assertEqual(commit3, "9ca415a9998a5e04e021c899fb66d93a17931d4f") + + def test_get_number_of_commits(self): + fetch_branch_commits("kwannoel/find-regress") + n = get_number_of_commits("72f70960226680e841a8fbdd09c79d74609f27a2", + "9ca415a9998a5e04e021c899fb66d93a17931d4f") + self.assertEqual(n, 2) + n2 = get_number_of_commits("617d23ddcac88ced87b96a2454c9217da0fe7915", + "9ca415a9998a5e04e021c899fb66d93a17931d4f") + self.assertEqual(n2, 3) + n3 = get_number_of_commits("72f70960226680e841a8fbdd09c79d74609f27a2", + "5c7b556ea60d136c5bccf1b1f7e313d2f9c79ef0") + self.assertEqual(n3, 1) + + def test_get_bisect_commit(self): + fetch_branch_commits("kwannoel/find-regress") + commit = get_bisect_commit("72f70960226680e841a8fbdd09c79d74609f27a2", + "9ca415a9998a5e04e021c899fb66d93a17931d4f") + self.assertEqual(commit, "5c7b556ea60d136c5bccf1b1f7e313d2f9c79ef0") + commit2 = get_bisect_commit("617d23ddcac88ced87b96a2454c9217da0fe7915", + "9ca415a9998a5e04e021c899fb66d93a17931d4f") + self.assertEqual(commit2, "72f70960226680e841a8fbdd09c79d74609f27a2") + commit3 = get_bisect_commit("72f70960226680e841a8fbdd09c79d74609f27a2", + "5c7b556ea60d136c5bccf1b1f7e313d2f9c79ef0") + self.assertEqual(commit3, "72f70960226680e841a8fbdd09c79d74609f27a2") + + def test_format_step(self): + fetch_branch_commits("kwannoel/find-regress") + self.maxDiff = None + env = { + "GOOD_COMMIT": "72f70960226680e841a8fbdd09c79d74609f27a2", + "BAD_COMMIT": "9ca415a9998a5e04e021c899fb66d93a17931d4f", + "BISECT_BRANCH": "kwannoel/find-regress", + "CI_STEPS": "test" + } + step = format_step(env) + self.assertEqual( + step, + ''' +cat <<- YAML | buildkite-agent pipeline upload +steps: + - label: "run-5c7b556ea60d136c5bccf1b1f7e313d2f9c79ef0" + key: "run-5c7b556ea60d136c5bccf1b1f7e313d2f9c79ef0" + trigger: "main-cron" + soft_fail: true + build: + branch: kwannoel/find-regress + commit: 5c7b556ea60d136c5bccf1b1f7e313d2f9c79ef0 + env: + CI_STEPS: test + - wait + - label: 'check' + command: | + GOOD_COMMIT=72f70960226680e841a8fbdd09c79d74609f27a2 BAD_COMMIT=9ca415a9998a5e04e021c899fb66d93a17931d4f BISECT_BRANCH=kwannoel/find-regress CI_STEPS='test' ci/scripts/find-regression.py check +YAML''' + ) + + +if __name__ == "__main__": + # You can run tests by just doing ./ci/scripts/find-regression.py + if len(sys.argv) == 1: + unittest.main() + else: + main() diff --git a/ci/workflows/main-cron-bisect.yml b/ci/workflows/main-cron-bisect.yml new file mode 100644 index 000000000000..ab8cebd234d7 --- /dev/null +++ b/ci/workflows/main-cron-bisect.yml @@ -0,0 +1,12 @@ +auto-retry: &auto-retry + automatic: + # Agent terminated because the AWS EC2 spot instance killed by AWS. + - signal_reason: agent_stop + limit: 3 + +steps: + - label: "find regressed step" + key: "find-regressed-step" + command: "GOOD_COMMIT=$GOOD_COMMIT BAD_COMMIT=$BAD_COMMIT BISECT_BRANCH=$BISECT_BRANCH CI_STEPS=$CI_STEPS ci/scripts/find-regression.py start" + if: build.env("CI_STEPS") != null + retry: *auto-retry diff --git a/ci/workflows/main-cron.yml b/ci/workflows/main-cron.yml index 0060335d8504..3c71be0f0984 100644 --- a/ci/workflows/main-cron.yml +++ b/ci/workflows/main-cron.yml @@ -8,6 +8,8 @@ steps: - label: "build" command: "ci/scripts/build.sh -p ci-release" key: "build" + if: | + build.env("CI_STEPS") !~ /(^|,)disable-build(,|$$)/ plugins: - docker-compose#v5.1.0: run: rw-build-env @@ -19,6 +21,8 @@ steps: - label: "build other components" command: "ci/scripts/build-other.sh" key: "build-other" + if: | + build.env("CI_STEPS") !~ /(^|,)disable-build(,|$$)/ plugins: - seek-oss/aws-sm#v2.3.1: env: @@ -35,6 +39,8 @@ steps: - label: "build simulation test" command: "ci/scripts/build-simulation.sh" key: "build-simulation" + if: | + build.env("CI_STEPS") !~ /(^|,)disable-build(,|$$)/ plugins: - docker-compose#v5.1.0: run: rw-build-env @@ -46,6 +52,8 @@ steps: - label: "docslt" command: "ci/scripts/docslt.sh" key: "docslt" + if: | + build.env("CI_STEPS") !~ /(^|,)disable-build(,|$$)/ plugins: - docker-compose#v5.1.0: run: rw-build-env @@ -649,8 +657,7 @@ steps: - label: "upload micro-benchmark" key: "upload-micro-benchmarks" if: | - build.branch == "main" - || !(build.pull_request.labels includes "ci/main-cron/run-selected") && build.env("CI_STEPS") == null + !(build.pull_request.labels includes "ci/main-cron/run-selected") && build.env("CI_STEPS") == null || build.pull_request.labels includes "ci/run-micro-benchmarks" || build.env("CI_STEPS") =~ /(^|,)micro-benchmarks?(,|$$)/ command: @@ -993,7 +1000,7 @@ steps: key: "e2e-mongodb-sink-tests" command: "ci/scripts/e2e-mongodb-sink-test.sh -p ci-release" if: | - !(build.pull_request.labels includes "ci/main-cron/skip-ci") && build.env("CI_STEPS") == null + !(build.pull_request.labels includes "ci/main-cron/run-selected") && build.env("CI_STEPS") == null || build.pull_request.labels includes "ci/run-e2e-mongodb-sink-tests" || build.env("CI_STEPS") =~ /(^|,)e2e-mongodb-sink-tests?(,|$$)/ depends_on: @@ -1119,13 +1126,13 @@ steps: # Notification test. - key: "test-notify" - if: build.pull_request.labels includes "ci/main-cron/test-notify" + if: build.pull_request.labels includes "ci/main-cron/test-notify" || build.env("CI_STEPS") =~ /(^|,)test_notify(,|$$)/ command: | bash -c 'echo test && exit -1' # Notification test. - key: "test-notify-2" - if: build.pull_request.labels includes "ci/main-cron/test-notify" + if: build.pull_request.labels includes "ci/main-cron/test-notify" || build.env("CI_STEPS") =~ /(^|,)test_notify(,|$$)/ command: | bash -c 'echo test && exit -1' @@ -1138,4 +1145,4 @@ steps: # This should be the LAST part of the main-cron file. - label: "trigger failed test notification" if: build.pull_request.labels includes "ci/main-cron/test-notify" || build.branch == "main" - command: "ci/scripts/notify.py" + command: "ci/scripts/notify.py" \ No newline at end of file diff --git a/docs/dev/src/ci.md b/docs/dev/src/ci.md index 840173766055..0f12d893acae 100644 --- a/docs/dev/src/ci.md +++ b/docs/dev/src/ci.md @@ -19,3 +19,23 @@ To run `e2e-test` and `e2e-source-test` for `main-cron` in your pull request: 1. Add `ci/run-e2e-test`. 2. Add `ci/run-e2e-source-tests`. 3. Add `ci/main-cron/run-selected` to skip all other steps which were not selected with `ci/run-xxx`. + +## Main Cron Bisect Guide + +1. Create a new build via buildkite: https://buildkite.com/risingwavelabs/main-cron-bisect/builds/#new +2. Add the following environment variables: + - `GOOD_COMMIT`: The good commit hash. + - `BAD_COMMIT`: The bad commit hash. + - `BISECT_BRANCH`: The branch name where the bisect will be performed. + - `CI_STEPS`: The `CI_STEPS` to run during the bisect. Separate multiple steps with a comma. + - You can check the labels for this in `main-cron.yml`, under the conditions for each step. + +Example you can try on [buildkite](https://buildkite.com/risingwavelabs/main-cron-bisect/builds/#new): +- Branch: `kwannoel/find-regress` +- Environment variables: + ``` + START_COMMIT=29791ddf16fdf2c2e83ad3a58215f434e610f89a + END_COMMIT=7f36bf17c1d19a1e6b2cdb90491d3c08ae8b0004 + BISECT_BRANCH=kwannoel/test-bisect + CI_STEPS="test-bisect,disable-build" + ``` \ No newline at end of file From 30fd4d8ab8995142eed9b185d769791e20e2bdb1 Mon Sep 17 00:00:00 2001 From: Yuhao Su <31772373+yuhao-su@users.noreply.github.com> Date: Wed, 17 Jul 2024 01:52:56 -0500 Subject: [PATCH 31/70] feat(sink): support encode jsonb data as dynamic json type in sink (#17693) --- src/connector/src/sink/encoder/json.rs | 44 ++++++++++++++++--- src/connector/src/sink/encoder/mod.rs | 25 +++++++++++ .../src/sink/formatter/debezium_json.rs | 7 ++- src/connector/src/sink/formatter/mod.rs | 5 ++- src/connector/src/sink/kafka.rs | 3 +- src/connector/src/sink/mqtt.rs | 7 +-- src/connector/src/sink/nats.rs | 5 ++- src/connector/src/sink/snowflake.rs | 4 +- 8 files changed, 85 insertions(+), 15 deletions(-) diff --git a/src/connector/src/sink/encoder/json.rs b/src/connector/src/sink/encoder/json.rs index d657b8b9c585..3652f38bacbb 100644 --- a/src/connector/src/sink/encoder/json.rs +++ b/src/connector/src/sink/encoder/json.rs @@ -30,8 +30,8 @@ use serde_json::{json, Map, Value}; use thiserror_ext::AsReport; use super::{ - CustomJsonType, DateHandlingMode, KafkaConnectParams, KafkaConnectParamsRef, Result, - RowEncoder, SerTo, TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode, + CustomJsonType, DateHandlingMode, JsonbHandlingMode, KafkaConnectParams, KafkaConnectParamsRef, + Result, RowEncoder, SerTo, TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode, }; use crate::sink::SinkError; @@ -41,6 +41,7 @@ pub struct JsonEncoderConfig { timestamp_handling_mode: TimestampHandlingMode, timestamptz_handling_mode: TimestamptzHandlingMode, custom_json_type: CustomJsonType, + jsonb_handling_mode: JsonbHandlingMode, } pub struct JsonEncoder { @@ -58,6 +59,7 @@ impl JsonEncoder { timestamp_handling_mode: TimestampHandlingMode, timestamptz_handling_mode: TimestamptzHandlingMode, time_handling_mode: TimeHandlingMode, + jsonb_handling_mode: JsonbHandlingMode, ) -> Self { let config = JsonEncoderConfig { time_handling_mode, @@ -65,6 +67,7 @@ impl JsonEncoder { timestamp_handling_mode, timestamptz_handling_mode, custom_json_type: CustomJsonType::None, + jsonb_handling_mode, }; Self { schema, @@ -81,6 +84,7 @@ impl JsonEncoder { timestamp_handling_mode: TimestampHandlingMode::String, timestamptz_handling_mode: TimestamptzHandlingMode::UtcWithoutSuffix, custom_json_type: CustomJsonType::Es, + jsonb_handling_mode: JsonbHandlingMode::Dynamic, }; Self { schema, @@ -101,6 +105,7 @@ impl JsonEncoder { timestamp_handling_mode: TimestampHandlingMode::String, timestamptz_handling_mode: TimestamptzHandlingMode::UtcWithoutSuffix, custom_json_type: CustomJsonType::Doris(map), + jsonb_handling_mode: JsonbHandlingMode::String, }; Self { schema, @@ -117,6 +122,7 @@ impl JsonEncoder { timestamp_handling_mode: TimestampHandlingMode::String, timestamptz_handling_mode: TimestamptzHandlingMode::UtcWithoutSuffix, custom_json_type: CustomJsonType::StarRocks, + jsonb_handling_mode: JsonbHandlingMode::Dynamic, }; Self { schema, @@ -289,11 +295,12 @@ fn datum_to_json_object( (DataType::Interval, ScalarRefImpl::Interval(v)) => { json!(v.as_iso_8601()) } - (DataType::Jsonb, ScalarRefImpl::Jsonb(jsonb_ref)) => match &config.custom_json_type { - CustomJsonType::Es | CustomJsonType::StarRocks => JsonbVal::from(jsonb_ref).take(), - CustomJsonType::Doris(_) | CustomJsonType::None => { + + (DataType::Jsonb, ScalarRefImpl::Jsonb(jsonb_ref)) => match config.jsonb_handling_mode { + JsonbHandlingMode::String => { json!(jsonb_ref.to_string()) } + JsonbHandlingMode::Dynamic => JsonbVal::from(jsonb_ref).take(), }, (DataType::List(datatype), ScalarRefImpl::List(list_ref)) => { let elems = list_ref.iter(); @@ -446,6 +453,7 @@ mod tests { timestamp_handling_mode: TimestampHandlingMode::String, timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, custom_json_type: CustomJsonType::None, + jsonb_handling_mode: JsonbHandlingMode::String, }; let boolean_value = datum_to_json_object( @@ -517,6 +525,7 @@ mod tests { timestamp_handling_mode: TimestampHandlingMode::String, timestamptz_handling_mode: TimestamptzHandlingMode::UtcWithoutSuffix, custom_json_type: CustomJsonType::None, + jsonb_handling_mode: JsonbHandlingMode::String, }; let tstz_inner = "2018-01-26T18:30:09.453Z".parse().unwrap(); @@ -537,6 +546,7 @@ mod tests { timestamp_handling_mode: TimestampHandlingMode::Milli, timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, custom_json_type: CustomJsonType::None, + jsonb_handling_mode: JsonbHandlingMode::String, }; let ts_value = datum_to_json_object( &Field { @@ -603,6 +613,7 @@ mod tests { timestamp_handling_mode: TimestampHandlingMode::String, timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, custom_json_type: CustomJsonType::Doris(map), + jsonb_handling_mode: JsonbHandlingMode::String, }; let decimal = datum_to_json_object( &Field { @@ -628,11 +639,12 @@ mod tests { assert_eq!(date_value, json!(719163)); let from_epoch_config = JsonEncoderConfig { - time_handling_mode: TimeHandlingMode::Milli, + time_handling_mode: TimeHandlingMode::String, date_handling_mode: DateHandlingMode::FromEpoch, timestamp_handling_mode: TimestampHandlingMode::String, timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, custom_json_type: CustomJsonType::None, + jsonb_handling_mode: JsonbHandlingMode::String, }; let date_value = datum_to_json_object( &Field { @@ -651,6 +663,7 @@ mod tests { timestamp_handling_mode: TimestampHandlingMode::String, timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, custom_json_type: CustomJsonType::Doris(HashMap::default()), + jsonb_handling_mode: JsonbHandlingMode::String, }; let date_value = datum_to_json_object( &Field { @@ -683,6 +696,25 @@ mod tests { ) .unwrap(); assert_eq!(interval_value, json!("{\"v3\":3,\"v2\":2,\"v1\":1}")); + + let encode_jsonb_obj_config = JsonEncoderConfig { + time_handling_mode: TimeHandlingMode::String, + date_handling_mode: DateHandlingMode::String, + timestamp_handling_mode: TimestampHandlingMode::String, + timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, + custom_json_type: CustomJsonType::None, + jsonb_handling_mode: JsonbHandlingMode::Dynamic, + }; + let json_value = datum_to_json_object( + &Field { + data_type: DataType::Jsonb, + ..mock_field.clone() + }, + Some(ScalarImpl::Jsonb(JsonbVal::from(json!([1, 2, 3]))).as_scalar_ref_impl()), + &encode_jsonb_obj_config, + ) + .unwrap(); + assert_eq!(json_value, json!([1, 2, 3])); } #[test] diff --git a/src/connector/src/sink/encoder/mod.rs b/src/connector/src/sink/encoder/mod.rs index 889d0162784b..0a8a9e5abce7 100644 --- a/src/connector/src/sink/encoder/mod.rs +++ b/src/connector/src/sink/encoder/mod.rs @@ -150,6 +150,31 @@ pub enum CustomJsonType { None, } +/// How the jsonb type is encoded. +/// +/// - `String`: encode jsonb as string. `[1, true, "foo"] -> "[1, true, \"foo\"]"` +/// - `Dynamic`: encode jsonb as json type dynamically. `[1, true, "foo"] -> [1, true, "foo"]` +pub enum JsonbHandlingMode { + String, + Dynamic, +} + +impl JsonbHandlingMode { + pub const OPTION_KEY: &'static str = "jsonb.handling.mode"; + + pub fn from_options(options: &BTreeMap) -> Result { + match options.get(Self::OPTION_KEY).map(std::ops::Deref::deref) { + Some("string") | None => Ok(Self::String), + Some("dynamic") => Ok(Self::Dynamic), + Some(v) => Err(super::SinkError::Config(anyhow::anyhow!( + "unrecognized {} value {}", + Self::OPTION_KEY, + v + ))), + } + } +} + #[derive(Debug)] struct FieldEncodeError { message: String, diff --git a/src/connector/src/sink/formatter/debezium_json.rs b/src/connector/src/sink/formatter/debezium_json.rs index 6fff15058bd6..fd4813e78541 100644 --- a/src/connector/src/sink/formatter/debezium_json.rs +++ b/src/connector/src/sink/formatter/debezium_json.rs @@ -21,8 +21,8 @@ use tracing::warn; use super::{Result, SinkFormatter, StreamChunk}; use crate::sink::encoder::{ - DateHandlingMode, JsonEncoder, RowEncoder, TimeHandlingMode, TimestampHandlingMode, - TimestamptzHandlingMode, + DateHandlingMode, JsonEncoder, JsonbHandlingMode, RowEncoder, TimeHandlingMode, + TimestampHandlingMode, TimestamptzHandlingMode, }; use crate::tri; @@ -69,6 +69,7 @@ impl DebeziumJsonFormatter { TimestampHandlingMode::Milli, TimestamptzHandlingMode::UtcString, TimeHandlingMode::Milli, + JsonbHandlingMode::String, ); let val_encoder = JsonEncoder::new( schema.clone(), @@ -77,6 +78,7 @@ impl DebeziumJsonFormatter { TimestampHandlingMode::Milli, TimestamptzHandlingMode::UtcString, TimeHandlingMode::Milli, + JsonbHandlingMode::String, ); Self { schema, @@ -397,6 +399,7 @@ mod tests { TimestampHandlingMode::Milli, TimestamptzHandlingMode::UtcString, TimeHandlingMode::Milli, + JsonbHandlingMode::String, ); let json_chunk = chunk_to_json(chunk, &encoder).unwrap(); let schema_json = schema_to_json(&schema, "test_db", "test_table"); diff --git a/src/connector/src/sink/formatter/mod.rs b/src/connector/src/sink/formatter/mod.rs index 4628a925da98..b2e93cba763e 100644 --- a/src/connector/src/sink/formatter/mod.rs +++ b/src/connector/src/sink/formatter/mod.rs @@ -31,7 +31,8 @@ use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc}; use super::encoder::template::TemplateEncoder; use super::encoder::text::TextEncoder; use super::encoder::{ - DateHandlingMode, KafkaConnectParams, TimeHandlingMode, TimestamptzHandlingMode, + DateHandlingMode, JsonbHandlingMode, KafkaConnectParams, TimeHandlingMode, + TimestamptzHandlingMode, }; use super::redis::{KEY_FORMAT, VALUE_FORMAT}; use crate::sink::encoder::{ @@ -113,6 +114,7 @@ pub trait EncoderBuild: Sized { impl EncoderBuild for JsonEncoder { async fn build(b: EncoderParams<'_>, pk_indices: Option>) -> Result { let timestamptz_mode = TimestamptzHandlingMode::from_options(&b.format_desc.options)?; + let jsonb_handling_mode = JsonbHandlingMode::from_options(&b.format_desc.options)?; let encoder = JsonEncoder::new( b.schema, pk_indices, @@ -120,6 +122,7 @@ impl EncoderBuild for JsonEncoder { TimestampHandlingMode::Milli, timestamptz_mode, TimeHandlingMode::Milli, + jsonb_handling_mode, ); let encoder = if let Some(s) = b.format_desc.options.get("schemas.enable") { match s.to_lowercase().parse::() { diff --git a/src/connector/src/sink/kafka.rs b/src/connector/src/sink/kafka.rs index 617f427ae71f..c6e65fb00c39 100644 --- a/src/connector/src/sink/kafka.rs +++ b/src/connector/src/sink/kafka.rs @@ -590,7 +590,7 @@ mod test { use super::*; use crate::sink::encoder::{ - DateHandlingMode, JsonEncoder, TimeHandlingMode, TimestampHandlingMode, + DateHandlingMode, JsonEncoder, JsonbHandlingMode, TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode, }; use crate::sink::formatter::AppendOnlyFormatter; @@ -778,6 +778,7 @@ mod test { TimestampHandlingMode::Milli, TimestamptzHandlingMode::UtcString, TimeHandlingMode::Milli, + JsonbHandlingMode::String, ), )), ) diff --git a/src/connector/src/sink/mqtt.rs b/src/connector/src/sink/mqtt.rs index 9e86b5ec97e2..072666b01564 100644 --- a/src/connector/src/sink/mqtt.rs +++ b/src/connector/src/sink/mqtt.rs @@ -31,8 +31,8 @@ use with_options::WithOptions; use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc}; use super::encoder::{ - DateHandlingMode, JsonEncoder, ProtoEncoder, ProtoHeader, RowEncoder, SerTo, TimeHandlingMode, - TimestampHandlingMode, TimestamptzHandlingMode, + DateHandlingMode, JsonEncoder, JsonbHandlingMode, ProtoEncoder, ProtoHeader, RowEncoder, SerTo, + TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode, }; use super::writer::AsyncTruncateSinkWriterExt; use super::{DummySinkCommitCoordinator, SinkWriterParam}; @@ -221,7 +221,7 @@ impl MqttSinkWriter { } let timestamptz_mode = TimestamptzHandlingMode::from_options(&format_desc.options)?; - + let jsonb_handling_mode = JsonbHandlingMode::from_options(&format_desc.options)?; let encoder = match format_desc.format { SinkFormat::AppendOnly => match format_desc.encode { SinkEncode::Json => RowEncoderWrapper::Json(JsonEncoder::new( @@ -231,6 +231,7 @@ impl MqttSinkWriter { TimestampHandlingMode::Milli, timestamptz_mode, TimeHandlingMode::Milli, + jsonb_handling_mode, )), SinkEncode::Protobuf => { let (descriptor, sid) = crate::schema::protobuf::fetch_descriptor( diff --git a/src/connector/src/sink/nats.rs b/src/connector/src/sink/nats.rs index 471ce6129841..aee0a5fbc723 100644 --- a/src/connector/src/sink/nats.rs +++ b/src/connector/src/sink/nats.rs @@ -28,7 +28,9 @@ use tokio_retry::strategy::{jitter, ExponentialBackoff}; use tokio_retry::Retry; use with_options::WithOptions; -use super::encoder::{DateHandlingMode, TimeHandlingMode, TimestamptzHandlingMode}; +use super::encoder::{ + DateHandlingMode, JsonbHandlingMode, TimeHandlingMode, TimestamptzHandlingMode, +}; use super::utils::chunk_to_json; use super::{DummySinkCommitCoordinator, SinkWriterParam}; use crate::connector_common::NatsCommon; @@ -151,6 +153,7 @@ impl NatsSinkWriter { TimestampHandlingMode::Milli, TimestamptzHandlingMode::UtcWithoutSuffix, TimeHandlingMode::Milli, + JsonbHandlingMode::String, ), }) } diff --git a/src/connector/src/sink/snowflake.rs b/src/connector/src/sink/snowflake.rs index 1c9d67352247..6c3cc291f58e 100644 --- a/src/connector/src/sink/snowflake.rs +++ b/src/connector/src/sink/snowflake.rs @@ -34,7 +34,8 @@ use uuid::Uuid; use with_options::WithOptions; use super::encoder::{ - JsonEncoder, RowEncoder, TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode, + JsonEncoder, JsonbHandlingMode, RowEncoder, TimeHandlingMode, TimestampHandlingMode, + TimestamptzHandlingMode, }; use super::writer::LogSinkerOf; use super::{SinkError, SinkParam}; @@ -193,6 +194,7 @@ impl SnowflakeSinkWriter { TimestampHandlingMode::String, TimestamptzHandlingMode::UtcString, TimeHandlingMode::String, + JsonbHandlingMode::String, ), // initial value of `epoch` will be set to 0 epoch: 0, From 46057734a9bf49631f56c466fc8ade5c239530f1 Mon Sep 17 00:00:00 2001 From: Li0k Date: Wed, 17 Jul 2024 17:33:56 +0800 Subject: [PATCH 32/70] fix(storage): fix the trivial-move loop caused by config and pick_whole_level (#17721) --- proto/hummock.proto | 2 +- .../hummock/compaction/compaction_config.rs | 2 +- src/meta/src/hummock/compaction/mod.rs | 7 ++-- .../picker/base_level_compaction_picker.rs | 5 ++- .../picker/compaction_task_validator.rs | 22 +++++++++++-- .../picker/intra_compaction_picker.rs | 33 +++++++++++++++++-- .../manager/compaction_group_manager.rs | 2 +- 7 files changed, 58 insertions(+), 15 deletions(-) diff --git a/proto/hummock.proto b/proto/hummock.proto index e19faee10c43..412f552fec5c 100644 --- a/proto/hummock.proto +++ b/proto/hummock.proto @@ -905,7 +905,7 @@ message CompactionConfig { bool enable_emergency_picker = 20; // The limitation of the level count of l0 compaction - uint32 max_l0_compact_level_count = 21; + optional uint32 max_l0_compact_level_count = 21; } message TableStats { diff --git a/src/meta/src/hummock/compaction/compaction_config.rs b/src/meta/src/hummock/compaction/compaction_config.rs index de91bf4f79de..798500e980b6 100644 --- a/src/meta/src/hummock/compaction/compaction_config.rs +++ b/src/meta/src/hummock/compaction/compaction_config.rs @@ -64,7 +64,7 @@ impl CompactionConfigBuilder { compaction_config::level0_overlapping_sub_level_compact_level_count(), tombstone_reclaim_ratio: compaction_config::tombstone_reclaim_ratio(), enable_emergency_picker: compaction_config::enable_emergency_picker(), - max_l0_compact_level_count: compaction_config::max_l0_compact_level_count(), + max_l0_compact_level_count: Some(compaction_config::max_l0_compact_level_count()), }, } } diff --git a/src/meta/src/hummock/compaction/mod.rs b/src/meta/src/hummock/compaction/mod.rs index bf4b608fe59a..b2a386011702 100644 --- a/src/meta/src/hummock/compaction/mod.rs +++ b/src/meta/src/hummock/compaction/mod.rs @@ -26,7 +26,7 @@ use std::fmt::{Debug, Formatter}; use std::sync::Arc; use picker::{LevelCompactionPicker, TierCompactionPicker}; -use risingwave_hummock_sdk::{can_concat, CompactionGroupId, HummockCompactionTaskId}; +use risingwave_hummock_sdk::{CompactionGroupId, HummockCompactionTaskId}; use risingwave_pb::hummock::compaction_config::CompactionMode; use risingwave_pb::hummock::hummock_version::Levels; use risingwave_pb::hummock::{CompactTask, CompactionConfig, LevelType}; @@ -140,10 +140,7 @@ impl CompactStatus { return false; } - if task.input_ssts.len() == 1 { - return task.input_ssts[0].level_idx == 0 - && can_concat(&task.input_ssts[0].table_infos); - } else if task.input_ssts.len() != 2 + if task.input_ssts.len() != 2 || task.input_ssts[0].level_type() != LevelType::Nonoverlapping { return false; diff --git a/src/meta/src/hummock/compaction/picker/base_level_compaction_picker.rs b/src/meta/src/hummock/compaction/picker/base_level_compaction_picker.rs index bf05afc6c6e8..62baa16a640e 100644 --- a/src/meta/src/hummock/compaction/picker/base_level_compaction_picker.rs +++ b/src/meta/src/hummock/compaction/picker/base_level_compaction_picker.rs @@ -16,6 +16,7 @@ use std::cell::RefCell; use std::sync::Arc; use itertools::Itertools; +use risingwave_common::config::default::compaction_config; use risingwave_hummock_sdk::compaction_group::hummock_version_ext::HummockLevelsExt; use risingwave_pb::hummock::hummock_version::Levels; use risingwave_pb::hummock::{CompactionConfig, InputLevel, Level, LevelType, OverlappingLevel}; @@ -166,7 +167,9 @@ impl LevelCompactionPicker { self.config.level0_max_compact_file_number, overlap_strategy.clone(), self.developer_config.enable_check_task_level_overlap, - self.config.max_l0_compact_level_count as usize, + self.config + .max_l0_compact_level_count + .unwrap_or(compaction_config::max_l0_compact_level_count()) as usize, ); let mut max_vnode_partition_idx = 0; diff --git a/src/meta/src/hummock/compaction/picker/compaction_task_validator.rs b/src/meta/src/hummock/compaction/picker/compaction_task_validator.rs index c7dd27a6b190..8bdb10c213d1 100644 --- a/src/meta/src/hummock/compaction/picker/compaction_task_validator.rs +++ b/src/meta/src/hummock/compaction/picker/compaction_task_validator.rs @@ -15,6 +15,7 @@ use std::collections::HashMap; use std::sync::Arc; +use risingwave_common::config::default::compaction_config; use risingwave_pb::hummock::CompactionConfig; use super::{CompactionInput, LocalPickerStatistic}; @@ -90,7 +91,12 @@ struct TierCompactionTaskValidationRule { impl CompactionTaskValidationRule for TierCompactionTaskValidationRule { fn validate(&self, input: &CompactionInput, stats: &mut LocalPickerStatistic) -> bool { if input.total_file_count >= self.config.level0_max_compact_file_number - || input.input_levels.len() >= self.config.max_l0_compact_level_count as usize + || input.input_levels.len() + >= self + .config + .max_l0_compact_level_count + .unwrap_or(compaction_config::max_l0_compact_level_count()) + as usize { return true; } @@ -124,7 +130,12 @@ impl CompactionTaskValidationRule for IntraCompactionTaskValidationRule { fn validate(&self, input: &CompactionInput, stats: &mut LocalPickerStatistic) -> bool { if (input.total_file_count >= self.config.level0_max_compact_file_number && input.input_levels.len() > 1) - || input.input_levels.len() >= self.config.max_l0_compact_level_count as usize + || input.input_levels.len() + >= self + .config + .max_l0_compact_level_count + .unwrap_or(compaction_config::max_l0_compact_level_count()) + as usize { return true; } @@ -172,7 +183,12 @@ struct BaseCompactionTaskValidationRule { impl CompactionTaskValidationRule for BaseCompactionTaskValidationRule { fn validate(&self, input: &CompactionInput, stats: &mut LocalPickerStatistic) -> bool { if input.total_file_count >= self.config.level0_max_compact_file_number - || input.input_levels.len() >= self.config.max_l0_compact_level_count as usize + || input.input_levels.len() + >= self + .config + .max_l0_compact_level_count + .unwrap_or(compaction_config::max_l0_compact_level_count()) + as usize { return true; } diff --git a/src/meta/src/hummock/compaction/picker/intra_compaction_picker.rs b/src/meta/src/hummock/compaction/picker/intra_compaction_picker.rs index 5cc65bd38a1c..1261cca55089 100644 --- a/src/meta/src/hummock/compaction/picker/intra_compaction_picker.rs +++ b/src/meta/src/hummock/compaction/picker/intra_compaction_picker.rs @@ -14,6 +14,7 @@ use std::sync::Arc; +use risingwave_common::config::default::compaction_config; use risingwave_pb::hummock::hummock_version::Levels; use risingwave_pb::hummock::{CompactionConfig, InputLevel, LevelType, OverlappingLevel}; @@ -54,10 +55,33 @@ impl CompactionPicker for IntraCompactionPicker { if let Some(ret) = self.pick_whole_level(l0, &level_handlers[0], vnode_partition_count, stats) { + if ret.input_levels.len() < 2 { + tracing::error!( + ?ret, + vnode_partition_count, + "pick_whole_level failed to pick enough levels" + ); + return None; + } + return Some(ret); } - self.pick_l0_intra(l0, &level_handlers[0], vnode_partition_count, stats) + if let Some(ret) = self.pick_l0_intra(l0, &level_handlers[0], vnode_partition_count, stats) + { + if ret.input_levels.len() < 2 { + tracing::error!( + ?ret, + vnode_partition_count, + "pick_l0_intra failed to pick enough levels" + ); + return None; + } + + return Some(ret); + } + + None } } @@ -144,7 +168,10 @@ impl IntraCompactionPicker { self.config.level0_max_compact_file_number, overlap_strategy.clone(), self.developer_config.enable_check_task_level_overlap, - self.config.max_l0_compact_level_count as usize, + self.config + .max_l0_compact_level_count + .unwrap_or(compaction_config::max_l0_compact_level_count()) + as usize, ); let l0_select_tables_vec = non_overlap_sub_level_picker @@ -357,7 +384,7 @@ impl WholeLevelCompactionPicker { table_infos: next_level.table_infos.clone(), }); } - if !select_level_inputs.is_empty() { + if select_level_inputs.len() > 1 { let vnode_partition_count = if select_input_size > self.config.sub_level_max_compaction_bytes / 2 { partition_count diff --git a/src/meta/src/hummock/manager/compaction_group_manager.rs b/src/meta/src/hummock/manager/compaction_group_manager.rs index 6fc637148e17..ae2611c70be0 100644 --- a/src/meta/src/hummock/manager/compaction_group_manager.rs +++ b/src/meta/src/hummock/manager/compaction_group_manager.rs @@ -761,7 +761,7 @@ fn update_compaction_config(target: &mut CompactionConfig, items: &[MutableConfi .clone_from(&c.compression_algorithm); } MutableConfig::MaxL0CompactLevelCount(c) => { - target.max_l0_compact_level_count = *c; + target.max_l0_compact_level_count = Some(*c); } } } From e8273cac1178edfcc928ae1ddb3efcc63fe9f99c Mon Sep 17 00:00:00 2001 From: Noel Kwan <47273164+kwannoel@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:38:34 +0800 Subject: [PATCH 33/70] doc: fix `main-cron-bisect` docs (#17720) --- docs/dev/src/ci.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/dev/src/ci.md b/docs/dev/src/ci.md index 0f12d893acae..1a34374cc324 100644 --- a/docs/dev/src/ci.md +++ b/docs/dev/src/ci.md @@ -22,20 +22,20 @@ To run `e2e-test` and `e2e-source-test` for `main-cron` in your pull request: ## Main Cron Bisect Guide -1. Create a new build via buildkite: https://buildkite.com/risingwavelabs/main-cron-bisect/builds/#new +1. Create a new build via [buildkite](https://buildkite.com/risingwavelabs/main-cron-bisect/builds/#new) 2. Add the following environment variables: - `GOOD_COMMIT`: The good commit hash. - `BAD_COMMIT`: The bad commit hash. - `BISECT_BRANCH`: The branch name where the bisect will be performed. - `CI_STEPS`: The `CI_STEPS` to run during the bisect. Separate multiple steps with a comma. - - You can check the labels for this in `main-cron.yml`, under the conditions for each step. + - You can check the labels for this in [main-cron.yml](https://github.com/risingwavelabs/risingwave/blob/main/ci/workflows/main-cron.yml), + under the conditions for each step. Example you can try on [buildkite](https://buildkite.com/risingwavelabs/main-cron-bisect/builds/#new): -- Branch: `kwannoel/find-regress` - Environment variables: ``` - START_COMMIT=29791ddf16fdf2c2e83ad3a58215f434e610f89a - END_COMMIT=7f36bf17c1d19a1e6b2cdb90491d3c08ae8b0004 + GOOD_COMMIT=29791ddf16fdf2c2e83ad3a58215f434e610f89a + BAD_COMMIT=7f36bf17c1d19a1e6b2cdb90491d3c08ae8b0004 BISECT_BRANCH=kwannoel/test-bisect CI_STEPS="test-bisect,disable-build" ``` \ No newline at end of file From 6e2c82f9d85c4650121638bd6f40cca21a2eb925 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Wed, 17 Jul 2024 17:53:57 +0800 Subject: [PATCH 34/70] feat(optimizer): add `columns_monotonicity` field for PlanNode (#17600) Signed-off-by: Richard Chien --- .../optimizer/plan_node/generic/cdc_scan.rs | 6 +- .../src/optimizer/plan_node/logical_source.rs | 3 +- src/frontend/src/optimizer/plan_node/mod.rs | 6 +- .../src/optimizer/plan_node/plan_base.rs | 18 ++++ .../src/optimizer/plan_node/stream.rs | 2 + .../plan_node/stream_cdc_table_scan.rs | 1 + .../optimizer/plan_node/stream_changelog.rs | 2 + .../src/optimizer/plan_node/stream_dedup.rs | 1 + .../optimizer/plan_node/stream_delta_join.rs | 3 +- .../src/optimizer/plan_node/stream_dml.rs | 2 + .../plan_node/stream_dynamic_filter.rs | 3 +- .../plan_node/stream_eowc_over_window.rs | 3 + .../optimizer/plan_node/stream_exchange.rs | 4 +- .../src/optimizer/plan_node/stream_expand.rs | 3 +- .../src/optimizer/plan_node/stream_filter.rs | 1 + .../optimizer/plan_node/stream_fs_fetch.rs | 3 +- .../optimizer/plan_node/stream_group_topn.rs | 3 +- .../optimizer/plan_node/stream_hash_agg.rs | 2 + .../optimizer/plan_node/stream_hash_join.rs | 3 +- .../optimizer/plan_node/stream_hop_window.rs | 2 + .../optimizer/plan_node/stream_materialize.rs | 1 + .../src/optimizer/plan_node/stream_now.rs | 7 +- .../optimizer/plan_node/stream_over_window.rs | 2 + .../src/optimizer/plan_node/stream_project.rs | 13 ++- .../optimizer/plan_node/stream_project_set.rs | 9 +- .../optimizer/plan_node/stream_row_id_gen.rs | 1 + .../src/optimizer/plan_node/stream_share.rs | 1 + .../optimizer/plan_node/stream_simple_agg.rs | 11 ++- .../src/optimizer/plan_node/stream_sort.rs | 8 ++ .../src/optimizer/plan_node/stream_source.rs | 3 +- .../optimizer/plan_node/stream_source_scan.rs | 3 +- .../plan_node/stream_stateless_simple_agg.rs | 3 +- .../optimizer/plan_node/stream_table_scan.rs | 3 +- .../plan_node/stream_temporal_join.rs | 7 ++ .../src/optimizer/plan_node/stream_topn.rs | 11 ++- .../src/optimizer/plan_node/stream_union.rs | 3 +- .../src/optimizer/plan_node/stream_values.rs | 3 +- .../plan_node/stream_watermark_filter.rs | 2 + .../src/optimizer/property/monotonicity.rs | 91 ++++++++++++++++++- .../src/utils/column_index_mapping.rs | 13 ++- 40 files changed, 232 insertions(+), 34 deletions(-) diff --git a/src/frontend/src/optimizer/plan_node/generic/cdc_scan.rs b/src/frontend/src/optimizer/plan_node/generic/cdc_scan.rs index 2d7d708291e4..ff1018de2e63 100644 --- a/src/frontend/src/optimizer/plan_node/generic/cdc_scan.rs +++ b/src/frontend/src/optimizer/plan_node/generic/cdc_scan.rs @@ -33,7 +33,7 @@ use crate::catalog::ColumnId; use crate::error::Result; use crate::expr::{ExprRewriter, ExprVisitor}; use crate::optimizer::optimizer_context::OptimizerContextRef; -use crate::optimizer::property::FunctionalDependencySet; +use crate::optimizer::property::{FunctionalDependencySet, MonotonicityMap}; use crate::WithOptions; /// [`CdcScan`] reads rows of a table from an external upstream database @@ -125,6 +125,10 @@ impl CdcScan { FixedBitSet::with_capacity(self.get_table_columns().len()) } + pub fn columns_monotonicity(&self) -> MonotonicityMap { + MonotonicityMap::new() + } + pub(crate) fn column_names_with_table_prefix(&self) -> Vec { self.output_col_idx .iter() diff --git a/src/frontend/src/optimizer/plan_node/logical_source.rs b/src/frontend/src/optimizer/plan_node/logical_source.rs index 918db2919e62..50024b4274e7 100644 --- a/src/frontend/src/optimizer/plan_node/logical_source.rs +++ b/src/frontend/src/optimizer/plan_node/logical_source.rs @@ -44,7 +44,7 @@ use crate::optimizer::plan_node::{ ToStreamContext, }; use crate::optimizer::property::Distribution::HashShard; -use crate::optimizer::property::{Distribution, Order, RequiredDist}; +use crate::optimizer::property::{Distribution, MonotonicityMap, Order, RequiredDist}; use crate::utils::{ColIndexMapping, Condition, IndexRewriter}; /// `LogicalSource` returns contents of a table or other equivalent object @@ -229,6 +229,7 @@ impl LogicalSource { true, // `list` will keep listing all objects, it must be append-only false, FixedBitSet::with_capacity(logical_source.column_catalog.len()), + MonotonicityMap::new(), ), core: logical_source, } diff --git a/src/frontend/src/optimizer/plan_node/mod.rs b/src/frontend/src/optimizer/plan_node/mod.rs index b5062398270e..ee2b16265e7a 100644 --- a/src/frontend/src/optimizer/plan_node/mod.rs +++ b/src/frontend/src/optimizer/plan_node/mod.rs @@ -49,7 +49,7 @@ use self::batch::BatchPlanRef; use self::generic::{GenericPlanRef, PhysicalPlanRef}; use self::stream::StreamPlanRef; use self::utils::Distill; -use super::property::{Distribution, FunctionalDependencySet, Order}; +use super::property::{Distribution, FunctionalDependencySet, MonotonicityMap, Order}; use crate::error::{ErrorCode, Result}; use crate::optimizer::ExpressionSimplifyRewriter; use crate::session::current::notice_to_user; @@ -609,6 +609,10 @@ impl StreamPlanRef for PlanRef { fn watermark_columns(&self) -> &FixedBitSet { self.plan_base().watermark_columns() } + + fn columns_monotonicity(&self) -> &MonotonicityMap { + self.plan_base().columns_monotonicity() + } } /// Allow access to all fields defined in [`BatchPlanRef`] for the type-erased plan node. diff --git a/src/frontend/src/optimizer/plan_node/plan_base.rs b/src/frontend/src/optimizer/plan_node/plan_base.rs index 12fba475241c..02c85858967f 100644 --- a/src/frontend/src/optimizer/plan_node/plan_base.rs +++ b/src/frontend/src/optimizer/plan_node/plan_base.rs @@ -59,6 +59,8 @@ pub struct StreamExtra { /// The watermark column indices of the `PlanNode`'s output. There could be watermark output from /// this stream operator. watermark_columns: FixedBitSet, + /// The monotonicity of columns in the output. + columns_monotonicity: MonotonicityMap, } impl GetPhysicalCommon for StreamExtra { @@ -168,6 +170,10 @@ impl stream::StreamPlanRef for PlanBase { fn watermark_columns(&self) -> &FixedBitSet { &self.extra.watermark_columns } + + fn columns_monotonicity(&self) -> &MonotonicityMap { + &self.extra.columns_monotonicity + } } impl batch::BatchPlanRef for PlanBase { @@ -222,6 +228,7 @@ impl PlanBase { append_only: bool, emit_on_window_close: bool, watermark_columns: FixedBitSet, + columns_monotonicity: MonotonicityMap, ) -> Self { let id = ctx.next_plan_node_id(); assert_eq!(watermark_columns.len(), schema.len()); @@ -236,6 +243,7 @@ impl PlanBase { append_only, emit_on_window_close, watermark_columns, + columns_monotonicity, }, } } @@ -246,6 +254,7 @@ impl PlanBase { append_only: bool, emit_on_window_close: bool, watermark_columns: FixedBitSet, + columns_monotonicity: MonotonicityMap, ) -> Self { Self::new_stream( core.ctx(), @@ -256,6 +265,7 @@ impl PlanBase { append_only, emit_on_window_close, watermark_columns, + columns_monotonicity, ) } } @@ -383,6 +393,10 @@ impl<'a> PlanBaseRef<'a> { dispatch_plan_base!(self, [Stream], StreamPlanRef::watermark_columns) } + pub(super) fn columns_monotonicity(self) -> &'a MonotonicityMap { + dispatch_plan_base!(self, [Stream], StreamPlanRef::columns_monotonicity) + } + pub(super) fn order(self) -> &'a Order { dispatch_plan_base!(self, [Batch], BatchPlanRef::order) } @@ -428,6 +442,10 @@ impl StreamPlanRef for PlanBaseRef<'_> { fn watermark_columns(&self) -> &FixedBitSet { (*self).watermark_columns() } + + fn columns_monotonicity(&self) -> &MonotonicityMap { + (*self).columns_monotonicity() + } } impl BatchPlanRef for PlanBaseRef<'_> { diff --git a/src/frontend/src/optimizer/plan_node/stream.rs b/src/frontend/src/optimizer/plan_node/stream.rs index 42a599ccd60b..e2df99d13d9f 100644 --- a/src/frontend/src/optimizer/plan_node/stream.rs +++ b/src/frontend/src/optimizer/plan_node/stream.rs @@ -15,6 +15,7 @@ use fixedbitset::FixedBitSet; use super::generic::PhysicalPlanRef; +use crate::optimizer::property::MonotonicityMap; /// A subtrait of [`PhysicalPlanRef`] for stream plans. /// @@ -29,6 +30,7 @@ pub trait StreamPlanRef: PhysicalPlanRef { fn append_only(&self) -> bool; fn emit_on_window_close(&self) -> bool; fn watermark_columns(&self) -> &FixedBitSet; + fn columns_monotonicity(&self) -> &MonotonicityMap; } /// Prelude for stream plan nodes. diff --git a/src/frontend/src/optimizer/plan_node/stream_cdc_table_scan.rs b/src/frontend/src/optimizer/plan_node/stream_cdc_table_scan.rs index 9fe334717145..a7aef5195ea5 100644 --- a/src/frontend/src/optimizer/plan_node/stream_cdc_table_scan.rs +++ b/src/frontend/src/optimizer/plan_node/stream_cdc_table_scan.rs @@ -50,6 +50,7 @@ impl StreamCdcTableScan { core.append_only(), false, core.watermark_columns(), + core.columns_monotonicity(), ); Self { base, core } } diff --git a/src/frontend/src/optimizer/plan_node/stream_changelog.rs b/src/frontend/src/optimizer/plan_node/stream_changelog.rs index b02c5eeb0c35..34bfdec28181 100644 --- a/src/frontend/src/optimizer/plan_node/stream_changelog.rs +++ b/src/frontend/src/optimizer/plan_node/stream_changelog.rs @@ -20,6 +20,7 @@ use super::stream::prelude::PhysicalPlanRef; use super::stream::StreamPlanRef; use super::utils::impl_distill_by_unit; use super::{generic, ExprRewritable, PlanBase, PlanTreeNodeUnary, Stream, StreamNode}; +use crate::optimizer::property::MonotonicityMap; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::PlanRef; @@ -48,6 +49,7 @@ impl StreamChangeLog { true, input.emit_on_window_close(), watermark_columns, + MonotonicityMap::new(), // TODO: derive monotonicity ); StreamChangeLog { base, core } } diff --git a/src/frontend/src/optimizer/plan_node/stream_dedup.rs b/src/frontend/src/optimizer/plan_node/stream_dedup.rs index b31415c12550..d642d0f9e7ee 100644 --- a/src/frontend/src/optimizer/plan_node/stream_dedup.rs +++ b/src/frontend/src/optimizer/plan_node/stream_dedup.rs @@ -44,6 +44,7 @@ impl StreamDedup { true, input.emit_on_window_close(), input.watermark_columns().clone(), + input.columns_monotonicity().clone(), ); StreamDedup { base, core } } diff --git a/src/frontend/src/optimizer/plan_node/stream_delta_join.rs b/src/frontend/src/optimizer/plan_node/stream_delta_join.rs index 7a99c8f7955b..f53d4331ae61 100644 --- a/src/frontend/src/optimizer/plan_node/stream_delta_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_delta_join.rs @@ -27,7 +27,7 @@ use crate::expr::{Expr, ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::utils::IndicesDisplay; use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay, TryToStreamPb}; -use crate::optimizer::property::Distribution; +use crate::optimizer::property::{Distribution, MonotonicityMap}; use crate::scheduler::SchedulerResult; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::utils::ColIndexMappingRewriteExt; @@ -76,6 +76,7 @@ impl StreamDeltaJoin { append_only, false, // TODO(rc): derive EOWC property from input watermark_columns, + MonotonicityMap::new(), // TODO: derive monotonicity ); Self { diff --git a/src/frontend/src/optimizer/plan_node/stream_dml.rs b/src/frontend/src/optimizer/plan_node/stream_dml.rs index e77704171829..7b671efa24c2 100644 --- a/src/frontend/src/optimizer/plan_node/stream_dml.rs +++ b/src/frontend/src/optimizer/plan_node/stream_dml.rs @@ -21,6 +21,7 @@ use super::stream::prelude::*; use super::utils::{childless_record, Distill}; use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::property::MonotonicityMap; use crate::stream_fragmenter::BuildFragmentGraphState; #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -41,6 +42,7 @@ impl StreamDml { append_only, false, // TODO(rc): decide EOWC property FixedBitSet::with_capacity(input.schema().len()), // no watermark if dml is allowed + MonotonicityMap::new(), // TODO: derive monotonicity ); Self { diff --git a/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs b/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs index a6c31b9197eb..f32bd63753d2 100644 --- a/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs +++ b/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs @@ -27,7 +27,7 @@ use super::{generic, ExprRewritable, PlanTreeNodeUnary}; use crate::expr::Expr; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::{PlanBase, PlanTreeNodeBinary, StreamNode}; -use crate::optimizer::property::Distribution; +use crate::optimizer::property::{Distribution, MonotonicityMap}; use crate::optimizer::PlanRef; use crate::stream_fragmenter::BuildFragmentGraphState; @@ -77,6 +77,7 @@ impl StreamDynamicFilter { out_append_only, false, // TODO(rc): decide EOWC property Self::derive_watermark_columns(&core), + MonotonicityMap::new(), // TODO: derive monotonicity ); let cleaned_by_watermark = Self::cleaned_by_watermark(&core); Self { diff --git a/src/frontend/src/optimizer/plan_node/stream_eowc_over_window.rs b/src/frontend/src/optimizer/plan_node/stream_eowc_over_window.rs index 78cb0e3b9d60..4d134df37799 100644 --- a/src/frontend/src/optimizer/plan_node/stream_eowc_over_window.rs +++ b/src/frontend/src/optimizer/plan_node/stream_eowc_over_window.rs @@ -23,6 +23,7 @@ use super::stream::prelude::*; use super::utils::{impl_distill_by_unit, TableCatalogBuilder}; use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::property::MonotonicityMap; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::TableCatalog; @@ -58,6 +59,8 @@ impl StreamEowcOverWindow { true, true, watermark_columns, + // we cannot derive monotonicity for any column for the same reason as watermark columns + MonotonicityMap::new(), ); StreamEowcOverWindow { base, core } } diff --git a/src/frontend/src/optimizer/plan_node/stream_exchange.rs b/src/frontend/src/optimizer/plan_node/stream_exchange.rs index 964accc0f69a..802f2e3d227c 100644 --- a/src/frontend/src/optimizer/plan_node/stream_exchange.rs +++ b/src/frontend/src/optimizer/plan_node/stream_exchange.rs @@ -20,7 +20,7 @@ use super::stream::prelude::*; use super::utils::{childless_record, plan_node_name, Distill}; use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; -use crate::optimizer::property::{Distribution, DistributionDisplay}; +use crate::optimizer::property::{Distribution, DistributionDisplay, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; /// `StreamExchange` imposes a particular distribution on its input @@ -44,6 +44,7 @@ impl StreamExchange { input.append_only(), input.emit_on_window_close(), input.watermark_columns().clone(), + MonotonicityMap::new(), // we lost monotonicity information when shuffling ); StreamExchange { base, @@ -64,6 +65,7 @@ impl StreamExchange { input.append_only(), input.emit_on_window_close(), input.watermark_columns().clone(), + input.columns_monotonicity().clone(), ); StreamExchange { base, diff --git a/src/frontend/src/optimizer/plan_node/stream_expand.rs b/src/frontend/src/optimizer/plan_node/stream_expand.rs index 4f38e95cdfea..fa0268a46fcf 100644 --- a/src/frontend/src/optimizer/plan_node/stream_expand.rs +++ b/src/frontend/src/optimizer/plan_node/stream_expand.rs @@ -21,7 +21,7 @@ use super::stream::prelude::*; use super::utils::impl_distill_by_unit; use super::{generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; -use crate::optimizer::property::Distribution; +use crate::optimizer::property::{Distribution, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -52,6 +52,7 @@ impl StreamExpand { input.append_only(), input.emit_on_window_close(), watermark_columns, + MonotonicityMap::new(), ); StreamExpand { base, core } } diff --git a/src/frontend/src/optimizer/plan_node/stream_filter.rs b/src/frontend/src/optimizer/plan_node/stream_filter.rs index 586b5ac8a84b..0a3126ffe718 100644 --- a/src/frontend/src/optimizer/plan_node/stream_filter.rs +++ b/src/frontend/src/optimizer/plan_node/stream_filter.rs @@ -42,6 +42,7 @@ impl StreamFilter { input.append_only(), input.emit_on_window_close(), input.watermark_columns().clone(), + input.columns_monotonicity().clone(), ); StreamFilter { base, core } } diff --git a/src/frontend/src/optimizer/plan_node/stream_fs_fetch.rs b/src/frontend/src/optimizer/plan_node/stream_fs_fetch.rs index c57494123672..08516631dc75 100644 --- a/src/frontend/src/optimizer/plan_node/stream_fs_fetch.rs +++ b/src/frontend/src/optimizer/plan_node/stream_fs_fetch.rs @@ -26,7 +26,7 @@ use crate::catalog::source_catalog::SourceCatalog; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::utils::{childless_record, Distill}; use crate::optimizer::plan_node::{generic, ExprRewritable, StreamNode}; -use crate::optimizer::property::Distribution; +use crate::optimizer::property::{Distribution, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -55,6 +55,7 @@ impl StreamFsFetch { source.catalog.as_ref().map_or(true, |s| s.append_only), false, FixedBitSet::with_capacity(source.column_catalog.len()), + MonotonicityMap::new(), // TODO: derive monotonicity ); Self { diff --git a/src/frontend/src/optimizer/plan_node/stream_group_topn.rs b/src/frontend/src/optimizer/plan_node/stream_group_topn.rs index 0cd8edc996c8..b9230270e634 100644 --- a/src/frontend/src/optimizer/plan_node/stream_group_topn.rs +++ b/src/frontend/src/optimizer/plan_node/stream_group_topn.rs @@ -22,7 +22,7 @@ use super::utils::{plan_node_name, watermark_pretty, Distill}; use super::{generic, ExprRewritable, PlanBase, PlanTreeNodeUnary, StreamNode}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::generic::GenericPlanNode; -use crate::optimizer::property::Order; +use crate::optimizer::property::{MonotonicityMap, Order}; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::PlanRef; @@ -79,6 +79,7 @@ impl StreamGroupTopN { // TODO: https://github.com/risingwavelabs/risingwave/issues/8348 false, watermark_columns, + MonotonicityMap::new(), // TODO: derive monotonicity ); StreamGroupTopN { base, diff --git a/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs b/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs index eb69d5c259bb..2dfad775ecc6 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs @@ -24,6 +24,7 @@ use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::error::{ErrorCode, Result}; use crate::expr::{ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::property::MonotonicityMap; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::utils::{ColIndexMapping, ColIndexMappingRewriteExt, IndexSet}; @@ -93,6 +94,7 @@ impl StreamHashAgg { emit_on_window_close, // in EOWC mode, we produce append only output emit_on_window_close, watermark_columns, + MonotonicityMap::new(), // TODO: derive monotonicity ); StreamHashAgg { base, diff --git a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs index b803fccef7b0..cbce1e1caf45 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs @@ -30,7 +30,7 @@ use crate::expr::{Expr, ExprDisplay, ExprRewriter, ExprVisitor, InequalityInputP use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::utils::IndicesDisplay; use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay}; -use crate::optimizer::property::Distribution; +use crate::optimizer::property::{Distribution, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::utils::ColIndexMappingRewriteExt; @@ -196,6 +196,7 @@ impl StreamHashJoin { append_only, false, // TODO(rc): derive EOWC property from input watermark_columns, + MonotonicityMap::new(), // TODO: derive monotonicity ); Self { diff --git a/src/frontend/src/optimizer/plan_node/stream_hop_window.rs b/src/frontend/src/optimizer/plan_node/stream_hop_window.rs index a94dfbe788f8..4a50387c50be 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hop_window.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hop_window.rs @@ -21,6 +21,7 @@ use super::utils::{childless_record, watermark_pretty, Distill}; use super::{generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::property::MonotonicityMap; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::utils::ColIndexMappingRewriteExt; @@ -62,6 +63,7 @@ impl StreamHopWindow { input.append_only(), input.emit_on_window_close(), internal2output.rewrite_bitset(&watermark_columns), + MonotonicityMap::new(), /* hop window start/end jumps, so monotonicity is not propagated */ ); Self { base, diff --git a/src/frontend/src/optimizer/plan_node/stream_materialize.rs b/src/frontend/src/optimizer/plan_node/stream_materialize.rs index e5a2496916ad..865dc71191b4 100644 --- a/src/frontend/src/optimizer/plan_node/stream_materialize.rs +++ b/src/frontend/src/optimizer/plan_node/stream_materialize.rs @@ -59,6 +59,7 @@ impl StreamMaterialize { input.append_only(), input.emit_on_window_close(), input.watermark_columns().clone(), + input.columns_monotonicity().clone(), ); Self { base, input, table } } diff --git a/src/frontend/src/optimizer/plan_node/stream_now.rs b/src/frontend/src/optimizer/plan_node/stream_now.rs index 22a0d2c5fb0f..9ec80d15bac3 100644 --- a/src/frontend/src/optimizer/plan_node/stream_now.rs +++ b/src/frontend/src/optimizer/plan_node/stream_now.rs @@ -26,7 +26,7 @@ use super::utils::{childless_record, Distill, TableCatalogBuilder}; use super::{generic, ExprRewritable, PlanBase, StreamNode}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::utils::column_names_pretty; -use crate::optimizer::property::Distribution; +use crate::optimizer::property::{Distribution, Monotonicity, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -39,12 +39,17 @@ impl StreamNow { pub fn new(core: generic::Now) -> Self { let mut watermark_columns = FixedBitSet::with_capacity(1); watermark_columns.set(0, true); + + let mut columns_monotonicity = MonotonicityMap::new(); + columns_monotonicity.insert(0, Monotonicity::NonDecreasing); + let base = PlanBase::new_stream_with_core( &core, Distribution::Single, core.mode.is_generate_series(), // append only core.mode.is_generate_series(), // emit on window close watermark_columns, + columns_monotonicity, ); Self { base, core } } diff --git a/src/frontend/src/optimizer/plan_node/stream_over_window.rs b/src/frontend/src/optimizer/plan_node/stream_over_window.rs index be6c63bcb50d..6b0beaa9f99c 100644 --- a/src/frontend/src/optimizer/plan_node/stream_over_window.rs +++ b/src/frontend/src/optimizer/plan_node/stream_over_window.rs @@ -23,6 +23,7 @@ use super::stream::prelude::*; use super::utils::{impl_distill_by_unit, TableCatalogBuilder}; use super::{generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::property::MonotonicityMap; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::TableCatalog; @@ -45,6 +46,7 @@ impl StreamOverWindow { false, // general over window cannot be append-only false, watermark_columns, + MonotonicityMap::new(), // TODO: derive monotonicity ); StreamOverWindow { base, core } } diff --git a/src/frontend/src/optimizer/plan_node/stream_project.rs b/src/frontend/src/optimizer/plan_node/stream_project.rs index e5828a326706..ef879627c66b 100644 --- a/src/frontend/src/optimizer/plan_node/stream_project.rs +++ b/src/frontend/src/optimizer/plan_node/stream_project.rs @@ -22,7 +22,7 @@ use super::utils::{childless_record, watermark_pretty, Distill}; use super::{generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; -use crate::optimizer::property::{analyze_monotonicity, monotonicity_variants}; +use crate::optimizer::property::{analyze_monotonicity, monotonicity_variants, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::utils::ColIndexMappingRewriteExt; @@ -82,19 +82,21 @@ impl StreamProject { let mut watermark_derivations = vec![]; let mut nondecreasing_exprs = vec![]; let mut out_watermark_columns = FixedBitSet::with_capacity(core.exprs.len()); + let mut out_monotonicity_map = MonotonicityMap::new(); for (expr_idx, expr) in core.exprs.iter().enumerate() { use monotonicity_variants::*; match analyze_monotonicity(expr) { - Inherent(Constant) => { - // XXX(rc): we can produce one watermark on each recovery for this case. - } Inherent(monotonicity) => { - if monotonicity.is_non_decreasing() { + out_monotonicity_map.insert(expr_idx, monotonicity); + if monotonicity.is_non_decreasing() && !monotonicity.is_constant() { + // TODO(rc): may be we should also derive watermark for constant later nondecreasing_exprs.push(expr_idx); // to produce watermarks out_watermark_columns.insert(expr_idx); } } FollowingInput(input_idx) => { + let in_monotonicity = input.columns_monotonicity()[input_idx]; + out_monotonicity_map.insert(expr_idx, in_monotonicity); if input.watermark_columns().contains(input_idx) { watermark_derivations.push((input_idx, expr_idx)); // to propagate watermarks out_watermark_columns.insert(expr_idx); @@ -111,6 +113,7 @@ impl StreamProject { input.append_only(), input.emit_on_window_close(), out_watermark_columns, + out_monotonicity_map, ); StreamProject { diff --git a/src/frontend/src/optimizer/plan_node/stream_project_set.rs b/src/frontend/src/optimizer/plan_node/stream_project_set.rs index 4630e1c62c83..5735c4b9d564 100644 --- a/src/frontend/src/optimizer/plan_node/stream_project_set.rs +++ b/src/frontend/src/optimizer/plan_node/stream_project_set.rs @@ -22,7 +22,7 @@ use super::utils::impl_distill_by_unit; use super::{generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::expr::{ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; -use crate::optimizer::property::{analyze_monotonicity, monotonicity_variants}; +use crate::optimizer::property::{analyze_monotonicity, monotonicity_variants, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::utils::ColIndexMappingRewriteExt; @@ -53,11 +53,9 @@ impl StreamProjectSet { use monotonicity_variants::*; match analyze_monotonicity(expr) { - Inherent(Constant) => { - // XXX(rc): we can produce one watermark on each recovery for this case. - } Inherent(monotonicity) => { - if monotonicity.is_non_decreasing() { + if monotonicity.is_non_decreasing() && !monotonicity.is_constant() { + // TODO(rc): may be we should also derive watermark for constant later // FIXME(rc): we need to check expr is not table function nondecreasing_exprs.push(expr_idx); // to produce watermarks out_watermark_columns.insert(out_expr_idx); @@ -81,6 +79,7 @@ impl StreamProjectSet { input.append_only(), input.emit_on_window_close(), out_watermark_columns, + MonotonicityMap::new(), // TODO: derive monotonicity ); StreamProjectSet { base, diff --git a/src/frontend/src/optimizer/plan_node/stream_row_id_gen.rs b/src/frontend/src/optimizer/plan_node/stream_row_id_gen.rs index bf4bafeed26d..36b96f4dad36 100644 --- a/src/frontend/src/optimizer/plan_node/stream_row_id_gen.rs +++ b/src/frontend/src/optimizer/plan_node/stream_row_id_gen.rs @@ -50,6 +50,7 @@ impl StreamRowIdGen { input.append_only(), input.emit_on_window_close(), input.watermark_columns().clone(), + input.columns_monotonicity().clone(), ); Self { base, diff --git a/src/frontend/src/optimizer/plan_node/stream_share.rs b/src/frontend/src/optimizer/plan_node/stream_share.rs index 5bf575f622bc..b082d82b022d 100644 --- a/src/frontend/src/optimizer/plan_node/stream_share.rs +++ b/src/frontend/src/optimizer/plan_node/stream_share.rs @@ -44,6 +44,7 @@ impl StreamShare { input.append_only(), input.emit_on_window_close(), input.watermark_columns().clone(), + input.columns_monotonicity().clone(), ) }; diff --git a/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs b/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs index 28a377ca04cd..6ecaa4c308f5 100644 --- a/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs +++ b/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs @@ -23,7 +23,7 @@ use super::utils::{childless_record, plan_node_name, Distill}; use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::expr::{ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; -use crate::optimizer::property::Distribution; +use crate::optimizer::property::{Distribution, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -50,7 +50,14 @@ impl StreamSimpleAgg { let watermark_columns = FixedBitSet::with_capacity(core.output_len()); // Simple agg executor might change the append-only behavior of the stream. - let base = PlanBase::new_stream_with_core(&core, dist, false, false, watermark_columns); + let base = PlanBase::new_stream_with_core( + &core, + dist, + false, + false, + watermark_columns, + MonotonicityMap::new(), + ); StreamSimpleAgg { base, core, diff --git a/src/frontend/src/optimizer/plan_node/stream_sort.rs b/src/frontend/src/optimizer/plan_node/stream_sort.rs index 6b45f8fd35a6..c4acd275f123 100644 --- a/src/frontend/src/optimizer/plan_node/stream_sort.rs +++ b/src/frontend/src/optimizer/plan_node/stream_sort.rs @@ -24,6 +24,7 @@ use super::stream::prelude::*; use super::utils::{childless_record, Distill, TableCatalogBuilder}; use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::property::{Monotonicity, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::TableCatalog; @@ -53,8 +54,14 @@ impl StreamEowcSort { let stream_key = input.stream_key().map(|v| v.to_vec()); let fd_set = input.functional_dependency().clone(); let dist = input.distribution().clone(); + let mut watermark_columns = FixedBitSet::with_capacity(input.schema().len()); watermark_columns.insert(sort_column_index); + + // StreamEowcSort makes the sorting watermark column non-decreasing + let mut columns_monotonicity = MonotonicityMap::new(); + columns_monotonicity.insert(sort_column_index, Monotonicity::NonDecreasing); + let base = PlanBase::new_stream( input.ctx(), schema, @@ -64,6 +71,7 @@ impl StreamEowcSort { true, true, watermark_columns, + columns_monotonicity, ); Self { base, diff --git a/src/frontend/src/optimizer/plan_node/stream_source.rs b/src/frontend/src/optimizer/plan_node/stream_source.rs index 7b0703aa8436..980df7911c7f 100644 --- a/src/frontend/src/optimizer/plan_node/stream_source.rs +++ b/src/frontend/src/optimizer/plan_node/stream_source.rs @@ -29,7 +29,7 @@ use super::{generic, ExprRewritable, PlanBase, StreamNode}; use crate::catalog::source_catalog::SourceCatalog; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::utils::column_names_pretty; -use crate::optimizer::property::Distribution; +use crate::optimizer::property::{Distribution, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; /// [`StreamSource`] represents a table/connector source at the very beginning of the graph. @@ -64,6 +64,7 @@ impl StreamSource { core.catalog.as_ref().map_or(true, |s| s.append_only), false, FixedBitSet::with_capacity(core.column_catalog.len()), + MonotonicityMap::new(), ); Self { base, core } } diff --git a/src/frontend/src/optimizer/plan_node/stream_source_scan.rs b/src/frontend/src/optimizer/plan_node/stream_source_scan.rs index b947cee641d4..47689c103879 100644 --- a/src/frontend/src/optimizer/plan_node/stream_source_scan.rs +++ b/src/frontend/src/optimizer/plan_node/stream_source_scan.rs @@ -32,7 +32,7 @@ use crate::catalog::source_catalog::SourceCatalog; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::utils::{childless_record, Distill}; use crate::optimizer::plan_node::{generic, ExprRewritable, StreamNode}; -use crate::optimizer::property::Distribution; +use crate::optimizer::property::{Distribution, MonotonicityMap}; use crate::scheduler::SchedulerResult; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::{Explain, TableCatalog}; @@ -75,6 +75,7 @@ impl StreamSourceScan { core.catalog.as_ref().map_or(true, |s| s.append_only), false, FixedBitSet::with_capacity(core.column_catalog.len()), + MonotonicityMap::new(), ); Self { base, core } diff --git a/src/frontend/src/optimizer/plan_node/stream_stateless_simple_agg.rs b/src/frontend/src/optimizer/plan_node/stream_stateless_simple_agg.rs index 8ce0997b7fe1..93c56efad3d5 100644 --- a/src/frontend/src/optimizer/plan_node/stream_stateless_simple_agg.rs +++ b/src/frontend/src/optimizer/plan_node/stream_stateless_simple_agg.rs @@ -22,7 +22,7 @@ use super::utils::impl_distill_by_unit; use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::expr::{ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; -use crate::optimizer::property::RequiredDist; +use crate::optimizer::property::{MonotonicityMap, RequiredDist}; use crate::stream_fragmenter::BuildFragmentGraphState; /// Streaming stateless simple agg. @@ -57,6 +57,7 @@ impl StreamStatelessSimpleAgg { input.append_only(), input.emit_on_window_close(), watermark_columns, + MonotonicityMap::new(), ); StreamStatelessSimpleAgg { base, core } } diff --git a/src/frontend/src/optimizer/plan_node/stream_table_scan.rs b/src/frontend/src/optimizer/plan_node/stream_table_scan.rs index 1e93514f6c0f..2255194dbee6 100644 --- a/src/frontend/src/optimizer/plan_node/stream_table_scan.rs +++ b/src/frontend/src/optimizer/plan_node/stream_table_scan.rs @@ -31,7 +31,7 @@ use crate::catalog::ColumnId; use crate::expr::{ExprRewriter, ExprVisitor, FunctionCall}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::utils::{IndicesDisplay, TableCatalogBuilder}; -use crate::optimizer::property::{Distribution, DistributionDisplay}; +use crate::optimizer::property::{Distribution, DistributionDisplay, MonotonicityMap}; use crate::scheduler::SchedulerResult; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::TableCatalog; @@ -74,6 +74,7 @@ impl StreamTableScan { core.append_only(), false, core.watermark_columns(), + MonotonicityMap::new(), ); Self { base, diff --git a/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs b/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs index f94dbba36cb7..7f50d4b27e62 100644 --- a/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs @@ -69,12 +69,19 @@ impl StreamTemporalJoin { .rewrite_bitset(core.left.watermark_columns()), ); + let columns_monotonicity = core.i2o_col_mapping().rewrite_monotonicity_map( + &core + .l2i_col_mapping() + .rewrite_monotonicity_map(core.left.columns_monotonicity()), + ); + let base = PlanBase::new_stream_with_core( &core, dist, append_only, false, // TODO(rc): derive EOWC property from input watermark_columns, + columns_monotonicity, ); Self { diff --git a/src/frontend/src/optimizer/plan_node/stream_topn.rs b/src/frontend/src/optimizer/plan_node/stream_topn.rs index 9581c07ff297..80ca9141033c 100644 --- a/src/frontend/src/optimizer/plan_node/stream_topn.rs +++ b/src/frontend/src/optimizer/plan_node/stream_topn.rs @@ -21,7 +21,7 @@ use super::stream::prelude::*; use super::utils::{plan_node_name, Distill}; use super::{generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; -use crate::optimizer::property::{Distribution, Order}; +use crate::optimizer::property::{Distribution, MonotonicityMap, Order}; use crate::stream_fragmenter::BuildFragmentGraphState; /// `StreamTopN` implements [`super::LogicalTopN`] to find the top N elements with a heap @@ -42,7 +42,14 @@ impl StreamTopN { }; let watermark_columns = FixedBitSet::with_capacity(input.schema().len()); - let base = PlanBase::new_stream_with_core(&core, dist, false, false, watermark_columns); + let base = PlanBase::new_stream_with_core( + &core, + dist, + false, + false, + watermark_columns, + MonotonicityMap::new(), + ); StreamTopN { base, core } } diff --git a/src/frontend/src/optimizer/plan_node/stream_union.rs b/src/frontend/src/optimizer/plan_node/stream_union.rs index 1c269ec0c5ad..2e424fc0604b 100644 --- a/src/frontend/src/optimizer/plan_node/stream_union.rs +++ b/src/frontend/src/optimizer/plan_node/stream_union.rs @@ -25,7 +25,7 @@ use super::{generic, ExprRewritable, PlanRef}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::generic::GenericPlanNode; use crate::optimizer::plan_node::{PlanBase, PlanTreeNode, StreamNode}; -use crate::optimizer::property::Distribution; +use crate::optimizer::property::{Distribution, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; /// `StreamUnion` implements [`super::LogicalUnion`] @@ -61,6 +61,7 @@ impl StreamUnion { inputs.iter().all(|x| x.append_only()), inputs.iter().all(|x| x.emit_on_window_close()), watermark_columns, + MonotonicityMap::new(), ); StreamUnion { base, core } diff --git a/src/frontend/src/optimizer/plan_node/stream_values.rs b/src/frontend/src/optimizer/plan_node/stream_values.rs index 05cb1659b96e..0a71c208c32e 100644 --- a/src/frontend/src/optimizer/plan_node/stream_values.rs +++ b/src/frontend/src/optimizer/plan_node/stream_values.rs @@ -23,7 +23,7 @@ use super::utils::{childless_record, Distill}; use super::{ExprRewritable, LogicalValues, PlanBase, StreamNode}; use crate::expr::{Expr, ExprImpl, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; -use crate::optimizer::property::Distribution; +use crate::optimizer::property::{Distribution, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; /// `StreamValues` implements `LogicalValues.to_stream()` @@ -48,6 +48,7 @@ impl StreamValues { true, false, FixedBitSet::with_capacity(logical.schema().len()), + MonotonicityMap::new(), ); Self { base, logical } } diff --git a/src/frontend/src/optimizer/plan_node/stream_watermark_filter.rs b/src/frontend/src/optimizer/plan_node/stream_watermark_filter.rs index ffb08776b3fe..7ea0dbaf6dd9 100644 --- a/src/frontend/src/optimizer/plan_node/stream_watermark_filter.rs +++ b/src/frontend/src/optimizer/plan_node/stream_watermark_filter.rs @@ -49,6 +49,8 @@ impl StreamWatermarkFilter { input.append_only(), false, // TODO(rc): decide EOWC property watermark_columns, + // watermark filter preserves input order and hence monotonicity + input.columns_monotonicity().clone(), ); Self::with_base(base, input, watermark_descs) } diff --git a/src/frontend/src/optimizer/property/monotonicity.rs b/src/frontend/src/optimizer/property/monotonicity.rs index 87f74c25b83f..d3091ca029db 100644 --- a/src/frontend/src/optimizer/property/monotonicity.rs +++ b/src/frontend/src/optimizer/property/monotonicity.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::BTreeMap; +use std::ops::Index; + use enum_as_inner::EnumAsInner; use risingwave_common::types::DataType; use risingwave_pb::expr::expr_node::Type as ExprType; @@ -42,8 +45,29 @@ impl MonotonicityDerivation { } } -/// Represents the monotonicity of a column. `NULL`s are considered largest when analyzing monotonicity. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumAsInner)] +/// Represents the monotonicity of a column. +/// +/// Monotonicity is a property of the output column of stream node that describes the the order +/// of the values in the column. One [`Monotonicity`] value is associated with one column, so +/// each stream node should have a [`MonotonicityMap`] to describe the monotonicity of all its +/// output columns. +/// +/// For operator that yields append-only stream, the monotonicity being `NonDecreasing` means +/// that it will never yield a row smaller than any previously yielded row. +/// +/// For operator that yields non-append-only stream, the monotonicity being `NonDecreasing` means +/// that it will never yield a change that has smaller value than any previously yielded change, +/// ignoring the `Op`. So if such operator yields a `NonDecreasing` column, `Delete` and `UpdateDelete`s +/// can only happen on the last emitted row (or last rows with the same value on the column). This +/// is especially useful for `StreamNow` operator with `UpdateCurrent` mode, in which case only +/// one output row is actively maintained and the value is non-decreasing. +/// +/// Monotonicity property is be considered in default order type, i.e., ASC NULLS LAST. This means +/// that `NULL`s are considered largest when analyzing monotonicity. +/// +/// For distributed operators, the monotonicity describes the property of the output column of +/// each shard of the operator. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Monotonicity { Constant, NonDecreasing, @@ -52,6 +76,24 @@ pub enum Monotonicity { } impl Monotonicity { + pub fn is_constant(self) -> bool { + matches!(self, Monotonicity::Constant) + } + + pub fn is_non_decreasing(self) -> bool { + // we don't use `EnumAsInner` here because we need to include `Constant` + matches!(self, Monotonicity::NonDecreasing | Monotonicity::Constant) + } + + pub fn is_non_increasing(self) -> bool { + // similar to `is_non_decreasing` + matches!(self, Monotonicity::NonIncreasing | Monotonicity::Constant) + } + + pub fn is_unknown(self) -> bool { + matches!(self, Monotonicity::Unknown) + } + pub fn inverse(self) -> Self { use Monotonicity::*; match self { @@ -271,3 +313,48 @@ impl MonotonicityAnalyzer { Inherent(Unknown) } } + +/// A map from column index to its monotonicity. +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +pub struct MonotonicityMap(BTreeMap); + +impl MonotonicityMap { + pub fn new() -> Self { + MonotonicityMap(BTreeMap::new()) + } + + pub fn insert(&mut self, idx: usize, monotonicity: Monotonicity) { + if monotonicity != Monotonicity::Unknown { + self.0.insert(idx, monotonicity); + } + } + + pub fn iter(&self) -> impl Iterator + '_ { + self.0 + .iter() + .map(|(idx, monotonicity)| (*idx, *monotonicity)) + } +} + +impl Index for MonotonicityMap { + type Output = Monotonicity; + + fn index(&self, idx: usize) -> &Self::Output { + self.0.get(&idx).unwrap_or(&Monotonicity::Unknown) + } +} + +impl IntoIterator for MonotonicityMap { + type IntoIter = std::collections::btree_map::IntoIter; + type Item = (usize, Monotonicity); + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl FromIterator<(usize, Monotonicity)> for MonotonicityMap { + fn from_iter>(iter: T) -> Self { + MonotonicityMap(iter.into_iter().collect()) + } +} diff --git a/src/frontend/src/utils/column_index_mapping.rs b/src/frontend/src/utils/column_index_mapping.rs index 08343eb9f09c..4a7a729eb9b7 100644 --- a/src/frontend/src/utils/column_index_mapping.rs +++ b/src/frontend/src/utils/column_index_mapping.rs @@ -20,7 +20,8 @@ use risingwave_common::util::sort_util::ColumnOrder; use crate::expr::{Expr, ExprImpl, ExprRewriter, InputRef}; use crate::optimizer::property::{ - Distribution, FunctionalDependency, FunctionalDependencySet, Order, RequiredDist, + Distribution, FunctionalDependency, FunctionalDependencySet, MonotonicityMap, Order, + RequiredDist, }; /// Extension trait for [`ColIndexMapping`] to rewrite frontend structures. @@ -186,6 +187,16 @@ impl ColIndexMapping { } ret } + + pub fn rewrite_monotonicity_map(&self, map: &MonotonicityMap) -> MonotonicityMap { + let mut new_map = MonotonicityMap::new(); + for (i, monotonicity) in map.iter() { + if let Some(mapped_i) = self.try_map(i) { + new_map.insert(mapped_i, monotonicity); + } + } + new_map + } } impl ExprRewriter for ColIndexMapping { From 4321a81042a8069c1c0fbaed3da9e26b0a8a518e Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Thu, 18 Jul 2024 11:05:25 +0800 Subject: [PATCH 35/70] feat(dyn-filter): derive `condition_always_relax` from column monotonicity (#17704) Signed-off-by: Richard Chien --- .../plan_node/stream_dynamic_filter.rs | 25 +++---------------- .../optimizer/plan_node/stream_exchange.rs | 19 +++++++++----- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs b/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs index f32bd63753d2..1d01650d68a0 100644 --- a/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs +++ b/src/frontend/src/optimizer/plan_node/stream_dynamic_filter.rs @@ -23,11 +23,11 @@ use super::stream::prelude::*; use super::utils::{ childless_record, column_names_pretty, plan_node_name, watermark_pretty, Distill, }; -use super::{generic, ExprRewritable, PlanTreeNodeUnary}; +use super::{generic, ExprRewritable}; use crate::expr::Expr; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::{PlanBase, PlanTreeNodeBinary, StreamNode}; -use crate::optimizer::property::{Distribution, MonotonicityMap}; +use crate::optimizer::property::MonotonicityMap; use crate::optimizer::PlanRef; use crate::stream_fragmenter::BuildFragmentGraphState; @@ -41,25 +41,8 @@ pub struct StreamDynamicFilter { impl StreamDynamicFilter { pub fn new(core: DynamicFilter) -> Self { - // TODO(st1page): here we just check if RHS - // is a `StreamNow`. It will be generalized to more cases - // by introducing monotonically increasing property of the node in https://github.com/risingwavelabs/risingwave/pull/13984. - let right_monotonically_increasing = { - if let Some(e) = core.right().as_stream_exchange() - && *e.distribution() == Distribution::Broadcast - { - if e.input().as_stream_now().is_some() { - true - } else if let Some(proj) = e.input().as_stream_project() { - proj.input().as_stream_now().is_some() - } else { - false - } - } else { - false - } - }; - let condition_always_relax = right_monotonically_increasing + let right_non_decreasing = core.right().columns_monotonicity()[0].is_non_decreasing(); + let condition_always_relax = right_non_decreasing && matches!( core.comparator(), ExprType::LessThan | ExprType::LessThanOrEqual diff --git a/src/frontend/src/optimizer/plan_node/stream_exchange.rs b/src/frontend/src/optimizer/plan_node/stream_exchange.rs index 802f2e3d227c..878e34d577b3 100644 --- a/src/frontend/src/optimizer/plan_node/stream_exchange.rs +++ b/src/frontend/src/optimizer/plan_node/stream_exchange.rs @@ -20,7 +20,9 @@ use super::stream::prelude::*; use super::utils::{childless_record, plan_node_name, Distill}; use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, StreamNode}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; -use crate::optimizer::property::{Distribution, DistributionDisplay, MonotonicityMap}; +use crate::optimizer::property::{ + Distribution, DistributionDisplay, MonotonicityMap, RequiredDist, +}; use crate::stream_fragmenter::BuildFragmentGraphState; /// `StreamExchange` imposes a particular distribution on its input @@ -34,17 +36,23 @@ pub struct StreamExchange { impl StreamExchange { pub fn new(input: PlanRef, dist: Distribution) -> Self { - // Dispatch executor won't change the append-only behavior of the stream. + let columns_monotonicity = if input.distribution().satisfies(&RequiredDist::single()) { + // If the input is a singleton, the monotonicity will be preserved during shuffle + // since we use ordered channel/buffer when exchanging data. + input.columns_monotonicity().clone() + } else { + MonotonicityMap::new() + }; let base = PlanBase::new_stream( input.ctx(), input.schema().clone(), input.stream_key().map(|v| v.to_vec()), input.functional_dependency().clone(), dist, - input.append_only(), + input.append_only(), // append-only property won't change input.emit_on_window_close(), input.watermark_columns().clone(), - MonotonicityMap::new(), // we lost monotonicity information when shuffling + columns_monotonicity, ); StreamExchange { base, @@ -55,14 +63,13 @@ impl StreamExchange { pub fn new_no_shuffle(input: PlanRef) -> Self { let ctx = input.ctx(); - // Dispatch executor won't change the append-only behavior of the stream. let base = PlanBase::new_stream( ctx, input.schema().clone(), input.stream_key().map(|v| v.to_vec()), input.functional_dependency().clone(), input.distribution().clone(), - input.append_only(), + input.append_only(), // append-only property won't change input.emit_on_window_close(), input.watermark_columns().clone(), input.columns_monotonicity().clone(), From 428563381674e1e25a21d2f61523938fc4655bb5 Mon Sep 17 00:00:00 2001 From: Li0k Date: Thu, 18 Jul 2024 12:22:47 +0800 Subject: [PATCH 36/70] fix(storage): fix multi builder data loss (#17730) --- src/storage/src/hummock/sstable/multi_builder.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/storage/src/hummock/sstable/multi_builder.rs b/src/storage/src/hummock/sstable/multi_builder.rs index 5ce80457e3dc..97b448faec8d 100644 --- a/src/storage/src/hummock/sstable/multi_builder.rs +++ b/src/storage/src/hummock/sstable/multi_builder.rs @@ -306,7 +306,10 @@ where self.seal_current().await?; try_join_all(self.concurrent_upload_join_handle.into_iter()) .await - .map_err(HummockError::sstable_upload_error)?; + .map_err(HummockError::sstable_upload_error)? + .into_iter() + .collect::>>()?; + Ok(self.sst_outputs) } } From 332802596d2eefa46e65912f5def5f25f0d4d5a6 Mon Sep 17 00:00:00 2001 From: William Wen <44139337+wenym1@users.noreply.github.com> Date: Thu, 18 Jul 2024 15:40:59 +0800 Subject: [PATCH 37/70] fix(storage): fix spill table id too strict assertion (#17736) --- .../src/hummock/event_handler/uploader/mod.rs | 7 ++++++- .../hummock/event_handler/uploader/spiller.rs | 16 +++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/storage/src/hummock/event_handler/uploader/mod.rs b/src/storage/src/hummock/event_handler/uploader/mod.rs index 101f54541fed..6c7ded7b3bf0 100644 --- a/src/storage/src/hummock/event_handler/uploader/mod.rs +++ b/src/storage/src/hummock/event_handler/uploader/mod.rs @@ -965,7 +965,12 @@ impl UploaderData { .map(|task_id| { let (sst, spill_table_ids) = self.spilled_data.remove(task_id).expect("should exist"); - assert_eq!(spill_table_ids, table_ids); + assert!( + spill_table_ids.is_subset(&table_ids), + "spilled tabled ids {:?} not a subset of sync table id {:?}", + spill_table_ids, + table_ids + ); sst }) .collect(); diff --git a/src/storage/src/hummock/event_handler/uploader/spiller.rs b/src/storage/src/hummock/event_handler/uploader/spiller.rs index a4caa3c05fe3..ba04d85856ac 100644 --- a/src/storage/src/hummock/event_handler/uploader/spiller.rs +++ b/src/storage/src/hummock/event_handler/uploader/spiller.rs @@ -248,7 +248,7 @@ mod tests { uploader.add_imm(instance_id2, imm2_4_1.clone()); // uploader state: - // table_id1: + // table_id1: table_id2: // instance_id1_1: instance_id1_2: instance_id2 // epoch1 imm1_1_1 imm1_2_1 | imm2_1 | // epoch2 imms1_1_2(size 3) | | @@ -314,12 +314,12 @@ mod tests { uploader.local_seal_epoch(instance_id2, u64::MAX, SealCurrentEpochOptions::for_test()); // uploader state: - // table_id1: + // table_id1: table_id2: // instance_id1_1: instance_id1_2: instance_id2 // epoch1 spill(imm1_1_1, imm1_2_1, size 2) | spill(imm2_1, size 1) | // epoch2 spill(imms1_1_2, size 3) | | // epoch3 spill(imms_1_2_3, size 4) | | - // epoch4 spill(imm1_1_4, imm1_2_4, size 2) | spill(imm2_4_1, size 1), imm2_4_2 | + // epoch4 spill(imm1_1_4, imm1_2_4, size 2) spill(imm2_4_1, size 1), imm2_4_2 | let (sync_tx1_1, sync_rx1_1) = oneshot::channel(); uploader.start_sync_epoch(epoch1, sync_tx1_1, HashSet::from_iter([table_id1])); @@ -339,10 +339,6 @@ mod tests { vec![imm2_4_2.batch_id()], )])); - let (sync_tx4, mut sync_rx4) = oneshot::channel(); - uploader.start_sync_epoch(epoch4, sync_tx4, HashSet::from_iter([table_id1, table_id2])); - await_start2_4_2.await; - finish_tx2_4_1.send(()).unwrap(); finish_tx3.send(()).unwrap(); finish_tx2.send(()).unwrap(); @@ -399,6 +395,12 @@ mod tests { finish_tx2_1.send(()).unwrap(); let sst = uploader.next_uploaded_sst().await; assert_eq!(&imm_ids1_4, sst.imm_ids()); + + // trigger the sync after the spill task is finished and acked to cover the case + let (sync_tx4, mut sync_rx4) = oneshot::channel(); + uploader.start_sync_epoch(epoch4, sync_tx4, HashSet::from_iter([table_id1, table_id2])); + await_start2_4_2.await; + let sst = uploader.next_uploaded_sst().await; assert_eq!(&imm_ids2_1, sst.imm_ids()); let sst = uploader.next_uploaded_sst().await; From 0883df550e7dc7358a509fd4d7f205f8b4c62241 Mon Sep 17 00:00:00 2001 From: StrikeW Date: Thu, 18 Jul 2024 16:59:15 +0800 Subject: [PATCH 38/70] fix(mysql-cdc): validate mysql version less then 8.4 (#17728) Co-authored-by: Chengyou Liu <35356271+cyliu0@users.noreply.github.com> --- .../connector/source/common/MySqlValidator.java | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/common/MySqlValidator.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/common/MySqlValidator.java index d20a18185a74..8c122f0f365e 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/common/MySqlValidator.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/common/MySqlValidator.java @@ -63,7 +63,15 @@ public MySqlValidator( @Override public void validateDbConfig() { try { - // TODO: check database server version + // Check whether MySQL version is less than 8.4, + // since MySQL 8.4 introduces some breaking changes: + // https://dev.mysql.com/doc/relnotes/mysql/8.4/en/news-8-4-0.html#mysqld-8-4-0-deprecation-removal + var major = jdbcConnection.getMetaData().getDatabaseMajorVersion(); + var minor = jdbcConnection.getMetaData().getDatabaseMinorVersion(); + + if ((major > 8) || (major == 8 && minor >= 4)) { + throw ValidatorUtils.failedPrecondition("MySQL version should be less than 8.4"); + } validateBinlogConfig(); } catch (SQLException e) { throw ValidatorUtils.internalError(e.getMessage()); From 2e1d910bc18b5a9351d961de92cb8b832b52ddad Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 18 Jul 2024 23:20:07 +0800 Subject: [PATCH 39/70] feat(batch): support batch read s3 parquet file (#17673) --- proto/batch_plan.proto | 21 ++++++ src/batch/src/executor/mod.rs | 2 + src/batch/src/executor/s3_file_scan.rs | 40 ++++++++++- .../src/optimizer/logical_optimization.rs | 14 ++-- .../optimizer/plan_node/batch_file_scan.rs | 21 +++++- .../optimizer/plan_node/generic/file_scan.rs | 15 ++++- .../optimizer/plan_node/logical_file_scan.rs | 4 +- .../src/scheduler/distributed/stage.rs | 9 +++ src/frontend/src/scheduler/local.rs | 34 ++++++++++ src/frontend/src/scheduler/plan_fragmenter.rs | 66 ++++++++++++++++--- 10 files changed, 204 insertions(+), 22 deletions(-) diff --git a/proto/batch_plan.proto b/proto/batch_plan.proto index a7a70246f0f7..b4b8eee563e2 100644 --- a/proto/batch_plan.proto +++ b/proto/batch_plan.proto @@ -66,6 +66,26 @@ message SourceNode { map secret_refs = 6; } +message FileScanNode { + enum FileFormat { + FILE_FORMAT_UNSPECIFIED = 0; + PARQUET = 1; + } + + enum StorageType { + STORAGE_TYPE_UNSPECIFIED = 0; + S3 = 1; + } + + repeated plan_common.ColumnDesc columns = 1; + FileFormat file_format = 2; + StorageType storage_type = 3; + string s3_region = 4; + string s3_access_key = 5; + string s3_secret_key = 6; + string file_location = 7; +} + message ProjectNode { repeated expr.ExprNode select_list = 1; } @@ -344,6 +364,7 @@ message PlanNode { SortOverWindowNode sort_over_window = 35; MaxOneRowNode max_one_row = 36; LogRowSeqScanNode log_row_seq_scan = 37; + FileScanNode file_scan = 38; // The following nodes are used for testing. bool block_executor = 100; bool busy_loop_executor = 101; diff --git a/src/batch/src/executor/mod.rs b/src/batch/src/executor/mod.rs index 3e2c2a8396a0..3a64901c64a0 100644 --- a/src/batch/src/executor/mod.rs +++ b/src/batch/src/executor/mod.rs @@ -87,6 +87,7 @@ pub use values::*; use self::log_row_seq_scan::LogStoreRowSeqScanExecutorBuilder; use self::test_utils::{BlockExecutorBuilder, BusyLoopExecutorBuilder}; use crate::error::Result; +use crate::executor::s3_file_scan::FileScanExecutorBuilder; use crate::executor::sys_row_seq_scan::SysRowSeqScanExecutorBuilder; use crate::task::{BatchTaskContext, ShutdownToken, TaskId}; @@ -241,6 +242,7 @@ impl<'a, C: BatchTaskContext> ExecutorBuilder<'a, C> { NodeBody::Source => SourceExecutor, NodeBody::SortOverWindow => SortOverWindowExecutor, NodeBody::MaxOneRow => MaxOneRowExecutor, + NodeBody::FileScan => FileScanExecutorBuilder, // Follow NodeBody only used for test NodeBody::BlockExecutor => BlockExecutorBuilder, NodeBody::BusyLoopExecutor => BusyLoopExecutorBuilder, diff --git a/src/batch/src/executor/s3_file_scan.rs b/src/batch/src/executor/s3_file_scan.rs index 7c56788f85ae..082df3585340 100644 --- a/src/batch/src/executor/s3_file_scan.rs +++ b/src/batch/src/executor/s3_file_scan.rs @@ -17,11 +17,15 @@ use futures_async_stream::try_stream; use futures_util::stream::StreamExt; use parquet::arrow::ProjectionMask; use risingwave_common::array::arrow::IcebergArrowConvert; -use risingwave_common::catalog::Schema; +use risingwave_common::catalog::{Field, Schema}; use risingwave_connector::source::iceberg::parquet_file_reader::create_parquet_stream_builder; +use risingwave_pb::batch_plan::file_scan_node; +use risingwave_pb::batch_plan::file_scan_node::StorageType; +use risingwave_pb::batch_plan::plan_node::NodeBody; use crate::error::BatchError; -use crate::executor::{DataChunk, Executor}; +use crate::executor::{BoxedExecutor, BoxedExecutorBuilder, DataChunk, Executor, ExecutorBuilder}; +use crate::task::BatchTaskContext; #[derive(PartialEq, Debug)] pub enum FileFormat { @@ -55,7 +59,6 @@ impl Executor for S3FileScanExecutor { } impl S3FileScanExecutor { - #![expect(dead_code)] pub fn new( file_format: FileFormat, location: String, @@ -113,3 +116,34 @@ impl S3FileScanExecutor { } } } + +pub struct FileScanExecutorBuilder {} + +#[async_trait::async_trait] +impl BoxedExecutorBuilder for FileScanExecutorBuilder { + async fn new_boxed_executor( + source: &ExecutorBuilder<'_, C>, + _inputs: Vec, + ) -> crate::error::Result { + let file_scan_node = try_match_expand!( + source.plan_node().get_node_body().unwrap(), + NodeBody::FileScan + )?; + + assert_eq!(file_scan_node.storage_type, StorageType::S3 as i32); + + Ok(Box::new(S3FileScanExecutor::new( + match file_scan_node::FileFormat::try_from(file_scan_node.file_format).unwrap() { + file_scan_node::FileFormat::Parquet => FileFormat::Parquet, + file_scan_node::FileFormat::Unspecified => unreachable!(), + }, + file_scan_node.file_location.clone(), + file_scan_node.s3_region.clone(), + file_scan_node.s3_access_key.clone(), + file_scan_node.s3_secret_key.clone(), + source.context.get_config().developer.chunk_size, + Schema::from_iter(file_scan_node.columns.iter().map(Field::from)), + source.plan_node().get_identity().clone(), + ))) + } +} diff --git a/src/frontend/src/optimizer/logical_optimization.rs b/src/frontend/src/optimizer/logical_optimization.rs index 4f95bde0b852..de9db6f8f22d 100644 --- a/src/frontend/src/optimizer/logical_optimization.rs +++ b/src/frontend/src/optimizer/logical_optimization.rs @@ -126,10 +126,14 @@ static STREAM_GENERATE_SERIES_WITH_NOW: LazyLock = LazyLock:: ) }); -static TABLE_FUNCTION_TO_PROJECT_SET: LazyLock = LazyLock::new(|| { +static TABLE_FUNCTION_CONVERT: LazyLock = LazyLock::new(|| { OptimizationStage::new( - "Table Function To Project Set", - vec![TableFunctionToProjectSetRule::create()], + "Table Function Convert", + vec![ + // Apply file scan rule first + TableFunctionToFileScanRule::create(), + TableFunctionToProjectSetRule::create(), + ], ApplyOrder::TopDown, ) }); @@ -592,7 +596,7 @@ impl LogicalOptimizer { // Should be applied before converting table function to project set. plan = plan.optimize_by_rules(&STREAM_GENERATE_SERIES_WITH_NOW); // In order to unnest a table function, we need to convert it into a `project_set` first. - plan = plan.optimize_by_rules(&TABLE_FUNCTION_TO_PROJECT_SET); + plan = plan.optimize_by_rules(&TABLE_FUNCTION_CONVERT); plan = Self::subquery_unnesting(plan, enable_share_plan, explain_trace, &ctx)?; if has_logical_max_one_row(plan.clone()) { @@ -700,7 +704,7 @@ impl LogicalOptimizer { // Table function should be converted into `file_scan` before `project_set`. plan = plan.optimize_by_rules(&TABLE_FUNCTION_TO_FILE_SCAN); // In order to unnest a table function, we need to convert it into a `project_set` first. - plan = plan.optimize_by_rules(&TABLE_FUNCTION_TO_PROJECT_SET); + plan = plan.optimize_by_rules(&TABLE_FUNCTION_CONVERT); plan = Self::subquery_unnesting(plan, false, explain_trace, &ctx)?; diff --git a/src/frontend/src/optimizer/plan_node/batch_file_scan.rs b/src/frontend/src/optimizer/plan_node/batch_file_scan.rs index 826f39441294..649c178855ef 100644 --- a/src/frontend/src/optimizer/plan_node/batch_file_scan.rs +++ b/src/frontend/src/optimizer/plan_node/batch_file_scan.rs @@ -13,7 +13,9 @@ // limitations under the License. use pretty_xmlish::XmlNode; +use risingwave_pb::batch_plan::file_scan_node::{FileFormat, StorageType}; use risingwave_pb::batch_plan::plan_node::NodeBody; +use risingwave_pb::batch_plan::FileScanNode; use super::batch::prelude::*; use super::utils::{childless_record, column_names_pretty, Distill}; @@ -75,7 +77,24 @@ impl ToDistributedBatch for BatchFileScan { impl ToBatchPb for BatchFileScan { fn to_batch_prost_body(&self) -> NodeBody { - todo!() + NodeBody::FileScan(FileScanNode { + columns: self + .core + .columns() + .into_iter() + .map(|col| col.to_protobuf()) + .collect(), + file_format: match self.core.file_format { + generic::FileFormat::Parquet => FileFormat::Parquet as i32, + }, + storage_type: match self.core.storage_type { + generic::StorageType::S3 => StorageType::S3 as i32, + }, + s3_region: self.core.s3_region.clone(), + s3_access_key: self.core.s3_access_key.clone(), + s3_secret_key: self.core.s3_secret_key.clone(), + file_location: self.core.file_location.clone(), + }) } } diff --git a/src/frontend/src/optimizer/plan_node/generic/file_scan.rs b/src/frontend/src/optimizer/plan_node/generic/file_scan.rs index f8ed20c12072..419780adc491 100644 --- a/src/frontend/src/optimizer/plan_node/generic/file_scan.rs +++ b/src/frontend/src/optimizer/plan_node/generic/file_scan.rs @@ -13,7 +13,7 @@ // limitations under the License. use educe::Educe; -use risingwave_common::catalog::Schema; +use risingwave_common::catalog::{ColumnDesc, ColumnId, Schema}; use super::GenericPlanNode; use crate::optimizer::optimizer_context::OptimizerContextRef; @@ -62,3 +62,16 @@ impl GenericPlanNode for FileScan { FunctionalDependencySet::new(self.schema.len()) } } + +impl FileScan { + pub fn columns(&self) -> Vec { + self.schema + .fields + .iter() + .enumerate() + .map(|(i, f)| { + ColumnDesc::named(f.name.clone(), ColumnId::new(i as i32), f.data_type.clone()) + }) + .collect() + } +} diff --git a/src/frontend/src/optimizer/plan_node/logical_file_scan.rs b/src/frontend/src/optimizer/plan_node/logical_file_scan.rs index df41023bb484..5e5964d0f086 100644 --- a/src/frontend/src/optimizer/plan_node/logical_file_scan.rs +++ b/src/frontend/src/optimizer/plan_node/logical_file_scan.rs @@ -106,13 +106,13 @@ impl ToBatch for LogicalFileScan { impl ToStream for LogicalFileScan { fn to_stream(&self, _ctx: &mut ToStreamContext) -> Result { - bail!("FileScan is not supported in streaming mode") + bail!("file_scan function is not supported in streaming mode") } fn logical_rewrite_for_stream( &self, _ctx: &mut RewriteStreamContext, ) -> Result<(PlanRef, ColIndexMapping)> { - bail!("FileScan is not supported in streaming mode") + bail!("file_scan function is not supported in streaming mode") } } diff --git a/src/frontend/src/scheduler/distributed/stage.rs b/src/frontend/src/scheduler/distributed/stage.rs index 8301c5a5b9d5..ce0d03356240 100644 --- a/src/frontend/src/scheduler/distributed/stage.rs +++ b/src/frontend/src/scheduler/distributed/stage.rs @@ -405,6 +405,15 @@ impl StageRunner { expr_context.clone(), )); } + } else if let Some(_file_scan_info) = self.stage.file_scan_info.as_ref() { + let task_id = PbTaskId { + query_id: self.stage.query_id.id.clone(), + stage_id: self.stage.id, + task_id: 0_u64, + }; + let plan_fragment = self.create_plan_fragment(0_u64, Some(PartitionInfo::File)); + let worker = self.choose_worker(&plan_fragment, 0_u32, self.stage.dml_table_id)?; + futures.push(self.schedule_task(task_id, plan_fragment, worker, expr_context.clone())); } else { for id in 0..self.stage.parallelism.unwrap() { let task_id = PbTaskId { diff --git a/src/frontend/src/scheduler/local.rs b/src/frontend/src/scheduler/local.rs index 966358da3ded..99f1b6eebda1 100644 --- a/src/frontend/src/scheduler/local.rs +++ b/src/frontend/src/scheduler/local.rs @@ -402,6 +402,40 @@ impl LocalQueryExecution { }; sources.push(exchange_source); } + } else if let Some(_file_scan_info) = &second_stage.file_scan_info { + let second_stage_plan_node = self.convert_plan_node( + &second_stage.root, + &mut None, + Some(PartitionInfo::File), + next_executor_id.clone(), + )?; + let second_stage_plan_fragment = PlanFragment { + root: Some(second_stage_plan_node), + exchange_info: Some(ExchangeInfo { + mode: DistributionMode::Single as i32, + ..Default::default() + }), + }; + let local_execute_plan = LocalExecutePlan { + plan: Some(second_stage_plan_fragment), + epoch: Some(self.snapshot.batch_query_epoch()), + tracing_context: tracing_context.clone(), + }; + // NOTE: select a random work node here. + let worker_node = self.worker_node_manager.next_random_worker()?; + let exchange_source = ExchangeSource { + task_output_id: Some(TaskOutputId { + task_id: Some(PbTaskId { + task_id: 0_u64, + stage_id: exchange_source_stage_id, + query_id: self.query.query_id.id.clone(), + }), + output_id: 0, + }), + host: Some(worker_node.host.as_ref().unwrap().clone()), + local_execute_plan: Some(Plan(local_execute_plan)), + }; + sources.push(exchange_source); } else { let second_stage_plan_node = self.convert_plan_node( &second_stage.root, diff --git a/src/frontend/src/scheduler/plan_fragmenter.rs b/src/frontend/src/scheduler/plan_fragmenter.rs index ee6f49486c1b..9ca9697fe3f5 100644 --- a/src/frontend/src/scheduler/plan_fragmenter.rs +++ b/src/frontend/src/scheduler/plan_fragmenter.rs @@ -433,6 +433,12 @@ pub struct TablePartitionInfo { pub enum PartitionInfo { Table(TablePartitionInfo), Source(Vec), + File, +} + +#[derive(Clone, Debug)] +pub struct FileScanInfo { + // Currently we only support one file, so we don't need to support any partition info. } /// Fragment part of `Query`. @@ -446,6 +452,7 @@ pub struct QueryStage { /// Indicates whether this stage contains a table scan node and the table's information if so. pub table_scan_info: Option, pub source_info: Option, + pub file_scan_info: Option, pub has_lookup_join: bool, pub dml_table_id: Option, pub session_id: SessionId, @@ -469,16 +476,21 @@ impl QueryStage { self.has_lookup_join } - pub fn clone_with_exchange_info(&self, exchange_info: Option) -> Self { + pub fn clone_with_exchange_info( + &self, + exchange_info: Option, + parallelism: Option, + ) -> Self { if let Some(exchange_info) = exchange_info { return Self { query_id: self.query_id.clone(), id: self.id, root: self.root.clone(), exchange_info: Some(exchange_info), - parallelism: self.parallelism, + parallelism, table_scan_info: self.table_scan_info.clone(), source_info: self.source_info.clone(), + file_scan_info: self.file_scan_info.clone(), has_lookup_join: self.has_lookup_join, dml_table_id: self.dml_table_id, session_id: self.session_id, @@ -509,6 +521,7 @@ impl QueryStage { parallelism: Some(task_parallelism), table_scan_info: self.table_scan_info.clone(), source_info: Some(source_info), + file_scan_info: self.file_scan_info.clone(), has_lookup_join: self.has_lookup_join, dml_table_id: self.dml_table_id, session_id: self.session_id, @@ -555,6 +568,7 @@ struct QueryStageBuilder { /// See also [`QueryStage::table_scan_info`]. table_scan_info: Option, source_info: Option, + file_scan_file: Option, has_lookup_join: bool, dml_table_id: Option, session_id: SessionId, @@ -572,6 +586,7 @@ impl QueryStageBuilder { exchange_info: Option, table_scan_info: Option, source_info: Option, + file_scan_file: Option, has_lookup_join: bool, dml_table_id: Option, session_id: SessionId, @@ -586,6 +601,7 @@ impl QueryStageBuilder { children_stages: vec![], table_scan_info, source_info, + file_scan_file, has_lookup_join, dml_table_id, session_id, @@ -608,6 +624,7 @@ impl QueryStageBuilder { parallelism: self.parallelism, table_scan_info: self.table_scan_info, source_info: self.source_info, + file_scan_info: self.file_scan_file, has_lookup_join: self.has_lookup_join, dml_table_id: self.dml_table_id, session_id: self.session_id, @@ -700,15 +717,10 @@ impl StageGraph { // If the stage has parallelism, it means it's a complete stage. complete_stages.insert( stage.id, - Arc::new(stage.clone_with_exchange_info(exchange_info)), + Arc::new(stage.clone_with_exchange_info(exchange_info, stage.parallelism)), ); None - } else { - assert!(matches!( - stage.source_info, - Some(SourceScanInfo::Incomplete(_)) - )); - + } else if matches!(stage.source_info, Some(SourceScanInfo::Incomplete(_))) { let complete_source_info = stage .source_info .as_ref() @@ -741,6 +753,13 @@ impl StageGraph { let parallelism = complete_stage.parallelism; complete_stages.insert(stage.id, complete_stage); parallelism + } else { + assert!(matches!(&stage.file_scan_info, Some(FileScanInfo {}))); + complete_stages.insert( + stage.id, + Arc::new(stage.clone_with_exchange_info(exchange_info, Some(1))), + ); + None }; for child_stage_id in self.child_edges.get(&stage.id).unwrap_or(&HashSet::new()) { @@ -854,6 +873,13 @@ impl BatchPlanFragmenter { } else { None }; + + let file_scan_info = if table_scan_info.is_none() && source_info.is_none() { + Self::collect_stage_file_scan(root.clone())? + } else { + None + }; + let mut has_lookup_join = false; let parallelism = match root.distribution() { Distribution::Single => { @@ -901,12 +927,14 @@ impl BatchPlanFragmenter { lookup_join_parallelism } else if source_info.is_some() { 0 + } else if file_scan_info.is_some() { + 1 } else { self.batch_parallelism } } }; - if source_info.is_none() && parallelism == 0 { + if source_info.is_none() && file_scan_info.is_none() && parallelism == 0 { return Err(BatchError::EmptyWorkerNodes.into()); } let parallelism = if parallelism == 0 { @@ -922,6 +950,7 @@ impl BatchPlanFragmenter { exchange_info, table_scan_info, source_info, + file_scan_info, has_lookup_join, dml_table_id, root.ctx().session_ctx().session_id(), @@ -1055,6 +1084,23 @@ impl BatchPlanFragmenter { .transpose() } + fn collect_stage_file_scan(node: PlanRef) -> SchedulerResult> { + if node.node_type() == PlanNodeType::BatchExchange { + // Do not visit next stage. + return Ok(None); + } + + if let Some(_batch_file_scan) = node.as_batch_file_scan() { + // Currently the file scan only support one file, so we just need a empty struct. + return Ok(Some(FileScanInfo {})); + } + + node.inputs() + .into_iter() + .find_map(|n| Self::collect_stage_file_scan(n).transpose()) + .transpose() + } + /// Check whether this stage contains a table scan node and the table's information if so. /// /// If there are multiple scan nodes in this stage, they must have the same distribution, but From af97625b736b1578de28d39fe7ad3656f643f449 Mon Sep 17 00:00:00 2001 From: lmatz Date: Fri, 19 Jul 2024 09:32:43 +0800 Subject: [PATCH 40/70] chore: update docker image from `v1.10.0-rc.1` to `v1.10.0-rc.3` (#17750) --- docker/docker-compose-distributed-etcd.yml | 2 +- docker/docker-compose-distributed.yml | 2 +- docker/docker-compose-etcd.yml | 2 +- docker/docker-compose-with-azblob.yml | 2 +- docker/docker-compose-with-gcs.yml | 2 +- docker/docker-compose-with-local-fs.yml | 2 +- docker/docker-compose-with-obs.yml | 2 +- docker/docker-compose-with-oss.yml | 2 +- docker/docker-compose-with-s3.yml | 2 +- docker/docker-compose-with-sqlite.yml | 2 +- docker/docker-compose.yml | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docker/docker-compose-distributed-etcd.yml b/docker/docker-compose-distributed-etcd.yml index 16382a54a2b7..1521134c9c49 100644 --- a/docker/docker-compose-distributed-etcd.yml +++ b/docker/docker-compose-distributed-etcd.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.3} services: compactor-0: <<: *image diff --git a/docker/docker-compose-distributed.yml b/docker/docker-compose-distributed.yml index 843d15c39cfd..05f78c97b45a 100644 --- a/docker/docker-compose-distributed.yml +++ b/docker/docker-compose-distributed.yml @@ -1,6 +1,6 @@ --- x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.3} services: compactor-0: <<: *image diff --git a/docker/docker-compose-etcd.yml b/docker/docker-compose-etcd.yml index 5cca2a704d9b..a5319193f181 100644 --- a/docker/docker-compose-etcd.yml +++ b/docker/docker-compose-etcd.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.3} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-azblob.yml b/docker/docker-compose-with-azblob.yml index ed056d9ad1c5..93d89c0ea285 100644 --- a/docker/docker-compose-with-azblob.yml +++ b/docker/docker-compose-with-azblob.yml @@ -1,6 +1,6 @@ --- x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.3} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-gcs.yml b/docker/docker-compose-with-gcs.yml index 28e52286df39..fe2b37d045de 100644 --- a/docker/docker-compose-with-gcs.yml +++ b/docker/docker-compose-with-gcs.yml @@ -1,6 +1,6 @@ --- x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.3} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-local-fs.yml b/docker/docker-compose-with-local-fs.yml index d6f0699c22b2..abd06cdc198a 100644 --- a/docker/docker-compose-with-local-fs.yml +++ b/docker/docker-compose-with-local-fs.yml @@ -1,6 +1,6 @@ --- x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.3} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-obs.yml b/docker/docker-compose-with-obs.yml index 72201df7e7dd..c28827447802 100644 --- a/docker/docker-compose-with-obs.yml +++ b/docker/docker-compose-with-obs.yml @@ -1,6 +1,6 @@ --- x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.3} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-oss.yml b/docker/docker-compose-with-oss.yml index 1fba5ea52f34..26366b5dd827 100644 --- a/docker/docker-compose-with-oss.yml +++ b/docker/docker-compose-with-oss.yml @@ -1,6 +1,6 @@ --- x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.3} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-s3.yml b/docker/docker-compose-with-s3.yml index a0cdcf5ef73d..3a2a4ffde4d3 100644 --- a/docker/docker-compose-with-s3.yml +++ b/docker/docker-compose-with-s3.yml @@ -1,6 +1,6 @@ --- x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.3} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-sqlite.yml b/docker/docker-compose-with-sqlite.yml index 0c3a04a661f3..eb575efcc5e0 100644 --- a/docker/docker-compose-with-sqlite.yml +++ b/docker/docker-compose-with-sqlite.yml @@ -1,6 +1,6 @@ --- x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.3} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index dc9ad11f0eeb..929dd23c2b65 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,6 +1,6 @@ --- x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.10.0-rc.3} services: risingwave-standalone: <<: *image From 2a52dd3a23c5ff29102fa21c1bf65a8ede81b68a Mon Sep 17 00:00:00 2001 From: Shanicky Chen Date: Fri, 19 Jul 2024 10:59:44 +0800 Subject: [PATCH 41/70] fix: Refactor `auto_parallelism.rs` to initialize `session` after killing compute node (#17751) --- .../scale/auto_parallelism.rs | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/tests/simulation/tests/integration_tests/scale/auto_parallelism.rs b/src/tests/simulation/tests/integration_tests/scale/auto_parallelism.rs index 982b85d48f4c..beb20ad8c0b2 100644 --- a/src/tests/simulation/tests/integration_tests/scale/auto_parallelism.rs +++ b/src/tests/simulation/tests/integration_tests/scale/auto_parallelism.rs @@ -217,7 +217,6 @@ async fn test_active_online() -> Result<()> { true, ); let mut cluster = Cluster::start(config.clone()).await?; - let mut session = cluster.start_session(); // Keep one worker reserved for adding later. cluster @@ -229,6 +228,8 @@ async fn test_active_online() -> Result<()> { )) .await; + let mut session = cluster.start_session(); + session.run("create table t (v1 int);").await?; session .run("create materialized view m as select count(*) from t;") @@ -303,7 +304,6 @@ async fn test_auto_parallelism_control_with_fixed_and_auto_helper( enable_auto_parallelism_control, ); let mut cluster = Cluster::start(config.clone()).await?; - let mut session = cluster.start_session(); // Keep one worker reserved for adding later. let select_worker = "compute-2"; @@ -316,6 +316,8 @@ async fn test_auto_parallelism_control_with_fixed_and_auto_helper( )) .await; + let mut session = cluster.start_session(); + session.run("create table t (v1 int);").await?; session @@ -490,10 +492,6 @@ async fn test_compatibility_with_low_level() -> Result<()> { true, ); let mut cluster = Cluster::start(config.clone()).await?; - let mut session = cluster.start_session(); - session - .run("SET streaming_use_arrangement_backfill = false;") - .await?; // Keep one worker reserved for adding later. let select_worker = "compute-2"; @@ -506,6 +504,11 @@ async fn test_compatibility_with_low_level() -> Result<()> { )) .await; + let mut session = cluster.start_session(); + session + .run("SET streaming_use_arrangement_backfill = false;") + .await?; + session.run("create table t(v int);").await?; // single fragment downstream @@ -631,7 +634,6 @@ async fn test_compatibility_with_low_level_and_arrangement_backfill() -> Result< true, ); let mut cluster = Cluster::start(config.clone()).await?; - let mut session = cluster.start_session(); // Keep one worker reserved for adding later. let select_worker = "compute-2"; @@ -644,6 +646,8 @@ async fn test_compatibility_with_low_level_and_arrangement_backfill() -> Result< )) .await; + let mut session = cluster.start_session(); + session.run("create table t(v int);").await?; // Streaming arrangement backfill From 50ca1003e022967057b78e3d49bc9858b4bbde84 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Fri, 19 Jul 2024 13:35:47 +0800 Subject: [PATCH 42/70] fix(risedev): always use dev profile for `risedev-dev` in `ci-start` (#17756) Signed-off-by: Bugen Zhao --- Makefile.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile.toml b/Makefile.toml index 6c392384f518..a1f6ad5421bf 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -1261,7 +1261,7 @@ echo If you still feel this is not enough, you may copy $(tput setaf 4)risedev$( [tasks.ci-start] category = "RiseDev - CI" dependencies = ["clean-data", "pre-start-dev"] -command = "target/${BUILD_MODE_DIR}/risedev-dev" +command = "target/debug/risedev-dev" # `risedev-dev` is always built in dev profile args = ["${@}"] description = "Clean data and start a full RisingWave dev cluster using risedev-dev" From dba2c53daa648e50006b3ef0162ee41b2d206f62 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Fri, 19 Jul 2024 14:02:35 +0800 Subject: [PATCH 43/70] feat(temporal-filter): support more now expressions in temporal filter pattern (#17745) Signed-off-by: Richard Chien --- .../tests/testdata/input/temporal_filter.yaml | 58 ++++++++++++------- .../tests/testdata/output/expr.yaml | 2 +- .../testdata/output/temporal_filter.yaml | 33 ++++++++++- src/frontend/src/expr/mod.rs | 48 ++++----------- .../src/optimizer/plan_node/logical_filter.rs | 21 ++----- 5 files changed, 86 insertions(+), 76 deletions(-) diff --git a/src/frontend/planner_test/tests/testdata/input/temporal_filter.yaml b/src/frontend/planner_test/tests/testdata/input/temporal_filter.yaml index 6bd62c1ce4d6..ce8fc00e1549 100644 --- a/src/frontend/planner_test/tests/testdata/input/temporal_filter.yaml +++ b/src/frontend/planner_test/tests/testdata/input/temporal_filter.yaml @@ -3,96 +3,96 @@ create table t1 (ts timestamp with time zone); select * from t1 where ts + interval '1 hour' > now(); expected_outputs: - - stream_plan + - stream_plan - name: Temporal filter works on complex columns on LHS (part 2) sql: | create table t1 (ts timestamp with time zone, time_to_live interval); select * from t1 where ts + time_to_live * 1.5 > now(); expected_outputs: - - stream_plan + - stream_plan - name: Temporal filter works on complex columns on LHS (part 2, flipped) sql: | create table t1 (ts timestamp with time zone, additional_time_to_live interval); select * from t1 where now() - interval '15 minutes' < ts + additional_time_to_live * 1.5; expected_outputs: - - stream_plan + - stream_plan - name: Temporal filter with `now()` in upper bound sql: |- create table t1 (ts timestamp with time zone); select * from t1 where now() - interval '15 minutes' > ts; expected_outputs: - - stream_plan - - stream_dist_plan + - stream_plan + - stream_dist_plan - name: Temporal filter with equal condition sql: |- create table t1 (ts timestamp with time zone); - select * from t1 where date_trunc('week', now()) = date_trunc('week',ts); + select * from t1 where date_trunc('week', now()) = date_trunc('week',ts); expected_outputs: - - stream_plan - - stream_dist_plan + - stream_plan + - stream_dist_plan - name: Temporal filter with `now()` in upper bound on append only table sql: |- create table t1 (ts timestamp with time zone) APPEND ONLY; select * from t1 where now() - interval '15 minutes' > ts; expected_outputs: - - stream_plan - - stream_dist_plan + - stream_plan + - stream_dist_plan - name: Temporal filter reorders now expressions correctly sql: | create table t1 (ts timestamp with time zone); select * from t1 where ts < now() - interval '1 hour' and ts >= now() - interval '2 hour'; expected_outputs: - - stream_plan - - stream_dist_plan + - stream_plan + - stream_dist_plan - name: Temporal filter in on clause for inner join's left side sql: | create table t1 (a int, ta timestamp with time zone); create table t2 (b int, tb timestamp with time zone); select * from t1 join t2 on a = b AND ta < now() - interval '1 hour' and ta >= now() - interval '2 hour'; expected_outputs: - - stream_plan + - stream_plan - name: Temporal filter in on clause for left join's left side sql: | create table t1 (a int, ta timestamp with time zone); create table t2 (b int, tb timestamp with time zone); select * from t1 left join t2 on a = b AND ta < now() - interval '1 hour' and ta >= now() - interval '2 hour'; expected_outputs: - - stream_error + - stream_error - name: Temporal filter in on clause for right join's left side sql: | create table t1 (a int, ta timestamp with time zone); create table t2 (b int, tb timestamp with time zone); select * from t1 right join t2 on a = b AND ta < now() - interval '1 hour' and ta >= now() - interval '2 hour'; expected_outputs: - - stream_plan + - stream_plan - name: Temporal filter in on clause for full join's left side sql: | create table t1 (a int, ta timestamp with time zone); create table t2 (b int, tb timestamp with time zone); select * from t1 full join t2 on a = b AND ta < now() - interval '1 hour' and ta >= now() - interval '2 hour'; expected_outputs: - - stream_error + - stream_error - name: Temporal filter in on clause for left join's right side sql: | create table t1 (a int, ta timestamp with time zone); create table t2 (b int, tb timestamp with time zone); select * from t1 left join t2 on a = b AND tb < now() - interval '1 hour' and tb >= now() - interval '2 hour'; expected_outputs: - - stream_plan + - stream_plan - name: Temporal filter in on clause for right join's right side sql: | create table t1 (a int, ta timestamp with time zone); create table t2 (b int, tb timestamp with time zone); select * from t1 right join t2 on a = b AND tb < now() - interval '1 hour' and tb >= now() - interval '2 hour'; expected_outputs: - - stream_error + - stream_error - name: Temporal filter after temporal join sql: | create table stream(id1 int, a1 int, b1 int, v1 timestamp with time zone) APPEND ONLY; create table version(id2 int, a2 int, b2 int, primary key (id2)); select id1, a1, id2, v1 from stream left join version FOR SYSTEM_TIME AS OF PROCTIME() on id1 = id2 where v1 > now(); expected_outputs: - - stream_plan + - stream_plan - name: Temporal filter with or predicate sql: | create table t1 (ts timestamp with time zone); @@ -116,4 +116,22 @@ create table t (t timestamp with time zone, a int); select * from t where (t > NOW() - INTERVAL '1 hour' OR t is NULL OR a < 1) AND (t < NOW() - INTERVAL '1 hour' OR a > 1); expected_outputs: - - stream_plan \ No newline at end of file + - stream_plan +- name: Non-trivial now expression + sql: | + create table t (ts timestamp with time zone, a int); + select * from t where ts + interval '1 hour' > date_trunc('day', now()); + expected_outputs: + - stream_plan +- name: Non-trivial now expression 2 + sql: | + create table t (ts timestamp with time zone, a int); + select * from t where ts + interval '1 hour' > date_trunc('day', ('2024-07-18 00:00:00+00:00'::timestamptz - ('2024-07-18 00:00:00+00:00'::timestamptz - now()))); + expected_outputs: + - stream_plan +- name: Non-monotonic now expression + sql: | + create table t (ts timestamp with time zone, a int); + select * from t where a > extract(hour from now()); + expected_outputs: + - stream_error diff --git a/src/frontend/planner_test/tests/testdata/output/expr.yaml b/src/frontend/planner_test/tests/testdata/output/expr.yaml index f88f7c4d69b7..d24d0f0eeba1 100644 --- a/src/frontend/planner_test/tests/testdata/output/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/output/expr.yaml @@ -520,7 +520,7 @@ sql: | create table t (v1 timestamp with time zone, v2 timestamp with time zone); select * from t where v1 >= now() or v2 >= now(); - stream_error: Conditions containing now must be of the form `input_expr cmp now() [+- const_expr]` or `now() [+- const_expr] cmp input_expr`, where `input_expr` references a column and contains no `now()`. + stream_error: Conditions containing now must be in the form of `input_expr cmp now_expr` or `now_expr cmp input_expr`, where `input_expr` references a column and contains no `now()`, and `now_expr` is a non-decreasing expression contains `now()`. - name: now inside HAVING clause sql: | create table t (v1 timestamp with time zone, v2 int); diff --git a/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml b/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml index 7bbd43ce3c35..edc3bb6c364c 100644 --- a/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml +++ b/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml @@ -81,7 +81,7 @@ - name: Temporal filter with equal condition sql: |- create table t1 (ts timestamp with time zone); - select * from t1 where date_trunc('week', now()) = date_trunc('week',ts); + select * from t1 where date_trunc('week', now()) = date_trunc('week',ts); stream_plan: |- StreamMaterialize { columns: [ts, t1._row_id(hidden), $expr1(hidden)], stream_key: [t1._row_id, $expr1], pk_columns: [t1._row_id, $expr1], pk_conflict: NoCheck } └─StreamExchange { dist: HashShard(t1._row_id, $expr1) } @@ -460,3 +460,34 @@ └─StreamShare { id: 2 } └─StreamFilter { predicate: IsNotNull(t.a) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- name: Non-trivial now expression + sql: | + create table t (ts timestamp with time zone, a int); + select * from t where ts + interval '1 hour' > date_trunc('day', now()); + stream_plan: |- + StreamMaterialize { columns: [ts, a, t._row_id(hidden)], stream_key: [t._row_id], pk_columns: [t._row_id], pk_conflict: NoCheck } + └─StreamProject { exprs: [t.ts, t.a, t._row_id] } + └─StreamDynamicFilter { predicate: ($expr1 > $expr2), output_watermarks: [$expr1], output: [t.ts, t.a, $expr1, t._row_id], cleaned_by_watermark: true } + ├─StreamProject { exprs: [t.ts, t.a, AddWithTimeZone(t.ts, '01:00:00':Interval, 'UTC':Varchar) as $expr1, t._row_id] } + │ └─StreamTableScan { table: t, columns: [t.ts, t.a, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamExchange { dist: Broadcast } + └─StreamProject { exprs: [DateTrunc('day':Varchar, now, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } + └─StreamNow { output: [now] } +- name: Non-trivial now expression 2 + sql: | + create table t (ts timestamp with time zone, a int); + select * from t where ts + interval '1 hour' > date_trunc('day', ('2024-07-18 00:00:00+00:00'::timestamptz - ('2024-07-18 00:00:00+00:00'::timestamptz - now()))); + stream_plan: |- + StreamMaterialize { columns: [ts, a, t._row_id(hidden)], stream_key: [t._row_id], pk_columns: [t._row_id], pk_conflict: NoCheck } + └─StreamProject { exprs: [t.ts, t.a, t._row_id] } + └─StreamDynamicFilter { predicate: ($expr1 > $expr2), output: [t.ts, t.a, $expr1, t._row_id] } + ├─StreamProject { exprs: [t.ts, t.a, AddWithTimeZone(t.ts, '01:00:00':Interval, 'UTC':Varchar) as $expr1, t._row_id] } + │ └─StreamTableScan { table: t, columns: [t.ts, t.a, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamExchange { dist: Broadcast } + └─StreamProject { exprs: [DateTrunc('day':Varchar, SubtractWithTimeZone('2024-07-18 00:00:00+00:00':Timestamptz, ('2024-07-18 00:00:00+00:00':Timestamptz - now), 'UTC':Varchar), 'UTC':Varchar) as $expr2] } + └─StreamNow { output: [now] } +- name: Non-monotonic now expression + sql: | + create table t (ts timestamp with time zone, a int); + select * from t where a > extract(hour from now()); + stream_error: Conditions containing now must be in the form of `input_expr cmp now_expr` or `now_expr cmp input_expr`, where `input_expr` references a column and contains no `now()`, and `now_expr` is a non-decreasing expression contains `now()`. diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index 19ffbaed92e2..444740c9400f 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -768,45 +768,36 @@ impl ExprImpl { } } - /// Accepts expressions of the form `input_expr cmp now() [+- const_expr]` or - /// `now() [+- const_expr] cmp input_expr`, where `input_expr` contains an - /// `InputRef` and contains no `now()`. + /// Accepts expressions of the form `input_expr cmp now_expr` or `now_expr cmp input_expr`, + /// where `input_expr` contains an `InputRef` and contains no `now()`, and `now_expr` + /// contains a `now()` but no `InputRef`. /// /// Canonicalizes to the first ordering and returns `(input_expr, cmp, now_expr)` pub fn as_now_comparison_cond(&self) -> Option<(ExprImpl, ExprType, ExprImpl)> { if let ExprImpl::FunctionCall(function_call) = self { match function_call.func_type() { - ty @ (ExprType::LessThan + ty @ (ExprType::Equal + | ExprType::LessThan | ExprType::LessThanOrEqual | ExprType::GreaterThan | ExprType::GreaterThanOrEqual) => { let (_, op1, op2) = function_call.clone().decompose_as_binary(); - if op1.count_nows() == 0 + if !op1.has_now() && op1.has_input_ref() - && op2.count_nows() > 0 - && op2.is_now_offset() + && op2.has_now() + && !op2.has_input_ref() { Some((op1, ty, op2)) - } else if op2.count_nows() == 0 + } else if op1.has_now() + && !op1.has_input_ref() + && !op2.has_now() && op2.has_input_ref() - && op1.count_nows() > 0 - && op1.is_now_offset() { Some((op2, Self::reverse_comparison(ty), op1)) } else { None } } - ty @ ExprType::Equal => { - let (_, op1, op2) = function_call.clone().decompose_as_binary(); - if op1.count_nows() == 0 && op1.has_input_ref() && op2.count_nows() > 0 { - Some((op1, ty, op2)) - } else if op2.count_nows() == 0 && op2.has_input_ref() && op1.count_nows() > 0 { - Some((op2, Self::reverse_comparison(ty), op1)) - } else { - None - } - } _ => None, } } else { @@ -862,23 +853,6 @@ impl ExprImpl { } } - /// Checks if expr is of the form `now() [+- const_expr]` - fn is_now_offset(&self) -> bool { - if let ExprImpl::Now(_) = self { - true - } else if let ExprImpl::FunctionCall(f) = self { - match f.func_type() { - ExprType::Add | ExprType::Subtract => { - let (_, lhs, rhs) = f.clone().decompose_as_binary(); - lhs.is_now_offset() && rhs.is_const() - } - _ => false, - } - } else { - false - } - } - /// Returns the `InputRef` and offset of a predicate if it matches /// the form `InputRef [+- const_expr]`, else returns None. fn as_input_offset(&self) -> Option<(usize, Option<(ExprType, ExprImpl)>)> { diff --git a/src/frontend/src/optimizer/plan_node/logical_filter.rs b/src/frontend/src/optimizer/plan_node/logical_filter.rs index 04cc2cb12a68..25062ee0eebc 100644 --- a/src/frontend/src/optimizer/plan_node/logical_filter.rs +++ b/src/frontend/src/optimizer/plan_node/logical_filter.rs @@ -197,24 +197,11 @@ impl ToStream for LogicalFilter { let new_input = self.input().to_stream(ctx)?; let predicate = self.predicate(); - let has_now = predicate - .conjunctions - .iter() - .any(|cond| cond.count_nows() > 0); - if has_now { - if predicate - .conjunctions - .iter() - .any(|expr| expr.count_nows() > 0 && expr.as_now_comparison_cond().is_none()) - { - bail!( - "Conditions containing now must be of the form `input_expr cmp now() [+- const_expr]` or \ - `now() [+- const_expr] cmp input_expr`, where `input_expr` references a column \ - and contains no `now()`." - ); - } + if predicate.conjunctions.iter().any(|cond| cond.has_now()) { bail!( - "All `now()` exprs were valid, but the condition must have at least one now expr as a lower bound." + "Conditions containing now must be in the form of `input_expr cmp now_expr` or \ + `now_expr cmp input_expr`, where `input_expr` references a column and contains \ + no `now()`, and `now_expr` is a non-decreasing expression contains `now()`." ); } let mut new_logical = self.core.clone(); From bd3b9a194efa1b78ab025f440147cba73ce20026 Mon Sep 17 00:00:00 2001 From: William Wen <44139337+wenym1@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:29:25 +0800 Subject: [PATCH 44/70] refactor: actor wait barrier manager inject barrier (#17613) --- src/batch/src/task/env.rs | 2 +- .../src/rpc/service/exchange_service.rs | 13 +- src/meta/src/barrier/mod.rs | 18 +-- src/meta/src/barrier/recovery.rs | 13 +- src/stream/src/executor/dispatch.rs | 96 +++++++++---- src/stream/src/executor/exchange/input.rs | 116 ++++++++++----- src/stream/src/executor/exchange/output.rs | 2 +- src/stream/src/executor/exchange/permit.rs | 4 +- src/stream/src/executor/integration_tests.rs | 87 ++++++++--- src/stream/src/executor/merge.rs | 136 +++++++++++++----- src/stream/src/executor/mod.rs | 118 +++++++++++---- src/stream/src/executor/receiver.rs | 51 +++++-- src/stream/src/task/barrier_manager.rs | 107 ++++++++++++-- .../src/task/barrier_manager/managed_state.rs | 48 +++---- src/stream/src/task/barrier_manager/tests.rs | 95 ++---------- src/stream/src/task/env.rs | 2 +- src/stream/src/task/mod.rs | 1 + 17 files changed, 588 insertions(+), 321 deletions(-) diff --git a/src/batch/src/task/env.rs b/src/batch/src/task/env.rs index ecb7a3a8d3eb..6c4f32ac92e7 100644 --- a/src/batch/src/task/env.rs +++ b/src/batch/src/task/env.rs @@ -112,7 +112,7 @@ impl BatchEnvironment { BatchManagerMetrics::for_test(), u64::MAX, )), - server_addr: "127.0.0.1:5688".parse().unwrap(), + server_addr: "127.0.0.1:2333".parse().unwrap(), config: Arc::new(BatchConfig::default()), worker_id: WorkerNodeId::default(), state_store: StateStoreImpl::shared_in_memory_store(Arc::new( diff --git a/src/compute/src/rpc/service/exchange_service.rs b/src/compute/src/rpc/service/exchange_service.rs index f44f6f7552bb..e4082a88ea9e 100644 --- a/src/compute/src/rpc/service/exchange_service.rs +++ b/src/compute/src/rpc/service/exchange_service.rs @@ -24,7 +24,7 @@ use risingwave_pb::task_service::{ permits, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, PbPermits, }; use risingwave_stream::executor::exchange::permit::{MessageWithPermits, Receiver}; -use risingwave_stream::executor::Message; +use risingwave_stream::executor::DispatcherMessage; use risingwave_stream::task::LocalStreamManager; use thiserror_ext::AsReport; use tokio_stream::wrappers::ReceiverStream; @@ -169,21 +169,14 @@ impl ExchangeServiceImpl { Either::Left(permits_to_add) => { permits.add_permits(permits_to_add); } - Either::Right(MessageWithPermits { - mut message, - permits, - }) => { - // Erase the mutation of the barrier to avoid decoding in remote side. - if let Message::Barrier(barrier) = &mut message { - barrier.mutation = None; - } + Either::Right(MessageWithPermits { message, permits }) => { let proto = message.to_protobuf(); // forward the acquired permit to the downstream let response = GetStreamResponse { message: Some(proto), permits: Some(PbPermits { value: permits }), }; - let bytes = Message::get_encoded_len(&response); + let bytes = DispatcherMessage::get_encoded_len(&response); yield response; diff --git a/src/meta/src/barrier/mod.rs b/src/meta/src/barrier/mod.rs index 1a792b5ebfab..1c05497665fe 100644 --- a/src/meta/src/barrier/mod.rs +++ b/src/meta/src/barrier/mod.rs @@ -579,7 +579,7 @@ impl GlobalBarrierManager { let paused = self.take_pause_on_bootstrap().await.unwrap_or(false); let paused_reason = paused.then_some(PausedReason::Manual); - self.recovery(paused_reason).instrument(span).await; + self.recovery(paused_reason, None).instrument(span).await; } self.context.set_status(BarrierManagerStatus::Running); @@ -789,12 +789,6 @@ impl GlobalBarrierManager { } async fn failure_recovery(&mut self, err: MetaError) { - self.context - .tracker - .lock() - .await - .abort_all(&err, &self.context) - .await; self.checkpoint_control.clear_on_err(&err).await; self.pending_non_checkpoint_barriers.clear(); @@ -813,7 +807,7 @@ impl GlobalBarrierManager { // No need to clean dirty tables for barrier recovery, // The foreground stream job should cleanup their own tables. - self.recovery(None).instrument(span).await; + self.recovery(None, Some(err)).instrument(span).await; self.context.set_status(BarrierManagerStatus::Running); } else { panic!("failed to execute barrier: {}", err.as_report()); @@ -822,12 +816,6 @@ impl GlobalBarrierManager { async fn adhoc_recovery(&mut self) { let err = MetaErrorInner::AdhocRecovery.into(); - self.context - .tracker - .lock() - .await - .abort_all(&err, &self.context) - .await; self.checkpoint_control.clear_on_err(&err).await; self.context @@ -842,7 +830,7 @@ impl GlobalBarrierManager { // No need to clean dirty tables for barrier recovery, // The foreground stream job should cleanup their own tables. - self.recovery(None).instrument(span).await; + self.recovery(None, Some(err)).instrument(span).await; self.context.set_status(BarrierManagerStatus::Running); } } diff --git a/src/meta/src/barrier/recovery.rs b/src/meta/src/barrier/recovery.rs index 4bb9d2f669c0..f9bba534e561 100644 --- a/src/meta/src/barrier/recovery.rs +++ b/src/meta/src/barrier/recovery.rs @@ -43,7 +43,7 @@ use crate::controller::catalog::ReleaseContext; use crate::manager::{ActiveStreamingWorkerNodes, MetadataManager, WorkerId}; use crate::model::{MetadataModel, MigrationPlan, TableFragments, TableParallelism}; use crate::stream::{build_actor_connector_splits, RescheduleOptions, TableResizePolicy}; -use crate::{model, MetaResult}; +use crate::{model, MetaError, MetaResult}; impl GlobalBarrierManager { // Retry base interval in milliseconds. @@ -224,7 +224,7 @@ impl GlobalBarrierManager { /// the cluster or `risectl` command. Used for debugging purpose. /// /// Returns the new state of the barrier manager after recovery. - pub async fn recovery(&mut self, paused_reason: Option) { + pub async fn recovery(&mut self, paused_reason: Option, err: Option) { let prev_epoch = TracedEpoch::new( self.context .hummock_manager @@ -246,6 +246,15 @@ impl GlobalBarrierManager { let new_state = tokio_retry::Retry::spawn(retry_strategy, || { async { let recovery_result: MetaResult<_> = try { + if let Some(err) = &err { + self.context + .tracker + .lock() + .await + .abort_all(err, &self.context) + .await; + } + self.context .clean_dirty_streaming_jobs() .await diff --git a/src/stream/src/executor/dispatch.rs b/src/stream/src/executor/dispatch.rs index cf788e9ebd7f..9f452dc1863b 100644 --- a/src/stream/src/executor/dispatch.rs +++ b/src/stream/src/executor/dispatch.rs @@ -31,7 +31,9 @@ use tokio::time::Instant; use tracing::{event, Instrument}; use super::exchange::output::{new_output, BoxedOutput}; -use super::{AddMutation, TroublemakerExecutor, UpdateMutation}; +use super::{ + AddMutation, DispatcherBarrier, DispatcherMessage, TroublemakerExecutor, UpdateMutation, +}; use crate::executor::prelude::*; use crate::executor::StreamConsumer; use crate::task::{DispatcherId, SharedContext}; @@ -142,7 +144,9 @@ impl DispatchExecutorInner { .map(Ok) .try_for_each_concurrent(limit, |dispatcher| async { let start_time = Instant::now(); - dispatcher.dispatch_barrier(barrier.clone()).await?; + dispatcher + .dispatch_barrier(barrier.clone().into_dispatcher()) + .await?; dispatcher .actor_output_buffer_blocking_duration_ns .inc_by(start_time.elapsed().as_nanos() as u64); @@ -497,7 +501,7 @@ macro_rules! impl_dispatcher { } } - pub async fn dispatch_barrier(&mut self, barrier: Barrier) -> StreamResult<()> { + pub async fn dispatch_barrier(&mut self, barrier: DispatcherBarrier) -> StreamResult<()> { match self { $( Self::$variant_name(inner) => inner.dispatch_barrier(barrier).await, )* } @@ -561,7 +565,7 @@ pub trait Dispatcher: Debug + 'static { /// Dispatch a data chunk to downstream actors. fn dispatch_data(&mut self, chunk: StreamChunk) -> impl DispatchFuture<'_>; /// Dispatch a barrier to downstream actors, generally by broadcasting it. - fn dispatch_barrier(&mut self, barrier: Barrier) -> impl DispatchFuture<'_>; + fn dispatch_barrier(&mut self, barrier: DispatcherBarrier) -> impl DispatchFuture<'_>; /// Dispatch a watermark to downstream actors, generally by broadcasting it. fn dispatch_watermark(&mut self, watermark: Watermark) -> impl DispatchFuture<'_>; @@ -591,7 +595,7 @@ pub trait Dispatcher: Debug + 'static { /// always unlimited. async fn broadcast_concurrent( outputs: impl IntoIterator, - message: Message, + message: DispatcherMessage, ) -> StreamResult<()> { futures::future::try_join_all( outputs @@ -637,21 +641,24 @@ impl Dispatcher for RoundRobinDataDispatcher { chunk.project(&self.output_indices) }; - self.outputs[self.cur].send(Message::Chunk(chunk)).await?; + self.outputs[self.cur] + .send(DispatcherMessage::Chunk(chunk)) + .await?; self.cur += 1; self.cur %= self.outputs.len(); Ok(()) } - async fn dispatch_barrier(&mut self, barrier: Barrier) -> StreamResult<()> { + async fn dispatch_barrier(&mut self, barrier: DispatcherBarrier) -> StreamResult<()> { // always broadcast barrier - broadcast_concurrent(&mut self.outputs, Message::Barrier(barrier)).await + broadcast_concurrent(&mut self.outputs, DispatcherMessage::Barrier(barrier)).await } async fn dispatch_watermark(&mut self, watermark: Watermark) -> StreamResult<()> { if let Some(watermark) = watermark.transform_with_indices(&self.output_indices) { // always broadcast watermark - broadcast_concurrent(&mut self.outputs, Message::Watermark(watermark)).await?; + broadcast_concurrent(&mut self.outputs, DispatcherMessage::Watermark(watermark)) + .await?; } Ok(()) } @@ -725,15 +732,16 @@ impl Dispatcher for HashDataDispatcher { self.outputs.extend(outputs); } - async fn dispatch_barrier(&mut self, barrier: Barrier) -> StreamResult<()> { + async fn dispatch_barrier(&mut self, barrier: DispatcherBarrier) -> StreamResult<()> { // always broadcast barrier - broadcast_concurrent(&mut self.outputs, Message::Barrier(barrier)).await + broadcast_concurrent(&mut self.outputs, DispatcherMessage::Barrier(barrier)).await } async fn dispatch_watermark(&mut self, watermark: Watermark) -> StreamResult<()> { if let Some(watermark) = watermark.transform_with_indices(&self.output_indices) { // always broadcast watermark - broadcast_concurrent(&mut self.outputs, Message::Watermark(watermark)).await?; + broadcast_concurrent(&mut self.outputs, DispatcherMessage::Watermark(watermark)) + .await?; } Ok(()) } @@ -818,7 +826,9 @@ impl Dispatcher for HashDataDispatcher { "send = \n{:#?}", new_stream_chunk ); - output.send(Message::Chunk(new_stream_chunk)).await?; + output + .send(DispatcherMessage::Chunk(new_stream_chunk)) + .await?; } StreamResult::Ok(()) }), @@ -888,18 +898,26 @@ impl Dispatcher for BroadcastDispatcher { } else { chunk.project(&self.output_indices) }; - broadcast_concurrent(self.outputs.values_mut(), Message::Chunk(chunk)).await + broadcast_concurrent(self.outputs.values_mut(), DispatcherMessage::Chunk(chunk)).await } - async fn dispatch_barrier(&mut self, barrier: Barrier) -> StreamResult<()> { + async fn dispatch_barrier(&mut self, barrier: DispatcherBarrier) -> StreamResult<()> { // always broadcast barrier - broadcast_concurrent(self.outputs.values_mut(), Message::Barrier(barrier)).await + broadcast_concurrent( + self.outputs.values_mut(), + DispatcherMessage::Barrier(barrier), + ) + .await } async fn dispatch_watermark(&mut self, watermark: Watermark) -> StreamResult<()> { if let Some(watermark) = watermark.transform_with_indices(&self.output_indices) { // always broadcast watermark - broadcast_concurrent(self.outputs.values_mut(), Message::Watermark(watermark)).await?; + broadcast_concurrent( + self.outputs.values_mut(), + DispatcherMessage::Watermark(watermark), + ) + .await?; } Ok(()) } @@ -970,10 +988,12 @@ impl Dispatcher for SimpleDispatcher { assert!(self.output.len() <= 2); } - async fn dispatch_barrier(&mut self, barrier: Barrier) -> StreamResult<()> { + async fn dispatch_barrier(&mut self, barrier: DispatcherBarrier) -> StreamResult<()> { // Only barrier is allowed to be dispatched to multiple outputs during migration. for output in &mut self.output { - output.send(Message::Barrier(barrier.clone())).await?; + output + .send(DispatcherMessage::Barrier(barrier.clone())) + .await?; } Ok(()) } @@ -992,7 +1012,7 @@ impl Dispatcher for SimpleDispatcher { } else { chunk.project(&self.output_indices) }; - output.send(Message::Chunk(chunk)).await + output.send(DispatcherMessage::Chunk(chunk)).await } async fn dispatch_watermark(&mut self, watermark: Watermark) -> StreamResult<()> { @@ -1003,7 +1023,7 @@ impl Dispatcher for SimpleDispatcher { .expect("expect exactly one output"); if let Some(watermark) = watermark.transform_with_indices(&self.output_indices) { - output.send(Message::Watermark(watermark)).await?; + output.send(DispatcherMessage::Watermark(watermark)).await?; } Ok(()) } @@ -1044,23 +1064,25 @@ mod tests { use crate::executor::exchange::output::Output; use crate::executor::exchange::permit::channel_for_test; use crate::executor::receiver::ReceiverExecutor; + use crate::executor::{BarrierInner as Barrier, MessageInner as Message}; + use crate::task::barrier_test_utils::LocalBarrierTestEnv; use crate::task::test_utils::helper_make_local_actor; #[derive(Debug)] pub struct MockOutput { actor_id: ActorId, - data: Arc>>, + data: Arc>>, } impl MockOutput { - pub fn new(actor_id: ActorId, data: Arc>>) -> Self { + pub fn new(actor_id: ActorId, data: Arc>>) -> Self { Self { actor_id, data } } } #[async_trait] impl Output for MockOutput { - async fn send(&mut self, message: Message) -> StreamResult<()> { + async fn send(&mut self, message: DispatcherMessage) -> StreamResult<()> { self.data.lock().unwrap().push(message); Ok(()) } @@ -1154,7 +1176,11 @@ mod tests { let (tx, rx) = channel_for_test(); let actor_id = 233; let fragment_id = 666; - let input = Executor::new(Default::default(), ReceiverExecutor::for_test(rx).boxed()); + let barrier_test_env = LocalBarrierTestEnv::for_test().await; + let input = Executor::new( + Default::default(), + ReceiverExecutor::for_test(233, rx, barrier_test_env.shared_context.clone()).boxed(), + ); let ctx = Arc::new(SharedContext::for_test()); let metrics = Arc::new(StreamingMetrics::unused()); @@ -1245,7 +1271,10 @@ mod tests { actor_new_dispatchers: Default::default(), }, )); - tx.send(Message::Barrier(b1)).await.unwrap(); + barrier_test_env.inject_barrier(&b1, [], [actor_id]); + tx.send(Message::Barrier(b1.clone().into_dispatcher())) + .await + .unwrap(); executor.next().await.unwrap().unwrap(); // 5. Check downstream. @@ -1261,7 +1290,9 @@ mod tests { try_recv!(old_simple).unwrap().as_barrier().unwrap(); // Untouched. // 6. Send another barrier. - tx.send(Message::Barrier(Barrier::new_test_barrier(test_epoch(2)))) + let b2 = Barrier::new_test_barrier(test_epoch(2)); + barrier_test_env.inject_barrier(&b2, [], [actor_id]); + tx.send(Message::Barrier(b2.into_dispatcher())) .await .unwrap(); executor.next().await.unwrap().unwrap(); @@ -1299,7 +1330,10 @@ mod tests { actor_new_dispatchers: Default::default(), }, )); - tx.send(Message::Barrier(b3)).await.unwrap(); + barrier_test_env.inject_barrier(&b3, [], [actor_id]); + tx.send(Message::Barrier(b3.into_dispatcher())) + .await + .unwrap(); executor.next().await.unwrap().unwrap(); // 10. Check downstream. @@ -1309,7 +1343,9 @@ mod tests { try_recv!(new_simple).unwrap().as_barrier().unwrap(); // Since it's just added, it won't receive the chunk. // 11. Send another barrier. - tx.send(Message::Barrier(Barrier::new_test_barrier(test_epoch(4)))) + let b4 = Barrier::new_test_barrier(test_epoch(4)); + barrier_test_env.inject_barrier(&b4, [], [actor_id]); + tx.send(Message::Barrier(b4.into_dispatcher())) .await .unwrap(); executor.next().await.unwrap().unwrap(); @@ -1403,7 +1439,7 @@ mod tests { } else { let message = guard.first().unwrap(); let real_chunk = match message { - Message::Chunk(chunk) => chunk, + DispatcherMessage::Chunk(chunk) => chunk, _ => panic!(), }; real_chunk diff --git a/src/stream/src/executor/exchange/input.rs b/src/stream/src/executor/exchange/input.rs index bd348e65defc..be64af33acd7 100644 --- a/src/stream/src/executor/exchange/input.rs +++ b/src/stream/src/executor/exchange/input.rs @@ -22,10 +22,12 @@ use pin_project::pin_project; use risingwave_common::util::addr::{is_local_address, HostAddr}; use risingwave_pb::task_service::{permits, GetStreamResponse}; use risingwave_rpc_client::ComputeClientPool; +use tokio::sync::mpsc; use super::error::ExchangeChannelClosed; use super::permit::Receiver; use crate::executor::prelude::*; +use crate::executor::{DispatcherBarrier, DispatcherMessage}; use crate::task::{ FragmentId, LocalBarrierManager, SharedContext, UpDownActorIds, UpDownFragmentIds, }; @@ -64,23 +66,80 @@ pub struct LocalInput { } type LocalInputStreamInner = impl MessageStream; +async fn process_msg<'a>( + msg: DispatcherMessage, + get_mutation_subscriber: impl for<'b> FnOnce( + &'b DispatcherBarrier, + ) + -> &'a mut mpsc::UnboundedReceiver + + 'a, +) -> StreamExecutorResult { + let barrier = match msg { + DispatcherMessage::Chunk(c) => { + return Ok(Message::Chunk(c)); + } + DispatcherMessage::Barrier(b) => b, + DispatcherMessage::Watermark(watermark) => { + return Ok(Message::Watermark(watermark)); + } + }; + let mutation_subscriber = get_mutation_subscriber(&barrier); + + let mutation = mutation_subscriber + .recv() + .await + .ok_or_else(|| anyhow!("failed to receive mutation of barrier {:?}", barrier)) + .map(|(prev_epoch, mutation)| { + assert_eq!(prev_epoch, barrier.epoch.prev); + mutation + })?; + Ok(Message::Barrier(Barrier { + epoch: barrier.epoch, + mutation, + kind: barrier.kind, + tracing_context: barrier.tracing_context, + passed_actors: barrier.passed_actors, + })) +} + impl LocalInput { - pub fn new(channel: Receiver, actor_id: ActorId) -> Self { + pub fn new( + channel: Receiver, + upstream_actor_id: ActorId, + self_actor_id: ActorId, + local_barrier_manager: LocalBarrierManager, + ) -> Self { Self { - inner: Self::run(channel, actor_id), - actor_id, + inner: Self::run( + channel, + upstream_actor_id, + self_actor_id, + local_barrier_manager, + ), + actor_id: upstream_actor_id, } } #[try_stream(ok = Message, error = StreamExecutorError)] - async fn run(mut channel: Receiver, actor_id: ActorId) { - let span: await_tree::Span = format!("LocalInput (actor {actor_id})").into(); + async fn run( + mut channel: Receiver, + upstream_actor_id: ActorId, + self_actor_id: ActorId, + local_barrier_manager: LocalBarrierManager, + ) { + let span: await_tree::Span = format!("LocalInput (actor {upstream_actor_id})").into(); + let mut mutation_subscriber = None; while let Some(msg) = channel.recv().verbose_instrument_await(span.clone()).await { - yield msg; + yield process_msg(msg, |barrier| { + mutation_subscriber.get_or_insert_with(|| { + local_barrier_manager.subscribe_barrier_mutation(self_actor_id, barrier) + }) + }) + .await?; } // Always emit an error outside the loop. This is because we use barrier as the control // message to stop the stream. Reaching here means the channel is closed unexpectedly. - Err(ExchangeChannelClosed::local_input(actor_id))? + Err(ExchangeChannelClosed::local_input(upstream_actor_id))? } } @@ -170,11 +229,11 @@ impl RemoteInput { match data_res { Ok(GetStreamResponse { message, permits }) => { let msg = message.unwrap(); - let bytes = Message::get_encoded_len(&msg); + let bytes = DispatcherMessage::get_encoded_len(&msg); exchange_frag_recv_size_metrics.inc_by(bytes as u64); - let msg_res = Message::from_protobuf(&msg); + let msg_res = DispatcherMessage::from_protobuf(&msg); if let Some(add_back_permits) = match permits.unwrap().value { // For records, batch the permits we received to reduce the backward // `AddPermits` messages. @@ -196,35 +255,14 @@ impl RemoteInput { .context("RemoteInput backward permits channel closed.")?; } - let mut msg = msg_res.context("RemoteInput decode message error")?; - - // Read barrier mutation from local barrier manager and attach it to the barrier message. - if cfg!(not(test)) { - if let Message::Barrier(barrier) = &mut msg { - assert!( - barrier.mutation.is_none(), - "Mutation should be erased in remote side" - ); - let mutation_subscriber = - mutation_subscriber.get_or_insert_with(|| { - local_barrier_manager - .subscribe_barrier_mutation(self_actor_id, barrier) - }); - - let mutation = mutation_subscriber - .recv() - .await - .ok_or_else(|| { - anyhow!("failed to receive mutation of barrier {:?}", barrier) - }) - .map(|(prev_epoch, mutation)| { - assert_eq!(prev_epoch, barrier.epoch.prev); - mutation - })?; - barrier.mutation = mutation; - } - } - yield msg; + let msg = msg_res.context("RemoteInput decode message error")?; + + yield process_msg(msg, |barrier| { + mutation_subscriber.get_or_insert_with(|| { + local_barrier_manager.subscribe_barrier_mutation(self_actor_id, barrier) + }) + }) + .await?; } Err(e) => Err(ExchangeChannelClosed::remote_input(up_down_ids.0, Some(e)))?, @@ -270,6 +308,8 @@ pub(crate) fn new_input( LocalInput::new( context.take_receiver((upstream_actor_id, actor_id))?, upstream_actor_id, + actor_id, + context.local_barrier_manager.clone(), ) .boxed_input() } else { diff --git a/src/stream/src/executor/exchange/output.rs b/src/stream/src/executor/exchange/output.rs index 41b4b5b84475..145286f561e1 100644 --- a/src/stream/src/executor/exchange/output.rs +++ b/src/stream/src/executor/exchange/output.rs @@ -22,7 +22,7 @@ use risingwave_common::util::addr::is_local_address; use super::error::ExchangeChannelClosed; use super::permit::Sender; use crate::error::StreamResult; -use crate::executor::Message; +use crate::executor::DispatcherMessage as Message; use crate::task::{ActorId, SharedContext}; /// `Output` provides an interface for `Dispatcher` to send data into downstream actors. diff --git a/src/stream/src/executor/exchange/permit.rs b/src/stream/src/executor/exchange/permit.rs index 159494355cff..8c86eb275381 100644 --- a/src/stream/src/executor/exchange/permit.rs +++ b/src/stream/src/executor/exchange/permit.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use risingwave_pb::task_service::permits; use tokio::sync::{mpsc, AcquireError, Semaphore, SemaphorePermit}; -use crate::executor::Message; +use crate::executor::DispatcherMessage as Message; /// Message with its required permits. /// @@ -214,7 +214,7 @@ mod tests { use futures::FutureExt; use super::*; - use crate::executor::Barrier; + use crate::executor::DispatcherBarrier as Barrier; #[test] fn test_channel_close() { diff --git a/src/stream/src/executor/integration_tests.rs b/src/stream/src/executor/integration_tests.rs index 8f73fe26ee7d..0b7415adac38 100644 --- a/src/stream/src/executor/integration_tests.rs +++ b/src/stream/src/executor/integration_tests.rs @@ -34,7 +34,8 @@ use crate::executor::monitor::StreamingMetrics; use crate::executor::test_utils::agg_executor::{ generate_agg_schema, new_boxed_simple_agg_executor, }; -use crate::task::{LocalBarrierManager, SharedContext}; +use crate::executor::{BarrierInner as Barrier, MessageInner as Message}; +use crate::task::barrier_test_utils::LocalBarrierTestEnv; /// This test creates a merger-dispatcher pair, and run a sum. Each chunk /// has 0~9 elements. We first insert the 10 chunks, then delete them, @@ -45,9 +46,19 @@ async fn test_merger_sum_aggr() { time_zone: String::from("UTC"), }; - let actor_ctx = ActorContext::for_test(0); + let barrier_test_env = LocalBarrierTestEnv::for_test().await; + let mut next_actor_id = 0; + let next_actor_id = &mut next_actor_id; + let mut actors = HashSet::new(); + let mut gen_next_actor_id = || { + *next_actor_id += 1; + actors.insert(*next_actor_id); + *next_actor_id + }; // `make_actor` build an actor to do local aggregation - let make_actor = |input_rx| { + let mut make_actor = |input_rx| { + let actor_id = gen_next_actor_id(); + let actor_ctx = ActorContext::for_test(actor_id); let input_schema = Schema { fields: vec![Field::unnamed(DataType::Int64)], }; @@ -57,7 +68,8 @@ async fn test_merger_sum_aggr() { pk_indices: PkIndices::new(), identity: "ReceiverExecutor".to_string(), }, - ReceiverExecutor::for_test(input_rx).boxed(), + ReceiverExecutor::for_test(actor_id, input_rx, barrier_test_env.shared_context.clone()) + .boxed(), ); let agg_calls = vec![ AggCall::from_pretty("(count:int8)"), @@ -72,13 +84,17 @@ async fn test_merger_sum_aggr() { input: aggregator.boxed(), channel: Box::new(LocalOutput::new(233, tx)), }; + let actor = Actor::new( consumer, vec![], StreamingMetrics::unused().into(), - actor_ctx.clone(), + actor_ctx, expr_context.clone(), - LocalBarrierManager::for_test(), + barrier_test_env + .shared_context + .local_barrier_manager + .clone(), ); (actor, rx) }; @@ -90,7 +106,6 @@ async fn test_merger_sum_aggr() { let mut inputs = vec![]; let mut outputs = vec![]; - let ctx = Arc::new(SharedContext::for_test()); let metrics = Arc::new(StreamingMetrics::unused()); // create 17 local aggregation actors @@ -103,6 +118,8 @@ async fn test_merger_sum_aggr() { } // create a round robin dispatcher, which dispatches messages to the actors + + let actor_id = gen_next_actor_id(); let (input, rx) = channel_for_test(); let receiver_op = Executor::new( ExecutorInfo { @@ -111,7 +128,7 @@ async fn test_merger_sum_aggr() { pk_indices: PkIndices::new(), identity: "ReceiverExecutor".to_string(), }, - ReceiverExecutor::for_test(rx).boxed(), + ReceiverExecutor::for_test(actor_id, rx, barrier_test_env.shared_context.clone()).boxed(), ); let dispatcher = DispatchExecutor::new( receiver_op, @@ -122,7 +139,7 @@ async fn test_merger_sum_aggr() { ))], 0, 0, - ctx, + barrier_test_env.shared_context.clone(), metrics, config::default::developer::stream_chunk_size(), ); @@ -130,12 +147,17 @@ async fn test_merger_sum_aggr() { dispatcher, vec![], StreamingMetrics::unused().into(), - actor_ctx.clone(), + ActorContext::for_test(actor_id), expr_context.clone(), - LocalBarrierManager::for_test(), + barrier_test_env + .shared_context + .local_barrier_manager + .clone(), ); handles.push(tokio::spawn(actor.run())); + let actor_ctx = ActorContext::for_test(gen_next_actor_id()); + // use a merge operator to collect data from dispatchers before sending them to aggregator let merger = Executor::new( ExecutorInfo { @@ -147,7 +169,12 @@ async fn test_merger_sum_aggr() { pk_indices: PkIndices::new(), identity: "MergeExecutor".to_string(), }, - MergeExecutor::for_test(outputs).boxed(), + MergeExecutor::for_test( + actor_ctx.id, + outputs, + barrier_test_env.shared_context.clone(), + ) + .boxed(), ); // for global aggregator, we need to sum data and sum row count @@ -192,13 +219,18 @@ async fn test_merger_sum_aggr() { StreamingMetrics::unused().into(), actor_ctx.clone(), expr_context.clone(), - LocalBarrierManager::for_test(), + barrier_test_env + .shared_context + .local_barrier_manager + .clone(), ); handles.push(tokio::spawn(actor.run())); let mut epoch = test_epoch(1); + let b1 = Barrier::new_test_barrier(epoch); + barrier_test_env.inject_barrier(&b1, [], actors.clone()); input - .send(Message::Barrier(Barrier::new_test_barrier(epoch))) + .send(Message::Barrier(b1.into_dispatcher())) .await .unwrap(); epoch.inc_epoch(); @@ -211,17 +243,19 @@ async fn test_merger_sum_aggr() { ); input.send(Message::Chunk(chunk)).await.unwrap(); } + let b = Barrier::new_test_barrier(epoch); + barrier_test_env.inject_barrier(&b, [], actors.clone()); input - .send(Message::Barrier(Barrier::new_test_barrier(epoch))) + .send(Message::Barrier(b.into_dispatcher())) .await .unwrap(); epoch.inc_epoch(); } + let b = Barrier::new_test_barrier(epoch) + .with_mutation(Mutation::Stop(actors.clone().into_iter().collect())); + barrier_test_env.inject_barrier(&b, [], actors); input - .send(Message::Barrier( - Barrier::new_test_barrier(epoch) - .with_mutation(Mutation::Stop([0].into_iter().collect())), - )) + .send(Message::Barrier(b.into_dispatcher())) .await .unwrap(); @@ -241,7 +275,7 @@ struct MockConsumer { } impl StreamConsumer for MockConsumer { - type BarrierStream = impl Stream> + Send; + type BarrierStream = impl Stream> + Send; fn execute(self: Box) -> Self::BarrierStream { let mut input = self.input.execute(); @@ -268,7 +302,7 @@ pub struct SenderConsumer { } impl StreamConsumer for SenderConsumer { - type BarrierStream = impl Stream> + Send; + type BarrierStream = impl Stream> + Send; fn execute(self: Box) -> Self::BarrierStream { let mut input = self.input.execute(); @@ -279,7 +313,16 @@ impl StreamConsumer for SenderConsumer { let msg = item?; let barrier = msg.as_barrier().cloned(); - channel.send(msg).await.expect("failed to send message"); + channel + .send(match msg { + Message::Chunk(chunk) => DispatcherMessage::Chunk(chunk), + Message::Barrier(barrier) => { + DispatcherMessage::Barrier(barrier.into_dispatcher()) + } + Message::Watermark(watermark) => DispatcherMessage::Watermark(watermark), + }) + .await + .expect("failed to send message"); if let Some(barrier) = barrier { yield barrier; diff --git a/src/stream/src/executor/merge.rs b/src/stream/src/executor/merge.rs index 19124fe8c22d..60f88e866c6b 100644 --- a/src/stream/src/executor/merge.rs +++ b/src/stream/src/executor/merge.rs @@ -74,20 +74,32 @@ impl MergeExecutor { } #[cfg(test)] - pub fn for_test(inputs: Vec) -> Self { + pub fn for_test( + actor_id: ActorId, + inputs: Vec, + shared_context: Arc, + ) -> Self { use super::exchange::input::LocalInput; use crate::executor::exchange::input::Input; Self::new( - ActorContext::for_test(114), + ActorContext::for_test(actor_id), 514, 1919, inputs .into_iter() .enumerate() - .map(|(idx, input)| LocalInput::new(input, idx as ActorId).boxed_input()) + .map(|(idx, input)| { + LocalInput::new( + input, + idx as ActorId, + actor_id, + shared_context.local_barrier_manager.clone(), + ) + .boxed_input() + }) .collect(), - SharedContext::for_test().into(), + shared_context, 810, StreamingMetrics::unused().into(), ) @@ -474,10 +486,11 @@ mod tests { use tonic::{Request, Response, Status, Streaming}; use super::*; - use crate::executor::exchange::input::RemoteInput; + use crate::executor::exchange::input::{Input, RemoteInput}; use crate::executor::exchange::permit::channel_for_test; + use crate::executor::{BarrierInner as Barrier, MessageInner as Message}; + use crate::task::barrier_test_utils::LocalBarrierTestEnv; use crate::task::test_utils::helper_make_local_actor; - use crate::task::LocalBarrierManager; fn build_test_chunk(epoch: u64) -> StreamChunk { // The number of items in `ops` is the epoch count. @@ -495,64 +508,80 @@ mod tests { txs.push(tx); rxs.push(rx); } - let merger = MergeExecutor::for_test(rxs); + let barrier_test_env = LocalBarrierTestEnv::for_test().await; + let merger = MergeExecutor::for_test(233, rxs, barrier_test_env.shared_context.clone()); + let actor_id = merger.actor_context.id; let mut handles = Vec::with_capacity(CHANNEL_NUMBER); - let epochs = (10..1000u64).step_by(10).collect_vec(); + let epochs = (10..1000u64) + .step_by(10) + .map(|idx| (idx, test_epoch(idx))) + .collect_vec(); + let mut prev_epoch = 0; + let prev_epoch = &mut prev_epoch; + let barriers: HashMap<_, _> = epochs + .iter() + .map(|(_, epoch)| { + let barrier = Barrier::with_prev_epoch_for_test(*epoch, *prev_epoch); + *prev_epoch = *epoch; + barrier_test_env.inject_barrier(&barrier, [], [actor_id]); + (*epoch, barrier) + }) + .collect(); + let b2 = Barrier::with_prev_epoch_for_test(test_epoch(1000), *prev_epoch) + .with_mutation(Mutation::Stop(HashSet::default())); + barrier_test_env.inject_barrier(&b2, [], [actor_id]); for (tx_id, tx) in txs.into_iter().enumerate() { let epochs = epochs.clone(); + let barriers = barriers.clone(); + let b2 = b2.clone(); let handle = tokio::spawn(async move { - for epoch in epochs { - if epoch % 20 == 0 { - tx.send(Message::Chunk(build_test_chunk(epoch))) + for (idx, epoch) in epochs { + if idx % 20 == 0 { + tx.send(Message::Chunk(build_test_chunk(idx))) .await .unwrap(); } else { tx.send(Message::Watermark(Watermark { - col_idx: (epoch as usize / 20 + tx_id) % CHANNEL_NUMBER, + col_idx: (idx as usize / 20 + tx_id) % CHANNEL_NUMBER, data_type: DataType::Int64, - val: ScalarImpl::Int64(epoch as i64), + val: ScalarImpl::Int64(idx as i64), })) .await .unwrap(); } - tx.send(Message::Barrier(Barrier::new_test_barrier(test_epoch( - epoch, - )))) - .await - .unwrap(); + tx.send(Message::Barrier(barriers[&epoch].clone().into_dispatcher())) + .await + .unwrap(); sleep(Duration::from_millis(1)).await; } - tx.send(Message::Barrier( - Barrier::new_test_barrier(test_epoch(1000)) - .with_mutation(Mutation::Stop(HashSet::default())), - )) - .await - .unwrap(); + tx.send(Message::Barrier(b2.clone().into_dispatcher())) + .await + .unwrap(); }); handles.push(handle); } let mut merger = merger.boxed().execute(); - for epoch in epochs { + for (idx, epoch) in epochs { // expect n chunks - if epoch % 20 == 0 { + if idx % 20 == 0 { for _ in 0..CHANNEL_NUMBER { assert_matches!(merger.next().await.unwrap().unwrap(), Message::Chunk(chunk) => { - assert_eq!(chunk.ops().len() as u64, epoch); + assert_eq!(chunk.ops().len() as u64, idx); }); } - } else if epoch as usize / 20 >= CHANNEL_NUMBER - 1 { + } else if idx as usize / 20 >= CHANNEL_NUMBER - 1 { for _ in 0..CHANNEL_NUMBER { assert_matches!(merger.next().await.unwrap().unwrap(), Message::Watermark(watermark) => { - assert_eq!(watermark.val, ScalarImpl::Int64((epoch - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64)); + assert_eq!(watermark.val, ScalarImpl::Int64((idx - 20 * (CHANNEL_NUMBER as u64 - 1)) as i64)); }); } } // expect a barrier assert_matches!(merger.next().await.unwrap().unwrap(), Message::Barrier(Barrier{epoch:barrier_epoch,mutation:_,..}) => { - assert_eq!(barrier_epoch.curr, test_epoch(epoch)); + assert_eq!(barrier_epoch.curr, epoch); }); } assert_matches!( @@ -572,7 +601,8 @@ mod tests { async fn test_configuration_change() { let actor_id = 233; let (untouched, old, new) = (234, 235, 238); // upstream actors - let ctx = Arc::new(SharedContext::for_test()); + let barrier_test_env = LocalBarrierTestEnv::for_test().await; + let ctx = barrier_test_env.shared_context.clone(); let metrics = Arc::new(StreamingMetrics::unused()); // 1. Register info in context. @@ -628,9 +658,21 @@ mod tests { } }; } + + macro_rules! assert_recv_pending { + () => { + assert!(merge + .next() + .now_or_never() + .flatten() + .transpose() + .unwrap() + .is_none()); + }; + } macro_rules! recv { () => { - merge.next().now_or_never().flatten().transpose().unwrap() + merge.next().await.transpose().unwrap() }; } @@ -638,7 +680,7 @@ mod tests { send!([untouched, old], Message::Chunk(StreamChunk::default())); recv!().unwrap().as_chunk().unwrap(); // We should be able to receive the chunk twice. recv!().unwrap().as_chunk().unwrap(); - assert!(recv!().is_none()); + assert_recv_pending!(); // 4. Send a configuration change barrier. let merge_updates = maplit::hashmap! { @@ -661,23 +703,31 @@ mod tests { actor_new_dispatchers: Default::default(), }, )); - send!([untouched, old], Message::Barrier(b1.clone())); - assert!(recv!().is_none()); // We should not receive the barrier, since merger is waiting for the new upstream new. + barrier_test_env.inject_barrier(&b1, [], [actor_id]); + send!( + [untouched, old], + Message::Barrier(b1.clone().into_dispatcher()) + ); + assert_recv_pending!(); // We should not receive the barrier, since merger is waiting for the new upstream new. - send!([new], Message::Barrier(b1.clone())); + send!([new], Message::Barrier(b1.clone().into_dispatcher())); recv!().unwrap().as_barrier().unwrap(); // We should now receive the barrier. // 5. Send a chunk. send!([untouched, new], Message::Chunk(StreamChunk::default())); recv!().unwrap().as_chunk().unwrap(); // We should be able to receive the chunk twice, since old is removed. recv!().unwrap().as_chunk().unwrap(); - assert!(recv!().is_none()); + assert_recv_pending!(); } struct FakeExchangeService { rpc_called: Arc, } + fn exchange_client_test_barrier() -> crate::executor::Barrier { + Barrier::new_test_barrier(test_epoch(1)) + } + #[async_trait::async_trait] impl ExchangeService for FakeExchangeService { type GetDataStream = ReceiverStream>; @@ -711,7 +761,7 @@ mod tests { .await .unwrap(); // send barrier - let barrier = Barrier::new_test_barrier(test_epoch(1)); + let barrier = exchange_client_test_barrier(); tx.send(Ok(GetStreamResponse { message: Some(StreamMessage { stream_message: Some( @@ -755,10 +805,12 @@ mod tests { sleep(Duration::from_secs(1)).await; assert!(server_run.load(Ordering::SeqCst)); + let test_env = LocalBarrierTestEnv::for_test().await; + let remote_input = { let pool = ComputeClientPool::default(); RemoteInput::new( - LocalBarrierManager::for_test(), + test_env.shared_context.local_barrier_manager.clone(), pool, addr.into(), (0, 0), @@ -768,6 +820,12 @@ mod tests { ) }; + test_env.inject_barrier( + &exchange_client_test_barrier(), + [], + [remote_input.actor_id()], + ); + pin_mut!(remote_input); assert_matches!(remote_input.next().await.unwrap().unwrap(), Message::Chunk(chunk) => { diff --git a/src/stream/src/executor/mod.rs b/src/stream/src/executor/mod.rs index dc63e62b1d58..a1ef0691d14e 100644 --- a/src/stream/src/executor/mod.rs +++ b/src/stream/src/executor/mod.rs @@ -41,9 +41,9 @@ use risingwave_pb::stream_plan::stream_message::StreamMessage; use risingwave_pb::stream_plan::update_mutation::{DispatcherUpdate, MergeUpdate}; use risingwave_pb::stream_plan::{ BarrierMutation, CombinedMutation, CreateSubscriptionMutation, Dispatchers, - DropSubscriptionMutation, PauseMutation, PbAddMutation, PbBarrier, PbDispatcher, - PbStreamMessage, PbUpdateMutation, PbWatermark, ResumeMutation, SourceChangeSplitMutation, - StopMutation, ThrottleMutation, + DropSubscriptionMutation, PauseMutation, PbAddMutation, PbBarrier, PbBarrierMutation, + PbDispatcher, PbStreamMessage, PbUpdateMutation, PbWatermark, ResumeMutation, + SourceChangeSplitMutation, StopMutation, ThrottleMutation, }; use smallvec::SmallVec; @@ -297,20 +297,28 @@ pub enum Mutation { }, } +/// The generic type `M` is the mutation type of the barrier. +/// +/// For barrier of in the dispatcher, `M` is `()`, which means the mutation is erased. +/// For barrier flowing within the streaming actor, `M` is the normal `BarrierMutationType`. #[derive(Debug, Clone)] -pub struct Barrier { +pub struct BarrierInner { pub epoch: EpochPair, - pub mutation: Option>, + pub mutation: M, pub kind: BarrierKind, /// Tracing context for the **current** epoch of this barrier. - tracing_context: TracingContext, + pub tracing_context: TracingContext, /// The actors that this barrier has passed locally. Used for debugging only. pub passed_actors: Vec, } -impl Barrier { +pub type BarrierMutationType = Option>; +pub type Barrier = BarrierInner; +pub type DispatcherBarrier = BarrierInner<()>; + +impl BarrierInner { /// Create a plain barrier. pub fn new_test_barrier(epoch: u64) -> Self { Self { @@ -331,6 +339,18 @@ impl Barrier { passed_actors: Default::default(), } } +} + +impl Barrier { + pub fn into_dispatcher(self) -> DispatcherBarrier { + DispatcherBarrier { + epoch: self.epoch, + mutation: (), + kind: self.kind, + tracing_context: self.tracing_context, + passed_actors: self.passed_actors, + } + } #[must_use] pub fn with_mutation(self, mutation: Mutation) -> Self { @@ -493,7 +513,7 @@ impl Barrier { } } -impl PartialEq for Barrier { +impl PartialEq for BarrierInner { fn eq(&self, other: &Self) -> bool { self.epoch == other.epoch && self.mutation == other.mutation } @@ -751,50 +771,72 @@ impl Mutation { } } -impl Barrier { - pub fn to_protobuf(&self) -> PbBarrier { - let Barrier { +impl BarrierInner { + fn to_protobuf_inner(&self, barrier_fn: impl FnOnce(&M) -> Option) -> PbBarrier { + let Self { epoch, mutation, kind, passed_actors, tracing_context, .. - } = self.clone(); + } = self; PbBarrier { epoch: Some(PbEpoch { curr: epoch.curr, prev: epoch.prev, }), - mutation: mutation.map(|m| BarrierMutation { - mutation: Some(m.to_protobuf()), + mutation: Some(PbBarrierMutation { + mutation: barrier_fn(mutation), }), tracing_context: tracing_context.to_protobuf(), - kind: kind as _, - passed_actors, + kind: *kind as _, + passed_actors: passed_actors.clone(), } } - pub fn from_protobuf(prost: &PbBarrier) -> StreamExecutorResult { - let mutation = prost - .mutation - .as_ref() - .map(|m| Mutation::from_protobuf(m.mutation.as_ref().unwrap())) - .transpose()? - .map(Arc::new); + fn from_protobuf_inner( + prost: &PbBarrier, + mutation_from_pb: impl FnOnce(Option<&PbMutation>) -> StreamExecutorResult, + ) -> StreamExecutorResult { let epoch = prost.get_epoch()?; - Ok(Barrier { + Ok(Self { kind: prost.kind(), epoch: EpochPair::new(epoch.curr, epoch.prev), - mutation, + mutation: mutation_from_pb( + prost + .mutation + .as_ref() + .and_then(|mutation| mutation.mutation.as_ref()), + )?, passed_actors: prost.get_passed_actors().clone(), tracing_context: TracingContext::from_protobuf(&prost.tracing_context), }) } } +impl DispatcherBarrier { + pub fn to_protobuf(&self) -> PbBarrier { + self.to_protobuf_inner(|_| None) + } +} + +impl Barrier { + pub fn to_protobuf(&self) -> PbBarrier { + self.to_protobuf_inner(|mutation| mutation.as_ref().map(|mutation| mutation.to_protobuf())) + } + + pub fn from_protobuf(prost: &PbBarrier) -> StreamExecutorResult { + Self::from_protobuf_inner(prost, |mutation| { + mutation + .map(|m| Mutation::from_protobuf(m).map(Arc::new)) + .transpose() + }) + } +} + #[derive(Debug, PartialEq, Eq, Clone)] pub struct Watermark { pub col_idx: usize, @@ -871,12 +913,15 @@ impl Watermark { } #[derive(Debug, EnumAsInner, PartialEq, Clone)] -pub enum Message { +pub enum MessageInner { Chunk(StreamChunk), - Barrier(Barrier), + Barrier(BarrierInner), Watermark(Watermark), } +pub type Message = MessageInner; +pub type DispatcherMessage = MessageInner<()>; + impl From for Message { fn from(chunk: StreamChunk) -> Self { Message::Chunk(chunk) @@ -910,7 +955,9 @@ impl Message { }) if mutation.as_ref().unwrap().is_stop() ) } +} +impl DispatcherMessage { pub fn to_protobuf(&self) -> PbStreamMessage { let prost = match self { Self::Chunk(stream_chunk) => { @@ -927,10 +974,21 @@ impl Message { pub fn from_protobuf(prost: &PbStreamMessage) -> StreamExecutorResult { let res = match prost.get_stream_message()? { - StreamMessage::StreamChunk(chunk) => Message::Chunk(StreamChunk::from_protobuf(chunk)?), - StreamMessage::Barrier(barrier) => Message::Barrier(Barrier::from_protobuf(barrier)?), + StreamMessage::StreamChunk(chunk) => Self::Chunk(StreamChunk::from_protobuf(chunk)?), + StreamMessage::Barrier(barrier) => Self::Barrier( + DispatcherBarrier::from_protobuf_inner(barrier, |mutation| { + if mutation.is_some() { + if cfg!(debug_assertions) { + panic!("should not receive message of barrier with mutation"); + } else { + warn!(?barrier, "receive message of barrier with mutation"); + } + } + Ok(()) + })?, + ), StreamMessage::Watermark(watermark) => { - Message::Watermark(Watermark::from_protobuf(watermark)?) + Self::Watermark(Watermark::from_protobuf(watermark)?) } }; Ok(res) diff --git a/src/stream/src/executor/receiver.rs b/src/stream/src/executor/receiver.rs index effe773d5459..58700d2a1350 100644 --- a/src/stream/src/executor/receiver.rs +++ b/src/stream/src/executor/receiver.rs @@ -72,16 +72,26 @@ impl ReceiverExecutor { } #[cfg(test)] - pub fn for_test(input: super::exchange::permit::Receiver) -> Self { + pub fn for_test( + actor_id: ActorId, + input: super::exchange::permit::Receiver, + shared_context: Arc, + ) -> Self { use super::exchange::input::LocalInput; use crate::executor::exchange::input::Input; Self::new( - ActorContext::for_test(114), + ActorContext::for_test(actor_id), 514, 1919, - LocalInput::new(input, 0).boxed_input(), - SharedContext::for_test().into(), + LocalInput::new( + input, + 0, + actor_id, + shared_context.local_barrier_manager.clone(), + ) + .boxed_input(), + shared_context, 810, StreamingMetrics::unused().into(), ) @@ -194,7 +204,8 @@ mod tests { use risingwave_pb::stream_plan::update_mutation::MergeUpdate; use super::*; - use crate::executor::UpdateMutation; + use crate::executor::{MessageInner as Message, UpdateMutation}; + use crate::task::barrier_test_utils::LocalBarrierTestEnv; use crate::task::test_utils::helper_make_local_actor; #[tokio::test] @@ -202,7 +213,9 @@ mod tests { let actor_id = 233; let (old, new) = (114, 514); // old and new upstream actor id - let ctx = Arc::new(SharedContext::for_test()); + let barrier_test_env = LocalBarrierTestEnv::for_test().await; + + let ctx = barrier_test_env.shared_context.clone(); let metrics = Arc::new(StreamingMetrics::unused()); // 1. Register info in context. @@ -261,21 +274,28 @@ mod tests { } }; } - macro_rules! recv { + macro_rules! assert_recv_pending { () => { - receiver + assert!(receiver .next() .now_or_never() .flatten() .transpose() .unwrap() + .is_none()); + }; + } + + macro_rules! recv { + () => { + receiver.next().await.transpose().unwrap() }; } // 3. Send a chunk. send!([old], Message::Chunk(StreamChunk::default())); recv!().unwrap().as_chunk().unwrap(); // We should be able to receive the chunk. - assert!(recv!().is_none()); + assert_recv_pending!(); // 4. Send a configuration change barrier. let merge_updates = maplit::hashmap! { @@ -298,19 +318,22 @@ mod tests { actor_new_dispatchers: Default::default(), }, )); - send!([new], Message::Barrier(b1.clone())); - assert!(recv!().is_none()); // We should not receive the barrier, as new is not the upstream. - send!([old], Message::Barrier(b1.clone())); + barrier_test_env.inject_barrier(&b1, [], [actor_id]); + + send!([new], Message::Barrier(b1.clone().into_dispatcher())); + assert_recv_pending!(); // We should not receive the barrier, as new is not the upstream. + + send!([old], Message::Barrier(b1.clone().into_dispatcher())); recv!().unwrap().as_barrier().unwrap(); // We should now receive the barrier. // 5. Send a chunk to the removed upstream. send_error!([old], Message::Chunk(StreamChunk::default())); - assert!(recv!().is_none()); + assert_recv_pending!(); // 6. Send a chunk to the added upstream. send!([new], Message::Chunk(StreamChunk::default())); recv!().unwrap().as_chunk().unwrap(); // We should be able to receive the chunk. - assert!(recv!().is_none()); + assert_recv_pending!(); } } diff --git a/src/stream/src/task/barrier_manager.rs b/src/stream/src/task/barrier_manager.rs index b0ce6ad30540..1fd932d96675 100644 --- a/src/stream/src/task/barrier_manager.rs +++ b/src/stream/src/task/barrier_manager.rs @@ -63,7 +63,10 @@ use risingwave_pb::stream_service::{ use crate::executor::exchange::permit::Receiver; use crate::executor::monitor::StreamingMetrics; -use crate::executor::{Actor, Barrier, DispatchExecutor, Mutation, StreamExecutorError}; +use crate::executor::{ + Actor, Barrier, BarrierInner, DispatchExecutor, DispatcherBarrier, Mutation, + StreamExecutorError, +}; use crate::task::barrier_manager::managed_state::ManagedBarrierStateDebugInfo; use crate::task::barrier_manager::progress::BackfillState; @@ -182,12 +185,12 @@ impl CreateActorContext { } } -pub(super) type SubscribeMutationItem = (u64, Option>); +pub(crate) type SubscribeMutationItem = (u64, Option>); pub(super) enum LocalBarrierEvent { ReportActorCollected { actor_id: ActorId, - barrier: Barrier, + epoch: EpochPair, }, ReportCreateProgress { current_epoch: u64, @@ -508,8 +511,8 @@ impl LocalBarrierWorker { fn handle_barrier_event(&mut self, event: LocalBarrierEvent) { match event { - LocalBarrierEvent::ReportActorCollected { actor_id, barrier } => { - self.collect(actor_id, &barrier) + LocalBarrierEvent::ReportActorCollected { actor_id, epoch } => { + self.collect(actor_id, epoch) } LocalBarrierEvent::ReportCreateProgress { current_epoch, @@ -764,8 +767,8 @@ impl LocalBarrierWorker { /// When a [`crate::executor::StreamConsumer`] (typically [`crate::executor::DispatchExecutor`]) get a barrier, it should report /// and collect this barrier with its own `actor_id` using this function. - fn collect(&mut self, actor_id: ActorId, barrier: &Barrier) { - self.state.collect(actor_id, barrier) + fn collect(&mut self, actor_id: ActorId, epoch: EpochPair) { + self.state.collect(actor_id, epoch) } /// When a actor exit unexpectedly, the error is reported using this function. The control stream @@ -904,10 +907,10 @@ impl LocalBarrierManager { /// When a [`crate::executor::StreamConsumer`] (typically [`crate::executor::DispatchExecutor`]) get a barrier, it should report /// and collect this barrier with its own `actor_id` using this function. - pub fn collect(&self, actor_id: ActorId, barrier: &Barrier) { + pub fn collect(&self, actor_id: ActorId, barrier: &BarrierInner) { self.send_event(LocalBarrierEvent::ReportActorCollected { actor_id, - barrier: barrier.clone(), + epoch: barrier.epoch, }) } @@ -923,7 +926,7 @@ impl LocalBarrierManager { pub fn subscribe_barrier_mutation( &self, actor_id: ActorId, - first_barrier: &Barrier, + first_barrier: &DispatcherBarrier, ) -> mpsc::UnboundedReceiver { let (tx, rx) = mpsc::unbounded_channel(); self.send_event(LocalBarrierEvent::SubscribeBarrierMutation { @@ -996,7 +999,7 @@ pub fn try_find_root_actor_failure<'a>( #[cfg(test)] impl LocalBarrierManager { - pub(super) fn spawn_for_test() -> EventSender { + fn spawn_for_test() -> EventSender { use std::sync::atomic::AtomicU64; let (tx, rx) = unbounded_channel(); let _join_handle = LocalBarrierWorker::spawn( @@ -1028,3 +1031,85 @@ impl LocalBarrierManager { rx.await.unwrap() } } + +#[cfg(test)] +pub(crate) mod barrier_test_utils { + use std::sync::Arc; + + use assert_matches::assert_matches; + use futures::StreamExt; + use risingwave_pb::stream_service::streaming_control_stream_request::InitRequest; + use risingwave_pb::stream_service::{ + streaming_control_stream_request, streaming_control_stream_response, InjectBarrierRequest, + StreamingControlStreamRequest, StreamingControlStreamResponse, + }; + use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; + use tokio_stream::wrappers::UnboundedReceiverStream; + use tonic::Status; + + use crate::executor::Barrier; + use crate::task::barrier_manager::{ControlStreamHandle, EventSender, LocalActorOperation}; + use crate::task::{ActorId, LocalBarrierManager, SharedContext}; + + pub(crate) struct LocalBarrierTestEnv { + pub shared_context: Arc, + pub(super) actor_op_tx: EventSender, + pub request_tx: UnboundedSender>, + pub response_rx: UnboundedReceiver>, + } + + impl LocalBarrierTestEnv { + pub(crate) async fn for_test() -> Self { + let actor_op_tx = LocalBarrierManager::spawn_for_test(); + + let (request_tx, request_rx) = unbounded_channel(); + let (response_tx, mut response_rx) = unbounded_channel(); + + actor_op_tx.send_event(LocalActorOperation::NewControlStream { + handle: ControlStreamHandle::new( + response_tx, + UnboundedReceiverStream::new(request_rx).boxed(), + ), + init_request: InitRequest { prev_epoch: 0 }, + }); + + assert_matches!( + response_rx.recv().await.unwrap().unwrap().response.unwrap(), + streaming_control_stream_response::Response::Init(_) + ); + + let shared_context = actor_op_tx + .send_and_await(LocalActorOperation::GetCurrentSharedContext) + .await + .unwrap(); + + Self { + shared_context, + actor_op_tx, + request_tx, + response_rx, + } + } + + pub(crate) fn inject_barrier( + &self, + barrier: &Barrier, + actor_to_send: impl IntoIterator, + actor_to_collect: impl IntoIterator, + ) { + self.request_tx + .send(Ok(StreamingControlStreamRequest { + request: Some(streaming_control_stream_request::Request::InjectBarrier( + InjectBarrierRequest { + request_id: "".to_string(), + barrier: Some(barrier.to_protobuf()), + actor_ids_to_send: actor_to_send.into_iter().collect(), + actor_ids_to_collect: actor_to_collect.into_iter().collect(), + table_ids_to_sync: vec![], + }, + )), + })) + .unwrap(); + } + } +} diff --git a/src/stream/src/task/barrier_manager/managed_state.rs b/src/stream/src/task/barrier_manager/managed_state.rs index f4a3fb31c03c..ae1a576fe7c4 100644 --- a/src/stream/src/task/barrier_manager/managed_state.rs +++ b/src/stream/src/task/barrier_manager/managed_state.rs @@ -476,21 +476,21 @@ impl ManagedBarrierState { } /// Collect a `barrier` from the actor with `actor_id`. - pub(super) fn collect(&mut self, actor_id: ActorId, barrier: &Barrier) { + pub(super) fn collect(&mut self, actor_id: ActorId, epoch: EpochPair) { tracing::debug!( target: "events::stream::barrier::manager::collect", - epoch = ?barrier.epoch, actor_id, state = ?self.epoch_barrier_state_map, + ?epoch, actor_id, state = ?self.epoch_barrier_state_map, "collect_barrier", ); - match self.epoch_barrier_state_map.get_mut(&barrier.epoch.prev) { + match self.epoch_barrier_state_map.get_mut(&epoch.prev) { None => { // If the barrier's state is stashed, this occurs exclusively in scenarios where the barrier has not been // injected by the barrier manager, or the barrier message is blocked at the `RemoteInput` side waiting for injection. // Given these conditions, it's inconceivable for an actor to attempt collect at this point. panic!( "cannot collect new actor barrier {:?} at current state: None", - barrier.epoch, + epoch, ) } Some(&mut BarrierState { @@ -506,15 +506,15 @@ impl ManagedBarrierState { assert!( exist, "the actor doesn't exist. actor_id: {:?}, curr_epoch: {:?}", - actor_id, barrier.epoch.curr + actor_id, epoch.curr ); - assert_eq!(curr_epoch, barrier.epoch.curr); - self.may_have_collected_all(barrier.epoch.prev); + assert_eq!(curr_epoch, epoch.curr); + self.may_have_collected_all(epoch.prev); } Some(BarrierState { inner, .. }) => { panic!( "cannot collect new actor barrier {:?} at current state: {:?}", - barrier.epoch, inner + epoch, inner ) } } @@ -723,8 +723,8 @@ mod tests { managed_barrier_state.transform_to_issued(&barrier1, actor_ids_to_collect1, HashSet::new()); managed_barrier_state.transform_to_issued(&barrier2, actor_ids_to_collect2, HashSet::new()); managed_barrier_state.transform_to_issued(&barrier3, actor_ids_to_collect3, HashSet::new()); - managed_barrier_state.collect(1, &barrier1); - managed_barrier_state.collect(2, &barrier1); + managed_barrier_state.collect(1, barrier1.epoch); + managed_barrier_state.collect(2, barrier1.epoch); assert_eq!( managed_barrier_state.pop_next_completed_epoch().await, test_epoch(0) @@ -737,9 +737,9 @@ mod tests { .0, &test_epoch(1) ); - managed_barrier_state.collect(1, &barrier2); - managed_barrier_state.collect(1, &barrier3); - managed_barrier_state.collect(2, &barrier2); + managed_barrier_state.collect(1, barrier2.epoch); + managed_barrier_state.collect(1, barrier3.epoch); + managed_barrier_state.collect(2, barrier2.epoch); assert_eq!( managed_barrier_state.pop_next_completed_epoch().await, test_epoch(1) @@ -752,8 +752,8 @@ mod tests { .0, { &test_epoch(2) } ); - managed_barrier_state.collect(2, &barrier3); - managed_barrier_state.collect(3, &barrier3); + managed_barrier_state.collect(2, barrier3.epoch); + managed_barrier_state.collect(3, barrier3.epoch); assert_eq!( managed_barrier_state.pop_next_completed_epoch().await, test_epoch(2) @@ -774,12 +774,12 @@ mod tests { managed_barrier_state.transform_to_issued(&barrier2, actor_ids_to_collect2, HashSet::new()); managed_barrier_state.transform_to_issued(&barrier3, actor_ids_to_collect3, HashSet::new()); - managed_barrier_state.collect(1, &barrier1); - managed_barrier_state.collect(1, &barrier2); - managed_barrier_state.collect(1, &barrier3); - managed_barrier_state.collect(2, &barrier1); - managed_barrier_state.collect(2, &barrier2); - managed_barrier_state.collect(2, &barrier3); + managed_barrier_state.collect(1, barrier1.epoch); + managed_barrier_state.collect(1, barrier2.epoch); + managed_barrier_state.collect(1, barrier3.epoch); + managed_barrier_state.collect(2, barrier1.epoch); + managed_barrier_state.collect(2, barrier2.epoch); + managed_barrier_state.collect(2, barrier3.epoch); assert_eq!( managed_barrier_state .epoch_barrier_state_map @@ -788,8 +788,8 @@ mod tests { .0, &0 ); - managed_barrier_state.collect(3, &barrier1); - managed_barrier_state.collect(3, &barrier2); + managed_barrier_state.collect(3, barrier1.epoch); + managed_barrier_state.collect(3, barrier2.epoch); assert_eq!( managed_barrier_state .epoch_barrier_state_map @@ -798,7 +798,7 @@ mod tests { .0, &0 ); - managed_barrier_state.collect(4, &barrier1); + managed_barrier_state.collect(4, barrier1.epoch); assert_eq!( managed_barrier_state.pop_next_completed_epoch().await, test_epoch(0) diff --git a/src/stream/src/task/barrier_manager/tests.rs b/src/stream/src/task/barrier_manager/tests.rs index 60b3867a1a0c..0d1bb159ea5f 100644 --- a/src/stream/src/task/barrier_manager/tests.rs +++ b/src/stream/src/task/barrier_manager/tests.rs @@ -17,44 +17,21 @@ use std::iter::once; use std::pin::pin; use std::task::Poll; -use assert_matches::assert_matches; use futures::future::join_all; use futures::FutureExt; use risingwave_common::util::epoch::test_epoch; -use risingwave_pb::stream_service::{streaming_control_stream_request, InjectBarrierRequest}; -use tokio_stream::wrappers::UnboundedReceiverStream; use super::*; +use crate::task::barrier_test_utils::LocalBarrierTestEnv; #[tokio::test] async fn test_managed_barrier_collection() -> StreamResult<()> { - let actor_op_tx = LocalBarrierManager::spawn_for_test(); + let mut test_env = LocalBarrierTestEnv::for_test().await; - let (request_tx, request_rx) = unbounded_channel(); - let (response_tx, mut response_rx) = unbounded_channel(); - - actor_op_tx.send_event(LocalActorOperation::NewControlStream { - handle: ControlStreamHandle::new( - response_tx, - UnboundedReceiverStream::new(request_rx).boxed(), - ), - init_request: InitRequest { prev_epoch: 0 }, - }); - - assert_matches!( - response_rx.recv().await.unwrap().unwrap().response.unwrap(), - streaming_control_stream_response::Response::Init(_) - ); - - let context = actor_op_tx - .send_and_await(LocalActorOperation::GetCurrentSharedContext) - .await - .unwrap(); - - let manager = &context.local_barrier_manager; + let manager = &test_env.shared_context.local_barrier_manager; let register_sender = |actor_id: u32| { - let actor_op_tx = &actor_op_tx; + let actor_op_tx = &test_env.actor_op_tx; async move { let (barrier_tx, barrier_rx) = unbounded_channel(); actor_op_tx @@ -79,19 +56,7 @@ async fn test_managed_barrier_collection() -> StreamResult<()> { let barrier = Barrier::new_test_barrier(curr_epoch); let epoch = barrier.epoch.prev; - request_tx - .send(Ok(StreamingControlStreamRequest { - request: Some(streaming_control_stream_request::Request::InjectBarrier( - InjectBarrierRequest { - request_id: "".to_string(), - barrier: Some(barrier.to_protobuf()), - actor_ids_to_send: actor_ids.clone(), - actor_ids_to_collect: actor_ids, - table_ids_to_sync: vec![], - }, - )), - })) - .unwrap(); + test_env.inject_barrier(&barrier, actor_ids.clone(), actor_ids); // Collect barriers from actors let collected_barriers = join_all(rxs.iter_mut().map(|(actor_id, rx)| async move { @@ -101,7 +66,7 @@ async fn test_managed_barrier_collection() -> StreamResult<()> { })) .await; - let mut await_epoch_future = pin!(response_rx.recv().map(|result| { + let mut await_epoch_future = pin!(test_env.response_rx.recv().map(|result| { let resp: StreamingControlStreamResponse = result.unwrap().unwrap(); let resp = resp.response.unwrap(); match resp { @@ -124,33 +89,12 @@ async fn test_managed_barrier_collection() -> StreamResult<()> { #[tokio::test] async fn test_managed_barrier_collection_separately() -> StreamResult<()> { - let actor_op_tx = LocalBarrierManager::spawn_for_test(); - - let (request_tx, request_rx) = unbounded_channel(); - let (response_tx, mut response_rx) = unbounded_channel(); + let mut test_env = LocalBarrierTestEnv::for_test().await; - actor_op_tx.send_event(LocalActorOperation::NewControlStream { - handle: ControlStreamHandle::new( - response_tx, - UnboundedReceiverStream::new(request_rx).boxed(), - ), - init_request: InitRequest { prev_epoch: 0 }, - }); - - assert_matches!( - response_rx.recv().await.unwrap().unwrap().response.unwrap(), - streaming_control_stream_response::Response::Init(_) - ); - - let context = actor_op_tx - .send_and_await(LocalActorOperation::GetCurrentSharedContext) - .await - .unwrap(); - - let manager = &context.local_barrier_manager; + let manager = &test_env.shared_context.local_barrier_manager; let register_sender = |actor_id: u32| { - let actor_op_tx = &actor_op_tx; + let actor_op_tx = &test_env.actor_op_tx; async move { let (barrier_tx, barrier_rx) = unbounded_channel(); actor_op_tx @@ -179,27 +123,16 @@ async fn test_managed_barrier_collection_separately() -> StreamResult<()> { // Prepare the barrier let curr_epoch = test_epoch(2); - let barrier = Barrier::new_test_barrier(curr_epoch); + let barrier = Barrier::new_test_barrier(curr_epoch).with_stop(); - let mut mutation_subscriber = manager.subscribe_barrier_mutation(extra_actor_id, &barrier); + let mut mutation_subscriber = + manager.subscribe_barrier_mutation(extra_actor_id, &barrier.clone().into_dispatcher()); // Read the mutation after receiving the barrier from remote input. let mut mutation_reader = pin!(mutation_subscriber.recv()); assert!(poll_fn(|cx| Poll::Ready(mutation_reader.as_mut().poll(cx).is_pending())).await); - request_tx - .send(Ok(StreamingControlStreamRequest { - request: Some(streaming_control_stream_request::Request::InjectBarrier( - InjectBarrierRequest { - request_id: "".to_string(), - barrier: Some(barrier.to_protobuf()), - actor_ids_to_send, - actor_ids_to_collect, - table_ids_to_sync: vec![], - }, - )), - })) - .unwrap(); + test_env.inject_barrier(&barrier, actor_ids_to_send, actor_ids_to_collect); let (epoch, mutation) = mutation_reader.await.unwrap(); assert_eq!((epoch, &mutation), (barrier.epoch.prev, &barrier.mutation)); @@ -215,7 +148,7 @@ async fn test_managed_barrier_collection_separately() -> StreamResult<()> { })) .await; - let mut await_epoch_future = pin!(response_rx.recv().map(|result| { + let mut await_epoch_future = pin!(test_env.response_rx.recv().map(|result| { let resp: StreamingControlStreamResponse = result.unwrap().unwrap(); let resp = resp.response.unwrap(); match resp { diff --git a/src/stream/src/task/env.rs b/src/stream/src/task/env.rs index a47eb8279224..75c64f9ae8bf 100644 --- a/src/stream/src/task/env.rs +++ b/src/stream/src/task/env.rs @@ -89,7 +89,7 @@ impl StreamEnvironment { use risingwave_dml::dml_manager::DmlManager; use risingwave_storage::monitor::MonitoredStorageMetrics; StreamEnvironment { - server_addr: "127.0.0.1:5688".parse().unwrap(), + server_addr: "127.0.0.1:2333".parse().unwrap(), config: Arc::new(StreamingConfig::default()), worker_id: WorkerNodeId::default(), state_store: StateStoreImpl::shared_in_memory_store(Arc::new( diff --git a/src/stream/src/task/mod.rs b/src/stream/src/task/mod.rs index 77f21b52406f..9a337e1dc0ab 100644 --- a/src/stream/src/task/mod.rs +++ b/src/stream/src/task/mod.rs @@ -28,6 +28,7 @@ mod barrier_manager; mod env; mod stream_manager; +pub(crate) use barrier_manager::SubscribeMutationItem; pub use barrier_manager::*; pub use env::*; pub use stream_manager::*; From 9335967a8558aaf4c822cfc5c37b2305dee22839 Mon Sep 17 00:00:00 2001 From: William Wen <44139337+wenym1@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:38:37 +0800 Subject: [PATCH 45/70] refactor(meta): only store CreateStreamingJob command in tracker (#17742) --- src/meta/src/barrier/command.rs | 160 ++++++++++++-------------- src/meta/src/barrier/info.rs | 20 ++-- src/meta/src/barrier/mod.rs | 92 ++++++++------- src/meta/src/barrier/progress.rs | 157 ++++++++++++------------- src/meta/src/barrier/recovery.rs | 12 +- src/meta/src/barrier/schedule.rs | 6 +- src/meta/src/manager/metadata.rs | 11 ++ src/meta/src/stream/stream_manager.rs | 12 +- 8 files changed, 234 insertions(+), 236 deletions(-) diff --git a/src/meta/src/barrier/command.rs b/src/meta/src/barrier/command.rs index 486e314ae046..4cfc21fd8f1c 100644 --- a/src/meta/src/barrier/command.rs +++ b/src/meta/src/barrier/command.rs @@ -41,7 +41,9 @@ use risingwave_pb::stream_service::WaitEpochCommitRequest; use thiserror_ext::AsReport; use tracing::warn; -use super::info::{CommandActorChanges, CommandFragmentChanges, InflightActorInfo}; +use super::info::{ + CommandActorChanges, CommandFragmentChanges, CommandNewFragmentInfo, InflightActorInfo, +}; use super::trace::TracedEpoch; use crate::barrier::GlobalBarrierManagerContext; use crate::manager::{DdlType, MetadataManager, StreamingJob, WorkerId}; @@ -107,7 +109,7 @@ impl ReplaceTablePlan { fn actor_changes(&self) -> CommandActorChanges { let mut fragment_changes = HashMap::new(); for fragment in self.new_table_fragments.fragments.values() { - let fragment_change = CommandFragmentChanges::NewFragment { + let fragment_change = CommandFragmentChanges::NewFragment(CommandNewFragmentInfo { new_actors: fragment .actors .iter() @@ -130,7 +132,7 @@ impl ReplaceTablePlan { .map(|table_id| TableId::new(*table_id)) .collect(), is_injectable: TableFragments::is_injectable(fragment.fragment_type_mask), - }; + }); assert!(fragment_changes .insert(fragment.fragment_id, fragment_change) .is_none()); @@ -144,6 +146,54 @@ impl ReplaceTablePlan { } } +#[derive(Debug, Clone)] +pub struct CreateStreamingJobCommandInfo { + pub table_fragments: TableFragments, + /// Refer to the doc on [`MetadataManager::get_upstream_root_fragments`] for the meaning of "root". + pub upstream_root_actors: HashMap>, + pub dispatchers: HashMap>, + pub init_split_assignment: SplitAssignment, + pub definition: String, + pub ddl_type: DdlType, + pub create_type: CreateType, + pub streaming_job: StreamingJob, + pub internal_tables: Vec
, +} + +impl CreateStreamingJobCommandInfo { + fn new_fragment_info(&self) -> impl Iterator + '_ { + self.table_fragments.fragments.values().map(|fragment| { + ( + fragment.fragment_id, + CommandNewFragmentInfo { + new_actors: fragment + .actors + .iter() + .map(|actor| { + ( + actor.actor_id, + self.table_fragments + .actor_status + .get(&actor.actor_id) + .expect("should exist") + .get_parallel_unit() + .expect("should set") + .worker_node_id, + ) + }) + .collect(), + table_ids: fragment + .state_table_ids + .iter() + .map(|table_id| TableId::new(*table_id)) + .collect(), + is_injectable: TableFragments::is_injectable(fragment.fragment_type_mask), + }, + ) + }) + } +} + /// [`Command`] is the input of [`crate::barrier::GlobalBarrierManager`]. For different commands, /// it will build different barriers to send, and may do different stuffs after the barrier is /// collected. @@ -187,16 +237,7 @@ pub enum Command { /// for a while** until the `finish` channel is signaled, then the state of `TableFragments` /// will be set to `Created`. CreateStreamingJob { - streaming_job: StreamingJob, - internal_tables: Vec
, - table_fragments: TableFragments, - /// Refer to the doc on [`MetadataManager::get_upstream_root_fragments`] for the meaning of "root". - upstream_root_actors: HashMap>, - dispatchers: HashMap>, - init_split_assignment: SplitAssignment, - definition: String, - ddl_type: DdlType, - create_type: CreateType, + info: CreateStreamingJobCommandInfo, /// This is for create SINK into table. replace_table: Option, }, @@ -279,43 +320,13 @@ impl Command { .collect(), }), Command::CreateStreamingJob { - table_fragments, + info, replace_table, - .. } => { - let fragment_changes = table_fragments - .fragments - .values() - .map(|fragment| { - ( - fragment.fragment_id, - CommandFragmentChanges::NewFragment { - new_actors: fragment - .actors - .iter() - .map(|actor| { - ( - actor.actor_id, - table_fragments - .actor_status - .get(&actor.actor_id) - .expect("should exist") - .get_parallel_unit() - .expect("should set") - .worker_node_id, - ) - }) - .collect(), - table_ids: fragment - .state_table_ids - .iter() - .map(|table_id| TableId::new(*table_id)) - .collect(), - is_injectable: TableFragments::is_injectable( - fragment.fragment_type_mask, - ), - }, - ) + let fragment_changes = info + .new_fragment_info() + .map(|(fragment_id, info)| { + (fragment_id, CommandFragmentChanges::NewFragment(info)) }) .collect(); let mut changes = CommandActorChanges { fragment_changes }; @@ -460,10 +471,6 @@ impl CommandContext { _span: span, } } - - pub fn metadata_manager(&self) -> &MetadataManager { - &self.barrier_manager_context.metadata_manager - } } impl CommandContext { @@ -521,11 +528,14 @@ impl CommandContext { })), Command::CreateStreamingJob { - table_fragments, - dispatchers, - init_split_assignment: split_assignment, + info: + CreateStreamingJobCommandInfo { + table_fragments, + dispatchers, + init_split_assignment: split_assignment, + .. + }, replace_table, - .. } => { let actor_dispatchers = dispatchers .iter() @@ -818,20 +828,6 @@ impl CommandContext { } } - /// For `CreateStreamingJob`, returns the actors of the `StreamScan`, and `StreamValue` nodes. For other commands, - /// returns an empty set. - pub fn actors_to_track(&self) -> HashSet { - match &self.command { - Command::CreateStreamingJob { - table_fragments, .. - } => table_fragments - .tracking_progress_actor_ids() - .into_iter() - .collect(), - _ => Default::default(), - } - } - /// For `CancelStreamingJob`, returns the table id of the target table. pub fn table_to_cancel(&self) -> Option { match &self.command { @@ -840,16 +836,6 @@ impl CommandContext { } } - /// For `CreateStreamingJob`, returns the table id of the target table. - pub fn table_to_create(&self) -> Option { - match &self.command { - Command::CreateStreamingJob { - table_fragments, .. - } => Some(table_fragments.table_id()), - _ => None, - } - } - /// Clean up actors in CNs if needed, used by drop, cancel and reschedule commands. async fn clean_up(&self, actors: Vec) -> MetaResult<()> { self.barrier_manager_context @@ -992,14 +978,16 @@ impl CommandContext { } Command::CreateStreamingJob { - table_fragments, - dispatchers, - upstream_root_actors, - init_split_assignment, - definition: _, + info, replace_table, - .. } => { + let CreateStreamingJobCommandInfo { + table_fragments, + dispatchers, + upstream_root_actors, + init_split_assignment, + .. + } = info; match &self.barrier_manager_context.metadata_manager { MetadataManager::V1(mgr) => { let mut dependent_table_actors = diff --git a/src/meta/src/barrier/info.rs b/src/meta/src/barrier/info.rs index f6617b9ceef4..a6bdcdace6e5 100644 --- a/src/meta/src/barrier/info.rs +++ b/src/meta/src/barrier/info.rs @@ -22,13 +22,16 @@ use crate::barrier::Command; use crate::manager::{ActiveStreamingWorkerNodes, ActorInfos, InflightFragmentInfo, WorkerId}; use crate::model::{ActorId, FragmentId}; +#[derive(Debug, Clone)] +pub(crate) struct CommandNewFragmentInfo { + pub new_actors: HashMap, + pub table_ids: HashSet, + pub is_injectable: bool, +} + #[derive(Debug, Clone)] pub(crate) enum CommandFragmentChanges { - NewFragment { - new_actors: HashMap, - table_ids: HashSet, - is_injectable: bool, - }, + NewFragment(CommandNewFragmentInfo), Reschedule { new_actors: HashMap, to_remove: HashSet, @@ -149,11 +152,12 @@ impl InflightActorInfo { let mut to_add = HashMap::new(); for (fragment_id, change) in fragment_changes { match change { - CommandFragmentChanges::NewFragment { + CommandFragmentChanges::NewFragment(CommandNewFragmentInfo { new_actors, table_ids, is_injectable, - } => { + .. + }) => { for (actor_id, node_id) in &new_actors { assert!(to_add .insert(*actor_id, (*node_id, is_injectable)) @@ -232,7 +236,7 @@ impl InflightActorInfo { let mut all_to_remove = HashSet::new(); for (fragment_id, changes) in fragment_changes.fragment_changes { match changes { - CommandFragmentChanges::NewFragment { .. } => {} + CommandFragmentChanges::NewFragment(_) => {} CommandFragmentChanges::Reschedule { to_remove, .. } => { let info = self .fragment_infos diff --git a/src/meta/src/barrier/mod.rs b/src/meta/src/barrier/mod.rs index 1c05497665fe..71b7ce60affc 100644 --- a/src/meta/src/barrier/mod.rs +++ b/src/meta/src/barrier/mod.rs @@ -23,7 +23,9 @@ use std::time::Duration; use anyhow::Context; use arc_swap::ArcSwap; use fail::fail_point; +use futures::future::try_join_all; use itertools::Itertools; +use parking_lot::Mutex; use prometheus::HistogramTimer; use risingwave_common::bail; use risingwave_common::catalog::TableId; @@ -44,16 +46,14 @@ use risingwave_pb::stream_service::barrier_complete_response::CreateMviewProgres use risingwave_pb::stream_service::BarrierCompleteResponse; use thiserror_ext::AsReport; use tokio::sync::oneshot::{Receiver, Sender}; -use tokio::sync::Mutex; use tokio::task::JoinHandle; use tracing::{error, info, warn, Instrument}; use self::command::CommandContext; use self::notifier::Notifier; -use self::progress::TrackingCommand; use crate::barrier::info::InflightActorInfo; use crate::barrier::notifier::BarrierInfo; -use crate::barrier::progress::CreateMviewProgressTracker; +use crate::barrier::progress::{CreateMviewProgressTracker, TrackingJob}; use crate::barrier::rpc::ControlStreamManager; use crate::barrier::state::BarrierManagerState; use crate::error::MetaErrorInner; @@ -77,7 +77,9 @@ mod schedule; mod state; mod trace; -pub use self::command::{BarrierKind, Command, ReplaceTablePlan, Reschedule}; +pub use self::command::{ + BarrierKind, Command, CreateStreamingJobCommandInfo, ReplaceTablePlan, Reschedule, +}; pub use self::rpc::StreamRpcManager; pub use self::schedule::BarrierScheduler; pub use self::trace::TracedEpoch; @@ -456,7 +458,7 @@ impl GlobalBarrierManager { let active_streaming_nodes = ActiveStreamingWorkerNodes::uninitialized(); - let tracker = CreateMviewProgressTracker::new(); + let tracker = CreateMviewProgressTracker::default(); let context = GlobalBarrierManagerContext { status: Arc::new(ArcSwap::new(Arc::new(BarrierManagerStatus::Starting))), @@ -789,6 +791,7 @@ impl GlobalBarrierManager { } async fn failure_recovery(&mut self, err: MetaError) { + self.context.tracker.lock().abort_all(); self.checkpoint_control.clear_on_err(&err).await; self.pending_non_checkpoint_barriers.clear(); @@ -816,6 +819,7 @@ impl GlobalBarrierManager { async fn adhoc_recovery(&mut self) { let err = MetaErrorInner::AdhocRecovery.into(); + self.context.tracker.lock().abort_all(); self.checkpoint_control.clear_on_err(&err).await; self.context @@ -846,13 +850,14 @@ impl GlobalBarrierManagerContext { .. } = node; assert!(state.node_to_collect.is_empty()); - let resps = state.resps; let wait_commit_timer = self.metrics.barrier_wait_commit_latency.start_timer(); - let create_mview_progress = resps + let create_mview_progress = state + .resps .iter() .flat_map(|resp| resp.create_mview_progress.iter().cloned()) .collect(); - if let Err(e) = self.update_snapshot(&command_ctx, resps).await { + + if let Err(e) = self.update_snapshot(&command_ctx, state).await { for notifier in notifiers { notifier.notify_collection_failed(e.clone()); } @@ -861,9 +866,15 @@ impl GlobalBarrierManagerContext { notifiers.into_iter().for_each(|notifier| { notifier.notify_collected(); }); - let has_remaining = self + + let (has_remaining, finished_jobs) = self .update_tracking_jobs(command_ctx.clone(), create_mview_progress) - .await?; + .await; + try_join_all(finished_jobs.into_iter().map(|finished_job| { + let metadata_manager = &self.metadata_manager; + async move { finished_job.pre_finish(metadata_manager).await } + })) + .await?; let duration_sec = enqueue_time.stop_and_record(); self.report_complete_event(duration_sec, &command_ctx); wait_commit_timer.observe_duration(); @@ -879,7 +890,7 @@ impl GlobalBarrierManagerContext { async fn update_snapshot( &self, command_ctx: &CommandContext, - resps: Vec, + state: BarrierEpochState, ) -> MetaResult<()> { { { @@ -894,7 +905,7 @@ impl GlobalBarrierManagerContext { match &command_ctx.kind { BarrierKind::Initial => {} BarrierKind::Checkpoint(epochs) => { - let commit_info = collect_commit_epoch_info(resps, command_ctx, epochs); + let commit_info = collect_commit_epoch_info(state, command_ctx, epochs); new_snapshot = self.hummock_manager.commit_epoch(commit_info).await?; } BarrierKind::Barrier => { @@ -926,23 +937,18 @@ impl GlobalBarrierManagerContext { &self, command_ctx: Arc, create_mview_progress: Vec, - ) -> MetaResult { + ) -> (bool, Vec) { { { // Notify about collected. let version_stats = self.hummock_manager.get_version_stats().await; - let mut tracker = self.tracker.lock().await; + let mut tracker = self.tracker.lock(); // Save `finished_commands` for Create MVs. let finished_commands = { let mut commands = vec![]; // Add the command to tracker. - if let Some(command) = tracker.add( - TrackingCommand { - context: command_ctx.clone(), - }, - &version_stats, - ) { + if let Some(command) = tracker.add(&command_ctx, &version_stats) { // Those with no actors to track can be finished immediately. commands.push(command); } @@ -969,11 +975,11 @@ impl GlobalBarrierManagerContext { tracker.cancel_command(table_id); } - let has_remaining_job = tracker - .finish_jobs(command_ctx.kind.is_checkpoint()) - .await?; - - Ok(has_remaining_job) + if command_ctx.kind.is_checkpoint() { + (false, tracker.take_finished_jobs()) + } else { + (tracker.has_pending_finished_jobs(), vec![]) + } } } } @@ -1099,7 +1105,7 @@ impl GlobalBarrierManagerContext { } pub async fn get_ddl_progress(&self) -> Vec { - let mut ddl_progress = self.tracker.lock().await.gen_ddl_progress(); + let mut ddl_progress = self.tracker.lock().gen_ddl_progress(); // If not in tracker, means the first barrier not collected yet. // In that case just return progress 0. match &self.metadata_manager { @@ -1142,10 +1148,11 @@ impl GlobalBarrierManagerContext { pub type BarrierManagerRef = GlobalBarrierManagerContext; fn collect_commit_epoch_info( - resps: Vec, + state: BarrierEpochState, command_ctx: &CommandContext, epochs: &Vec, ) -> CommitEpochInfo { + let resps = state.resps; let mut sst_to_worker: HashMap = HashMap::new(); let mut synced_ssts: Vec = vec![]; let mut table_watermarks = Vec::with_capacity(resps.len()); @@ -1163,22 +1170,21 @@ fn collect_commit_epoch_info( table_watermarks.push(resp.table_watermarks); old_value_ssts.extend(resp.old_value_sstables); } - let new_table_fragment_info = if let Command::CreateStreamingJob { - table_fragments, .. - } = &command_ctx.command - { - Some(NewTableFragmentInfo { - table_id: table_fragments.table_id(), - mv_table_id: table_fragments.mv_table_id().map(TableId::new), - internal_table_ids: table_fragments - .internal_table_ids() - .into_iter() - .map(TableId::new) - .collect(), - }) - } else { - None - }; + let new_table_fragment_info = + if let Command::CreateStreamingJob { info, .. } = &command_ctx.command { + let table_fragments = &info.table_fragments; + Some(NewTableFragmentInfo { + table_id: table_fragments.table_id(), + mv_table_id: table_fragments.mv_table_id().map(TableId::new), + internal_table_ids: table_fragments + .internal_table_ids() + .into_iter() + .map(TableId::new) + .collect(), + }) + } else { + None + }; let table_new_change_log = build_table_change_log_delta( old_value_ssts.into_iter(), diff --git a/src/meta/src/barrier/progress.rs b/src/meta/src/barrier/progress.rs index 5fdf875486fd..fc562c67e71f 100644 --- a/src/meta/src/barrier/progress.rs +++ b/src/meta/src/barrier/progress.rs @@ -14,6 +14,7 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::mem::take; use std::sync::Arc; use risingwave_common::catalog::TableId; @@ -25,12 +26,12 @@ use risingwave_pb::hummock::HummockVersionStats; use risingwave_pb::stream_service::barrier_complete_response::CreateMviewProgress; use super::command::CommandContext; -use crate::barrier::{Command, GlobalBarrierManagerContext}; +use crate::barrier::{Command, CreateStreamingJobCommandInfo, ReplaceTablePlan}; use crate::manager::{ DdlType, MetadataManager, MetadataManagerV1, MetadataManagerV2, StreamingJob, }; use crate::model::{ActorId, TableFragments}; -use crate::{MetaError, MetaResult}; +use crate::MetaResult; type ConsumedRows = u64; @@ -160,28 +161,16 @@ pub enum TrackingJob { } impl TrackingJob { - /// Returns whether the `TrackingJob` requires a checkpoint to complete. - pub(crate) fn is_checkpoint_required(&self) -> bool { - match self { - // Recovered tracking job is always a streaming job, - // It requires a checkpoint to complete. - TrackingJob::RecoveredV1(_) | TrackingJob::RecoveredV2(_) => true, - TrackingJob::New(command) => { - command.context.kind.is_initial() || command.context.kind.is_checkpoint() - } - } - } - - pub(crate) async fn pre_finish(&self) -> MetaResult<()> { + pub(crate) async fn pre_finish(&self, metadata_manager: &MetadataManager) -> MetaResult<()> { match &self { - TrackingJob::New(command) => match &command.context.command { - Command::CreateStreamingJob { + TrackingJob::New(command) => { + let CreateStreamingJobCommandInfo { table_fragments, streaming_job, internal_tables, - replace_table, .. - } => match command.context.metadata_manager() { + } = &command.info; + match metadata_manager { MetadataManager::V1(mgr) => { mgr.fragment_manager .mark_table_fragments_created(table_fragments.table_id()) @@ -193,13 +182,15 @@ impl TrackingJob { } MetadataManager::V2(mgr) => { mgr.catalog_controller - .finish_streaming_job(streaming_job.id() as i32, replace_table.clone()) + .finish_streaming_job( + streaming_job.id() as i32, + command.replace_table_info.clone(), + ) .await?; Ok(()) } - }, - _ => Ok(()), - }, + } + } TrackingJob::RecoveredV1(recovered) => { let manager = &recovered.metadata_manager; manager @@ -226,11 +217,11 @@ impl TrackingJob { } } - pub(crate) fn table_to_create(&self) -> Option { + pub(crate) fn table_to_create(&self) -> TableId { match self { - TrackingJob::New(command) => command.context.table_to_create(), - TrackingJob::RecoveredV1(recovered) => Some(recovered.fragments.table_id()), - TrackingJob::RecoveredV2(recovered) => Some((recovered.id as u32).into()), + TrackingJob::New(command) => command.info.table_fragments.table_id(), + TrackingJob::RecoveredV1(recovered) => recovered.fragments.table_id(), + TrackingJob::RecoveredV2(recovered) => (recovered.id as u32).into(), } } } @@ -241,7 +232,7 @@ impl std::fmt::Debug for TrackingJob { TrackingJob::New(command) => write!( f, "TrackingJob::New({:?})", - command.context.table_to_create() + command.info.table_fragments.table_id() ), TrackingJob::RecoveredV1(recovered) => { write!( @@ -271,8 +262,8 @@ pub struct RecoveredTrackingJobV2 { /// The command tracking by the [`CreateMviewProgressTracker`]. pub(super) struct TrackingCommand { - /// The context of the command. - pub context: Arc, + pub info: CreateStreamingJobCommandInfo, + pub replace_table_info: Option, } /// Tracking is done as follows: @@ -280,6 +271,7 @@ pub(super) struct TrackingCommand { /// 2. For each stream job, there are several actors which run its tasks. /// 3. With `progress_map` we can use the ID of the `StreamJob` to view its progress. /// 4. With `actor_map` we can use an actor's `ActorId` to find the ID of the `StreamJob`. +#[derive(Default)] pub(super) struct CreateMviewProgressTracker { /// Progress of the create-mview DDL indicated by the `TableId`. progress_map: HashMap, @@ -403,14 +395,6 @@ impl CreateMviewProgressTracker { } } - pub fn new() -> Self { - Self { - progress_map: Default::default(), - actor_map: Default::default(), - finished_jobs: Vec::new(), - } - } - pub fn gen_ddl_progress(&self) -> HashMap { self.progress_map .iter() @@ -436,38 +420,26 @@ impl CreateMviewProgressTracker { /// If not checkpoint, jobs which do not require checkpoint can be finished. /// /// Returns whether there are still remaining stashed jobs to finish. - pub(super) async fn finish_jobs(&mut self, checkpoint: bool) -> MetaResult { + pub(super) fn take_finished_jobs(&mut self) -> Vec { tracing::trace!(finished_jobs=?self.finished_jobs, progress_map=?self.progress_map, "finishing jobs"); - for job in self - .finished_jobs - .extract_if(|job| checkpoint || !job.is_checkpoint_required()) - { - // The command is ready to finish. We can now call `pre_finish`. - job.pre_finish().await?; - } - Ok(!self.finished_jobs.is_empty()) + take(&mut self.finished_jobs) + } + + pub(super) fn has_pending_finished_jobs(&self) -> bool { + !self.finished_jobs.is_empty() } pub(super) fn cancel_command(&mut self, id: TableId) { let _ = self.progress_map.remove(&id); - self.finished_jobs - .retain(|x| x.table_to_create() != Some(id)); + self.finished_jobs.retain(|x| x.table_to_create() != id); self.actor_map.retain(|_, table_id| *table_id != id); } /// Notify all tracked commands that error encountered and clear them. - pub async fn abort_all(&mut self, err: &MetaError, context: &GlobalBarrierManagerContext) { + pub fn abort_all(&mut self) { self.actor_map.clear(); self.finished_jobs.clear(); self.progress_map.clear(); - match &context.metadata_manager { - MetadataManager::V1(mgr) => { - mgr.notify_finish_failed(err).await; - } - MetadataManager::V2(mgr) => { - mgr.notify_finish_failed(err).await; - } - } } /// Add a new create-mview DDL command to track. @@ -475,32 +447,43 @@ impl CreateMviewProgressTracker { /// If the actors to track is empty, return the given command as it can be finished immediately. pub fn add( &mut self, - command: TrackingCommand, + command_ctx: &Arc, version_stats: &HummockVersionStats, ) -> Option { - let actors = command.context.actors_to_track(); - if actors.is_empty() { - // The command can be finished immediately. - return Some(TrackingJob::New(command)); - } + let (info, actors, replace_table_info) = if let Command::CreateStreamingJob { + info, + replace_table, + } = &command_ctx.command + { + let CreateStreamingJobCommandInfo { + table_fragments, .. + } = info; + let actors = table_fragments.tracking_progress_actor_ids(); + if actors.is_empty() { + // The command can be finished immediately. + return Some(TrackingJob::New(TrackingCommand { + info: info.clone(), + replace_table_info: replace_table.clone(), + })); + } + (info.clone(), actors, replace_table.clone()) + } else { + return None; + }; - let ( - creating_mv_id, - upstream_mv_count, - upstream_total_key_count, - definition, - ddl_type, - create_type, - ) = if let Command::CreateStreamingJob { + let CreateStreamingJobCommandInfo { table_fragments, - dispatchers, upstream_root_actors, + dispatchers, definition, ddl_type, create_type, .. - } = &command.context.command - { + } = &info; + + let creating_mv_id = table_fragments.table_id(); + + let (upstream_mv_count, upstream_total_key_count, ddl_type, create_type) = { // Keep track of how many times each upstream MV appears. let mut upstream_mv_count = HashMap::new(); for (table_id, actors) in upstream_root_actors { @@ -524,15 +507,11 @@ impl CreateMviewProgressTracker { }) .sum(); ( - table_fragments.table_id(), upstream_mv_count, upstream_total_key_count, - definition.to_string(), ddl_type, create_type, ) - } else { - unreachable!("Must be CreateStreamingJob."); }; for &actor in &actors { @@ -543,7 +522,7 @@ impl CreateMviewProgressTracker { actors, upstream_mv_count, upstream_total_key_count, - definition, + definition.clone(), ); if *ddl_type == DdlType::Sink && *create_type == CreateType::Background { // We return the original tracking job immediately. @@ -551,11 +530,21 @@ impl CreateMviewProgressTracker { // We don't need to wait for sink to finish backfill. // This still contains the notifiers, so we can tell listeners // that the sink job has been created. - Some(TrackingJob::New(command)) + Some(TrackingJob::New(TrackingCommand { + info, + replace_table_info, + })) } else { - let old = self - .progress_map - .insert(creating_mv_id, (progress, TrackingJob::New(command))); + let old = self.progress_map.insert( + creating_mv_id, + ( + progress, + TrackingJob::New(TrackingCommand { + info, + replace_table_info, + }), + ), + ); assert!(old.is_none()); None } diff --git a/src/meta/src/barrier/recovery.rs b/src/meta/src/barrier/recovery.rs index f9bba534e561..92c13e543a73 100644 --- a/src/meta/src/barrier/recovery.rs +++ b/src/meta/src/barrier/recovery.rs @@ -152,8 +152,7 @@ impl GlobalBarrierManagerContext { let version_stats = self.hummock_manager.get_version_stats().await; // If failed, enter recovery mode. { - let mut tracker = self.tracker.lock().await; - *tracker = + *self.tracker.lock() = CreateMviewProgressTracker::recover_v1(version_stats, table_mview_map, mgr.clone()); } Ok(()) @@ -180,8 +179,7 @@ impl GlobalBarrierManagerContext { let version_stats = self.hummock_manager.get_version_stats().await; // If failed, enter recovery mode. { - let mut tracker = self.tracker.lock().await; - *tracker = + *self.tracker.lock() = CreateMviewProgressTracker::recover_v2(mview_map, version_stats, mgr.clone()); } Ok(()) @@ -248,10 +246,8 @@ impl GlobalBarrierManager { let recovery_result: MetaResult<_> = try { if let Some(err) = &err { self.context - .tracker - .lock() - .await - .abort_all(err, &self.context) + .metadata_manager + .notify_finish_failed(err) .await; } diff --git a/src/meta/src/barrier/schedule.rs b/src/meta/src/barrier/schedule.rs index 7af683c18fc7..d2a2febbec2f 100644 --- a/src/meta/src/barrier/schedule.rs +++ b/src/meta/src/barrier/schedule.rs @@ -184,10 +184,8 @@ impl BarrierScheduler { pub fn try_cancel_scheduled_create(&self, table_id: TableId) -> bool { let queue = &mut self.inner.queue.lock(); if let Some(idx) = queue.queue.iter().position(|scheduled| { - if let Command::CreateStreamingJob { - table_fragments, .. - } = &scheduled.command - && table_fragments.table_id() == table_id + if let Command::CreateStreamingJob { info, .. } = &scheduled.command + && info.table_fragments.table_id() == table_id { true } else { diff --git a/src/meta/src/manager/metadata.rs b/src/meta/src/manager/metadata.rs index ecd9d4971d2b..a2aab4371b17 100644 --- a/src/meta/src/manager/metadata.rs +++ b/src/meta/src/manager/metadata.rs @@ -855,6 +855,17 @@ impl MetadataManager { MetadataManager::V2(mgr) => mgr.wait_streaming_job_finished(job.id() as _).await, } } + + pub(crate) async fn notify_finish_failed(&self, err: &MetaError) { + match self { + MetadataManager::V1(mgr) => { + mgr.notify_finish_failed(err).await; + } + MetadataManager::V2(mgr) => { + mgr.notify_finish_failed(err).await; + } + } + } } impl MetadataManagerV2 { diff --git a/src/meta/src/stream/stream_manager.rs b/src/meta/src/stream/stream_manager.rs index 2756d71a8a6c..260d1ed537ac 100644 --- a/src/meta/src/stream/stream_manager.rs +++ b/src/meta/src/stream/stream_manager.rs @@ -29,7 +29,9 @@ use tokio::sync::{oneshot, Mutex}; use tracing::Instrument; use super::{Locations, RescheduleOptions, ScaleControllerRef, TableResizePolicy}; -use crate::barrier::{BarrierScheduler, Command, ReplaceTablePlan, StreamRpcManager}; +use crate::barrier::{ + BarrierScheduler, Command, CreateStreamingJobCommandInfo, ReplaceTablePlan, StreamRpcManager, +}; use crate::manager::{DdlType, MetaSrvEnv, MetadataManager, NotificationVersion, StreamingJob}; use crate::model::{ActorId, FragmentId, MetadataModel, TableFragments, TableParallelism}; use crate::stream::{to_build_actor_info, SourceManagerRef}; @@ -474,7 +476,7 @@ impl GlobalStreamManager { .await?, ); - let command = Command::CreateStreamingJob { + let info = CreateStreamingJobCommandInfo { table_fragments, upstream_root_actors, dispatchers, @@ -483,9 +485,13 @@ impl GlobalStreamManager { streaming_job: streaming_job.clone(), internal_tables: internal_tables.into_values().collect_vec(), ddl_type, - replace_table: replace_table_command, create_type, }; + + let command = Command::CreateStreamingJob { + info, + replace_table: replace_table_command, + }; tracing::debug!("sending Command::CreateStreamingJob"); let result: MetaResult = try { self.barrier_scheduler.run_command(command).await?; From d0ba17b4d27d69757954182cf2b6f0e977264561 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Fri, 19 Jul 2024 17:53:52 +0800 Subject: [PATCH 46/70] feat(meta): pass streaming error score to meta to locate cluster-level root error (#17685) Signed-off-by: Bugen Zhao --- e2e_test/error_ui/simple/recovery.slt | 2 +- src/error/src/tonic.rs | 16 +++ src/error/src/tonic/extra.rs | 50 +++++++++ src/meta/src/barrier/rpc.rs | 45 +++++++- src/rpc_client/src/error.rs | 11 +- src/stream/src/task/barrier_manager.rs | 146 +++++++++++++++---------- 6 files changed, 206 insertions(+), 64 deletions(-) create mode 100644 src/error/src/tonic/extra.rs diff --git a/e2e_test/error_ui/simple/recovery.slt b/e2e_test/error_ui/simple/recovery.slt index e3830be6d25c..526e9d2f1022 100644 --- a/e2e_test/error_ui/simple/recovery.slt +++ b/e2e_test/error_ui/simple/recovery.slt @@ -25,7 +25,7 @@ with error as ( limit 1 ) select -case when error like '%Actor % exited unexpectedly: Executor error: %Numeric out of range%' then 'ok' +case when error like '%get error from control stream, in worker node %: %Actor % exited unexpectedly: Executor error: %Numeric out of range%' then 'ok' else error end as result from error; diff --git a/src/error/src/tonic.rs b/src/error/src/tonic.rs index 11cbb7106365..4e3476c460fd 100644 --- a/src/error/src/tonic.rs +++ b/src/error/src/tonic.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod extra; + use std::borrow::Cow; use std::error::Error; use std::sync::Arc; @@ -24,6 +26,7 @@ use tonic::metadata::{MetadataMap, MetadataValue}; const ERROR_KEY: &str = "risingwave-error-bin"; /// The service name that the error is from. Used to provide better error message. +// TODO: also make it a field of `Extra`? type ServiceName = Cow<'static, str>; /// The error produced by the gRPC server and sent to the client on the wire. @@ -31,6 +34,7 @@ type ServiceName = Cow<'static, str>; struct ServerError { error: serde_error::Error, service_name: Option, + extra: extra::Extra, } impl std::fmt::Display for ServerError { @@ -43,6 +47,10 @@ impl std::error::Error for ServerError { fn source(&self) -> Option<&(dyn Error + 'static)> { self.error.source() } + + fn provide<'a>(&'a self, request: &mut std::error::Request<'a>) { + self.extra.provide(request); + } } fn to_status(error: &T, code: tonic::Code, service_name: Option) -> tonic::Status @@ -55,6 +63,7 @@ where let source = ServerError { error: serde_error::Error::new(error), service_name, + extra: extra::Extra::new(error), }; let serialized = bincode::serialize(&source).unwrap(); @@ -204,6 +213,13 @@ impl std::error::Error for TonicStatusWrapper { // Delegate to `self.inner` as if we're transparent. self.inner.source() } + + fn provide<'a>(&'a self, request: &mut std::error::Request<'a>) { + // The source error, typically a `ServerError`, may provide additional information through `extra`. + if let Some(source) = self.source() { + source.provide(request); + } + } } #[cfg(test)] diff --git a/src/error/src/tonic/extra.rs b/src/error/src/tonic/extra.rs new file mode 100644 index 000000000000..dbf6b80e2912 --- /dev/null +++ b/src/error/src/tonic/extra.rs @@ -0,0 +1,50 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use serde::{Deserialize, Serialize}; + +/// The score of the error. +/// +/// Currently, it's used to identify the root cause of streaming pipeline failures, i.e., which actor +/// led to the failure. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Score(pub i32); + +/// Extra fields in errors that can be passed through the gRPC boundary. +/// +/// - Field being set to `None` means it is not available. +/// - To add a new field, also update the `provide` method. +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub(super) struct Extra { + pub score: Option, +} + +impl Extra { + /// Create a new [`Extra`] by [requesting](std::error::request_ref) each field from the given error. + pub fn new(error: &T) -> Self + where + T: ?Sized + std::error::Error, + { + Self { + score: std::error::request_value(error), + } + } + + /// Provide each field to the given [request](std::error::Request). + pub fn provide<'a>(&'a self, request: &mut std::error::Request<'a>) { + if let Some(score) = self.score { + request.provide_value(score); + } + } +} diff --git a/src/meta/src/barrier/rpc.rs b/src/meta/src/barrier/rpc.rs index c1a337bde046..dd30a3fe9e00 100644 --- a/src/meta/src/barrier/rpc.rs +++ b/src/meta/src/barrier/rpc.rs @@ -536,17 +536,54 @@ where Err(results_err) } -fn merge_node_rpc_errors( +fn merge_node_rpc_errors( message: &str, errors: impl IntoIterator, ) -> MetaError { + use std::error::request_value; use std::fmt::Write; + use risingwave_common::error::tonic::extra::Score; + + let errors = errors.into_iter().collect_vec(); + + if errors.is_empty() { + return anyhow!(message.to_owned()).into(); + } + + // Create the error from the single error. + let single_error = |(worker_id, e)| { + anyhow::Error::from(e) + .context(format!("{message}, in worker node {worker_id}")) + .into() + }; + + if errors.len() == 1 { + return single_error(errors.into_iter().next().unwrap()); + } + + // Find the error with the highest score. + let max_score = errors + .iter() + .filter_map(|(_, e)| request_value::(e)) + .max(); + + if let Some(max_score) = max_score { + let mut errors = errors; + let max_scored = errors + .extract_if(|(_, e)| request_value::(e) == Some(max_score)) + .next() + .unwrap(); + + return single_error(max_scored); + } + + // The errors do not have scores, so simply concatenate them. let concat: String = errors .into_iter() - .fold(format!("{message}:"), |mut s, (w, e)| { - write!(&mut s, " worker node {}, {};", w, e.as_report()).unwrap(); + .fold(format!("{message}: "), |mut s, (w, e)| { + write!(&mut s, " in worker node {}, {};", w, e.as_report()).unwrap(); s }); - anyhow::anyhow!(concat).into() + anyhow!(concat).into() } diff --git a/src/rpc_client/src/error.rs b/src/rpc_client/src/error.rs index 5626912c2f88..c5c5613a32a4 100644 --- a/src/rpc_client/src/error.rs +++ b/src/rpc_client/src/error.rs @@ -28,7 +28,14 @@ pub enum RpcError { TransportError(Box), #[error(transparent)] - GrpcStatus(Box), + GrpcStatus( + #[from] + // Typically it does not have a backtrace, + // but this is to let `thiserror` generate `provide` implementation to make `Extra` work. + // See `risingwave_error::tonic::extra`. + #[backtrace] + Box, + ), #[error(transparent)] MetaAddressParse(#[from] MetaAddressStrategyParseError), @@ -61,7 +68,7 @@ macro_rules! impl_from_status { $( #[doc = "Convert a gRPC status from " $service " service into an [`RpcError`]."] pub fn [](s: tonic::Status) -> Self { - Self::grpc_status(s.with_client_side_service_name(stringify!($service))) + Box::new(s.with_client_side_service_name(stringify!($service))).into() } )* } diff --git a/src/stream/src/task/barrier_manager.rs b/src/stream/src/task/barrier_manager.rs index 1fd932d96675..c9c998e81a72 100644 --- a/src/stream/src/task/barrier_manager.rs +++ b/src/stream/src/task/barrier_manager.rs @@ -23,6 +23,7 @@ use futures::stream::{BoxStream, FuturesUnordered}; use futures::StreamExt; use itertools::Itertools; use parking_lot::Mutex; +use risingwave_common::error::tonic::extra::Score; use risingwave_pb::stream_service::barrier_complete_response::{ GroupedSstableInfo, PbCreateMviewProgress, }; @@ -372,7 +373,7 @@ pub(super) struct LocalBarrierWorker { actor_failure_rx: UnboundedReceiver<(ActorId, StreamError)>, /// Cached result of [`Self::try_find_root_failure`]. - cached_root_failure: Option, + cached_root_failure: Option, } impl LocalBarrierWorker { @@ -796,7 +797,10 @@ impl LocalBarrierWorker { /// This is similar to [`Self::notify_actor_failure`], but since there's not always an actor failure, /// the given `err` will be used if there's no root failure found. async fn notify_other_failure(&mut self, err: StreamError, message: impl Into) { - let root_err = self.try_find_root_failure().await.unwrap_or(err); + let root_err = self + .try_find_root_failure() + .await + .unwrap_or_else(|| ScoredStreamError::new(err)); self.control_stream_handle.reset_stream_with_err( anyhow!(root_err) @@ -818,10 +822,11 @@ impl LocalBarrierWorker { /// Collect actor errors for a while and find the one that might be the root cause. /// /// Returns `None` if there's no actor error received. - async fn try_find_root_failure(&mut self) -> Option { + async fn try_find_root_failure(&mut self) -> Option { if self.cached_root_failure.is_some() { return self.cached_root_failure.clone(); } + // fetch more actor errors within a timeout let _ = tokio::time::timeout(Duration::from_secs(3), async { while let Some((actor_id, error)) = self.actor_failure_rx.recv().await { @@ -829,7 +834,13 @@ impl LocalBarrierWorker { } }) .await; - self.cached_root_failure = try_find_root_actor_failure(self.failure_actors.values()); + + // Find the error with highest score. + self.cached_root_failure = self + .failure_actors + .values() + .map(|e| ScoredStreamError::new(e.clone())) + .max_by_key(|e| e.score); self.cached_root_failure.clone() } @@ -938,63 +949,84 @@ impl LocalBarrierManager { } } -/// Tries to find the root cause of actor failures, based on hard-coded rules. -/// -/// Returns `None` if the input is empty. -pub fn try_find_root_actor_failure<'a>( - actor_errors: impl IntoIterator, -) -> Option { - // Explicitly list all error kinds here to notice developers to update this function when - // there are changes in error kinds. - - fn stream_executor_error_score(e: &StreamExecutorError) -> i32 { - use crate::executor::error::ErrorKind; - match e.inner() { - // `ChannelClosed` or `ExchangeChannelClosed` is likely to be caused by actor exit - // and not the root cause. - ErrorKind::ChannelClosed(_) | ErrorKind::ExchangeChannelClosed(_) => 1, - - // Normal errors. - ErrorKind::Uncategorized(_) - | ErrorKind::Storage(_) - | ErrorKind::ArrayError(_) - | ErrorKind::ExprError(_) - | ErrorKind::SerdeError(_) - | ErrorKind::SinkError(_) - | ErrorKind::RpcError(_) - | ErrorKind::AlignBarrier(_, _) - | ErrorKind::ConnectorError(_) - | ErrorKind::DmlError(_) - | ErrorKind::NotImplemented(_) => 999, - } +/// A [`StreamError`] with a score, used to find the root cause of actor failures. +#[derive(Debug, Clone)] +struct ScoredStreamError { + error: StreamError, + score: Score, +} + +impl std::fmt::Display for ScoredStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.error.fmt(f) } +} - fn stream_error_score(e: &StreamError) -> i32 { - use crate::error::ErrorKind; - match e.inner() { - // `UnexpectedExit` wraps the original error. Score on the inner error. - ErrorKind::UnexpectedExit { source, .. } => stream_error_score(source), - - // `BarrierSend` is likely to be caused by actor exit and not the root cause. - ErrorKind::BarrierSend { .. } => 1, - - // Executor errors first. - ErrorKind::Executor(ee) => 2000 + stream_executor_error_score(ee), - - // Then other errors. - ErrorKind::Uncategorized(_) - | ErrorKind::Storage(_) - | ErrorKind::Expression(_) - | ErrorKind::Array(_) - | ErrorKind::Sink(_) - | ErrorKind::Secret(_) => 1000, - } +impl std::error::Error for ScoredStreamError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.error.source() + } + + fn provide<'a>(&'a self, request: &mut std::error::Request<'a>) { + self.error.provide(request); + // HIGHLIGHT: Provide the score to make it retrievable from meta service. + request.provide_value(self.score); } +} + +impl ScoredStreamError { + /// Score the given error based on hard-coded rules. + fn new(error: StreamError) -> Self { + // Explicitly list all error kinds here to notice developers to update this function when + // there are changes in error kinds. + + fn stream_executor_error_score(e: &StreamExecutorError) -> i32 { + use crate::executor::error::ErrorKind; + match e.inner() { + // `ChannelClosed` or `ExchangeChannelClosed` is likely to be caused by actor exit + // and not the root cause. + ErrorKind::ChannelClosed(_) | ErrorKind::ExchangeChannelClosed(_) => 1, + + // Normal errors. + ErrorKind::Uncategorized(_) + | ErrorKind::Storage(_) + | ErrorKind::ArrayError(_) + | ErrorKind::ExprError(_) + | ErrorKind::SerdeError(_) + | ErrorKind::SinkError(_) + | ErrorKind::RpcError(_) + | ErrorKind::AlignBarrier(_, _) + | ErrorKind::ConnectorError(_) + | ErrorKind::DmlError(_) + | ErrorKind::NotImplemented(_) => 999, + } + } - actor_errors - .into_iter() - .max_by_key(|&e| stream_error_score(e)) - .cloned() + fn stream_error_score(e: &StreamError) -> i32 { + use crate::error::ErrorKind; + match e.inner() { + // `UnexpectedExit` wraps the original error. Score on the inner error. + ErrorKind::UnexpectedExit { source, .. } => stream_error_score(source), + + // `BarrierSend` is likely to be caused by actor exit and not the root cause. + ErrorKind::BarrierSend { .. } => 1, + + // Executor errors first. + ErrorKind::Executor(ee) => 2000 + stream_executor_error_score(ee), + + // Then other errors. + ErrorKind::Uncategorized(_) + | ErrorKind::Storage(_) + | ErrorKind::Expression(_) + | ErrorKind::Array(_) + | ErrorKind::Sink(_) + | ErrorKind::Secret(_) => 1000, + } + } + + let score = Score(stream_error_score(&error)); + Self { error, score } + } } #[cfg(test)] From 625c5aaec41f954abb14f13df20a60464678e14b Mon Sep 17 00:00:00 2001 From: stonepage <40830455+st1page@users.noreply.github.com> Date: Fri, 19 Jul 2024 19:44:45 +0800 Subject: [PATCH 47/70] fix: watermark filter use commited epoch to read global watermark (#17724) --- e2e_test/streaming/watermark.slt | 2 +- src/stream/src/executor/watermark_filter.rs | 100 ++++++++++++-------- 2 files changed, 63 insertions(+), 39 deletions(-) diff --git a/e2e_test/streaming/watermark.slt b/e2e_test/streaming/watermark.slt index 5d8f189dfd96..a7bc67066fd1 100644 --- a/e2e_test/streaming/watermark.slt +++ b/e2e_test/streaming/watermark.slt @@ -21,7 +21,7 @@ statement ok insert into t values ('2023-05-06 16:56:01', 1); skipif in-memory -sleep 10s +sleep 20s skipif in-memory query TI diff --git a/src/stream/src/executor/watermark_filter.rs b/src/stream/src/executor/watermark_filter.rs index 2aca1251dd05..cf0546491767 100644 --- a/src/stream/src/executor/watermark_filter.rs +++ b/src/stream/src/executor/watermark_filter.rs @@ -77,6 +77,7 @@ impl Execute for WatermarkFilterExecutor { self.execute_inner().boxed() } } +const UPDATE_GLOBAL_WATERMARK_FREQUENCY_WHEN_IDLE: usize = 5; impl WatermarkFilterExecutor { #[try_stream(ok = Message, error = StreamExecutorError)] @@ -99,13 +100,18 @@ impl WatermarkFilterExecutor { let mut input = input.execute(); let first_barrier = expect_first_barrier(&mut input).await?; + let prev_epoch = first_barrier.epoch.prev; table.init_epoch(first_barrier.epoch); // The first barrier message should be propagated. yield Message::Barrier(first_barrier); // Initiate and yield the first watermark. - let mut current_watermark = - Self::get_global_max_watermark(&table, &global_watermark_table).await?; + let mut current_watermark = Self::get_global_max_watermark( + &table, + &global_watermark_table, + HummockReadEpoch::Committed(prev_epoch), + ) + .await?; let mut last_checkpoint_watermark = None; @@ -119,7 +125,7 @@ impl WatermarkFilterExecutor { // If the input is idle let mut idle_input = true; - + let mut barrier_num_during_idle = 0; #[for_await] for msg in input { let msg = msg?; @@ -208,6 +214,9 @@ impl WatermarkFilterExecutor { } } Message::Barrier(barrier) => { + let prev_epoch = barrier.epoch.prev; + let is_checkpoint = barrier.kind.is_checkpoint(); + let mut need_update_global_max_watermark = false; // Update the vnode bitmap for state tables of all agg calls if asked. if let Some(vnode_bitmap) = barrier.as_update_vnode_bitmap(ctx.id) { let other_vnodes_bitmap = Arc::new( @@ -220,15 +229,11 @@ impl WatermarkFilterExecutor { // Take the global max watermark when scaling happens. if previous_vnode_bitmap != vnode_bitmap { - current_watermark = - Self::get_global_max_watermark(&table, &global_watermark_table) - .await?; + need_update_global_max_watermark = true; } } - if barrier.kind.is_checkpoint() - && last_checkpoint_watermark != current_watermark - { + if is_checkpoint && last_checkpoint_watermark != current_watermark { last_checkpoint_watermark.clone_from(¤t_watermark); // Persist the watermark when checkpoint arrives. if let Some(watermark) = current_watermark.clone() { @@ -242,39 +247,59 @@ impl WatermarkFilterExecutor { } table.commit(barrier.epoch).await?; + yield Message::Barrier(barrier); + + if need_update_global_max_watermark { + current_watermark = Self::get_global_max_watermark( + &table, + &global_watermark_table, + HummockReadEpoch::Committed(prev_epoch), + ) + .await?; + } - if barrier.kind.is_checkpoint() { + if is_checkpoint { if idle_input { - // Align watermark - let global_max_watermark = - Self::get_global_max_watermark(&table, &global_watermark_table) - .await?; - - current_watermark = if let Some(global_max_watermark) = - global_max_watermark.clone() - && let Some(watermark) = current_watermark.clone() + barrier_num_during_idle += 1; + + if barrier_num_during_idle + == UPDATE_GLOBAL_WATERMARK_FREQUENCY_WHEN_IDLE { - Some(cmp::max_by( - watermark, - global_max_watermark, - DefaultOrd::default_cmp, - )) - } else { - current_watermark.or(global_max_watermark) - }; - if let Some(watermark) = current_watermark.clone() { - yield Message::Watermark(Watermark::new( - event_time_col_idx, - watermark_type.clone(), - watermark, - )); + barrier_num_during_idle = 0; + // Align watermark + // NOTE(st1page): Should be `NoWait` because it could lead to a degradation of concurrent checkpoint situations, as it would require waiting for the previous epoch + let global_max_watermark = Self::get_global_max_watermark( + &table, + &global_watermark_table, + HummockReadEpoch::NoWait(prev_epoch), + ) + .await?; + + current_watermark = if let Some(global_max_watermark) = + global_max_watermark.clone() + && let Some(watermark) = current_watermark.clone() + { + Some(cmp::max_by( + watermark, + global_max_watermark, + DefaultOrd::default_cmp, + )) + } else { + current_watermark.or(global_max_watermark) + }; + if let Some(watermark) = current_watermark.clone() { + yield Message::Watermark(Watermark::new( + event_time_col_idx, + watermark_type.clone(), + watermark, + )); + } } } else { idle_input = true; + barrier_num_during_idle = 0; } } - - yield Message::Barrier(barrier); } } } @@ -301,8 +326,8 @@ impl WatermarkFilterExecutor { async fn get_global_max_watermark( table: &StateTable, global_watermark_table: &StorageTable, + wait_epoch: HummockReadEpoch, ) -> StreamExecutorResult> { - let epoch = table.epoch(); let handle_watermark_row = |watermark_row: Option| match watermark_row { Some(row) => { if row.len() == 1 { @@ -319,9 +344,8 @@ impl WatermarkFilterExecutor { .iter_vnodes() .map(|vnode| async move { let pk = row::once(vnode.to_datum()); - let watermark_row: Option = global_watermark_table - .get_row(pk, HummockReadEpoch::NoWait(epoch)) - .await?; + let watermark_row: Option = + global_watermark_table.get_row(pk, wait_epoch).await?; handle_watermark_row(watermark_row) }); let local_watermark_iter_futures = table.vnodes().iter_vnodes().map(|vnode| async move { From ecad0cf2a11ffc7058d1a272e78629f2a2569f32 Mon Sep 17 00:00:00 2001 From: Noel Kwan <47273164+kwannoel@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:35:13 +0800 Subject: [PATCH 48/70] fix(ci): fix trigger condition for `s3-v2-source-check` step (#17738) --- ci/workflows/main-cron.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/workflows/main-cron.yml b/ci/workflows/main-cron.yml index 3c71be0f0984..1f264957f395 100644 --- a/ci/workflows/main-cron.yml +++ b/ci/workflows/main-cron.yml @@ -512,7 +512,7 @@ steps: key: "s3-v2-source-check-parquet-file" command: "ci/scripts/s3-source-test.sh -p ci-release -s fs_parquet_source.py" if: | - !(build.pull_request.labels includes "ci/main-cron/skip-ci") && build.env("CI_STEPS") == null + !(build.pull_request.labels includes "ci/main-cron/run-selected") && build.env("CI_STEPS") == null || build.pull_request.labels includes "ci/run-s3-source-tests" || build.env("CI_STEPS") =~ /(^|,)s3-source-tests?(,|$$)/ depends_on: build From 4f3f7f7240d12569fdd49a24c74c332b2caadd8a Mon Sep 17 00:00:00 2001 From: xxchan Date: Mon, 22 Jul 2024 14:39:50 +0800 Subject: [PATCH 49/70] refactor(types): simplify ArrayImpl::from_protobuf (#17763) --- src/common/src/array/mod.rs | 1 - src/common/src/array/primitive_array.rs | 16 ++- src/common/src/array/proto_reader.rs | 154 ++++++++---------------- src/common/src/array/value_reader.rs | 88 -------------- src/common/src/types/datetime.rs | 28 ++++- src/common/src/types/interval.rs | 15 ++- src/common/src/types/mod.rs | 12 +- src/common/src/types/scalar_impl.rs | 2 +- src/common/src/types/timestamptz.rs | 11 +- 9 files changed, 118 insertions(+), 209 deletions(-) delete mode 100644 src/common/src/array/value_reader.rs diff --git a/src/common/src/array/mod.rs b/src/common/src/array/mod.rs index f2ae95aa71ef..89b3b0626678 100644 --- a/src/common/src/array/mod.rs +++ b/src/common/src/array/mod.rs @@ -35,7 +35,6 @@ mod stream_chunk_iter; pub mod stream_record; pub mod struct_array; mod utf8_array; -mod value_reader; use std::convert::From; use std::hash::{Hash, Hasher}; diff --git a/src/common/src/array/primitive_array.rs b/src/common/src/array/primitive_array.rs index 9ae7912e9887..45e1bc0ddd84 100644 --- a/src/common/src/array/primitive_array.rs +++ b/src/common/src/array/primitive_array.rs @@ -13,9 +13,11 @@ // limitations under the License. use std::fmt::Debug; -use std::io::Write; +use std::io::{Cursor, Write}; use std::mem::size_of; +use anyhow::Context; +use byteorder::{BigEndian, ReadBytesExt}; use risingwave_common_estimate_size::{EstimateSize, ZeroHeapSize}; use risingwave_pb::common::buffer::CompressionType; use risingwave_pb::common::Buffer; @@ -50,6 +52,7 @@ where // item methods fn to_protobuf(self, output: &mut T) -> ArrayResult; + fn from_protobuf(cur: &mut Cursor<&[u8]>) -> ArrayResult; } macro_rules! impl_array_methods { @@ -81,7 +84,7 @@ macro_rules! impl_array_methods { } macro_rules! impl_primitive_for_native_types { - ($({ $naive_type:ty, $scalar_type:ident } ),*) => { + ($({ $naive_type:ty, $scalar_type:ident, $read_fn:ident } ),*) => { $( impl PrimitiveArrayItemType for $naive_type { impl_array_methods!($naive_type, $scalar_type, $scalar_type); @@ -89,6 +92,12 @@ macro_rules! impl_primitive_for_native_types { fn to_protobuf(self, output: &mut T) -> ArrayResult { NativeType::to_protobuf(self, output) } + fn from_protobuf(cur: &mut Cursor<&[u8]>) -> ArrayResult { + let v = cur + .$read_fn::() + .context("failed to read value from buffer")?; + Ok(v.into()) + } } )* } @@ -106,6 +115,9 @@ macro_rules! impl_primitive_for_others { fn to_protobuf(self, output: &mut T) -> ArrayResult { <$scalar_type>::to_protobuf(self, output) } + fn from_protobuf(cur: &mut Cursor<&[u8]>) -> ArrayResult { + <$scalar_type>::from_protobuf(cur) + } } )* } diff --git a/src/common/src/array/proto_reader.rs b/src/common/src/array/proto_reader.rs index 7c3b05437770..aa296900190d 100644 --- a/src/common/src/array/proto_reader.rs +++ b/src/common/src/array/proto_reader.rs @@ -16,43 +16,32 @@ use std::io::{Cursor, Read}; use anyhow::Context; use byteorder::{BigEndian, ReadBytesExt}; -use paste::paste; use risingwave_pb::data::PbArrayType; use super::*; -use crate::array::value_reader::{PrimitiveValueReader, VarSizedValueReader}; impl ArrayImpl { pub fn from_protobuf(array: &PbArray, cardinality: usize) -> ArrayResult { - use crate::array::value_reader::*; let array = match array.array_type() { PbArrayType::Unspecified => unreachable!(), - PbArrayType::Int16 => read_numeric_array::(array, cardinality)?, - PbArrayType::Int32 => read_numeric_array::(array, cardinality)?, - PbArrayType::Int64 => read_numeric_array::(array, cardinality)?, - PbArrayType::Serial => { - read_numeric_array::(array, cardinality)? - } - PbArrayType::Float32 => read_numeric_array::(array, cardinality)?, - PbArrayType::Float64 => read_numeric_array::(array, cardinality)?, + PbArrayType::Int16 => read_primitive_array::(array, cardinality)?, + PbArrayType::Int32 => read_primitive_array::(array, cardinality)?, + PbArrayType::Int64 => read_primitive_array::(array, cardinality)?, + PbArrayType::Serial => read_primitive_array::(array, cardinality)?, + PbArrayType::Float32 => read_primitive_array::(array, cardinality)?, + PbArrayType::Float64 => read_primitive_array::(array, cardinality)?, PbArrayType::Bool => read_bool_array(array, cardinality)?, - PbArrayType::Utf8 => { - read_string_array::(array, cardinality)? - } - PbArrayType::Decimal => { - read_numeric_array::(array, cardinality)? - } - PbArrayType::Date => read_date_array(array, cardinality)?, - PbArrayType::Time => read_time_array(array, cardinality)?, - PbArrayType::Timestamp => read_timestamp_array(array, cardinality)?, - PbArrayType::Timestamptz => read_timestamptz_array(array, cardinality)?, - PbArrayType::Interval => read_interval_array(array, cardinality)?, + PbArrayType::Utf8 => read_string_array::(array, cardinality)?, + PbArrayType::Decimal => read_primitive_array::(array, cardinality)?, + PbArrayType::Date => read_primitive_array::(array, cardinality)?, + PbArrayType::Time => read_primitive_array::