From 75735e2b15c82429d69529683d56136fc80de1bd Mon Sep 17 00:00:00 2001 From: Zihao Xu Date: Fri, 5 Jan 2024 03:08:25 -0500 Subject: [PATCH 01/20] feat(sql-udf): support basic anonymous sql udf (#14139) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- e2e_test/udf/sql_udf.slt | 192 ++++++++++++++++++ proto/catalog.proto | 1 + src/frontend/src/binder/expr/column.rs | 12 ++ src/frontend/src/binder/expr/function.rs | 168 ++++++++++++++- src/frontend/src/binder/expr/mod.rs | 8 + src/frontend/src/binder/mod.rs | 7 +- src/frontend/src/catalog/function_catalog.rs | 2 + src/frontend/src/catalog/root_catalog.rs | 2 +- .../src/expr/user_defined_function.rs | 2 + src/frontend/src/handler/create_function.rs | 2 + .../src/handler/create_sql_function.rs | 181 +++++++++++++++++ src/frontend/src/handler/mod.rs | 44 +++- .../migration/src/m20230908_072257_init.rs | 2 + src/meta/model_v2/src/function.rs | 2 + src/meta/src/controller/mod.rs | 1 + src/sqlparser/README.md | 4 +- src/sqlparser/tests/sqlparser_postgres.rs | 47 +++++ 17 files changed, 653 insertions(+), 24 deletions(-) create mode 100644 e2e_test/udf/sql_udf.slt create mode 100644 src/frontend/src/handler/create_sql_function.rs diff --git a/e2e_test/udf/sql_udf.slt b/e2e_test/udf/sql_udf.slt new file mode 100644 index 0000000000000..8b89010f70a93 --- /dev/null +++ b/e2e_test/udf/sql_udf.slt @@ -0,0 +1,192 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +# Create an anonymous function with double dollar as clause +statement ok +create function add(INT, INT) returns int language sql as $$select $1 + $2$$; + +# Create an anonymous function with single quote as clause +statement ok +create function sub(INT, INT) returns int language sql as 'select $1 - $2'; + +# Create an anonymous function that calls other pre-defined sql udfs +statement ok +create function add_sub_binding() returns int language sql as 'select add(1, 1) + sub(2, 2)'; + +# Create an anonymous function that calls built-in functions +# Note that double dollar signs should be used otherwise the parsing will fail, as illutrates below +statement ok +create function call_regexp_replace() returns varchar language sql as $$select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')$$; + +statement error Expected end of statement, found: 💩 +create function call_regexp_replace() returns varchar language sql as 'select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')'; + +# Create an anonymous function with return expression +statement ok +create function add_return(INT, INT) returns int language sql return $1 + $2; + +statement ok +create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1); + +# Recursive definition is forbidden +statement error recursive definition is forbidden, please recheck your function syntax +create function recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + recursive($1, $2)'; + +# Create a wrapper function for `add` & `sub` +statement ok +create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512'; + +# Call the defined sql udf +query I +select add(1, -1); +---- +0 + +query I +select sub(1, 1); +---- +0 + +query I +select add_sub_binding(); +---- +2 + +query III +select add(1, -1), sub(1, 1), add_sub_binding(); +---- +0 0 2 + +query I +select add_return(1, 1); +---- +2 + +query I +select add_return_binding(); +---- +4 + +query T +select call_regexp_replace(); +---- +💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥 + +query I +select add_sub_wrapper(1, 1); +---- +114514 + +# Create a mock table +statement ok +create table t1 (c1 INT, c2 INT); + +# Insert some data into the mock table +statement ok +insert into t1 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5); + +query III +select sub(c1, c2), c1, c2, add(c1, c2) from t1 order by c1 asc; +---- +0 1 1 2 +0 2 2 4 +0 3 3 6 +0 4 4 8 +0 5 5 10 + +query I +select c1, c2, add_return(c1, c2) from t1 order by c1 asc; +---- +1 1 2 +2 2 4 +3 3 6 +4 4 8 +5 5 10 + +# Invalid function body syntax +statement error Expected an expression:, found: EOF at the end +create function add_error(INT, INT) returns int language sql as $$select $1 + $2 +$$; + +# Multiple type interleaving sql udf +statement ok +create function add_sub(INT, FLOAT, INT) returns float language sql as $$select -$1 + $2 - $3$$; + +statement ok +create function add_sub_return(INT, FLOAT, INT) returns float language sql return -$1 + $2 - $3; + +# Note: need EXPLICIT type cast in order to call the multiple types interleaving sql udf +query I +select add_sub(1::INT, 5.1415926::FLOAT, 1::INT); +---- +3.1415926 + +# Without EXPLICIT type cast +statement error unsupported function: "add_sub" +select add_sub(1, 3.14, 2); + +# Same as above, need EXPLICIT type cast to make the binding works +query I +select add_sub_return(1::INT, 5.1415926::FLOAT, 1::INT); +---- +3.1415926 + +query III +select add(1, -1), sub(1, 1), add_sub(1::INT, 5.1415926::FLOAT, 1::INT); +---- +0 0 3.1415926 + +# Create another mock table +statement ok +create table t2 (c1 INT, c2 FLOAT, c3 INT); + +statement ok +insert into t2 values (1, 3.14, 2), (2, 4.44, 5), (20, 10.30, 02); + +query IIIIII +select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub(c1::INT, c2::FLOAT, c3::INT) from t2 order by c1 asc; +---- +1 3.14 2 3 -1 0.14000000000000012 +2 4.44 5 7 -3 -2.5599999999999996 +20 10.3 2 22 18 -11.7 + +query IIIIII +select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub_return(c1::INT, c2::FLOAT, c3::INT) from t2 order by c1 asc; +---- +1 3.14 2 3 -1 0.14000000000000012 +2 4.44 5 7 -3 -2.5599999999999996 +20 10.3 2 22 18 -11.7 + +# Drop the functions +statement ok +drop function add; + +statement ok +drop function sub; + +statement ok +drop function add_sub_binding; + +statement ok +drop function add_sub; + +statement ok +drop function add_sub_return; + +statement ok +drop function add_return; + +statement ok +drop function add_return_binding; + +statement ok +drop function call_regexp_replace; + +statement ok +drop function add_sub_wrapper; + +# Drop the mock table +statement ok +drop table t1; + +statement ok +drop table t2; diff --git a/proto/catalog.proto b/proto/catalog.proto index 01a7893383232..ec7c68a3802ba 100644 --- a/proto/catalog.proto +++ b/proto/catalog.proto @@ -218,6 +218,7 @@ message Function { string language = 7; string link = 8; string identifier = 10; + optional string body = 14; oneof kind { ScalarFunction scalar = 11; diff --git a/src/frontend/src/binder/expr/column.rs b/src/frontend/src/binder/expr/column.rs index 16053208ec8d3..2f2a8d9335256 100644 --- a/src/frontend/src/binder/expr/column.rs +++ b/src/frontend/src/binder/expr/column.rs @@ -37,6 +37,18 @@ impl Binder { } }; + // Special check for sql udf + // Note: The check in `bind_column` is to inline the identifiers, + // which, in the context of sql udf, will NOT be perceived as normal + // columns, but the actual named input parameters. + // Thus, we need to figure out if the current "column name" corresponds + // to the name of the defined sql udf parameters stored in `udf_context`. + // If so, we will treat this bind as an special bind, the actual expression + // stored in `udf_context` will then be bound instead of binding the non-existing column. + if let Some(expr) = self.udf_context.get(&column_name) { + return self.bind_expr(expr.clone()); + } + match self .context .get_column_binding_indices(&table_name, &column_name) diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index c8558b0756b5d..7244a8527f857 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -15,7 +15,7 @@ use std::collections::HashMap; use std::iter::once; use std::str::FromStr; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; use bk_tree::{metrics, BKTree}; use itertools::Itertools; @@ -30,13 +30,15 @@ use risingwave_expr::window_function::{ Frame, FrameBound, FrameBounds, FrameExclusion, WindowFuncKind, }; use risingwave_sqlparser::ast::{ - self, Function, FunctionArg, FunctionArgExpr, Ident, WindowFrameBound, WindowFrameExclusion, - WindowFrameUnits, WindowSpec, + self, Expr as AstExpr, Function, FunctionArg, FunctionArgExpr, Ident, SelectItem, SetExpr, + Statement, WindowFrameBound, WindowFrameExclusion, WindowFrameUnits, WindowSpec, }; +use risingwave_sqlparser::parser::ParserError; use thiserror_ext::AsReport; use crate::binder::bind_context::Clause; use crate::binder::{Binder, BoundQuery, BoundSetExpr}; +use crate::catalog::function_catalog::FunctionCatalog; use crate::expr::{ AggCall, Expr, ExprImpl, ExprType, FunctionCall, FunctionCallWithLambda, Literal, Now, OrderBy, Subquery, SubqueryKind, TableFunction, TableFunctionType, UserDefinedFunction, WindowFunction, @@ -117,6 +119,9 @@ impl Binder { return self.bind_array_transform(f); } + // Used later in sql udf expression evaluation + let args = f.args.clone(); + let inputs = f .args .into_iter() @@ -149,6 +154,73 @@ impl Binder { return Ok(TableFunction::new(function_type, inputs)?.into()); } + /// TODO: add name related logic + /// NOTE: need to think of a way to prevent naming conflict + /// e.g., when existing column names conflict with parameter names in sql udf + fn create_udf_context( + args: &[FunctionArg], + _catalog: &Arc, + ) -> Result> { + let mut ret: HashMap = HashMap::new(); + for (i, current_arg) in args.iter().enumerate() { + if let FunctionArg::Unnamed(arg) = current_arg { + let FunctionArgExpr::Expr(e) = arg else { + return Err( + ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into() + ); + }; + // if catalog.arg_names.is_some() { + // todo!() + // } + ret.insert(format!("${}", i + 1), e.clone()); + continue; + } + return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()); + } + Ok(ret) + } + + fn extract_udf_expression(ast: Vec) -> Result { + if ast.len() != 1 { + return Err(ErrorCode::InvalidInputSyntax( + "the query for sql udf should contain only one statement".to_string(), + ) + .into()); + } + + // Extract the expression out + let Statement::Query(query) = ast[0].clone() else { + return Err(ErrorCode::InvalidInputSyntax( + "invalid function definition, please recheck the syntax".to_string(), + ) + .into()); + }; + + let SetExpr::Select(select) = query.body else { + return Err(ErrorCode::InvalidInputSyntax( + "missing `select` body for sql udf expression, please recheck the syntax" + .to_string(), + ) + .into()); + }; + + if select.projection.len() != 1 { + return Err(ErrorCode::InvalidInputSyntax( + "`projection` should contain only one `SelectItem`".to_string(), + ) + .into()); + } + + let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else { + return Err(ErrorCode::InvalidInputSyntax( + "expect `UnnamedExpr` for `projection`".to_string(), + ) + .into()); + }; + + Ok(expr) + } + // user defined function // TODO: resolve schema name https://github.com/risingwavelabs/risingwave/issues/12422 if let Ok(schema) = self.first_valid_schema() @@ -158,13 +230,89 @@ impl Binder { ) { use crate::catalog::function_catalog::FunctionKind::*; - match &func.kind { - Scalar { .. } => return Ok(UserDefinedFunction::new(func.clone(), inputs).into()), - Table { .. } => { - self.ensure_table_function_allowed()?; - return Ok(TableFunction::new_user_defined(func.clone(), inputs).into()); + if func.language == "sql" { + if func.body.is_none() { + return Err(ErrorCode::InvalidInputSyntax( + "`body` must exist for sql udf".to_string(), + ) + .into()); + } + // This represents the current user defined function is `language sql` + let parse_result = risingwave_sqlparser::parser::Parser::parse_sql( + func.body.as_ref().unwrap().as_str(), + ); + if let Err(ParserError::ParserError(err)) | Err(ParserError::TokenizerError(err)) = + parse_result + { + // Here we just return the original parse error message + return Err(ErrorCode::InvalidInputSyntax(err).into()); + } + debug_assert!(parse_result.is_ok()); + + // We can safely unwrap here + let ast = parse_result.unwrap(); + + let mut clean_flag = true; + + // We need to check if the `udf_context` is empty first, consider the following example: + // - create function add(INT, INT) returns int language sql as 'select $1 + $2'; + // - create function add_wrapper(INT, INT) returns int language sql as 'select add($1, $2)'; + // - select add_wrapper(1, 1); + // When binding `add($1, $2)` in `add_wrapper`, the input args are [$1, $2] instead of + // the original [1, 1], thus we need to check `udf_context` to see if the input + // args already exist in the context. If so, we do NOT need to create the context again. + // Otherwise the current `udf_context` will be corrupted. + if self.udf_context.is_empty() { + // The actual inline logic for sql udf + if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) { + self.udf_context = context; + } else { + return Err(ErrorCode::InvalidInputSyntax( + "failed to create the `udf_context`, please recheck your function definition and syntax".to_string() + ) + .into()); + } + } else { + // If the `udf_context` is not empty, this means the current binding + // function is not the root binding sql udf, thus we should NOT + // clean the context after binding. + clean_flag = false; + } + + if let Ok(expr) = extract_udf_expression(ast) { + let bind_result = self.bind_expr(expr); + // Clean the `udf_context` after inlining, + // which makes sure the subsequent binding will not be affected + if clean_flag { + self.udf_context.clear(); + } + return bind_result; + } else { + return Err(ErrorCode::InvalidInputSyntax( + "failed to parse the input query and extract the udf expression, + please recheck the syntax" + .to_string(), + ) + .into()); + } + } else { + // Note that `language` may be empty for external udf + if !func.language.is_empty() { + debug_assert!( + func.language == "python" || func.language == "java", + "only `python` and `java` are currently supported for general udf" + ); + } + match &func.kind { + Scalar { .. } => { + return Ok(UserDefinedFunction::new(func.clone(), inputs).into()) + } + Table { .. } => { + self.ensure_table_function_allowed()?; + return Ok(TableFunction::new_user_defined(func.clone(), inputs).into()); + } + Aggregate => todo!("support UDAF"), } - Aggregate => todo!("support UDAF"), } } @@ -1213,7 +1361,7 @@ impl Binder { static FUNCTIONS_BKTREE: LazyLock> = LazyLock::new(|| { let mut tree = BKTree::new(metrics::Levenshtein); - // TODO: Also hint other functinos, e,g, Agg or UDF. + // TODO: Also hint other functinos, e.g., Agg or UDF. for k in HANDLES.keys() { tree.add(*k); } diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index 179469e545c2b..cacd2d80dcfe4 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -378,6 +378,14 @@ impl Binder { } fn bind_parameter(&mut self, index: u64) -> Result { + // Special check for sql udf + // Note: This is specific to anonymous sql udf, since the + // parameters will be parsed and treated as `Parameter`. + // For detailed explanation, consider checking `bind_column`. + if let Some(expr) = self.udf_context.get(&format!("${index}")) { + return self.bind_expr(expr.clone()); + } + Ok(Parameter::new(index, self.param_types.clone()).into()) } diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index be0f441c0743a..6ba891aa6b513 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -21,7 +21,7 @@ use risingwave_common::error::Result; use risingwave_common::session_config::{ConfigMap, SearchPath}; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqDebug; -use risingwave_sqlparser::ast::Statement; +use risingwave_sqlparser::ast::{Expr as AstExpr, Statement}; mod bind_context; mod bind_param; @@ -115,6 +115,10 @@ pub struct Binder { included_relations: HashSet, param_types: ParameterTypes, + + /// The mapping from sql udf parameters to ast expressions + /// Note: The expressions are constructed during runtime, correspond to the actual users' input + udf_context: HashMap, } /// `ParameterTypes` is used to record the types of the parameters during binding. It works @@ -216,6 +220,7 @@ impl Binder { shared_views: HashMap::new(), included_relations: HashSet::new(), param_types: ParameterTypes::new(param_types), + udf_context: HashMap::new(), } } diff --git a/src/frontend/src/catalog/function_catalog.rs b/src/frontend/src/catalog/function_catalog.rs index 7197821b33ce6..d0f037bcb47b5 100644 --- a/src/frontend/src/catalog/function_catalog.rs +++ b/src/frontend/src/catalog/function_catalog.rs @@ -30,6 +30,7 @@ pub struct FunctionCatalog { pub return_type: DataType, pub language: String, pub identifier: String, + pub body: Option, pub link: String, } @@ -63,6 +64,7 @@ impl From<&PbFunction> for FunctionCatalog { return_type: prost.return_type.as_ref().expect("no return type").into(), language: prost.language.clone(), identifier: prost.identifier.clone(), + body: prost.body.clone(), link: prost.link.clone(), } } diff --git a/src/frontend/src/catalog/root_catalog.rs b/src/frontend/src/catalog/root_catalog.rs index bd411577476ba..9d2045f5dc61f 100644 --- a/src/frontend/src/catalog/root_catalog.rs +++ b/src/frontend/src/catalog/root_catalog.rs @@ -88,7 +88,7 @@ impl<'a> SchemaPath<'a> { /// - catalog (root catalog) /// - database catalog /// - schema catalog -/// - function catalog +/// - function catalog (i.e., user defined function) /// - table/sink/source/index/view catalog /// - column catalog pub struct Catalog { diff --git a/src/frontend/src/expr/user_defined_function.rs b/src/frontend/src/expr/user_defined_function.rs index abd39fdbbc0c4..165774d1acb4b 100644 --- a/src/frontend/src/expr/user_defined_function.rs +++ b/src/frontend/src/expr/user_defined_function.rs @@ -55,6 +55,8 @@ impl UserDefinedFunction { return_type, language: udf.get_language().clone(), identifier: udf.get_identifier().clone(), + // TODO: Ensure if we need `body` here + body: None, link: udf.get_link().clone(), }; diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index c623b67eecb60..4557b71223b98 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -44,6 +44,7 @@ pub async fn handle_create_function( if temporary { bail_not_implemented!("CREATE TEMPORARY FUNCTION"); } + // e.g., `language [ python / java / ...etc]` let language = match params.language { Some(lang) => { let lang = lang.real_value().to_lowercase(); @@ -161,6 +162,7 @@ pub async fn handle_create_function( return_type: Some(return_type.into()), language, identifier, + body: None, link, owner: session.user_id(), }; diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs new file mode 100644 index 0000000000000..834e0bec3135d --- /dev/null +++ b/src/frontend/src/handler/create_sql_function.rs @@ -0,0 +1,181 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use itertools::Itertools; +use pgwire::pg_response::StatementType; +use risingwave_common::catalog::FunctionId; +use risingwave_common::types::DataType; +use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; +use risingwave_pb::catalog::Function; +use risingwave_sqlparser::ast::{ + CreateFunctionBody, FunctionDefinition, ObjectName, OperateFunctionArg, +}; +use risingwave_sqlparser::parser::{Parser, ParserError}; + +use super::*; +use crate::catalog::CatalogError; +use crate::{bind_data_type, Binder}; + +pub async fn handle_create_sql_function( + handler_args: HandlerArgs, + or_replace: bool, + temporary: bool, + name: ObjectName, + args: Option>, + returns: Option, + params: CreateFunctionBody, +) -> Result { + if or_replace { + bail_not_implemented!("CREATE OR REPLACE FUNCTION"); + } + + if temporary { + bail_not_implemented!("CREATE TEMPORARY FUNCTION"); + } + + let language = "sql".to_string(); + // Just a basic sanity check for language + if !matches!(params.language, Some(lang) if lang.real_value().to_lowercase() == "sql") { + return Err(ErrorCode::InvalidParameterValue( + "`language` for sql udf must be `sql`".to_string(), + ) + .into()); + } + + // SQL udf function supports both single quote (i.e., as 'select $1 + $2') + // and double dollar (i.e., as $$select $1 + $2$$) for as clause + let body = match ¶ms.as_ { + Some(FunctionDefinition::SingleQuotedDef(s)) => s.clone(), + Some(FunctionDefinition::DoubleDollarDef(s)) => s.clone(), + None => { + if params.return_.is_none() { + return Err(ErrorCode::InvalidParameterValue( + "AS or RETURN must be specified".to_string(), + ) + .into()); + } + // Otherwise this is a return expression + // Note: this is a current work around, and we are assuming return sql udf + // will NOT involve complex syntax, so just reuse the logic for select definition + format!("select {}", ¶ms.return_.unwrap().to_string()) + } + }; + + // We do NOT allow recursive calling inside sql udf + // Since there does not exist the base case for this definition + if body.contains(format!("{}(", name.real_value()).as_str()) { + return Err(ErrorCode::InvalidInputSyntax( + "recursive definition is forbidden, please recheck your function syntax".to_string(), + ) + .into()); + } + + // Sanity check for link, this must be none with sql udf function + if let Some(CreateFunctionUsing::Link(_)) = params.using { + return Err(ErrorCode::InvalidParameterValue( + "USING must NOT be specified with sql udf function".to_string(), + ) + .into()); + }; + + // Get return type for the current sql udf function + let return_type; + let kind = match returns { + Some(CreateFunctionReturns::Value(data_type)) => { + return_type = bind_data_type(&data_type)?; + Kind::Scalar(ScalarFunction {}) + } + Some(CreateFunctionReturns::Table(columns)) => { + if columns.len() == 1 { + // return type is the original type for single column + return_type = bind_data_type(&columns[0].data_type)?; + } else { + // return type is a struct for multiple columns + let datatypes = columns + .iter() + .map(|c| bind_data_type(&c.data_type)) + .collect::>>()?; + let names = columns + .iter() + .map(|c| c.name.real_value()) + .collect::>(); + return_type = DataType::new_struct(datatypes, names); + } + Kind::Table(TableFunction {}) + } + None => { + return Err(ErrorCode::InvalidParameterValue( + "return type must be specified".to_string(), + ) + .into()) + } + }; + + let mut arg_types = vec![]; + for arg in args.unwrap_or_default() { + arg_types.push(bind_data_type(&arg.data_type)?); + } + + // resolve database and schema id + let session = &handler_args.session; + let db_name = session.database(); + let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?; + let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?; + + // check if function exists + if (session.env().catalog_reader().read_guard()) + .get_schema_by_id(&database_id, &schema_id)? + .get_function_by_name_args(&function_name, &arg_types) + .is_some() + { + let name = format!( + "{function_name}({})", + arg_types.iter().map(|t| t.to_string()).join(",") + ); + return Err(CatalogError::Duplicated("function", name).into()); + } + + // Parse function body here + // Note that the parsing here is just basic syntax / semantic check, the result will NOT be stored + // e.g., The provided function body contains invalid syntax, return type mismatch, ..., etc. + let parse_result = Parser::parse_sql(body.as_str()); + if let Err(ParserError::ParserError(err)) | Err(ParserError::TokenizerError(err)) = parse_result + { + // Here we just return the original parse error message + return Err(ErrorCode::InvalidInputSyntax(err).into()); + } else { + debug_assert!(parse_result.is_ok()); + } + + // Create the actual function, will be stored in function catalog + let function = Function { + id: FunctionId::placeholder().0, + schema_id, + database_id, + name: function_name, + kind: Some(kind), + arg_types: arg_types.into_iter().map(|t| t.into()).collect(), + return_type: Some(return_type.into()), + language, + identifier: "".to_string(), + body: Some(body), + link: "".to_string(), + owner: session.user_id(), + }; + + let catalog_writer = session.catalog_writer()?; + catalog_writer.create_function(function).await?; + + Ok(PgResponse::empty_result(StatementType::CREATE_FUNCTION)) +} diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index d2ec467df424f..495d9ebd3729a 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -53,6 +53,7 @@ pub mod create_mv; pub mod create_schema; pub mod create_sink; pub mod create_source; +pub mod create_sql_function; pub mod create_table; pub mod create_table_as; pub mod create_user; @@ -205,16 +206,39 @@ pub async fn handle( returns, params, } => { - create_function::handle_create_function( - handler_args, - or_replace, - temporary, - name, - args, - returns, - params, - ) - .await + // For general udf, `language` clause could be ignored + // refer: https://github.com/risingwavelabs/risingwave/pull/10608 + if params.language.is_none() + || !params + .language + .as_ref() + .unwrap() + .real_value() + .eq_ignore_ascii_case("sql") + { + // User defined function with external source (e.g., language [ python / java ]) + create_function::handle_create_function( + handler_args, + or_replace, + temporary, + name, + args, + returns, + params, + ) + .await + } else { + create_sql_function::handle_create_sql_function( + handler_args, + or_replace, + temporary, + name, + args, + returns, + params, + ) + .await + } } Statement::CreateTable { name, diff --git a/src/meta/model_v2/migration/src/m20230908_072257_init.rs b/src/meta/model_v2/migration/src/m20230908_072257_init.rs index 5b3d55ab83bfb..bc9ce2b08c32b 100644 --- a/src/meta/model_v2/migration/src/m20230908_072257_init.rs +++ b/src/meta/model_v2/migration/src/m20230908_072257_init.rs @@ -708,6 +708,7 @@ impl MigrationTrait for Migration { .col(ColumnDef::new(Function::Language).string().not_null()) .col(ColumnDef::new(Function::Link).string().not_null()) .col(ColumnDef::new(Function::Identifier).string().not_null()) + .col(ColumnDef::new(Function::Body).string()) .col(ColumnDef::new(Function::Kind).string().not_null()) .foreign_key( &mut ForeignKey::create() @@ -1099,6 +1100,7 @@ enum Function { Language, Link, Identifier, + Body, Kind, } diff --git a/src/meta/model_v2/src/function.rs b/src/meta/model_v2/src/function.rs index 71391a3cc27b0..5976685893afb 100644 --- a/src/meta/model_v2/src/function.rs +++ b/src/meta/model_v2/src/function.rs @@ -41,6 +41,7 @@ pub struct Model { pub language: String, pub link: String, pub identifier: String, + pub body: Option, pub kind: FunctionKind, } @@ -94,6 +95,7 @@ impl From for ActiveModel { language: Set(function.language), link: Set(function.link), identifier: Set(function.identifier), + body: Set(function.body), kind: Set(function.kind.unwrap().into()), } } diff --git a/src/meta/src/controller/mod.rs b/src/meta/src/controller/mod.rs index d6d6891a12e8d..037f9e3417163 100644 --- a/src/meta/src/controller/mod.rs +++ b/src/meta/src/controller/mod.rs @@ -277,6 +277,7 @@ impl From> for PbFunction { language: value.0.language, link: value.0.link, identifier: value.0.identifier, + body: value.0.body, kind: Some(value.0.kind.into()), } } diff --git a/src/sqlparser/README.md b/src/sqlparser/README.md index 20a5ac0c8e6ab..2a829102710ba 100644 --- a/src/sqlparser/README.md +++ b/src/sqlparser/README.md @@ -4,5 +4,5 @@ This parser is a fork of . ## Add a new test case -1. Copy an item in the yaml file and edit the `input` to the sql you want to test -2. Run `./risedev do-apply-parser-test` to regenerate the `formatted_sql` whicih is the expected output \ No newline at end of file +1. Copy an item in the yaml file and edit the `input` to the sql you want to test. +2. Run `./risedev do-apply-parser-test` to regenerate the `formatted_sql` which is the expected output. \ No newline at end of file diff --git a/src/sqlparser/tests/sqlparser_postgres.rs b/src/sqlparser/tests/sqlparser_postgres.rs index 2d97834ad23b5..99e2c185fdcff 100644 --- a/src/sqlparser/tests/sqlparser_postgres.rs +++ b/src/sqlparser/tests/sqlparser_postgres.rs @@ -768,6 +768,53 @@ fn parse_create_function() { } ); + let sql = "CREATE FUNCTION sub(INT, INT) RETURNS INT LANGUAGE SQL AS $$select $1 - $2;$$"; + assert_eq!( + verified_stmt(sql), + Statement::CreateFunction { + or_replace: false, + temporary: false, + name: ObjectName(vec![Ident::new_unchecked("sub")]), + args: Some(vec![ + OperateFunctionArg::unnamed(DataType::Int), + OperateFunctionArg::unnamed(DataType::Int), + ]), + returns: Some(CreateFunctionReturns::Value(DataType::Int)), + params: CreateFunctionBody { + language: Some("SQL".into()), + as_: Some(FunctionDefinition::DoubleDollarDef( + "select $1 - $2;".into() + )), + ..Default::default() + } + }, + ); + + // Anonymous return sql udf parsing test + let sql = "CREATE FUNCTION return_test(INT, INT) RETURNS INT LANGUAGE SQL RETURN $1 + $2"; + assert_eq!( + verified_stmt(sql), + Statement::CreateFunction { + or_replace: false, + temporary: false, + name: ObjectName(vec![Ident::new_unchecked("return_test")]), + args: Some(vec![ + OperateFunctionArg::unnamed(DataType::Int), + OperateFunctionArg::unnamed(DataType::Int), + ]), + returns: Some(CreateFunctionReturns::Value(DataType::Int)), + params: CreateFunctionBody { + language: Some("SQL".into()), + return_: Some(Expr::BinaryOp { + left: Box::new(Expr::Parameter { index: 1 }), + op: BinaryOperator::Plus, + right: Box::new(Expr::Parameter { index: 2 }), + }), + ..Default::default() + } + }, + ); + let sql = "CREATE OR REPLACE FUNCTION add(a INT, IN b INT = 1) RETURNS INT LANGUAGE SQL IMMUTABLE RETURN a + b"; assert_eq!( verified_stmt(sql), From dc0402aefdd9b4a182617d3650326bb7c513d9a9 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Fri, 5 Jan 2024 17:27:20 +0800 Subject: [PATCH 02/20] chore(dashboard): cleanup dead code for old dashboard (#14386) Signed-off-by: Bugen Zhao --- dashboard/components/StatusLamp.js | 29 - dashboard/lib/color.js | 90 -- dashboard/lib/graaphEngine/canvasEngine.js | 598 ----------- dashboard/lib/graaphEngine/svgEngine.js | 198 ---- dashboard/lib/str.js | 24 - dashboard/lib/streamPlan/parser.js | 428 -------- dashboard/lib/streamPlan/streamChartHelper.js | 945 ------------------ dashboard/test/algo.test.js | 109 -- 8 files changed, 2421 deletions(-) delete mode 100644 dashboard/components/StatusLamp.js delete mode 100644 dashboard/lib/color.js delete mode 100644 dashboard/lib/graaphEngine/canvasEngine.js delete mode 100644 dashboard/lib/graaphEngine/svgEngine.js delete mode 100644 dashboard/lib/str.js delete mode 100644 dashboard/lib/streamPlan/parser.js delete mode 100644 dashboard/lib/streamPlan/streamChartHelper.js delete mode 100644 dashboard/test/algo.test.js diff --git a/dashboard/components/StatusLamp.js b/dashboard/components/StatusLamp.js deleted file mode 100644 index 6cadd9e2e7764..0000000000000 --- a/dashboard/components/StatusLamp.js +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2024 RisingWave Labs - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ -export default function StatusLamp(props) { - return ( -
- ) -} diff --git a/dashboard/lib/color.js b/dashboard/lib/color.js deleted file mode 100644 index 8b6d7d0fccd6a..0000000000000 --- a/dashboard/lib/color.js +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright 2024 RisingWave Labs - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ -const two = [ - ["#6FE0D3", "#E09370"], - ["#75D0E0", "#E0A575"], - ["#75BAE0", "#E0B175"], - ["#77A5E0", "#E0BC77"], - ["#768CE0", "#E0C577"], - ["#7575E0", "#E0CD75"], - ["#A978E0", "#E0DF79"], - ["#C977E0", "#B9E077"], - ["#E072D7", "#92E072"], - ["#E069A4", "#69E069"], - ["#E06469", "#65E086"], - ["#E07860", "#60E0B2"], - ["#E08159", "#5AE0CE"], - ["#E09C5C", "#5CC5E0"], - ["#E0B763", "#6395E0"], - ["#E0CE5A", "#6B5AE0"], - ["#C8E051", "#AA51E0"], - ["#92E06F", "#E070DB"], - ["#79E085", "#E07998"], - ["#80E0B1", "#E08C80"], - ["#91DBE0", "#E0B292"], -] - -const twoGradient = [ - ["#1976d2", "#a6c9ff"], - ["#FFF38A", "#E0D463"], - ["#A3ACFF", "#7983DF"], - ["#A6C9FF", "#7BA3DF"], - ["#FFBE8C", "#E09B65"], - ["#FFD885", "#E0B65D"], - ["#9EE2FF", "#73BEDF"], - ["#DAFF8F", "#B8E066"], - ["#FFC885", "#E0A65D"], - ["#9EFCFF", "#74DCDF"], - ["#FBFF8C", "#DBE065"], - ["#9CFFDE", "#71DFBB"], - ["#FFAF91", "#E08869"], - ["#B699FF", "#9071E0"], - ["#9EFFB6", "#74DF8F"], - ["#FFA19C", "#E07872"], - ["#AEFF9C", "#85DF71"], - ["#FF96B9", "#E06D94"], - ["#FFE785", "#E0C75F"], - ["#FF94FB", "#E06BDC"], - ["#DA99FF", "#B66FE0"], - ["#8F93FF", "#666AE0"], -] - -const five = [ - ["#A8936C", "#F5D190", "#8B84F5", "#9AA84A", "#E1F578"], - ["#A87C6A", "#F5AB8E", "#82CBF5", "#A89348", "#F5D876"], - ["#A87490", "#F59DCB", "#90F5C7", "#A87752", "#F5B584"], - ["#856FA8", "#B995F5", "#BAF58A", "#A84D5B", "#F57D8E"], - ["#7783A8", "#A2B4F5", "#F5EE95", "#9C56A8", "#E589F5"], - ["#74A895", "#9DF5D4", "#F5BF91", "#526CA8", "#84A6F5"], - ["#74A878", "#9DF5A3", "#F5A290", "#5298A8", "#84DFF5"], - ["#94A877", "#D2F5A2", "#F596B6", "#56A88C", "#89F5D0"], - ["#A8A072", "#F5E79A", "#CD8DF5", "#5DA851", "#92F582"], - ["#A89176", "#F5CD9F", "#92A3F5", "#A8A554", "#F5F087"], - ["#A8726A", "#F59B8E", "#83ECF5", "#A88948", "#F5CB76"], -] -export function TwoColor(index) { - return two[index % two.length] -} - -export function FiveColor(index) { - return five[index % five.length] -} - -let s = Math.random() * 100 -export function TwoGradient(index) { - return twoGradient[(Math.round(s) + index) % two.length] -} diff --git a/dashboard/lib/graaphEngine/canvasEngine.js b/dashboard/lib/graaphEngine/canvasEngine.js deleted file mode 100644 index df661b53c5ec6..0000000000000 --- a/dashboard/lib/graaphEngine/canvasEngine.js +++ /dev/null @@ -1,598 +0,0 @@ -/* - * Copyright 2024 RisingWave Labs - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ -import { fabric } from "fabric" - -// Disable cache to improve performance. -fabric.Object.prototype.objectCaching = false -fabric.Object.prototype.statefullCache = false -fabric.Object.prototype.noScaleCache = true -fabric.Object.prototype.needsItsOwnCache = () => false - -export class DrawElement { - /** - * @param {{svgElement: d3.Selection}} props - */ - constructor(props) { - /** - * @type {{svgElement: d3.Selection}} - */ - this.props = props - if (props.canvasElement) { - props.engine.canvas.add(props.canvasElement) - props.canvasElement.on("mouse:down", (e) => { - console.log(e) - }) - } - - this.eventHandler = new Map() - } - - // TODO: this method is for migrating from d3.js to fabric.js. - // This should be replaced by a more suitable way. - _attrMap(key, value) { - return [key, value] - } - - // TODO: this method is for migrating from d3.js to fabric.js. - // This should be replaced by a more suitable way. - attr(key, value) { - let setting = this._attrMap(key, value) - if (setting && setting.length === 2) { - this.props.canvasElement && - this.props.canvasElement.set(setting[0], setting[1]) - } - return this - } - - _afterPosition() { - let ele = this.props.canvasElement - ele && this.props.engine._addDrawElement(this) - } - - // TODO: this method is for migrating from d3.js to fabric.js. - // This should be replaced by a more suitable way. - position(x, y) { - this.props.canvasElement.set("left", x) - this.props.canvasElement.set("top", y) - this._afterPosition() - return this - } - - on(event, callback) { - this.eventHandler.set(event, callback) - return this - } - - getEventHandler(event) { - return this.eventHandler.get(event) - } - - // TODO: this method is for migrating from d3.js to fabric.js. - // This should be replaced by a more suitable way. - style(key, value) { - return this.attr(key, value) - } - - classed(clazz, flag) { - this.props.engine.classedElement(clazz, this, flag) - return this - } -} - -export class Group extends DrawElement { - /** - * @param {{engine: CanvasEngine}} props - */ - constructor(props) { - super(props) - - this.appendFunc = { - g: this._appendGroup, - circle: this._appendCircle, - rect: this._appendRect, - text: this._appendText, - path: this._appendPath, - polygon: this._appendPolygan, - } - - this.basicSetting = { - engine: props.engine, - } - } - - _appendGroup = () => { - return new Group(this.basicSetting) - } - - _appendCircle = () => { - return new Circle({ - ...this.basicSetting, - ...{ - canvasElement: new fabric.Circle({ - selectable: false, - hoverCursor: "pointer", - }), - }, - }) - } - - _appendRect = () => { - return new Rectangle({ - ...this.basicSetting, - ...{ - canvasElement: new fabric.Rect({ - selectable: false, - hoverCursor: "pointer", - }), - }, - }) - } - - _appendText = () => { - return (content) => - new Text({ - ...this.basicSetting, - ...{ - canvasElement: new fabric.Text(content || "undefined", { - selectable: false, - textAlign: "justify-center", - }), - }, - }) - } - - _appendPath = () => { - return (d) => - new Path({ - ...this.basicSetting, - ...{ - canvasElement: new fabric.Path(d, { selectable: false }), - }, - }) - } - - _appendPolygan = () => { - return new Polygan(this.basicSetting) - } - - append = (type) => { - return this.appendFunc[type]() - } -} - -export class Rectangle extends DrawElement { - /** - * @param {{g: fabric.Group}} props - */ - constructor(props) { - super(props) - this.props = props - } - - init(x, y, width, height) { - let ele = this.props.canvasElement - ele.set("left", x) - ele.set("top", y) - ele.set("width", width) - ele.set("height", height) - super._afterPosition() - return this - } - - _attrMap(key, value) { - if (key === "rx") { - this.props.canvasElement.set("rx", value) - this.props.canvasElement.set("ry", value) - return false - } - return [key, value] - } -} - -export class Circle extends DrawElement { - /** - * @param {{svgElement: d3.Selection}} props - */ - constructor(props) { - super(props) - this.props = props - this.radius = 0 - } - - init(x, y, r) { - this.props.canvasElement.set("left", x - r) - this.props.canvasElement.set("top", y - r) - this.props.canvasElement.set("radius", r) - super._afterPosition() - return this - } - - _attrMap(key, value) { - if (key === "r") { - this.radius = value - return ["radius", value] - } - if (key === "cx") { - return ["left", value - this.radius] - } - if (key === "cy") { - return ["top", value - this.radius] - } - return [key, value] - } -} - -export class Text extends DrawElement { - /** - * @param {{svgElement: d3.Selection, any, null, undefined>}} props - */ - constructor(props) { - super(props) - this.props = props - } - - position(x, y) { - let e = this.props.canvasElement - e.set("top", y) - e.set("left", x) - super._afterPosition() - return this - } - - _attrMap(key, value) { - if (key === "text-anchor") { - return ["textAlign", value] - } - if (key === "font-size") { - return ["fontSize", value] - } - return [key, value] - } - - text(content) { - return this - } - - getWidth() {} -} - -export class Polygan extends DrawElement { - constructor(props) { - super(props) - this.props = props - } -} - -export class Path extends DrawElement { - constructor(props) { - super(props) - this.props = props - this.strokeWidth = 1 - super._afterPosition() - } - - _attrMap(key, value) { - if (key === "fill") { - return ["fill", value === "none" ? false : value] - } - if (key === "stroke-width") { - this.props.canvasElement.set( - "top", - this.props.canvasElement.get("top") - value / 2 - ) - return ["strokeWidth", value] - } - if (key === "stroke-dasharray") { - return ["strokeDashArray", value.split(",")] - } - if (key === "layer") { - if (value === "back") { - this.props.canvasElement.canvas.sendToBack(this.props.canvasElement) - } - return false - } - return [key, value] - } -} - -// TODO: Use rbtree -class CordMapper { - constructor() { - this.map = new Map() - } - - rangeQuery(start, end) { - let rtn = new Set() - for (let [k, s] of this.map.entries()) { - if (start <= k && k <= end) { - s.forEach((v) => rtn.add(v)) - } - } - return rtn - } - - insert(k, v) { - if (this.map.has(k)) { - this.map.get(k).add(v) - } else { - this.map.set(k, new Set([v])) - } - } -} - -class GridMapper { - constructor() { - this.xMap = new CordMapper() - this.yMap = new CordMapper() - this.gs = 100 // grid size - } - - _getKey(value) { - return Math.round(value / this.gs) - } - - addObject(minX, maxX, minY, maxY, ele) { - for (let i = minX; i <= maxX + this.gs; i += this.gs) { - this.xMap.insert(this._getKey(i), ele) - } - for (let i = minY; i <= maxY + this.gs; i += this.gs) { - this.yMap.insert(this._getKey(i), ele) - } - } - - areaQuery(minX, maxX, minY, maxY) { - let xs = this.xMap.rangeQuery(this._getKey(minX), this._getKey(maxX)) - let ys = this.yMap.rangeQuery(this._getKey(minY), this._getKey(maxY)) - let rtn = new Set() - xs.forEach((e) => { - if (ys.has(e)) { - rtn.add(e) - } - }) - return rtn - } -} - -export class CanvasEngine { - /** - * @param {string} canvasId The DOM id of the canvas - * @param {number} height the height of the canvas - * @param {number} width the width of the canvas - */ - constructor(canvasId, height, width) { - let canvas = new fabric.Canvas(canvasId) - canvas.selection = false // improve performance - - this.height = height - this.width = width - this.canvas = canvas - this.clazzMap = new Map() - this.topGroup = new Group({ engine: this }) - this.gridMapper = new GridMapper() - this.canvasElementToDrawElement = new Map() - - let that = this - canvas.on("mouse:wheel", function (opt) { - var evt = opt.e - if (evt.ctrlKey === true) { - var delta = opt.e.deltaY - var zoom = canvas.getZoom() - zoom *= 0.999 ** delta - if (zoom > 10) zoom = 10 - if (zoom < 0.03) zoom = 0.03 - canvas.zoomToPoint({ x: opt.e.offsetX, y: opt.e.offsetY }, zoom) - that._refreshView() - evt.preventDefault() - evt.stopPropagation() - } else { - that.moveCamera(-evt.deltaX, -evt.deltaY) - evt.preventDefault() - evt.stopPropagation() - } - }) - - canvas.on("mouse:down", function (opt) { - var evt = opt.e - this.isDragging = true - this.selection = false - this.lastPosX = evt.clientX - this.lastPosY = evt.clientY - - that._handleClickEvent(opt.target) - }) - - canvas.on("mouse:move", function (opt) { - if (this.isDragging) { - var e = opt.e - that.moveCamera(e.clientX - this.lastPosX, e.clientY - this.lastPosY) - this.lastPosX = e.clientX - this.lastPosY = e.clientY - } - }) - canvas.on("mouse:up", function (opt) { - this.setViewportTransform(this.viewportTransform) - this.isDragging = false - this.selection = true - }) - } - - /** - * Move the current view point. - * @param {number} deltaX - * @param {number} deltaY - */ - async moveCamera(deltaX, deltaY) { - this.canvas.setZoom(this.canvas.getZoom()) // essential for rendering (seems like a bug) - let vpt = this.canvas.viewportTransform - vpt[4] += deltaX - vpt[5] += deltaY - this._refreshView() - } - - /** - * Invoke the click handler of an object. - * @param {fabric.Object} target - */ - async _handleClickEvent(target) { - if (target === null) { - return - } - let ele = this.canvasElementToDrawElement.get(target) - let func = ele.getEventHandler("click") - if (func) { - func() - } - } - - /** - * Set the objects in the current view point visible. - * And set other objects not visible. - */ - async _refreshView() { - const padding = 50 // Make the rendering area a little larger. - let vpt = this.canvas.viewportTransform - let zoom = this.canvas.getZoom() - let cameraWidth = this.width - let cameraHeight = this.height - let minX = -vpt[4] - padding - let maxX = -vpt[4] + cameraWidth + padding - let minY = -vpt[5] - padding - let maxY = -vpt[5] + cameraHeight + padding - let visibleSet = this.gridMapper.areaQuery( - minX / zoom, - maxX / zoom, - minY / zoom, - maxY / zoom - ) - - this.canvas.getObjects().forEach((e) => { - if (visibleSet.has(e)) { - e.visible = true - } else { - e.visible = false - } - }) - - this.canvas.requestRenderAll() - } - - /** - * Register an element to the engine. This should - * be called when a DrawElement instance is added - * to the canvas. - * @param {DrawElement} ele - */ - _addDrawElement(ele) { - let canvasElement = ele.props.canvasElement - this.canvasElementToDrawElement.set(canvasElement, ele) - this.gridMapper.addObject( - canvasElement.left, - canvasElement.left + canvasElement.width, - canvasElement.top, - canvasElement.top + canvasElement.height, - canvasElement - ) - } - - /** - * Assign a class to an object or remove a class from it. - * @param {string} clazz class name - * @param {DrawElement} element target object - * @param {boolean} flag true if the object is assigned, otherwise - * remove the class from the object - */ - classedElement(clazz, element, flag) { - if (!flag) { - this.clazzMap.has(clazz) && this.clazzMap.get(clazz).delete(element) - } else { - if (this.clazzMap.has(clazz)) { - this.clazzMap.get(clazz).add(element) - } else { - this.clazzMap.set(clazz, new Set([element])) - } - } - } - - /** - * Move current view point to the object specified by - * the selector. The selector is the class of the - * target object for now. - * @param {string} selector The class of the target object - */ - locateTo(selector) { - // - let selectorSet = this.clazzMap.get(selector) - if (selectorSet) { - let arr = Array.from(selectorSet) - if (arr.length > 0) { - let ele = arr[0] - let x = ele.props.canvasElement.get("left") - let y = ele.props.canvasElement.get("top") - let scale = 0.6 - this.canvas.setZoom(scale) - let vpt = this.canvas.viewportTransform - vpt[4] = (-x + this.width * 0.5) * scale - vpt[5] = (-y + this.height * 0.5) * scale - this.canvas.requestRenderAll() - this._refreshView() - } - } - } - - /** - * Move current view point to (0, 0) - */ - resetCamera() { - let zoom = this.canvas.getZoom() - zoom *= 0.999 - this.canvas.setZoom(zoom) - let vpt = this.canvas.viewportTransform - vpt[4] = 0 - vpt[5] = 0 - this.canvas.requestRenderAll() - this._refreshView() - } - - /** - * Dispose the current canvas. Remove all the objects to - * free memory. All objects in the canvas will be removed. - */ - cleanGraph() { - console.log("clean called") - this.canvas.dispose() - } - - /** - * Resize the canvas. This is called when the browser size - * is changed, such that the canvas can fix the current - * size of the browser. - * - * Note that the outer div box of the canvas will be set - * according to the parameters. However, the width and - * height of the canvas is double times of the parameters. - * This is the feature of fabric.js to keep the canvas - * in high resolution all the time. - * - * @param {number} width the width of the canvas - * @param {number} height the height of the canvas - */ - resize(width, height) { - this.width = width - this.height = height - this.canvas.setDimensions({ width: this.width, height: this.height }) - } -} diff --git a/dashboard/lib/graaphEngine/svgEngine.js b/dashboard/lib/graaphEngine/svgEngine.js deleted file mode 100644 index 8102b79df0a4b..0000000000000 --- a/dashboard/lib/graaphEngine/svgEngine.js +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright 2024 RisingWave Labs - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ -import * as d3 from "d3" - -export class DrawElement { - /** - * @param {{svgElement: d3.Selection}} props - */ - constructor(props) { - /** - * @type {{svgElement: d3.Selection}} - */ - this.props = props - this.appendFunc = { - g: this._appendGroup, - circle: this._appendCircle, - rect: this._appendRect, - text: this._appendText, - path: this._appendPath, - polygon: this._appendPolygan, - } - } - - _appendGroup = () => { - return new Group({ svgElement: this.props.svgElement.append("g") }) - } - - _appendCircle = () => { - return new Circle({ svgElement: this.props.svgElement.append("circle") }) - } - - _appendRect = () => { - return new Rectangle({ svgElement: this.props.svgElement.append("rect") }) - } - - _appendText = () => { - return new Text({ svgElement: this.props.svgElement.append("text") }) - } - - _appendPath = () => { - return new Path({ svgElement: this.props.svgElement.append("path") }) - } - - _appendPolygan = () => { - return new Polygan({ svgElement: this.props.svgElement.append("polygon") }) - } - - on = (event, callback) => { - this.props.svgElement.on(event, callback) - return this - } - - style = (key, value) => { - this.props.svgElement.style(key, value) - return this - } - - classed = (clazz, flag) => { - this.props.svgElement.classed(clazz, flag) - return this - } - - attr = (key, value) => { - this.props.svgElement.attr(key, value) - return this - } - - append = (type) => { - return this.appendFunc[type]() - } -} - -export class Group extends DrawElement { - /** - * @param {{svgElement: d3.Selection}} props - */ - constructor(props) { - super(props) - } -} - -export class Rectangle extends DrawElement { - /** - * @param {{svgElement: d3.Selection}} props - */ - constructor(props) { - super(props) - } -} - -export class Circle extends DrawElement { - /** - * @param {{svgElement: d3.Selection}} props - */ - constructor(props) { - super(props) - } -} - -export class Text extends DrawElement { - /** - * @param {{svgElement: d3.Selection, any, null, undefined>}} props - */ - constructor(props) { - super(props) - this.props = props - } - - text(content) { - this.props.svgElement.text(content) - return this - } - - getWidth() { - return this.props.svgElement.node().getComputedTextLength() - } -} - -export class Polygan extends DrawElement { - constructor(props) { - super(props) - this.props = props - } -} - -export class Path extends DrawElement { - constructor(props) { - super(props) - } -} - -const originalZoom = new d3.ZoomTransform(0.5, 0, 0) - -export class SvgEngine { - /** - * @param {{g: d3.Selection}} props - */ - constructor(svgRef, height, width) { - this.height = height - this.width = width - this.svgRef = svgRef - - d3.select(svgRef).selectAll("*").remove() - this.svg = d3.select(svgRef).attr("viewBox", [0, 0, width, height]) - - this._g = this.svg.append("g").attr("class", "top") - this.topGroup = new Group({ svgElement: this._g }) - - this.transform - this.zoom = d3.zoom().on("zoom", (e) => { - this.transform = e.transform - this._g.attr("transform", e.transform) - }) - - this.svg.call(this.zoom).call(this.zoom.transform, originalZoom) - this.svg.on("pointermove", (event) => { - this.transform.invert(d3.pointer(event)) - }) - } - - locateTo(selector) { - let selection = d3.select(selector) - if (!selection.empty()) { - this.svg - .call(this.zoom) - .call( - this.zoom.transform, - new d3.ZoomTransform( - 0.7, - -0.7 * selection.attr("x"), - 0.7 * (-selection.attr("y") + this.height / 2) - ) - ) - } - } - - resetCamera() { - this.svg.call(this.zoom.transform, originalZoom) - } - - cleanGraph() { - d3.select(this.svgRef).selectAll("*").remove() - } -} diff --git a/dashboard/lib/str.js b/dashboard/lib/str.js deleted file mode 100644 index 52af7892f4a75..0000000000000 --- a/dashboard/lib/str.js +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright 2024 RisingWave Labs - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ -export function capitalize(sentence) { - let words = sentence.split(" ") - let s = "" - for (let word of words) { - s += word.charAt(0).toUpperCase() + word.slice(1, word.length) - } - return s -} diff --git a/dashboard/lib/streamPlan/parser.js b/dashboard/lib/streamPlan/parser.js deleted file mode 100644 index 050e0f05619a4..0000000000000 --- a/dashboard/lib/streamPlan/parser.js +++ /dev/null @@ -1,428 +0,0 @@ -/* - * Copyright 2024 RisingWave Labs - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ -import { graphBfs } from "../algo" - -let cnt = 0 -function generateNewNodeId() { - return "g" + ++cnt -} - -function getNodeId(nodeProto, actorId) { - return ( - actorId + - ":" + - (nodeProto.operatorId === undefined - ? generateNewNodeId() - : "o" + nodeProto.operatorId) - ) -} - -class Node { - constructor(id, actorId, nodeProto) { - this.id = id - /** - * @type {any} - */ - this.nodeProto = nodeProto - /** - * @type {Array} - */ - this.output = output - /** - * @type {Node} - */ - this.rootNode = rootNode - /** - * @type {number} - */ - this.fragmentId = fragmentId - /** - * @type {string} - */ - this.computeNodeAddress = computeNodeAddress - /** - * @type {Array} - */ - this.representedActorList = null - /** - * @type {Set} - */ - this.representedWorkNodes = null - } -} - -export default class StreamPlanParser { - /** - * - * @param {[{node: any, actors: []}]} data raw response from the meta node - */ - constructor(data, shownActorList) { - this.actorId2Proto = new Map() - /** - * @type {Set} - * @private - */ - this.actorIdTomviewNodes = new Map() - this.shownActorSet = new Set(shownActorList) - - for (let computeNodeData of data) { - for (let singleActorProto of computeNodeData.actors) { - if ( - shownActorList && - !this.shownActorSet.has(singleActorProto.actorId) - ) { - continue - } - this.actorId2Proto.set(singleActorProto.actorId, { - computeNodeAddress: `${computeNodeData.node.host.host}:${computeNodeData.node.host.port}`, - ...singleActorProto, - }) - } - } - - this.parsedNodeMap = new Map() - this.parsedActorMap = new Map() - - for (let [_, singleActorProto] of this.actorId2Proto.entries()) { - this.parseActor(singleActorProto) - } - - this.parsedActorList = [] - for (let [_, actor] of this.parsedActorMap.entries()) { - this.parsedActorList.push(actor) - } - - /** @type {Set} */ - this.fragmentRepresentedActors = this._constructRepresentedActorList() - - /** @type {Map} */ - this.mvTableIdToSingleViewActorList = this._constructSingleViewMvList() - - /** @type {Map} */ - this.mvTableIdToChainViewActorList = this._constructChainViewMvList() - } - - /** - * Randomly select a actor to represent its - * fragment, and append a property named `representedActorList` - * to store all the other actors in the same fragment. - * - * Actors are degree of parallelism of a fragment, such that one of - * the actor in a fragment can represent all the other actor in - * the same fragment. - * - * @returns A Set containing actors representing its fragment. - */ - _constructRepresentedActorList() { - const fragmentId2actorList = new Map() - let fragmentRepresentedActors = new Set() - for (let actor of this.parsedActorList) { - if (!fragmentId2actorList.has(actor.fragmentId)) { - fragmentRepresentedActors.add(actor) - fragmentId2actorList.set(actor.fragmentId, [actor]) - } else { - fragmentId2actorList.get(actor.fragmentId).push(actor) - } - } - - for (let actor of fragmentRepresentedActors) { - actor.representedActorList = fragmentId2actorList - .get(actor.fragmentId) - .sort((x) => x.actorId) - actor.representedWorkNodes = new Set() - for (let representedActor of actor.representedActorList) { - representedActor.representedActorList = actor.representedActorList - actor.representedWorkNodes.add(representedActor.computeNodeAddress) - } - } - return fragmentRepresentedActors - } - - _constructChainViewMvList() { - let mvTableIdToChainViewActorList = new Map() - let shellNodes = new Map() - const getShellNode = (actorId) => { - if (shellNodes.has(actorId)) { - return shellNodes.get(actorId) - } - let shellNode = { - id: actorId, - parentNodes: [], - nextNodes: [], - } - for (let node of this.parsedActorMap.get(actorId).output) { - let nextNode = getShellNode(node.actorId) - nextNode.parentNodes.push(shellNode) - shellNode.nextNodes.push(nextNode) - } - shellNodes.set(actorId, shellNode) - return shellNode - } - - for (let actorId of this.actorId2Proto.keys()) { - getShellNode(actorId) - } - - for (let [actorId, mviewNode] of this.actorIdTomviewNodes.entries()) { - let list = new Set() - let shellNode = getShellNode(actorId) - graphBfs(shellNode, (n) => { - list.add(n.id) - }) - graphBfs( - shellNode, - (n) => { - list.add(n.id) - }, - "parentNodes" - ) - for (let actor of this.parsedActorMap.get(actorId).representedActorList) { - list.add(actor.actorId) - } - mvTableIdToChainViewActorList.set(mviewNode.typeInfo.tableId, [ - ...list.values(), - ]) - } - - return mvTableIdToChainViewActorList - } - - _constructSingleViewMvList() { - let mvTableIdToSingleViewActorList = new Map() - let shellNodes = new Map() - const getShellNode = (actorId) => { - if (shellNodes.has(actorId)) { - return shellNodes.get(actorId) - } - let shellNode = { - id: actorId, - parentNodes: [], - } - for (let node of this.parsedActorMap.get(actorId).output) { - getShellNode(node.actorId).parentNodes.push(shellNode) - } - shellNodes.set(actorId, shellNode) - return shellNode - } - for (let actor of this.parsedActorList) { - getShellNode(actor.actorId) - } - - for (let actorId of this.actorId2Proto.keys()) { - getShellNode(actorId) - } - - for (let [actorId, mviewNode] of this.actorIdTomviewNodes.entries()) { - let list = [] - let shellNode = getShellNode(actorId) - graphBfs( - shellNode, - (n) => { - list.push(n.id) - if (shellNode.id !== n.id && this.actorIdTomviewNodes.has(n.id)) { - return true // stop to traverse its next nodes - } - }, - "parentNodes" - ) - for (let actor of this.parsedActorMap.get(actorId).representedActorList) { - list.push(actor.actorId) - } - mvTableIdToSingleViewActorList.set(mviewNode.typeInfo.tableId, list) - } - - return mvTableIdToSingleViewActorList - } - - newDispatcher(actorId, type, downstreamActorId) { - return new Dispatcher(actorId, type, downstreamActorId, { - operatorId: 100000 + actorId, - }) - } - - /** - * Parse raw data from meta node to an actor - * @param {{ - * actorId: number, - * fragmentId: number, - * nodes: any, - * dispatcher?: {type: string}, - * downstreamActorId?: any - * }} actorProto - * @returns {Actor} - */ - parseActor(actorProto) { - let actorId = actorProto.actorId - if (this.parsedActorMap.has(actorId)) { - return this.parsedActorMap.get(actorId) - } - - let actor = new Actor( - actorId, - [], - null, - actorProto.fragmentId, - actorProto.computeNodeAddress - ) - - let rootNode - this.parsedActorMap.set(actorId, actor) - if (actorProto.dispatcher && actorProto.dispatcher[0].type) { - let nodeBeforeDispatcher = this.parseNode(actor.actorId, actorProto.nodes) - rootNode = this.newDispatcher( - actor.actorId, - actorProto.dispatcher[0].type, - actorProto.downstreamActorId - ) - rootNode.nextNodes = [nodeBeforeDispatcher] - } else { - rootNode = this.parseNode(actorId, actorProto.nodes) - } - actor.rootNode = rootNode - - return actor - } - - parseNode(actorId, nodeProto) { - let id = getNodeId(nodeProto, actorId) - if (this.parsedNodeMap.has(id)) { - return this.parsedNodeMap.get(id) - } - let newNode = new StreamNode(id, actorId, nodeProto) - this.parsedNodeMap.set(id, newNode) - - if (nodeProto.input !== undefined) { - for (let nextNodeProto of nodeProto.input) { - newNode.nextNodes.push(this.parseNode(actorId, nextNodeProto)) - } - } - - if (newNode.type === "merge" && newNode.typeInfo.upstreamActorId) { - for (let upStreamActorId of newNode.typeInfo.upstreamActorId) { - if (!this.actorId2Proto.has(upStreamActorId)) { - continue - } - this.parseActor(this.actorId2Proto.get(upStreamActorId)).output.push( - newNode - ) - } - } - - if (newNode.type === "streamScan" && newNode.typeInfo.upstreamActorIds) { - for (let upStreamActorId of newNode.typeInfo.upstreamActorIds) { - if (!this.actorId2Proto.has(upStreamActorId)) { - continue - } - this.parseActor(this.actorId2Proto.get(upStreamActorId)).output.push( - newNode - ) - } - } - - if (newNode.type === "materialize") { - this.actorIdTomviewNodes.set(actorId, newNode) - } - - return newNode - } - - getActor(actorId) { - return this.parsedActorMap.get(actorId) - } - - getOperator(operatorId) { - return this.parsedNodeMap.get(operatorId) - } - - /** - * @returns {Array} - */ - getParsedActorList() { - return this.parsedActorList - } -} diff --git a/dashboard/lib/streamPlan/streamChartHelper.js b/dashboard/lib/streamPlan/streamChartHelper.js deleted file mode 100644 index 5610273cda3de..0000000000000 --- a/dashboard/lib/streamPlan/streamChartHelper.js +++ /dev/null @@ -1,945 +0,0 @@ -/* - * Copyright 2024 RisingWave Labs - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ -import * as d3 from "d3" -import { cloneDeep, max } from "lodash" -import { getConnectedComponent, treeBfs } from "../algo" -import * as color from "../color" -import { Group } from "../graaphEngine/canvasEngine" -import { newNumberArray } from "../util" -import StreamPlanParser, { Actor } from "./parser" -// Actor constant -// -// ======================================================= -// ^ -// | actorBoxPadding -// v -// --┌───────────┐ -// | │ node │ -// | │<--->radius│>───────\ -// | │ │ │ -// | └───────────┘ │ -// | │ -// | ┌───────────┐ │ ┌───────────┐ -// | │ node │ │ │ node │ -// widthUnit│<--->radius│>───────┼────────>│<--->radius│ -// | │ │ │ │ │ -// | └───────────┘ │ └───────────┘ -// | │ -// | ┌───────────┐ │ -// | │ node │ │ -// | │<--->radius│>───────/ -// | │ │ -// ---└───────────┘ -// |-----------------heightUnit---------------| -// - -const SCALE_FACTOR = 0.5 - -const operatorNodeRadius = 30 * SCALE_FACTOR // the radius of the tree nodes in an actor -const operatorNodeStrokeWidth = 5 * SCALE_FACTOR // the stroke width of the link of the tree nodes in an actor -const widthUnit = 230 * SCALE_FACTOR // the width of a tree node in an actor -const heightUnit = 250 * SCALE_FACTOR // the height of a tree layer in an actor -const actorBoxPadding = 100 * SCALE_FACTOR // box padding -const actorBoxStroke = 15 * SCALE_FACTOR // the width of the stroke of the box -const internalLinkStrokeWidth = 30 * SCALE_FACTOR // the width of the link between nodes -const actorBoxRadius = 20 * SCALE_FACTOR - -// Stream Plan constant -const gapBetweenRow = 100 * SCALE_FACTOR -const gapBetweenLayer = 300 * SCALE_FACTOR -const gapBetweenFlowChart = 500 * SCALE_FACTOR -const outgoingLinkStrokeWidth = 20 * SCALE_FACTOR -const outgoingLinkBgStrokeWidth = 40 * SCALE_FACTOR - -// Draw linking effect -const bendGap = 50 * SCALE_FACTOR // try example at: http://bl.ocks.org/d3indepth/b6d4845973089bc1012dec1674d3aff8 -const connectionGap = 20 * SCALE_FACTOR - -// Others -const fontSize = 30 * SCALE_FACTOR -const outGoingLinkBgColor = "#eee" - -/** - * Construct an id for a link in actor box. - * You may use this method to query and get the svg element - * of the link. - * @param {{id: number}} node1 a node (operator) in an actor box - * @param {{id: number}} node2 a node (operator) in an actor box - * @returns {string} The link id - */ -function constructInternalLinkId(node1, node2) { - return ( - "node-" + - (node1.id > node2.id - ? node1.id + "-" + node2.id - : node2.id + "-" + node1.id) - ) -} - -/** - * Construct an id for a node (operator) in an actor box. - * You may use this method to query and get the svg element - * of the link. - * @param {{id: number}} node a node (operator) in an actor box - * @returns {string} The node id - */ -function constructOperatorNodeId(node) { - return "node-" + node.id -} - -function hashIpv4Index(addr) { - let [ip, port] = addr.split(":") - let s = "" - ip.split(".").map((x) => (s += x)) - return Number(s + port) -} - -export function computeNodeAddrToSideColor(addr) { - return color.TwoGradient(hashIpv4Index(addr))[1] -} - -/** - * Work flow - * 1. Get the layout for actor boxes (Calculate the base coordination of each actor box) - * 2. Get the layout for operators in each actor box - * 3. Draw all actor boxes - * 4. Draw link between actor boxes - * - * - * Dependencies - * layoutActorBox <- dagLayout <- drawActorBox <- drawFlow - * [ The layout of the ] [ The layout of ] [ Draw an actor ] [ Draw many actors ] - * [ operators in an ] [ actors in a ] [ in specified ] [ and links between ] - * [ actor. ] [ stream plan ] [ place ] [ them. ] - * - */ -export class StreamChartHelper { - /** - * - * @param {Group} g The group element in canvas engine - * @param {*} data The raw response from the meta node - * @param {(e, node) => void} onNodeClick The callback function triggered when a node is click - * @param {(e, actor) => void} onActorClick - * @param {{type: string, node: {host: {host: string, port: number}}, id?: number}} selectedWokerNode - * @param {Array} shownActorIdList - */ - constructor( - g, - data, - onNodeClick, - onActorClick, - selectedWokerNode, - shownActorIdList - ) { - this.topGroup = g - this.streamPlan = new StreamPlanParser(data, shownActorIdList) - this.onNodeClick = onNodeClick - this.onActorClick = onActorClick - this.selectedWokerNode = selectedWokerNode - this.selectedWokerNodeStr = this.selectedWokerNode - ? selectedWokerNode.host.host + ":" + selectedWokerNode.host.port - : "Show All" - } - - getMvTableIdToSingleViewActorList() { - return this.streamPlan.mvTableIdToSingleViewActorList - } - - getMvTableIdToChainViewActorList() { - return this.streamPlan.mvTableIdToChainViewActorList - } - - /** - * @param {Actor} actor - * @returns - */ - isInSelectedActor(actor) { - if (this.selectedWokerNodeStr === "Show All") { - // show all - return true - } else { - return actor.representedWorkNodes.has(this.selectedWokerNodeStr) - } - } - - _mainColor(actor) { - let addr = actor.representedWorkNodes.has(this.selectedWokerNodeStr) - ? this.selectedWokerNodeStr - : actor.computeNodeAddress - return color.TwoGradient(hashIpv4Index(addr))[0] - } - - _sideColor(actor) { - let addr = actor.representedWorkNodes.has(this.selectedWokerNodeStr) - ? this.selectedWokerNodeStr - : actor.computeNodeAddress - return color.TwoGradient(hashIpv4Index(addr))[1] - } - - _operatorColor = (actor, operator) => { - return this.isInSelectedActor(actor) && operator.type === "mviewNode" - ? this._mainColor(actor) - : "#eee" - } - _actorBoxBackgroundColor = (actor) => { - return this.isInSelectedActor(actor) ? this._sideColor(actor) : "#eee" - } - _actorOutgoinglinkColor = (actor) => { - return this.isInSelectedActor(actor) ? this._mainColor(actor) : "#fff" - } - - // - // A simple DAG layout algorithm. - // The layout is built based on two rules. - // 1. The link should have at two turnning points. - // 2. The turnning point of a link should be placed - // at the margin after the layer of its starting point. - // ------------------------------------------------------- - // Example 1: (X)-(Z) and (Y)-(Z) is valid. - // Row 0 (X)---------------->(Z) - // | - // Row 1 | - // | - // Row 2 (Y)---/ - // Layer 1 | Layer 2 | Layer 3 - // ------------------------------------------------------- - // Example 2: (A)-(B) is not valid. - // Row 0 (X) /---------\ (Z) - // | | - // Row 1 (A)---/ (Y) |-->(B) - // - // Layer 1 | Layer 2 | Layer 3 - // ------------------------------------------------------- - // Example 3: (C)-(Z) is not valid - // Row 0 (X) /-->(Z) - // | - // Row 1 (C)-------------/ - // - // Layer 1 | Layer 2 | Layer 3 - // ------------------------------------------------------- - // Note that the layer of each node can be different - // For example: - // Row 0 ( 1) ( 3) ( 5) ( 2) ( 9) - // Row 1 ( 4) ( 6) (10) - // Row 2 ( 7) ( 8) - // Layer 0 | Layer 1 | Layer 2 | Layer 3 | Layer 4 | - // - // Row 0 ( 1) ( 3) ( 5) ( 2) ( 9) - // Row 1 ( 4) ( 6) (10) - // Row 2 ( 7) ( 8) - // Layer 0 | Layer 1 | Layer 2 | Layer 3 | Layer 4 | - /** - * Topological sort - * @param {Array} nodes An array of node: {nextNodes: [...]} - * @returns {Map} position of each node - */ - dagLayout(nodes) { - let sorted = [] - let _nodes = [] - let node2dagNode = new Map() - const visit = (n) => { - if (n.temp) { - throw Error("This is not a DAG") - } - if (!n.perm) { - n.temp = true - let maxG = -1 - for (let nextNode of n.node.nextNodes) { - node2dagNode.get(nextNode).isInput = false - n.isOutput = false - let g = visit(node2dagNode.get(nextNode)) - if (g > maxG) { - maxG = g - } - } - n.temp = false - n.perm = true - n.g = maxG + 1 - sorted.unshift(n.node) - } - return n.g - } - for (let node of nodes) { - let dagNode = { - node: node, - temp: false, - perm: false, - isInput: true, - isOutput: true, - } - node2dagNode.set(node, dagNode) - _nodes.push(dagNode) - } - let maxLayer = 0 - for (let node of _nodes) { - let g = visit(node) - if (g > maxLayer) { - maxLayer = g - } - } - // use the bottom up strategy to construct generation number - // makes the generation number of root node the samllest - // to make the computation easier, need to flip it back. - for (let node of _nodes) { - // node.g = node.isInput ? 0 : (maxLayer - node.g); // TODO: determine which is more suitable - node.g = maxLayer - node.g - } - - let layers = [] - for (let i = 0; i < maxLayer + 1; ++i) { - layers.push({ - nodes: [], - occupyRow: new Set(), - }) - } - let node2Layer = new Map() - let node2Row = new Map() - for (let node of _nodes) { - layers[node.g].nodes.push(node.node) - node2Layer.set(node.node, node.g) - } - - // layers to rtn - let rtn = new Map() - - const putNodeInPosition = (node, row) => { - node2Row.set(node, row) - layers[node2Layer.get(node)].occupyRow.add(row) - } - - const occupyLine = (ls, le, r) => { - // layer start, layer end, row - for (let i = ls; i <= le; ++i) { - layers[i].occupyRow.add(r) - } - } - - const hasOccupied = (layer, row) => layers[layer].occupyRow.has(row) - - const isStraightLineOccupied = (ls, le, r) => { - // layer start, layer end, row - if (r < 0) { - return false - } - for (let i = ls; i <= le; ++i) { - if (hasOccupied(i, r)) { - return true - } - } - return false - } - - for (let node of nodes) { - node.nextNodes.sort((a, b) => node2Layer.get(b) - node2Layer.get(a)) - } - - for (let layer of layers) { - for (let node of layer.nodes) { - if (!node2Row.has(node)) { - // checking node is not placed. - for (let nextNode of node.nextNodes) { - if (node2Row.has(nextNode)) { - continue - } - let r = -1 - while ( - isStraightLineOccupied( - node2Layer.get(node), - node2Layer.get(nextNode), - ++r - ) - ) {} - putNodeInPosition(node, r) - putNodeInPosition(nextNode, r) - occupyLine( - node2Layer.get(node) + 1, - node2Layer.get(nextNode) - 1, - r - ) - break - } - if (!node2Row.has(node)) { - let r = -1 - while (hasOccupied(node2Layer.get(node), ++r)) {} - putNodeInPosition(node, r) - } - } - // checking node is placed in some position - for (let nextNode of node.nextNodes) { - if (node2Row.has(nextNode)) { - continue - } - // check straight line position first - let r = node2Row.get(node) - if ( - !isStraightLineOccupied( - node2Layer.get(node) + 1, - node2Layer.get(nextNode), - r - ) - ) { - putNodeInPosition(nextNode, r) - occupyLine( - node2Layer.get(node) + 1, - node2Layer.get(nextNode) - 1, - r - ) - continue - } - // check lowest available position - r = -1 - while ( - isStraightLineOccupied( - node2Layer.get(node) + 1, - node2Layer.get(nextNode), - ++r - ) - ) {} - putNodeInPosition(nextNode, r) - occupyLine(node2Layer.get(node) + 1, node2Layer.get(nextNode) - 1, r) - } - } - } - for (let node of nodes) { - rtn.set(node.id, [node2Layer.get(node), node2Row.get(node)]) - } - - return rtn - } - - /** - * Calculate the position of each node in the actor box. - * @param {{id: any, nextNodes: [], x: number, y: number}} rootNode The root node of an actor box (dispatcher) - * @returns {[width, height]} The size of the actor box - */ - calculateActorBoxSize(rootNode) { - let rootNodeCopy = cloneDeep(rootNode) - return this.layoutActorBox(rootNodeCopy, 0, 0) - } - - /** - * Calculate the position of each node (operator) in the actor box. - * This will change the node's position - * @param {{id: any, nextNodes: [], x: number, y: number}} rootNode The root node of an actor box (dispatcher) - * @param {number} baseX The x coordination of the top-left corner of the actor box - * @param {number} baseY The y coordination of the top-left corner of the actor box - * @returns {[width, height]} The size of the actor box - */ - layoutActorBox(rootNode, baseX, baseY) { - // calculate nodes' required width - let maxLayer = 0 - const getRequiredWidth = (node, layer) => { - if (node.width !== undefined) { - return node.width - } - - if (layer > maxLayer) { - maxLayer = layer - } - - node.layer = layer - - let requiredWidth = 0 - for (let nextNode of node.nextNodes) { - requiredWidth += getRequiredWidth(nextNode, layer + 1) - } - - node.isLeaf = requiredWidth === 0 - - node.width = requiredWidth > 0 ? requiredWidth : widthUnit - - return node.width - } - - getRequiredWidth(rootNode, 0) - - // calculate nodes' position - rootNode.x = baseX || 0 - rootNode.y = baseY || 0 - let leafY = rootNode.x - heightUnit * maxLayer - treeBfs(rootNode, (c) => { - let tmpY = c.y - c.width / 2 - for (let nextNode of c.nextNodes) { - nextNode.x = nextNode.isLeaf ? leafY : c.x - heightUnit - nextNode.y = tmpY + nextNode.width / 2 - tmpY += nextNode.width - } - }) - - // calculate box size - let minX = Infinity - let maxX = -Infinity - let minY = Infinity - let maxY = -Infinity - treeBfs(rootNode, (node) => { - if (node.x > maxX) { - maxX = node.x - } - if (node.x < minX) { - minX = node.x - } - if (node.y > maxY) { - maxY = node.y - } - if (node.y < minY) { - minY = node.y - } - }) - let boxWidth = maxX - minX - let boxHeight = maxY - minY - return [boxWidth + actorBoxPadding * 2, boxHeight + actorBoxPadding * 2] - } - - /** - * @param {{ - * g: Group, - * rootNode: {id: any, nextNodes: []}, - * nodeColor: string, - * strokeColor?: string, - * baseX?: number, - * baseY?: number - * }} props - * @param {Group} props.g The group element in canvas engine - * @param {{id: any, nextNodes: []}} props.rootNode The root node of the tree in the actor - * @param {string} props.nodeColor [optional] The filled color of nodes. - * @param {string} props.strokeColor [optional] The color of the stroke. - * @param {number} props.baseX [optional] The x coordination of the lef-top corner. default: 0 - * @param {number} props.baseY [optional] The y coordination of the lef-top corner. default: 0 - * @returns {Group} The group element of this tree - */ - drawActorBox(props) { - if (props.g === undefined) { - throw Error("Invalid Argument: Target group cannot be undefined.") - } - - const actor = props.actor - const group = props.g.append("g") - const rootNode = props.rootNode || [] - const baseX = props.x === undefined ? 0 : props.x - const baseY = props.y === undefined ? 0 : props.y - const strokeColor = props.strokeColor || "white" - const linkColor = props.linkColor || "gray" - - group.attr("class", actor.computeNodeAddress) - - const [boxWidth, boxHeight] = this.calculateActorBoxSize(rootNode) - this.layoutActorBox( - rootNode, - baseX + boxWidth - actorBoxPadding, - baseY + boxHeight / 2 - ) - - const onNodeClicked = (e, node, actor) => { - this.onNodeClick && this.onNodeClick(e, node, actor) - } - - const onActorClick = (e, actor) => { - this.onActorClick && this.onActorClick(e, actor) - } - - /** - * @param {Group} g actor box group - * @param {number} x top-right corner of the label - * @param {number} y top-right corner of the label - * @param {Array} actorIds - * @param {string} color - * @returns {number} width of this label - */ - const drawActorIdLabel = (g, x, y, actorIds, color) => { - y = y - actorBoxStroke - let actorStr = actorIds.toString() - let padding = 15 - // let height = fontSize + 2 * padding; - let gap = 30 - // let polygon = g.append("polygon"); - let textEle = g - .append("text")(actorStr) - .attr("font-size", fontSize) - .position(x - padding - 5, y + padding) - let width = textEle.getWidth() + 2 * padding - // polygon.attr("points", `${x},${y} ${x - width - gap},${y}, ${x - width},${y + height}, ${x},${y + height}`) - // .attr("fill", color); - return width + gap - } - - // draw box - group.attr("id", "actor-" + actor.actorId) - let actorRect = group.append("rect") - for (let representedActor of actor.representedActorList) { - actorRect.classed("actor-" + representedActor.actorId, true) - } - actorRect.classed("fragment-" + actor.fragmentId, true) - actorRect - .init(baseX, baseY, boxWidth, boxHeight) - .attr("fill", this._actorBoxBackgroundColor(actor)) - .attr("rx", actorBoxRadius) - .attr("stroke-width", actorBoxStroke) - .on("click", (e) => onActorClick(e, actor)) - - group - .append("text")(`Fragment ${actor.fragmentId}`) - .position(baseX, baseY - actorBoxStroke - fontSize) - .attr("font-size", fontSize) - - // draw compute node label - let computeNodeToActorIds = new Map() - for (let representedActor of actor.representedActorList) { - if (computeNodeToActorIds.has(representedActor.computeNodeAddress)) { - computeNodeToActorIds - .get(representedActor.computeNodeAddress) - .push(representedActor.actorId) - } else { - computeNodeToActorIds.set(representedActor.computeNodeAddress, [ - representedActor.actorId, - ]) - } - } - let labelStartX = baseX + actorBoxStroke - for (let [addr, actorIds] of computeNodeToActorIds.entries()) { - let w = drawActorIdLabel( - group, - labelStartX, - baseY + boxHeight, - actorIds, - color.TwoGradient(hashIpv4Index(addr))[1] - ) - labelStartX -= w - } - - // draw links - const linkData = [] - treeBfs(rootNode, (c) => { - for (let nextNode of c.nextNodes) { - linkData.push({ - sourceNode: c, - nextNode: nextNode, - source: [c.x, c.y], - target: [nextNode.x, nextNode.y], - }) - } - }) - const linkGen = d3.linkHorizontal() - for (let link of linkData) { - group - .append("path")(linkGen(link)) - .attr( - "stroke-dasharray", - `${internalLinkStrokeWidth / 2},${internalLinkStrokeWidth / 2}` - ) - // .attr("d", linkGen(link)) - .attr("fill", "none") - .attr("class", "actor-" + actor.actorId) - .classed("internal-link", true) - .attr("id", constructInternalLinkId(link.sourceNode, link.nextNode)) - .style("stroke-width", internalLinkStrokeWidth) - .attr("stroke", linkColor) - } - - // draw nodes - treeBfs(rootNode, (node) => { - node.d3Selection = group - .append("circle") - .init(node.x, node.y, operatorNodeRadius) - .attr("id", constructOperatorNodeId(node)) - .attr("stroke", strokeColor) - .attr("fill", this._operatorColor(actor, node)) - .style("cursor", "pointer") - .style("stroke-width", operatorNodeStrokeWidth) - .on("click", (e) => onNodeClicked(e, node, actor)) - group - .append("text")(node.type ? node.type : node.dispatcherType) - .position(node.x, node.y + operatorNodeRadius + 10) - .attr("font-size", fontSize) - }) - - return { - g: group, - x: baseX - boxWidth - actorBoxPadding, - y: baseY - boxHeight / 2 - actorBoxPadding, - width: boxWidth + actorBoxPadding * 2, - height: boxHeight + actorBoxPadding * 2, - } - } - /** - * - * @param {{ - * g: Group, - * actorDagList: Array, - * baseX?: number, - * baseY?: number - * }} props - * @param {Group} props.g The target group contains this group. - * @param {Array} props.actorDagList A list of dag nodes constructed from actors - * { id: actor.actorId, nextNodes: [], actor: actor } - * @param {number} props.baseX [optional] The x coordination of left-top corner. default: 0. - * @param {number} props.baseY [optional] The y coordination of left-top corner. default: 0. - * @returns {{group: Group, width: number, height: number}} The size of the flow - */ - drawFlow(props) { - if (props.g === undefined) { - throw Error("Invalid Argument: Target group cannot be undefined.") - } - - const g = props.g - const actorDagList = props.actorDagList || [] - const baseX = props.baseX || 0 - const baseY = props.baseY || 0 - - let layoutPositionMapper = this.dagLayout(actorDagList) - const actors = [] - for (let actorDag of actorDagList) { - actors.push(actorDag.actor) - } - - // calculate actor box size - for (let actor of actors) { - ;[actor.boxWidth, actor.boxHeight] = this.calculateActorBoxSize( - actor.rootNode - ) - ;[actor.layer, actor.row] = layoutPositionMapper.get(actor.actorId) - } - - // calculate the minimum required width of each layer and row - let maxRow = 0 - let maxLayer = 0 - for (let actor of actors) { - maxLayer = max([actor.layer, maxLayer]) - maxRow = max([actor.row, maxRow]) - } - let rowGap = newNumberArray(maxRow + 1) - let layerGap = newNumberArray(maxLayer + 1) - for (let actor of actors) { - layerGap[actor.layer] = max([layerGap[actor.layer], actor.boxWidth]) - rowGap[actor.row] = max([rowGap[actor.row], actor.boxHeight]) - } - let row2y = newNumberArray(maxRow + 1) - let layer2x = newNumberArray(maxLayer + 1) - row2y = row2y.map((_, r) => { - if (r === 0) { - return 0 - } - let rtn = 0 - for (let i = 0; i < r; ++i) { - rtn += rowGap[i] + gapBetweenRow - } - return rtn - }) - layer2x = layer2x.map((_, l) => { - if (l === 0) { - return 0 - } - let rtn = 0 - for (let i = 0; i < l; ++i) { - rtn += layerGap[i] + gapBetweenLayer - } - return rtn - }) - - // Draw fragment (represent by one actor) - const group = g.append("g") - const linkLayerBackground = group.append("g") - const linkLayer = group.append("g") - const fragmentLayer = group.append("g") - linkLayerBackground.attr("class", "linkLayerBackground") - linkLayer.attr("class", "linkLayer") - fragmentLayer.attr("class", "fragmentLayer") - - let actorBoxList = [] - for (let actor of actors) { - let actorBox = this.drawActorBox({ - actor: actor, - g: fragmentLayer, - rootNode: actor.rootNode, - x: baseX + layer2x[actor.layer], - y: baseY + row2y[actor.row], - strokeColor: "white", - linkColor: "white", - }) - actorBoxList.push(actorBox) - } - - // Draw link between (represent by one actor) - const getLinkBetweenPathStr = (start, end, compensation) => { - const lineGen = d3.line().curve(d3.curveBasis) - let pathStr = lineGen([ - end, - [ - start[0] + - compensation + - actorBoxPadding + - connectionGap + - bendGap * 2, - end[1], - ], - [ - start[0] + compensation + actorBoxPadding + connectionGap + bendGap, - end[1], - ], - [ - start[0] + compensation + actorBoxPadding + connectionGap + bendGap, - start[1], - ], - [start[0] + compensation + actorBoxPadding + connectionGap, start[1]], - start, - ]) - return pathStr - } - - let linkData = [] - for (let actor of actors) { - for (let outputNode of actor.output) { - linkData.push({ - actor: actor, - d: getLinkBetweenPathStr( - [actor.rootNode.x, actor.rootNode.y], - [outputNode.x, outputNode.y], - layerGap[actor.layer] - actor.boxWidth - ), - }) - } - } - - for (let s of linkData) { - linkLayer - .append("path")(s.d) - .attr( - "stroke-dasharray", - `${outgoingLinkStrokeWidth},${outgoingLinkStrokeWidth}` - ) - .attr("fill", "none") - .attr("class", "actor-" + s.actor.actorId) - .classed("outgoing-link", true) - .style("stroke-width", outgoingLinkStrokeWidth) - .attr("stroke", this._actorOutgoinglinkColor(s.actor)) - .attr("layer", "back") - } - - for (let s of linkData) { - linkLayerBackground - .append("path")(s.d) - .attr("fill", "none") - .style("stroke-width", outgoingLinkBgStrokeWidth) - .attr("class", "actor-" + s.actor.actorId) - .classed("outgoing-link-bg", true) - .attr("stroke", outGoingLinkBgColor) - .attr("layer", "back") - } - - // calculate box size - let width = 0 - let height = 0 - for (let actorBox of actorBoxList) { - let biggestX = actorBox.x - baseX + actorBox.width - let biggestY = actorBox.y - baseY + actorBox.height - width = max([biggestX, width]) - height = max([biggestY, height]) - } - - group.attr("class", "flowchart") - return { - g: group, - width: width, - height: height, - } - } - - /** - * A flow is an extracted connected component of actors of - * the raw response from the meta node. This method will first - * merge actors in the same fragment using some identifier - * (currently it is the id of the operator before the dispatcher). - * And then use `drawFlow()` to draw each connected component. - */ - drawManyFlow() { - const g = this.topGroup - const baseX = 0 - const baseY = 0 - - g.attr("id", "") - - let fragmentRepresentedActors = this.streamPlan.fragmentRepresentedActors - // get dag layout of these actors - let dagNodeMap = new Map() - for (let actor of fragmentRepresentedActors) { - actor.rootNode.actorId = actor.actorId - treeBfs(actor.rootNode, (node) => { - node.actorId = actor.actorId - }) - dagNodeMap.set(actor.actorId, { - id: actor.actorId, - nextNodes: [], - actor: actor, - }) - } - for (let actor of fragmentRepresentedActors) { - for (let outputActorNode of actor.output) { - let outputDagNode = dagNodeMap.get(outputActorNode.actorId) - if (outputDagNode) { - // the output actor node is in a represented actor - dagNodeMap.get(actor.actorId).nextNodes.push(outputDagNode) - } - } - } - let actorDagNodes = [] - for (let id of dagNodeMap.keys()) { - actorDagNodes.push(dagNodeMap.get(id)) - } - - let actorsList = getConnectedComponent(actorDagNodes) - - let y = baseY - for (let actorDagList of actorsList) { - let flowChart = this.drawFlow({ - g: g, - baseX: baseX, - baseY: y, - actorDagList: actorDagList, - }) - y += flowChart.height + gapBetweenFlowChart - } - } -} - -/** - * create a graph view based on raw input from the meta node, - * and append the svg component to the giving svg group. - * @param {Group} g The parent group contain the graph. - * @param {any} data Raw response from the meta node. e.g. [{node: {...}, actors: {...}}, ...] - * @param {(clickEvent, node, actor) => void} onNodeClick callback when a node (operator) is clicked. - * @param {{type: string, node: {host: {host: string, port: number}}, id?: number}} selectedWokerNode - * @returns {StreamChartHelper} - */ -export default function createView( - engine, - data, - onNodeClick, - onActorClick, - selectedWokerNode, - shownActorIdList -) { - console.log(shownActorIdList, "shownActorList") - let streamChartHelper = new StreamChartHelper( - engine.topGroup, - data, - onNodeClick, - onActorClick, - selectedWokerNode, - shownActorIdList - ) - streamChartHelper.drawManyFlow() - return streamChartHelper -} diff --git a/dashboard/test/algo.test.js b/dashboard/test/algo.test.js deleted file mode 100644 index 9fe47669fed6e..0000000000000 --- a/dashboard/test/algo.test.js +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright 2024 RisingWave Labs - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ -import { Node } from "../lib/algo" -import { StreamChartHelper } from "../lib/streamPlan/streamChartHelper" - -describe("Algo", () => { - it("should generate right dag layout", () => { - // fake data - let nodes = [] - for (let i = 0; i < 10; ++i) { - nodes.push(new Node([], i + 1)) - } - const n = (i) => nodes[i - 1] - n(1).nextNodes = [n(2), n(3)] - n(2).nextNodes = [n(9)] - n(3).nextNodes = [n(5), n(10)] - n(4).nextNodes = [n(5)] - n(5).nextNodes = [n(6), n(7)] - n(6).nextNodes = [n(9), n(10)] - n(7).nextNodes = [n(8)] - - let dagPositionMapper = new StreamChartHelper().dagLayout(nodes) - - // construct map - let maxLayer = 0 - let maxRow = 0 - for (let node of dagPositionMapper.keys()) { - let pos = dagPositionMapper.get(node) - maxLayer = pos[0] > maxLayer ? pos[0] : maxLayer - maxRow = pos[1] > maxRow ? pos[1] : maxRow - } - let m = [] - for (let i = 0; i < maxLayer + 1; ++i) { - m.push([]) - for (let r = 0; r < maxRow + 1; ++r) { - m[i].push([]) - } - } - for (let node of dagPositionMapper.keys()) { - let pos = dagPositionMapper.get(node) - m[pos[0]][pos[1]] = node - } - - // search - const _search = (l, r, d) => { - // Layer, Row - if (l > maxLayer || r > maxRow || r < 0) { - return false - } - if (m[l][r].id !== undefined) { - return m[l][r].id === d - } - return _search(l + 1, r, d) - } - - const canReach = (node, nextNode) => { - let pos = dagPositionMapper.get(node) - for (let r = 0; r <= maxRow; ++r) { - if (_search(pos[0] + 1, r, nextNode.id)) { - return true - } - } - return false - } - - //check all links - let ok = true - for (let node of nodes) { - for (let nextNode of node.nextNodes) { - if (!canReach(node, nextNode)) { - console.error( - `Failed to connect node ${node.id} to node ${nextNode.id}` - ) - ok = false - break - } - } - if (!ok) { - break - } - } - - // visualization - // let s = ""; - // for(let r = maxRow; r >= 0; --r){ - // for(let l = 0; l <= maxLayer; ++l){ - // s += `\t${m[l][r].id ? m[l][r].id : " "}` - // } - // s += "\n" - // } - // console.log(s); - - expect(ok).toEqual(true) - }) -}) From ed4101f86b33b18ad6d40cdbf4a3f8367c0a273c Mon Sep 17 00:00:00 2001 From: Wallace Date: Fri, 5 Jan 2024 18:08:23 +0800 Subject: [PATCH 03/20] feat(storage): reduce size of compaction task (#14160) Signed-off-by: Little-Wallace --- .../compaction/selector/level_selector.rs | 2 +- .../compactor/fast_compactor_runner.rs | 1 - src/storage/src/hummock/compactor/mod.rs | 1 - src/storage/src/hummock/sstable/builder.rs | 4 - .../src/hummock/sstable/multi_builder.rs | 75 +++++-------------- 5 files changed, 21 insertions(+), 62 deletions(-) diff --git a/src/meta/src/hummock/compaction/selector/level_selector.rs b/src/meta/src/hummock/compaction/selector/level_selector.rs index 736c10fac69c6..e2b1cc65f3c2a 100644 --- a/src/meta/src/hummock/compaction/selector/level_selector.rs +++ b/src/meta/src/hummock/compaction/selector/level_selector.rs @@ -120,7 +120,7 @@ impl DynamicLevelSelectorCore { Box::new(MinOverlappingPicker::new( picker_info.select_level, picker_info.target_level, - self.config.max_bytes_for_level_base, + self.config.max_bytes_for_level_base / 2, overlap_strategy, )) } diff --git a/src/storage/src/hummock/compactor/fast_compactor_runner.rs b/src/storage/src/hummock/compactor/fast_compactor_runner.rs index 1391414b91bdd..653a00f21b7c2 100644 --- a/src/storage/src/hummock/compactor/fast_compactor_runner.rs +++ b/src/storage/src/hummock/compactor/fast_compactor_runner.rs @@ -317,7 +317,6 @@ impl CompactorRunner { builder_factory, context.compactor_metrics.clone(), Some(task_progress.clone()), - task_config.is_target_l0_or_lbase, task_config.table_vnode_partition.clone(), ); assert_eq!( diff --git a/src/storage/src/hummock/compactor/mod.rs b/src/storage/src/hummock/compactor/mod.rs index a663c7c3819a6..d82172650e19c 100644 --- a/src/storage/src/hummock/compactor/mod.rs +++ b/src/storage/src/hummock/compactor/mod.rs @@ -265,7 +265,6 @@ impl Compactor { builder_factory, self.context.compactor_metrics.clone(), task_progress.clone(), - self.task_config.is_target_l0_or_lbase, self.task_config.table_vnode_partition.clone(), ); let compaction_statistics = compact_and_build_sst( diff --git a/src/storage/src/hummock/sstable/builder.rs b/src/storage/src/hummock/sstable/builder.rs index b246d15b03fcf..9202b3ec28788 100644 --- a/src/storage/src/hummock/sstable/builder.rs +++ b/src/storage/src/hummock/sstable/builder.rs @@ -717,10 +717,6 @@ impl SstableBuilder { self.approximate_len() >= self.options.capacity } - pub fn reach_max_sst_size(&self) -> bool { - self.approximate_len() as u64 >= self.options.max_sst_size - } - fn finalize_last_table_stats(&mut self) { if self.table_ids.is_empty() || self.last_table_id.is_none() { return; diff --git a/src/storage/src/hummock/sstable/multi_builder.rs b/src/storage/src/hummock/sstable/multi_builder.rs index a1946f56a33de..42a19866fc467 100644 --- a/src/storage/src/hummock/sstable/multi_builder.rs +++ b/src/storage/src/hummock/sstable/multi_builder.rs @@ -72,13 +72,11 @@ where task_progress: Option>, last_table_id: u32, - is_target_level_l0_or_lbase: bool, table_partition_vnode: BTreeMap, split_weight_by_vnode: u32, /// When vnode of the coming key is greater than `largest_vnode_in_current_partition`, we will /// switch SST. largest_vnode_in_current_partition: usize, - last_vnode: usize, } impl CapacitySplitTableBuilder @@ -91,7 +89,6 @@ where builder_factory: F, compactor_metrics: Arc, task_progress: Option>, - is_target_level_l0_or_lbase: bool, table_partition_vnode: BTreeMap, ) -> Self { Self { @@ -101,11 +98,9 @@ where compactor_metrics, task_progress, last_table_id: 0, - is_target_level_l0_or_lbase, table_partition_vnode, split_weight_by_vnode: 0, largest_vnode_in_current_partition: VirtualNode::MAX.to_index(), - last_vnode: 0, } } @@ -117,11 +112,9 @@ where compactor_metrics: Arc::new(CompactorMetrics::unused()), task_progress: None, last_table_id: 0, - is_target_level_l0_or_lbase: false, table_partition_vnode: BTreeMap::default(), split_weight_by_vnode: 0, largest_vnode_in_current_partition: VirtualNode::MAX.to_index(), - last_vnode: 0, } } @@ -181,7 +174,7 @@ where value: HummockValue<&[u8]>, is_new_user_key: bool, ) -> HummockResult<()> { - let (switch_builder, vnode_changed) = self.check_table_and_vnode_change(&full_key.user_key); + let switch_builder = self.check_switch_builder(&full_key.user_key); // We use this `need_seal_current` flag to store whether we need to call `seal_current` and // then call `seal_current` later outside the `if let` instead of calling @@ -195,15 +188,7 @@ where let mut last_range_tombstone_epoch = HummockEpoch::MAX; if let Some(builder) = self.current_builder.as_mut() { if is_new_user_key { - if switch_builder { - need_seal_current = true; - } else if builder.reach_capacity() { - if !self.is_target_level_l0_or_lbase || builder.reach_max_sst_size() { - need_seal_current = true; - } else { - need_seal_current = self.is_target_level_l0_or_lbase && vnode_changed; - } - } + need_seal_current = switch_builder || builder.reach_capacity(); } if need_seal_current && let Some(event) = builder.last_range_tombstone() @@ -253,9 +238,8 @@ where builder.add(full_key, value).await } - pub fn check_table_and_vnode_change(&mut self, user_key: &UserKey<&[u8]>) -> (bool, bool) { + pub fn check_switch_builder(&mut self, user_key: &UserKey<&[u8]>) -> bool { let mut switch_builder = false; - let mut vnode_changed = false; if user_key.table_id.table_id != self.last_table_id { let new_vnode_partition_count = self.table_partition_vnode.get(&user_key.table_id.table_id); @@ -272,8 +256,6 @@ where // table_id change self.last_table_id = user_key.table_id.table_id; switch_builder = true; - self.last_vnode = 0; - vnode_changed = true; if self.split_weight_by_vnode > 1 { self.largest_vnode_in_current_partition = VirtualNode::COUNT / (self.split_weight_by_vnode as usize) - 1; @@ -285,10 +267,6 @@ where } if self.largest_vnode_in_current_partition != VirtualNode::MAX.to_index() { let key_vnode = user_key.get_vnode_id(); - if key_vnode != self.last_vnode { - self.last_vnode = key_vnode; - vnode_changed = true; - } if key_vnode > self.largest_vnode_in_current_partition { // vnode partition change switch_builder = true; @@ -303,11 +281,10 @@ where ((key_vnode - small_segments_area) / (basic + 1) + 1) * (basic + 1) + small_segments_area }) - 1; - self.last_vnode = key_vnode; debug_assert!(key_vnode <= self.largest_vnode_in_current_partition); } } - (switch_builder, vnode_changed) + switch_builder } pub fn need_flush(&self) -> bool { @@ -616,7 +593,6 @@ mod tests { LocalTableBuilderFactory::new(1001, mock_sstable_store(), opts), Arc::new(CompactorMetrics::unused()), None, - false, BTreeMap::default(), ); let full_key = FullKey::for_test( @@ -713,7 +689,6 @@ mod tests { LocalTableBuilderFactory::new(1001, mock_sstable_store(), opts), Arc::new(CompactorMetrics::unused()), None, - false, BTreeMap::default(), ); del_iter.rewind().await.unwrap(); @@ -750,7 +725,6 @@ mod tests { LocalTableBuilderFactory::new(1001, mock_sstable_store(), opts), Arc::new(CompactorMetrics::unused()), None, - false, BTreeMap::default(), ); builder @@ -870,56 +844,47 @@ mod tests { LocalTableBuilderFactory::new(1001, mock_sstable_store(), opts), Arc::new(CompactorMetrics::unused()), None, - false, table_partition_vnode, ); let mut table_key = VirtualNode::from_index(0).to_be_bytes().to_vec(); table_key.extend_from_slice("a".as_bytes()); - let (switch_builder, vnode_changed) = - builder.check_table_and_vnode_change(&UserKey::for_test(TableId::from(1), &table_key)); + let switch_builder = + builder.check_switch_builder(&UserKey::for_test(TableId::from(1), &table_key)); assert!(switch_builder); - assert!(vnode_changed); { let mut table_key = VirtualNode::from_index(62).to_be_bytes().to_vec(); table_key.extend_from_slice("a".as_bytes()); - let (switch_builder, vnode_changed) = builder - .check_table_and_vnode_change(&UserKey::for_test(TableId::from(1), &table_key)); + let switch_builder = + builder.check_switch_builder(&UserKey::for_test(TableId::from(1), &table_key)); assert!(!switch_builder); - assert!(vnode_changed); let mut table_key = VirtualNode::from_index(63).to_be_bytes().to_vec(); table_key.extend_from_slice("a".as_bytes()); - let (switch_builder, vnode_changed) = builder - .check_table_and_vnode_change(&UserKey::for_test(TableId::from(1), &table_key)); + let switch_builder = + builder.check_switch_builder(&UserKey::for_test(TableId::from(1), &table_key)); assert!(!switch_builder); - assert!(vnode_changed); let mut table_key = VirtualNode::from_index(64).to_be_bytes().to_vec(); table_key.extend_from_slice("a".as_bytes()); - let (switch_builder, vnode_changed) = builder - .check_table_and_vnode_change(&UserKey::for_test(TableId::from(1), &table_key)); + let switch_builder = + builder.check_switch_builder(&UserKey::for_test(TableId::from(1), &table_key)); assert!(switch_builder); - assert!(vnode_changed); } - let (switch_builder, vnode_changed) = - builder.check_table_and_vnode_change(&UserKey::for_test(TableId::from(2), &table_key)); + let switch_builder = + builder.check_switch_builder(&UserKey::for_test(TableId::from(2), &table_key)); assert!(switch_builder); - assert!(vnode_changed); - let (switch_builder, vnode_changed) = - builder.check_table_and_vnode_change(&UserKey::for_test(TableId::from(3), &table_key)); + let switch_builder = + builder.check_switch_builder(&UserKey::for_test(TableId::from(3), &table_key)); assert!(switch_builder); - assert!(vnode_changed); - let (switch_builder, vnode_changed) = - builder.check_table_and_vnode_change(&UserKey::for_test(TableId::from(4), &table_key)); + let switch_builder = + builder.check_switch_builder(&UserKey::for_test(TableId::from(4), &table_key)); assert!(switch_builder); - assert!(vnode_changed); - let (switch_builder, vnode_changed) = - builder.check_table_and_vnode_change(&UserKey::for_test(TableId::from(5), &table_key)); + let switch_builder = + builder.check_switch_builder(&UserKey::for_test(TableId::from(5), &table_key)); assert!(!switch_builder); - assert!(!vnode_changed); } } From 7954da3e7ea628d4d27d840287c31bd5148d2c9a Mon Sep 17 00:00:00 2001 From: stonepage <40830455+st1page@users.noreply.github.com> Date: Fri, 5 Jan 2024 18:14:55 +0800 Subject: [PATCH 04/20] feat: support multiple temporal filter with or (#14382) --- .../tests/testdata/input/temporal_filter.yaml | 6 + .../testdata/output/temporal_filter.yaml | 123 +++++++++++------- .../src/optimizer/logical_optimization.rs | 6 +- src/frontend/src/optimizer/rule/mod.rs | 2 + .../stream/filter_with_now_to_join_rule.rs | 14 -- src/frontend/src/optimizer/rule/stream/mod.rs | 1 + .../rule/stream/split_now_and_rule.rs | 84 ++++++++++++ 7 files changed, 177 insertions(+), 59 deletions(-) create mode 100644 src/frontend/src/optimizer/rule/stream/split_now_and_rule.rs diff --git a/src/frontend/planner_test/tests/testdata/input/temporal_filter.yaml b/src/frontend/planner_test/tests/testdata/input/temporal_filter.yaml index 8df9d78869f04..6bd62c1ce4d61 100644 --- a/src/frontend/planner_test/tests/testdata/input/temporal_filter.yaml +++ b/src/frontend/planner_test/tests/testdata/input/temporal_filter.yaml @@ -109,5 +109,11 @@ sql: | create table t1 (ts timestamp with time zone); select * from t1 where ts + interval '1 hour' > now() or ts > ' 2023-12-18 00:00:00+00:00' or ts is null; + expected_outputs: + - stream_plan +- name: Many Temporal filter with or predicate + sql: | + create table t (t timestamp with time zone, a int); + select * from t where (t > NOW() - INTERVAL '1 hour' OR t is NULL OR a < 1) AND (t < NOW() - INTERVAL '1 hour' OR a > 1); expected_outputs: - stream_plan \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml b/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml index a8799aba8a4f4..1f5934ce76378 100644 --- a/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml +++ b/src/frontend/planner_test/tests/testdata/output/temporal_filter.yaml @@ -218,56 +218,43 @@ create table t1 (ts timestamp with time zone); select * from t1 where ts < now() - interval '1 hour' and ts >= now() - interval '2 hour'; stream_plan: |- - StreamMaterialize { columns: [ts, t1._row_id(hidden)], stream_key: [t1._row_id], pk_columns: [t1._row_id], pk_conflict: NoCheck } - └─StreamDynamicFilter { predicate: (t1.ts < $expr2), output: [t1.ts, t1._row_id], condition_always_relax: true } - ├─StreamDynamicFilter { predicate: (t1.ts >= $expr1), output_watermarks: [t1.ts], output: [t1.ts, t1._row_id], cleaned_by_watermark: true } + StreamMaterialize { columns: [ts, t1._row_id(hidden)], stream_key: [t1._row_id], pk_columns: [t1._row_id], pk_conflict: NoCheck, watermark_columns: [ts] } + └─StreamDynamicFilter { predicate: (t1.ts >= $expr2), output_watermarks: [t1.ts], output: [t1.ts, t1._row_id], cleaned_by_watermark: true } + ├─StreamDynamicFilter { predicate: (t1.ts < $expr1), output: [t1.ts, t1._row_id], condition_always_relax: true } │ ├─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } │ └─StreamExchange { dist: Broadcast } - │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '02:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } + │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } │ └─StreamNow { output: [now] } └─StreamExchange { dist: Broadcast } - └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } + └─StreamProject { exprs: [SubtractWithTimeZone(now, '02:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } └─StreamNow { output: [now] } stream_dist_plan: |+ Fragment 0 - StreamMaterialize { columns: [ts, t1._row_id(hidden)], stream_key: [t1._row_id], pk_columns: [t1._row_id], pk_conflict: NoCheck } + StreamMaterialize { columns: [ts, t1._row_id(hidden)], stream_key: [t1._row_id], pk_columns: [t1._row_id], pk_conflict: NoCheck, watermark_columns: [ts] } ├── materialized table: 4294967294 - └── StreamDynamicFilter { predicate: (t1.ts < $expr2), output: [t1.ts, t1._row_id], condition_always_relax: true } + └── StreamDynamicFilter { predicate: (t1.ts >= $expr2), output_watermarks: [t1.ts], output: [t1.ts, t1._row_id], cleaned_by_watermark: true } ├── left table: 0 ├── right table: 1 - ├── StreamDynamicFilter { predicate: (t1.ts >= $expr1), output_watermarks: [t1.ts], output: [t1.ts, t1._row_id], cleaned_by_watermark: true } - │ ├── left table: 2 - │ ├── right table: 3 - │ ├── StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } - │ │ ├── state table: 4 + ├── StreamDynamicFilter { predicate: (t1.ts < $expr1), output: [t1.ts, t1._row_id], condition_always_relax: true } { left table: 2, right table: 3 } + │ ├── StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } { state table: 4 } │ │ ├── Upstream │ │ └── BatchPlanNode │ └── StreamExchange Broadcast from 1 └── StreamExchange Broadcast from 2 Fragment 1 - StreamProject { exprs: [SubtractWithTimeZone(now, '02:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } + StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } └── StreamNow { output: [now] } { state table: 5 } Fragment 2 - StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } + StreamProject { exprs: [SubtractWithTimeZone(now, '02:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } └── StreamNow { output: [now] } { state table: 6 } - Table 0 - ├── columns: [ t1_ts, t1__row_id ] - ├── primary key: [ $0 ASC, $1 ASC ] - ├── value indices: [ 0, 1 ] - ├── distribution key: [ 1 ] - └── read pk prefix len hint: 1 + Table 0 { columns: [ t1_ts, t1__row_id ], primary key: [ $0 ASC, $1 ASC ], value indices: [ 0, 1 ], distribution key: [ 1 ], read pk prefix len hint: 1 } Table 1 { columns: [ $expr2 ], primary key: [], value indices: [ 0 ], distribution key: [], read pk prefix len hint: 0 } - Table 2 - ├── columns: [ t1_ts, t1__row_id ] - ├── primary key: [ $0 ASC, $1 ASC ] - ├── value indices: [ 0, 1 ] - ├── distribution key: [ 1 ] - └── read pk prefix len hint: 1 + Table 2 { columns: [ t1_ts, t1__row_id ], primary key: [ $0 ASC, $1 ASC ], value indices: [ 0, 1 ], distribution key: [ 1 ], read pk prefix len hint: 1 } Table 3 { columns: [ $expr1 ], primary key: [], value indices: [ 0 ], distribution key: [], read pk prefix len hint: 0 } @@ -283,12 +270,7 @@ Table 6 { columns: [ now ], primary key: [], value indices: [ 0 ], distribution key: [], read pk prefix len hint: 0 } - Table 4294967294 - ├── columns: [ ts, t1._row_id ] - ├── primary key: [ $1 ASC ] - ├── value indices: [ 0, 1 ] - ├── distribution key: [ 1 ] - └── read pk prefix len hint: 1 + Table 4294967294 { columns: [ ts, t1._row_id ], primary key: [ $1 ASC ], value indices: [ 0, 1 ], distribution key: [ 1 ], read pk prefix len hint: 1 } - name: Temporal filter in on clause for inner join's left side sql: | @@ -300,14 +282,14 @@ └─StreamExchange { dist: HashShard(t1.a, t1._row_id, t2._row_id) } └─StreamHashJoin { type: Inner, predicate: t1.a = t2.b, output: [t1.a, t1.ta, t2.b, t2.tb, t1._row_id, t2._row_id] } ├─StreamExchange { dist: HashShard(t1.a) } - │ └─StreamDynamicFilter { predicate: (t1.ta < $expr2), output: [t1.a, t1.ta, t1._row_id], condition_always_relax: true } - │ ├─StreamDynamicFilter { predicate: (t1.ta >= $expr1), output_watermarks: [t1.ta], output: [t1.a, t1.ta, t1._row_id], cleaned_by_watermark: true } + │ └─StreamDynamicFilter { predicate: (t1.ta >= $expr2), output_watermarks: [t1.ta], output: [t1.a, t1.ta, t1._row_id], cleaned_by_watermark: true } + │ ├─StreamDynamicFilter { predicate: (t1.ta < $expr1), output: [t1.a, t1.ta, t1._row_id], condition_always_relax: true } │ │ ├─StreamTableScan { table: t1, columns: [t1.a, t1.ta, t1._row_id], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } │ │ └─StreamExchange { dist: Broadcast } - │ │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '02:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } + │ │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } │ │ └─StreamNow { output: [now] } │ └─StreamExchange { dist: Broadcast } - │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } + │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '02:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } │ └─StreamNow { output: [now] } └─StreamExchange { dist: HashShard(t2.b) } └─StreamTableScan { table: t2, columns: [t2.b, t2.tb, t2._row_id], pk: [t2._row_id], dist: UpstreamHashShard(t2._row_id) } @@ -331,14 +313,14 @@ ├─StreamExchange { dist: HashShard(t2.b) } │ └─StreamTableScan { table: t2, columns: [t2.b, t2.tb, t2._row_id], pk: [t2._row_id], dist: UpstreamHashShard(t2._row_id) } └─StreamExchange { dist: HashShard(t1.a) } - └─StreamDynamicFilter { predicate: (t1.ta < $expr2), output: [t1.a, t1.ta, t1._row_id], condition_always_relax: true } - ├─StreamDynamicFilter { predicate: (t1.ta >= $expr1), output_watermarks: [t1.ta], output: [t1.a, t1.ta, t1._row_id], cleaned_by_watermark: true } + └─StreamDynamicFilter { predicate: (t1.ta >= $expr2), output_watermarks: [t1.ta], output: [t1.a, t1.ta, t1._row_id], cleaned_by_watermark: true } + ├─StreamDynamicFilter { predicate: (t1.ta < $expr1), output: [t1.a, t1.ta, t1._row_id], condition_always_relax: true } │ ├─StreamTableScan { table: t1, columns: [t1.a, t1.ta, t1._row_id], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } │ └─StreamExchange { dist: Broadcast } - │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '02:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } + │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } │ └─StreamNow { output: [now] } └─StreamExchange { dist: Broadcast } - └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } + └─StreamProject { exprs: [SubtractWithTimeZone(now, '02:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } └─StreamNow { output: [now] } - name: Temporal filter in on clause for full join's left side sql: | @@ -360,14 +342,14 @@ ├─StreamExchange { dist: HashShard(t1.a) } │ └─StreamTableScan { table: t1, columns: [t1.a, t1.ta, t1._row_id], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } └─StreamExchange { dist: HashShard(t2.b) } - └─StreamDynamicFilter { predicate: (t2.tb < $expr2), output: [t2.b, t2.tb, t2._row_id], condition_always_relax: true } - ├─StreamDynamicFilter { predicate: (t2.tb >= $expr1), output_watermarks: [t2.tb], output: [t2.b, t2.tb, t2._row_id], cleaned_by_watermark: true } + └─StreamDynamicFilter { predicate: (t2.tb >= $expr2), output_watermarks: [t2.tb], output: [t2.b, t2.tb, t2._row_id], cleaned_by_watermark: true } + ├─StreamDynamicFilter { predicate: (t2.tb < $expr1), output: [t2.b, t2.tb, t2._row_id], condition_always_relax: true } │ ├─StreamTableScan { table: t2, columns: [t2.b, t2.tb, t2._row_id], pk: [t2._row_id], dist: UpstreamHashShard(t2._row_id) } │ └─StreamExchange { dist: Broadcast } - │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '02:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } + │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } │ └─StreamNow { output: [now] } └─StreamExchange { dist: Broadcast } - └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } + └─StreamProject { exprs: [SubtractWithTimeZone(now, '02:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } └─StreamNow { output: [now] } - name: Temporal filter in on clause for right join's right side sql: | @@ -462,3 +444,56 @@ └─StreamShare { id: 2 } └─StreamFilter { predicate: (((Not((t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) AND Not(IsNull(t1.ts))) OR (t1.ts > '2023-12-18 00:00:00+00:00':Timestamptz)) OR IsNull(t1.ts)) } └─StreamTableScan { table: t1, columns: [t1.ts, t1._row_id], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } +- name: Many Temporal filter with or predicate + sql: | + create table t (t timestamp with time zone, a int); + select * from t where (t > NOW() - INTERVAL '1 hour' OR t is NULL OR a < 1) AND (t < NOW() - INTERVAL '1 hour' OR a > 1); + stream_plan: |- + StreamMaterialize { columns: [t, a, $src(hidden), t._row_id(hidden), $src#1(hidden)], stream_key: [t._row_id, $src, $src#1], pk_columns: [t._row_id, $src, $src#1], pk_conflict: NoCheck } + └─StreamUnion { all: true } + ├─StreamExchange { dist: HashShard(t._row_id, $src, 0:Int32) } + │ └─StreamProject { exprs: [t.t, t.a, $src, t._row_id, 0:Int32] } + │ └─StreamDynamicFilter { predicate: (t.t < $expr2), output: [t.t, t.a, t._row_id, $src], condition_always_relax: true } + │ ├─StreamFilter { predicate: Not((t.a > 1:Int32)) } + │ │ └─StreamShare { id: 13 } + │ │ └─StreamUnion { all: true } + │ │ ├─StreamExchange { dist: HashShard(t._row_id, 0:Int32) } + │ │ │ └─StreamProject { exprs: [t.t, t.a, t._row_id, 0:Int32], output_watermarks: [t.t] } + │ │ │ └─StreamDynamicFilter { predicate: (t.t > $expr1), output_watermarks: [t.t], output: [t.t, t.a, t._row_id], cleaned_by_watermark: true } + │ │ │ ├─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND Not(IsNull(t.t)) AND Not((t.a < 1:Int32)) } + │ │ │ │ └─StreamShare { id: 2 } + │ │ │ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } + │ │ │ │ └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } + │ │ │ └─StreamExchange { dist: Broadcast } + │ │ │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } + │ │ │ └─StreamNow { output: [now] } + │ │ └─StreamExchange { dist: HashShard(t._row_id, 1:Int32) } + │ │ └─StreamProject { exprs: [t.t, t.a, t._row_id, 1:Int32] } + │ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (IsNull(t.t) OR (t.a < 1:Int32)) } + │ │ └─StreamShare { id: 2 } + │ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } + │ │ └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } + │ └─StreamExchange { dist: Broadcast } + │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr2], output_watermarks: [$expr2] } + │ └─StreamNow { output: [now] } + └─StreamExchange { dist: HashShard(t._row_id, $src, 1:Int32) } + └─StreamProject { exprs: [t.t, t.a, $src, t._row_id, 1:Int32] } + └─StreamFilter { predicate: (t.a > 1:Int32) } + └─StreamShare { id: 13 } + └─StreamUnion { all: true } + ├─StreamExchange { dist: HashShard(t._row_id, 0:Int32) } + │ └─StreamProject { exprs: [t.t, t.a, t._row_id, 0:Int32], output_watermarks: [t.t] } + │ └─StreamDynamicFilter { predicate: (t.t > $expr1), output_watermarks: [t.t], output: [t.t, t.a, t._row_id], cleaned_by_watermark: true } + │ ├─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND Not(IsNull(t.t)) AND Not((t.a < 1:Int32)) } + │ │ └─StreamShare { id: 2 } + │ │ └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } + │ │ └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } + │ └─StreamExchange { dist: Broadcast } + │ └─StreamProject { exprs: [SubtractWithTimeZone(now, '01:00:00':Interval, 'UTC':Varchar) as $expr1], output_watermarks: [$expr1] } + │ └─StreamNow { output: [now] } + └─StreamExchange { dist: HashShard(t._row_id, 1:Int32) } + └─StreamProject { exprs: [t.t, t.a, t._row_id, 1:Int32] } + └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (IsNull(t.t) OR (t.a < 1:Int32)) } + └─StreamShare { id: 2 } + └─StreamFilter { predicate: (Not((t.a > 1:Int32)) OR (t.a > 1:Int32)) AND (((Not(IsNull(t.t)) AND Not((t.a < 1:Int32))) OR IsNull(t.t)) OR (t.a < 1:Int32)) } + └─StreamTableScan { table: t, columns: [t.t, t.a, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } diff --git a/src/frontend/src/optimizer/logical_optimization.rs b/src/frontend/src/optimizer/logical_optimization.rs index b513d4904669c..db5dc8ceca7d2 100644 --- a/src/frontend/src/optimizer/logical_optimization.rs +++ b/src/frontend/src/optimizer/logical_optimization.rs @@ -240,7 +240,11 @@ static BUSHY_TREE_JOIN_ORDERING: LazyLock = LazyLock::new(|| static FILTER_WITH_NOW_TO_JOIN: LazyLock = LazyLock::new(|| { OptimizationStage::new( "Push down filter with now into a left semijoin", - vec![SplitNowOrRule::create(), FilterWithNowToJoinRule::create()], + vec![ + SplitNowAndRule::create(), + SplitNowOrRule::create(), + FilterWithNowToJoinRule::create(), + ], ApplyOrder::TopDown, ) }); diff --git a/src/frontend/src/optimizer/rule/mod.rs b/src/frontend/src/optimizer/rule/mod.rs index d59bde580b1d1..acde2f7b72eb6 100644 --- a/src/frontend/src/optimizer/rule/mod.rs +++ b/src/frontend/src/optimizer/rule/mod.rs @@ -90,6 +90,7 @@ pub use top_n_on_index_rule::*; mod stream; pub use stream::bushy_tree_join_ordering_rule::*; pub use stream::filter_with_now_to_join_rule::*; +pub use stream::split_now_and_rule::*; pub use stream::split_now_or_rule::*; pub use stream::stream_project_merge_rule::*; mod trivial_project_to_values_rule; @@ -190,6 +191,7 @@ macro_rules! for_all_rules { , { AggProjectMergeRule } , { UnionMergeRule } , { DagToTreeRule } + , { SplitNowAndRule } , { SplitNowOrRule } , { FilterWithNowToJoinRule } , { TopNOnIndexRule } diff --git a/src/frontend/src/optimizer/rule/stream/filter_with_now_to_join_rule.rs b/src/frontend/src/optimizer/rule/stream/filter_with_now_to_join_rule.rs index b5f1d46d51743..498696589c81b 100644 --- a/src/frontend/src/optimizer/rule/stream/filter_with_now_to_join_rule.rs +++ b/src/frontend/src/optimizer/rule/stream/filter_with_now_to_join_rule.rs @@ -13,7 +13,6 @@ // limitations under the License. use risingwave_common::types::DataType; -use risingwave_pb::expr::expr_node::Type; use risingwave_pb::plan_common::JoinType; use crate::expr::{ @@ -55,11 +54,6 @@ impl Rule for FilterWithNowToJoinRule { } }); - // We want to put `input_expr >/>= now_expr` before `input_expr u8 { - match cmp { - Type::GreaterThan | Type::GreaterThanOrEqual => 0, - Type::LessThan | Type::LessThanOrEqual => 1, - _ => 2, - } -} - struct NowAsInputRef { index: usize, } diff --git a/src/frontend/src/optimizer/rule/stream/mod.rs b/src/frontend/src/optimizer/rule/stream/mod.rs index 3b088d93d64ec..cc86298e766e8 100644 --- a/src/frontend/src/optimizer/rule/stream/mod.rs +++ b/src/frontend/src/optimizer/rule/stream/mod.rs @@ -14,5 +14,6 @@ pub(crate) mod bushy_tree_join_ordering_rule; pub(crate) mod filter_with_now_to_join_rule; +pub(crate) mod split_now_and_rule; pub(crate) mod split_now_or_rule; pub(crate) mod stream_project_merge_rule; diff --git a/src/frontend/src/optimizer/rule/stream/split_now_and_rule.rs b/src/frontend/src/optimizer/rule/stream/split_now_and_rule.rs new file mode 100644 index 0000000000000..f82e4a8fdd304 --- /dev/null +++ b/src/frontend/src/optimizer/rule/stream/split_now_and_rule.rs @@ -0,0 +1,84 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::optimizer::plan_node::{LogicalFilter, PlanTreeNodeUnary}; +use crate::optimizer::rule::{BoxedRule, Rule}; +use crate::optimizer::PlanRef; +use crate::utils::Condition; + +/// Split `LogicalFilter` with many AND conjunctions with now into multiple `LogicalFilter`, prepared for `SplitNowOrRule` +/// +/// Before: +/// ```text +/// `LogicalFilter` +/// (now() or c11 or c12 ..) and (now() or c21 or c22 ...) and .. and other exprs +/// | +/// Input +/// ``` +/// +/// After: +/// ```text +/// `LogicalFilter`(now() or c11 or c12 ..) +/// | +/// `LogicalFilter`(now() or c21 or c22 ...) +/// | +/// ...... +/// | +/// `LogicalFilter` other exprs +/// | +/// Input +/// ``` +pub struct SplitNowAndRule {} +impl Rule for SplitNowAndRule { + fn apply(&self, plan: PlanRef) -> Option { + let filter: &LogicalFilter = plan.as_logical_filter()?; + let input = filter.input(); + if filter.predicate().conjunctions.len() == 1 { + return None; + } + + if filter + .predicate() + .conjunctions + .iter() + .all(|e| e.count_nows() == 0) + { + return None; + } + + let [with_now, others] = + filter + .predicate() + .clone() + .group_by::<_, 2>(|e| if e.count_nows() > 0 { 0 } else { 1 }); + + let mut plan = LogicalFilter::create(input, others); + for e in with_now { + plan = LogicalFilter::new( + plan, + Condition { + conjunctions: vec![e], + }, + ) + .into(); + } + Some(plan) + } +} + +impl SplitNowAndRule { + pub fn create() -> BoxedRule { + Box::new(SplitNowAndRule {}) + } +} From a01a30dece65f0391210c1a60e6ef17d4270581a Mon Sep 17 00:00:00 2001 From: Yingjun Wu Date: Sat, 6 Jan 2024 21:47:24 -0800 Subject: [PATCH 05/20] chore: Update README.md (#14371) Co-authored-by: hengm3467 <100685635+hengm3467@users.noreply.github.com> --- README.md | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 13d90a85ed7c7..f9f4265593695 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,8 @@ -RisingWave is a distributed SQL streaming database that enables simple, efficient, and reliable processing of streaming data. +RisingWave is a distributed SQL streaming database engineered to provide the simplest and most cost-efficient approach for processing and managing streaming data with utmost reliability. + ![RisingWave](https://github.com/risingwavelabs/risingwave-docs/blob/main/docs/images/new_archi_grey.png) @@ -98,7 +99,7 @@ For **Kubernetes deployments**, please refer to [Kubernetes with Helm](https://d ## Why RisingWave for stream processing? -RisingWave specializes in providing **incrementally updated, consistent materialized views** — a persistent data structure that represents the results of stream processing. RisingWave significantly reduces the complexity of building stream processing applications by allowing developers to express intricate stream processing logic through cascaded materialized views. Furthermore, it allows users to persist data directly within the system, eliminating the need to deliver results to external databases for storage and query serving. +RisingWave provides users with a comprehensive set of frequently used stream processing features, including exactly-once consistency, [time window functions](https://docs.risingwave.com/docs/current/sql-function-time-window/), [watermarks](https://docs.risingwave.com/docs/current/watermarks/), and more. It specializes in providing **incrementally updated, consistent materialized views** — a persistent data structure that represents the results of stream processing. RisingWave significantly reduces the complexity of building stream processing applications by allowing developers to express intricate stream processing logic through cascaded materialized views. Furthermore, it allows users to persist data directly within the system, eliminating the need to deliver results to external databases for storage and query serving. ![Real-time Data Pipelines without or with RisingWave](https://github.com/risingwavelabs/risingwave/assets/100685635/414afbb7-5187-410f-9ba4-9a640c8c6306) @@ -122,12 +123,25 @@ Compared to existing stream processing systems like [Apache Flink](https://flink * **Instant failure recovery** * RisingWave's state management mechanism also allows it to recover from failure in seconds, not minutes or hours. +### RisingWave as a database +RisingWave is fundamentally a database that **extends beyond basic streaming data processing capabilities**. It excels in **the effective management of streaming data**, making it a trusted choice for data persistence and powering online applications. RisingWave offers an extensive range of database capabilities, which include: + +* High availability +* Serving highly concurrent queries +* Role-based access control (RBAC) +* Integration with data modeling tools, such as [dbt](https://docs.risingwave.com/docs/current/use-dbt/) +* Integration with database management tools, such as [Dbeaver](https://docs.risingwave.com/docs/current/dbeaver-integration/) +* Integration with BI tools, such as [Grafana](https://docs.risingwave.com/docs/current/grafana-integration/) +* Schema change +* Processing of semi-structured data + + ## RisingWave's limitations RisingWave isn’t a panacea for all data engineering hurdles. It has its own set of limitations: * **No programmable interfaces** - * RisingWave does not provide low-level APIs in languages like Java and Scala, and does not allow users to manage internal states manually (unless you want to hack!). For coding in Java, Scala, and other languages, please consider using RisingWave's User-Defined Functions (UDF). + * RisingWave does not provide low-level APIs in languages like Java and Scala, and does not allow users to manage internal states manually (unless you want to hack!). _For coding in Java, Python, and other languages, please consider using RisingWave's [User-Defined Functions (UDF)](https://docs.risingwave.com/docs/current/user-defined-functions/)_. * **No support for transaction processing** - * RisingWave isn’t cut out for transactional workloads, thus it’s not a viable substitute for operational databases dedicated to transaction processing. However, it supports read-only transactions, ensuring data freshness and consistency. It also comprehends the transactional semantics of upstream database Change Data Capture (CDC). + * RisingWave isn’t cut out for transactional workloads, thus it’s not a viable substitute for operational databases dedicated to transaction processing. _However, it supports [read-only transactions](https://docs.risingwave.com/docs/current/transactions/#read-only-transactions), ensuring data freshness and consistency. It also comprehends the transactional semantics of upstream database [Change Data Capture (CDC)](https://docs.risingwave.com/docs/current/transactions/#transactions-within-a-cdc-table)_. * **Not tailored for ad-hoc analytical queries** * RisingWave's row store design is tailored for optimal stream processing performance rather than interactive analytical workloads. Hence, it's not a suitable replacement for OLAP databases. Yet, a reliable integration with many OLAP databases exists, and a collaborative use of RisingWave and OLAP databases is a common practice among many users. From 33b81af40e21d833966125a11bf804e9f477995c Mon Sep 17 00:00:00 2001 From: Yingjun Wu Date: Sun, 7 Jan 2024 19:07:37 -0800 Subject: [PATCH 06/20] chore: Update README.md (#14399) --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index f9f4265593695..11793ce064647 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ -RisingWave is a distributed SQL streaming database engineered to provide the simplest and most cost-efficient approach for processing and managing streaming data with utmost reliability. +RisingWave is a distributed SQL streaming database engineered to provide the simplest and most cost-efficient approach for processing and managing streaming data with utmost reliability. ![RisingWave](https://github.com/risingwavelabs/risingwave-docs/blob/main/docs/images/new_archi_grey.png) @@ -142,8 +142,6 @@ RisingWave isn’t a panacea for all data engineering hurdles. It has its own se * RisingWave does not provide low-level APIs in languages like Java and Scala, and does not allow users to manage internal states manually (unless you want to hack!). _For coding in Java, Python, and other languages, please consider using RisingWave's [User-Defined Functions (UDF)](https://docs.risingwave.com/docs/current/user-defined-functions/)_. * **No support for transaction processing** * RisingWave isn’t cut out for transactional workloads, thus it’s not a viable substitute for operational databases dedicated to transaction processing. _However, it supports [read-only transactions](https://docs.risingwave.com/docs/current/transactions/#read-only-transactions), ensuring data freshness and consistency. It also comprehends the transactional semantics of upstream database [Change Data Capture (CDC)](https://docs.risingwave.com/docs/current/transactions/#transactions-within-a-cdc-table)_. -* **Not tailored for ad-hoc analytical queries** - * RisingWave's row store design is tailored for optimal stream processing performance rather than interactive analytical workloads. Hence, it's not a suitable replacement for OLAP databases. Yet, a reliable integration with many OLAP databases exists, and a collaborative use of RisingWave and OLAP databases is a common practice among many users. ## In-production use cases From dc32ac07e3928bb3cf9815d23cce2f786abf76b5 Mon Sep 17 00:00:00 2001 From: Kexiang Wang Date: Sun, 7 Jan 2024 22:21:24 -0500 Subject: [PATCH 07/20] feat(catalog): add relpersistence in pg_class (#14400) --- .../planner_test/tests/testdata/output/subquery.yaml | 2 +- .../src/catalog/system_catalog/pg_catalog/pg_class.rs | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/frontend/planner_test/tests/testdata/output/subquery.yaml b/src/frontend/planner_test/tests/testdata/output/subquery.yaml index 076a3ceb8fac1..554525b391efd 100644 --- a/src/frontend/planner_test/tests/testdata/output/subquery.yaml +++ b/src/frontend/planner_test/tests/testdata/output/subquery.yaml @@ -238,7 +238,7 @@ ├─LogicalFilter { predicate: In($expr1, 'r':Varchar, 'p':Varchar, 'v':Varchar, 'm':Varchar, 'S':Varchar, 'f':Varchar, '':Varchar) AND (rw_schemas.name <> 'pg_catalog':Varchar) AND Not(RegexpEq(rw_schemas.name, '^pg_toast':Varchar)) AND (rw_schemas.name <> 'information_schema':Varchar) } │ └─LogicalJoin { type: LeftOuter, on: (rw_schemas.id = rw_tables.schema_id), output: all } │ ├─LogicalShare { id: 16 } - │ │ └─LogicalProject { exprs: [rw_tables.id, rw_tables.name, rw_tables.schema_id, rw_tables.owner, Case(('table':Varchar = 'table':Varchar), 'r':Varchar, ('table':Varchar = 'system table':Varchar), 'r':Varchar, ('table':Varchar = 'index':Varchar), 'i':Varchar, ('table':Varchar = 'view':Varchar), 'v':Varchar, ('table':Varchar = 'materialized view':Varchar), 'm':Varchar) as $expr1, 0:Int32, 0:Int32, Array as $expr2] } + │ │ └─LogicalProject { exprs: [rw_tables.id, rw_tables.name, rw_tables.schema_id, rw_tables.owner, 'p':Varchar, Case(('table':Varchar = 'table':Varchar), 'r':Varchar, ('table':Varchar = 'system table':Varchar), 'r':Varchar, ('table':Varchar = 'index':Varchar), 'i':Varchar, ('table':Varchar = 'view':Varchar), 'v':Varchar, ('table':Varchar = 'materialized view':Varchar), 'm':Varchar) as $expr1, 0:Int32, 0:Int32, Array as $expr2] } │ │ └─LogicalShare { id: 14 } │ │ └─LogicalUnion { all: true } │ │ ├─LogicalUnion { all: true } diff --git a/src/frontend/src/catalog/system_catalog/pg_catalog/pg_class.rs b/src/frontend/src/catalog/system_catalog/pg_catalog/pg_class.rs index e09543dd81c6d..6ff7abac7187a 100644 --- a/src/frontend/src/catalog/system_catalog/pg_catalog/pg_class.rs +++ b/src/frontend/src/catalog/system_catalog/pg_catalog/pg_class.rs @@ -25,6 +25,8 @@ pub static PG_CLASS_COLUMNS: LazyLock>> = LazyLo (DataType::Varchar, "relname"), (DataType::Int32, "relnamespace"), (DataType::Int32, "relowner"), + (DataType::Varchar, "relpersistence"), /* p = permanent table, u = unlogged table, t = + * temporary table */ (DataType::Varchar, "relkind"), /* r = ordinary table, i = index, S = sequence, t = * TOAST table, v = view, m = materialized view, c = * composite type, f = foreign table, p = partitioned @@ -38,11 +40,12 @@ pub static PG_CLASS_COLUMNS: LazyLock>> = LazyLo /// The catalog `pg_class` catalogs tables and most everything else that has columns or is otherwise /// similar to a table. Ref: [`https://www.postgresql.org/docs/current/catalog-pg-class.html`] /// todo: should we add internal tables as well? -pub static PG_CLASS: LazyLock = LazyLock::new(|| BuiltinView { +pub static PG_CLASS: LazyLock = LazyLock::new(|| { + BuiltinView { name: "pg_class", schema: PG_CATALOG_SCHEMA_NAME, columns: &PG_CLASS_COLUMNS, - sql: "SELECT id AS oid, name AS relname, schema_id AS relnamespace, owner AS relowner, \ + sql: "SELECT id AS oid, name AS relname, schema_id AS relnamespace, owner AS relowner, 'p' as relpersistence, \ CASE \ WHEN relation_type = 'table' THEN 'r' \ WHEN relation_type = 'system table' THEN 'r' \ @@ -56,4 +59,5 @@ pub static PG_CLASS: LazyLock = LazyLock::new(|| BuiltinView { FROM rw_catalog.rw_relations\ " .to_string(), +} }); From 10f7b7736c1c6f783730e66f6ab1fc0ccd6ea39d Mon Sep 17 00:00:00 2001 From: StrikeW Date: Mon, 8 Jan 2024 11:22:41 +0800 Subject: [PATCH 08/20] feat(cdc): support transaction for shared cdc source (#14375) --- .../connector/source/core/DbzCdcEngine.java | 10 +- .../source/core/DbzCdcEventConsumer.java | 132 +++++--- proto/connector_service.proto | 1 + src/batch/src/executor/source.rs | 1 + src/connector/src/parser/avro/parser.rs | 1 + .../src/parser/debezium/debezium_parser.rs | 77 ++++- src/connector/src/parser/mod.rs | 5 +- src/connector/src/parser/plain_parser.rs | 297 +++++++++++++++++- src/connector/src/parser/unified/debezium.rs | 77 +++-- src/connector/src/source/base.rs | 11 + .../src/source/cdc/source/message.rs | 3 + src/connector/src/source/cdc/source/reader.rs | 2 +- src/connector/src/source/test_source.rs | 2 +- .../plan_node/stream_cdc_table_scan.rs | 7 + .../src/executor/source/fetch_executor.rs | 1 + .../src/executor/source/fs_source_executor.rs | 1 + .../src/executor/source/source_executor.rs | 1 + 17 files changed, 551 insertions(+), 78 deletions(-) diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngine.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngine.java index 311d329ffeb57..61d1f6284a67f 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngine.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngine.java @@ -14,7 +14,8 @@ package com.risingwave.connector.source.core; -import static io.debezium.schema.AbstractTopicNamingStrategy.TOPIC_HEARTBEAT_PREFIX; +import static io.debezium.config.CommonConnectorConfig.TOPIC_PREFIX; +import static io.debezium.schema.AbstractTopicNamingStrategy.*; import com.risingwave.connector.api.source.CdcEngine; import com.risingwave.proto.ConnectorServiceProto; @@ -36,11 +37,14 @@ public DbzCdcEngine( long sourceId, Properties config, DebeziumEngine.CompletionCallback completionCallback) { - var dbzHeartbeatPrefix = config.getProperty(TOPIC_HEARTBEAT_PREFIX.name()); + var heartbeatTopicPrefix = config.getProperty(TOPIC_HEARTBEAT_PREFIX.name()); + var topicPrefix = config.getProperty(TOPIC_PREFIX.name()); + var transactionTopic = String.format("%s.%s", topicPrefix, DEFAULT_TRANSACTION_TOPIC); var consumer = new DbzCdcEventConsumer( sourceId, - dbzHeartbeatPrefix, + heartbeatTopicPrefix, + transactionTopic, new ArrayBlockingQueue<>(DEFAULT_QUEUE_CAPACITY)); // Builds a debezium engine but not start it diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEventConsumer.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEventConsumer.java index ac46691780c39..f0880d52c8b57 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEventConsumer.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEventConsumer.java @@ -34,6 +34,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +enum EventType { + HEARTBEAT, + TRANSACTION, + DATA, +} + public class DbzCdcEventConsumer implements DebeziumEngine.ChangeConsumer> { static final Logger LOG = LoggerFactory.getLogger(DbzCdcEventConsumer.class); @@ -42,14 +48,18 @@ public class DbzCdcEventConsumer private final long sourceId; private final JsonConverter converter; private final String heartbeatTopicPrefix; + private final String transactionTopic; DbzCdcEventConsumer( long sourceId, String heartbeatTopicPrefix, - BlockingQueue store) { + String transactionTopic, + BlockingQueue queue) { this.sourceId = sourceId; - this.outputChannel = store; + this.outputChannel = queue; this.heartbeatTopicPrefix = heartbeatTopicPrefix; + this.transactionTopic = transactionTopic; + LOG.info("heartbeat topic: {}, trnx topic: {}", heartbeatTopicPrefix, transactionTopic); // The default JSON converter will output the schema field in the JSON which is unnecessary // to source parser, we use a customized JSON converter to avoid outputting the `schema` @@ -64,6 +74,16 @@ public class DbzCdcEventConsumer this.converter = jsonConverter; } + private EventType getEventType(SourceRecord record) { + if (isHeartbeatEvent(record)) { + return EventType.HEARTBEAT; + } else if (isTransactionMetaEvent(record)) { + return EventType.TRANSACTION; + } else { + return EventType.DATA; + } + } + private boolean isHeartbeatEvent(SourceRecord record) { String topic = record.topic(); return topic != null @@ -71,6 +91,11 @@ private boolean isHeartbeatEvent(SourceRecord record) { && topic.startsWith(heartbeatTopicPrefix); } + private boolean isTransactionMetaEvent(SourceRecord record) { + String topic = record.topic(); + return topic != null && topic.equals(transactionTopic); + } + @Override public void handleBatch( List> events, @@ -79,10 +104,12 @@ public void handleBatch( var respBuilder = GetEventStreamResponse.newBuilder(); for (ChangeEvent event : events) { var record = event.value(); - boolean isHeartbeat = isHeartbeatEvent(record); + EventType eventType = getEventType(record); DebeziumOffset offset = new DebeziumOffset( - record.sourcePartition(), record.sourceOffset(), isHeartbeat); + record.sourcePartition(), + record.sourceOffset(), + (eventType == EventType.HEARTBEAT)); // serialize the offset to a JSON, so that kernel doesn't need to // aware its layout String offsetStr = ""; @@ -98,43 +125,68 @@ var record = event.value(); .setOffset(offsetStr) .setPartition(String.valueOf(sourceId)); - if (isHeartbeat) { - var message = msgBuilder.build(); - LOG.debug("heartbeat => {}", message.getOffset()); - respBuilder.addEvents(message); - } else { - - // Topic naming conventions - // - PG: serverName.schemaName.tableName - // - MySQL: serverName.databaseName.tableName - // We can extract the full table name from the topic - var fullTableName = record.topic().substring(record.topic().indexOf('.') + 1); - - // ignore null record - if (record.value() == null) { - committer.markProcessed(event); - continue; - } - // get upstream event time from the "source" field - var sourceStruct = ((Struct) record.value()).getStruct("source"); - long sourceTsMs = - sourceStruct == null - ? System.currentTimeMillis() - : sourceStruct.getInt64("ts_ms"); - byte[] payload = - converter.fromConnectData( - record.topic(), record.valueSchema(), record.value()); - msgBuilder - .setFullTableName(fullTableName) - .setPayload(new String(payload, StandardCharsets.UTF_8)) - .setSourceTsMs(sourceTsMs) - .build(); - var message = msgBuilder.build(); - LOG.debug("record => {}", message.getPayload()); - - respBuilder.addEvents(message); - committer.markProcessed(event); + switch (eventType) { + case HEARTBEAT: + { + var message = msgBuilder.build(); + LOG.debug("heartbeat => {}", message.getOffset()); + respBuilder.addEvents(message); + break; + } + case TRANSACTION: + { + long trxTs = ((Struct) record.value()).getInt64("ts_ms"); + byte[] payload = + converter.fromConnectData( + record.topic(), record.valueSchema(), record.value()); + var message = + msgBuilder + .setIsTransactionMeta(true) + .setPayload(new String(payload, StandardCharsets.UTF_8)) + .setSourceTsMs(trxTs) + .build(); + LOG.debug("transaction => {}", message); + respBuilder.addEvents(message); + break; + } + case DATA: + { + // Topic naming conventions + // - PG: serverName.schemaName.tableName + // - MySQL: serverName.databaseName.tableName + // We can extract the full table name from the topic + var fullTableName = + record.topic().substring(record.topic().indexOf('.') + 1); + + // ignore null record + if (record.value() == null) { + break; + } + // get upstream event time from the "source" field + var sourceStruct = ((Struct) record.value()).getStruct("source"); + long sourceTsMs = + sourceStruct == null + ? System.currentTimeMillis() + : sourceStruct.getInt64("ts_ms"); + byte[] payload = + converter.fromConnectData( + record.topic(), record.valueSchema(), record.value()); + var message = + msgBuilder + .setFullTableName(fullTableName) + .setPayload(new String(payload, StandardCharsets.UTF_8)) + .setSourceTsMs(sourceTsMs) + .build(); + LOG.debug("record => {}", message.getPayload()); + respBuilder.addEvents(message); + break; + } + default: + break; } + + // mark the event as processed + committer.markProcessed(event); } // skip empty batch diff --git a/proto/connector_service.proto b/proto/connector_service.proto index 465af0d2a55a8..49fca31d1330d 100644 --- a/proto/connector_service.proto +++ b/proto/connector_service.proto @@ -161,6 +161,7 @@ message CdcMessage { string offset = 3; string full_table_name = 4; int64 source_ts_ms = 5; + bool is_transaction_meta = 6; } enum SourceType { diff --git a/src/batch/src/executor/source.rs b/src/batch/src/executor/source.rs index a60398ef12f7b..2714d5335b906 100644 --- a/src/batch/src/executor/source.rs +++ b/src/batch/src/executor/source.rs @@ -146,6 +146,7 @@ impl SourceExecutor { self.metrics, self.source_ctrl_opts.clone(), None, + ConnectorProperties::default(), )); let stream = self .connector_source diff --git a/src/connector/src/parser/avro/parser.rs b/src/connector/src/parser/avro/parser.rs index 28e37fd0e0935..10e000a4fdab7 100644 --- a/src/connector/src/parser/avro/parser.rs +++ b/src/connector/src/parser/avro/parser.rs @@ -288,6 +288,7 @@ mod test { )?), rw_columns: Vec::default(), source_ctx: Default::default(), + transaction_meta_builder: None, }) } diff --git a/src/connector/src/parser/debezium/debezium_parser.rs b/src/connector/src/parser/debezium/debezium_parser.rs index b0a18e8c930d1..0f79677860f8d 100644 --- a/src/connector/src/parser/debezium/debezium_parser.rs +++ b/src/connector/src/parser/debezium/debezium_parser.rs @@ -109,7 +109,9 @@ impl DebeziumParser { Err(err) => { // Only try to access transaction control message if the row operation access failed // to make it a fast path. - if let Ok(transaction_control) = row_op.transaction_control() { + if let Ok(transaction_control) = + row_op.transaction_control(&self.source_ctx.connector_props) + { Ok(ParseResult::TransactionControl(transaction_control)) } else { Err(err)? @@ -151,3 +153,76 @@ impl ByteStreamSourceParser for DebeziumParser { self.parse_inner(key, payload, writer).await } } + +#[cfg(test)] +mod tests { + use std::ops::Deref; + use std::sync::Arc; + + use risingwave_common::catalog::{ColumnCatalog, ColumnDesc, ColumnId}; + + use super::*; + use crate::parser::{SourceStreamChunkBuilder, TransactionControl}; + use crate::source::{ConnectorProperties, DataType}; + + #[tokio::test] + async fn test_parse_transaction_metadata() { + let schema = vec![ + ColumnCatalog { + column_desc: ColumnDesc::named("payload", ColumnId::placeholder(), DataType::Jsonb), + is_hidden: false, + }, + ColumnCatalog::offset_column(), + ColumnCatalog::cdc_table_name_column(), + ]; + + let columns = schema + .iter() + .map(|c| SourceColumnDesc::from(&c.column_desc)) + .collect::>(); + + let props = SpecificParserConfig { + key_encoding_config: None, + encoding_config: EncodingProperties::Json(JsonProperties { + use_schema_registry: false, + }), + protocol_config: ProtocolProperties::Debezium, + }; + let mut source_ctx = SourceContext::default(); + source_ctx.connector_props = ConnectorProperties::PostgresCdc(Box::default()); + let mut parser = DebeziumParser::new(props, columns.clone(), Arc::new(source_ctx)) + .await + .unwrap(); + let mut builder = SourceStreamChunkBuilder::with_capacity(columns, 0); + + // "id":"35352:3962948040" Postgres transaction ID itself and LSN of given operation separated by colon, i.e. the format is txID:LSN + let begin_msg = r#"{"schema":null,"payload":{"status":"BEGIN","id":"35352:3962948040","event_count":null,"data_collections":null,"ts_ms":1704269323180}}"#; + let commit_msg = r#"{"schema":null,"payload":{"status":"END","id":"35352:3962950064","event_count":11,"data_collections":[{"data_collection":"public.orders_tx","event_count":5},{"data_collection":"public.person","event_count":6}],"ts_ms":1704269323180}}"#; + let res = parser + .parse_one_with_txn( + None, + Some(begin_msg.as_bytes().to_vec()), + builder.row_writer(), + ) + .await; + match res { + Ok(ParseResult::TransactionControl(TransactionControl::Begin { id })) => { + assert_eq!(id.deref(), "35352"); + } + _ => panic!("unexpected parse result: {:?}", res), + } + let res = parser + .parse_one_with_txn( + None, + Some(commit_msg.as_bytes().to_vec()), + builder.row_writer(), + ) + .await; + match res { + Ok(ParseResult::TransactionControl(TransactionControl::Commit { id })) => { + assert_eq!(id.deref(), "35352"); + } + _ => panic!("unexpected parse result: {:?}", res), + } + } +} diff --git a/src/connector/src/parser/mod.rs b/src/connector/src/parser/mod.rs index 8e878f19ef123..1c165b45660e9 100644 --- a/src/connector/src/parser/mod.rs +++ b/src/connector/src/parser/mod.rs @@ -158,7 +158,7 @@ pub struct SourceStreamChunkRowWriter<'a> { /// The meta data of the original message for a row writer. /// /// Extracted from the `SourceMessage`. -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub struct MessageMeta<'a> { meta: &'a SourceMeta, split_id: &'a str, @@ -665,6 +665,7 @@ async fn into_chunk_stream(mut parser: P, data_stream if let Some(Transaction { id: current_id, .. }) = ¤t_transaction { tracing::warn!(current_id, id, "already in transaction"); } + tracing::debug!("begin upstream transaction: id={}", id); current_transaction = Some(Transaction { id, len: 0 }); } TransactionControl::Commit { id } => { @@ -672,6 +673,7 @@ async fn into_chunk_stream(mut parser: P, data_stream if current_id != Some(&id) { tracing::warn!(?current_id, id, "transaction id mismatch"); } + tracing::debug!("commit upstream transaction: id={}", id); current_transaction = None; } } @@ -692,6 +694,7 @@ async fn into_chunk_stream(mut parser: P, data_stream // If we are not in a transaction, we should yield the chunk now. if current_transaction.is_none() { yield_asap = false; + yield StreamChunkWithState { chunk: builder.take(0), split_offset_mapping: Some(std::mem::take(&mut split_offset_mapping)), diff --git a/src/connector/src/parser/plain_parser.rs b/src/connector/src/parser/plain_parser.rs index f113c279f2ef6..4f15234be8371 100644 --- a/src/connector/src/parser/plain_parser.rs +++ b/src/connector/src/parser/plain_parser.rs @@ -20,12 +20,14 @@ use super::{ SourceStreamChunkRowWriter, SpecificParserConfig, }; use crate::parser::bytes_parser::BytesAccessBuilder; +use crate::parser::simd_json_parser::DebeziumJsonAccessBuilder; +use crate::parser::unified::debezium::parse_transaction_meta; use crate::parser::unified::upsert::UpsertChangeEvent; use crate::parser::unified::util::apply_row_operation_on_stream_chunk_writer_with_op; use crate::parser::unified::{AccessImpl, ChangeEventOperation}; use crate::parser::upsert_parser::get_key_column_name; -use crate::parser::{BytesProperties, ParserFormat}; -use crate::source::{SourceColumnDesc, SourceContext, SourceContextRef}; +use crate::parser::{BytesProperties, ParseResult, ParserFormat}; +use crate::source::{SourceColumnDesc, SourceContext, SourceContextRef, SourceMeta}; #[derive(Debug)] pub struct PlainParser { @@ -33,6 +35,8 @@ pub struct PlainParser { pub payload_builder: AccessBuilderImpl, pub(crate) rw_columns: Vec, pub source_ctx: SourceContextRef, + // parsing transaction metadata for shared cdc source + pub transaction_meta_builder: Option, } impl PlainParser { @@ -64,11 +68,16 @@ impl PlainParser { ))); } }; + + let transaction_meta_builder = Some(AccessBuilderImpl::DebeziumJson( + DebeziumJsonAccessBuilder::new()?, + )); Ok(Self { key_builder, payload_builder, rw_columns, source_ctx, + transaction_meta_builder, }) } @@ -77,7 +86,25 @@ impl PlainParser { key: Option>, payload: Option>, mut writer: SourceStreamChunkRowWriter<'_>, - ) -> Result<()> { + ) -> Result { + // if the message is transaction metadata, parse it and return + if let Some(msg_meta) = writer.row_meta + && let SourceMeta::DebeziumCdc(cdc_meta) = msg_meta.meta + && cdc_meta.is_transaction_meta + && let Some(data) = payload + { + let accessor = self + .transaction_meta_builder + .as_mut() + .expect("expect transaction metadata access builder") + .generate_accessor(data) + .await?; + return match parse_transaction_meta(&accessor, &self.source_ctx.connector_props) { + Ok(transaction_control) => Ok(ParseResult::TransactionControl(transaction_control)), + Err(err) => Err(err)?, + }; + } + // reuse upsert component but always insert let mut row_op: UpsertChangeEvent, AccessImpl<'_, '_>> = UpsertChangeEvent::default(); @@ -94,8 +121,14 @@ impl PlainParser { row_op = row_op.with_value(self.payload_builder.generate_accessor(data).await?); } - apply_row_operation_on_stream_chunk_writer_with_op(row_op, &mut writer, change_event_op) - .map_err(Into::into) + Ok( + apply_row_operation_on_stream_chunk_writer_with_op( + row_op, + &mut writer, + change_event_op, + ) + .map(|_| ParseResult::Rows)?, + ) } } @@ -113,11 +146,263 @@ impl ByteStreamSourceParser for PlainParser { } async fn parse_one<'a>( + &'a mut self, + _key: Option>, + _payload: Option>, + _writer: SourceStreamChunkRowWriter<'a>, + ) -> Result<()> { + unreachable!("should call `parse_one_with_txn` instead") + } + + async fn parse_one_with_txn<'a>( &'a mut self, key: Option>, payload: Option>, writer: SourceStreamChunkRowWriter<'a>, - ) -> Result<()> { + ) -> Result { + tracing::info!("parse_one_with_txn"); self.parse_inner(key, payload, writer).await } } + +#[cfg(test)] +mod tests { + use std::ops::Deref; + use std::sync::Arc; + + use futures::executor::block_on; + use futures::StreamExt; + use futures_async_stream::try_stream; + use itertools::Itertools; + use risingwave_common::catalog::{ColumnCatalog, ColumnDesc, ColumnId}; + + use super::*; + use crate::parser::{MessageMeta, SourceStreamChunkBuilder, TransactionControl}; + use crate::source::cdc::DebeziumCdcMeta; + use crate::source::{ConnectorProperties, DataType, SourceMessage, SplitId}; + + #[tokio::test] + async fn test_emit_transactional_chunk() { + let schema = vec![ + ColumnCatalog { + column_desc: ColumnDesc::named("payload", ColumnId::placeholder(), DataType::Jsonb), + is_hidden: false, + }, + ColumnCatalog::offset_column(), + ColumnCatalog::cdc_table_name_column(), + ]; + + let columns = schema + .iter() + .map(|c| SourceColumnDesc::from(&c.column_desc)) + .collect::>(); + + let mut source_ctx = SourceContext::default(); + source_ctx.connector_props = ConnectorProperties::PostgresCdc(Box::default()); + let source_ctx = Arc::new(source_ctx); + // format plain encode json parser + let parser = PlainParser::new( + SpecificParserConfig::DEFAULT_PLAIN_JSON, + columns.clone(), + source_ctx.clone(), + ) + .await + .unwrap(); + + let mut transactional = false; + // for untransactional source, we expect emit a chunk for each message batch + let message_stream = source_message_stream(transactional); + let chunk_stream = crate::parser::into_chunk_stream(parser, message_stream.boxed()); + let output: std::result::Result, _> = block_on(chunk_stream.collect::>()) + .into_iter() + .collect(); + let output = output + .unwrap() + .into_iter() + .filter(|c| c.chunk.cardinality() > 0) + .enumerate() + .map(|(i, c)| { + if i == 0 { + // begin + 3 data messages + assert_eq!(4, c.chunk.cardinality()); + } + if i == 1 { + // 2 data messages + 1 end + assert_eq!(3, c.chunk.cardinality()); + } + c.chunk + }) + .collect_vec(); + + // 2 chunks for 2 message batches + assert_eq!(2, output.len()); + + // format plain encode json parser + let parser = PlainParser::new( + SpecificParserConfig::DEFAULT_PLAIN_JSON, + columns.clone(), + source_ctx, + ) + .await + .unwrap(); + + // for transactional source, we expect emit a single chunk for the transaction + transactional = true; + let message_stream = source_message_stream(transactional); + let chunk_stream = crate::parser::into_chunk_stream(parser, message_stream.boxed()); + let output: std::result::Result, _> = block_on(chunk_stream.collect::>()) + .into_iter() + .collect(); + let output = output + .unwrap() + .into_iter() + .filter(|c| c.chunk.cardinality() > 0) + .map(|c| { + // 5 data messages in a single chunk + assert_eq!(5, c.chunk.cardinality()); + c.chunk + }) + .collect_vec(); + + // a single transactional chunk + assert_eq!(1, output.len()); + } + + #[try_stream(ok = Vec, error = anyhow::Error)] + async fn source_message_stream(transactional: bool) { + let begin_msg = r#"{"schema":null,"payload":{"status":"BEGIN","id":"35352:3962948040","event_count":null,"data_collections":null,"ts_ms":1704269323180}}"#; + let commit_msg = r#"{"schema":null,"payload":{"status":"END","id":"35352:3962950064","event_count":11,"data_collections":[{"data_collection":"public.orders_tx","event_count":5},{"data_collection":"public.person","event_count":6}],"ts_ms":1704269323180}}"#; + let data_batches = vec![ + vec![ + r#"{ "schema": null, "payload": {"after": {"customer_name": "a1", "order_date": "2020-01-30", "order_id": 10021, "order_status": false, "price": "50.50", "product_id": 102}, "before": null, "op": "c", "source": {"connector": "postgresql", "db": "mydb", "lsn": 3963199336, "name": "RW_CDC_1001", "schema": "public", "sequence": "[\"3963198512\",\"3963199336\"]", "snapshot": "false", "table": "orders_tx", "ts_ms": 1704355505506, "txId": 35352, "version": "2.4.2.Final", "xmin": null}, "transaction": {"data_collection_order": 1, "id": "35392:3963199336", "total_order": 1}, "ts_ms": 1704355839905} }"#, + r#"{ "schema": null, "payload": {"after": {"customer_name": "a2", "order_date": "2020-02-30", "order_id": 10022, "order_status": false, "price": "50.50", "product_id": 102}, "before": null, "op": "c", "source": {"connector": "postgresql", "db": "mydb", "lsn": 3963199336, "name": "RW_CDC_1001", "schema": "public", "sequence": "[\"3963198512\",\"3963199336\"]", "snapshot": "false", "table": "orders_tx", "ts_ms": 1704355505506, "txId": 35352, "version": "2.4.2.Final", "xmin": null}, "transaction": {"data_collection_order": 1, "id": "35392:3963199336", "total_order": 1}, "ts_ms": 1704355839905} }"#, + r#"{ "schema": null, "payload": {"after": {"customer_name": "a3", "order_date": "2020-03-30", "order_id": 10023, "order_status": false, "price": "50.50", "product_id": 102}, "before": null, "op": "c", "source": {"connector": "postgresql", "db": "mydb", "lsn": 3963199336, "name": "RW_CDC_1001", "schema": "public", "sequence": "[\"3963198512\",\"3963199336\"]", "snapshot": "false", "table": "orders_tx", "ts_ms": 1704355505506, "txId": 35352, "version": "2.4.2.Final", "xmin": null}, "transaction": {"data_collection_order": 1, "id": "35392:3963199336", "total_order": 1}, "ts_ms": 1704355839905} }"#, + ], + vec![ + r#"{ "schema": null, "payload": {"after": {"customer_name": "a4", "order_date": "2020-04-30", "order_id": 10024, "order_status": false, "price": "50.50", "product_id": 102}, "before": null, "op": "c", "source": {"connector": "postgresql", "db": "mydb", "lsn": 3963199336, "name": "RW_CDC_1001", "schema": "public", "sequence": "[\"3963198512\",\"3963199336\"]", "snapshot": "false", "table": "orders_tx", "ts_ms": 1704355505506, "txId": 35352, "version": "2.4.2.Final", "xmin": null}, "transaction": {"data_collection_order": 1, "id": "35392:3963199336", "total_order": 1}, "ts_ms": 1704355839905} }"#, + r#"{ "schema": null, "payload": {"after": {"customer_name": "a5", "order_date": "2020-05-30", "order_id": 10025, "order_status": false, "price": "50.50", "product_id": 102}, "before": null, "op": "c", "source": {"connector": "postgresql", "db": "mydb", "lsn": 3963199336, "name": "RW_CDC_1001", "schema": "public", "sequence": "[\"3963198512\",\"3963199336\"]", "snapshot": "false", "table": "orders_tx", "ts_ms": 1704355505506, "txId": 35352, "version": "2.4.2.Final", "xmin": null}, "transaction": {"data_collection_order": 1, "id": "35392:3963199336", "total_order": 1}, "ts_ms": 1704355839905} }"#, + ], + ]; + for (i, batch) in data_batches.iter().enumerate() { + let mut source_msg_batch = vec![]; + if i == 0 { + // put begin message at first + source_msg_batch.push(SourceMessage { + meta: SourceMeta::DebeziumCdc(DebeziumCdcMeta { + full_table_name: "orders".to_string(), + source_ts_ms: 0, + is_transaction_meta: transactional, + }), + split_id: SplitId::from("1001"), + offset: "0".into(), + key: None, + payload: Some(begin_msg.as_bytes().to_vec()), + }); + } + // put data messages + for data_msg in batch { + source_msg_batch.push(SourceMessage { + meta: SourceMeta::DebeziumCdc(DebeziumCdcMeta { + full_table_name: "orders".to_string(), + source_ts_ms: 0, + is_transaction_meta: false, + }), + split_id: SplitId::from("1001"), + offset: "0".into(), + key: None, + payload: Some(data_msg.as_bytes().to_vec()), + }); + } + if i == data_batches.len() - 1 { + // put commit message at last + source_msg_batch.push(SourceMessage { + meta: SourceMeta::DebeziumCdc(DebeziumCdcMeta { + full_table_name: "orders".to_string(), + source_ts_ms: 0, + is_transaction_meta: transactional, + }), + split_id: SplitId::from("1001"), + offset: "0".into(), + key: None, + payload: Some(commit_msg.as_bytes().to_vec()), + }); + } + yield source_msg_batch; + } + } + + #[tokio::test] + async fn test_parse_transaction_metadata() { + let schema = vec![ + ColumnCatalog { + column_desc: ColumnDesc::named("payload", ColumnId::placeholder(), DataType::Jsonb), + is_hidden: false, + }, + ColumnCatalog::offset_column(), + ColumnCatalog::cdc_table_name_column(), + ]; + + let columns = schema + .iter() + .map(|c| SourceColumnDesc::from(&c.column_desc)) + .collect::>(); + + // format plain encode json parser + let mut source_ctx = SourceContext::default(); + source_ctx.connector_props = ConnectorProperties::MysqlCdc(Box::default()); + let mut parser = PlainParser::new( + SpecificParserConfig::DEFAULT_PLAIN_JSON, + columns.clone(), + Arc::new(source_ctx), + ) + .await + .unwrap(); + let mut builder = SourceStreamChunkBuilder::with_capacity(columns, 0); + + // "id":"35352:3962948040" Postgres transaction ID itself and LSN of given operation separated by colon, i.e. the format is txID:LSN + let begin_msg = r#"{"schema":null,"payload":{"status":"BEGIN","id":"3E11FA47-71CA-11E1-9E33-C80AA9429562:23","event_count":null,"data_collections":null,"ts_ms":1704269323180}}"#; + let commit_msg = r#"{"schema":null,"payload":{"status":"END","id":"3E11FA47-71CA-11E1-9E33-C80AA9429562:23","event_count":11,"data_collections":[{"data_collection":"public.orders_tx","event_count":5},{"data_collection":"public.person","event_count":6}],"ts_ms":1704269323180}}"#; + + let cdc_meta = SourceMeta::DebeziumCdc(DebeziumCdcMeta { + full_table_name: "orders".to_string(), + source_ts_ms: 0, + is_transaction_meta: true, + }); + let msg_meta = MessageMeta { + meta: &cdc_meta, + split_id: "1001", + offset: "", + }; + + let expect_tx_id = "3E11FA47-71CA-11E1-9E33-C80AA9429562:23"; + let res = parser + .parse_one_with_txn( + None, + Some(begin_msg.as_bytes().to_vec()), + builder.row_writer().with_meta(msg_meta), + ) + .await; + match res { + Ok(ParseResult::TransactionControl(TransactionControl::Begin { id })) => { + assert_eq!(id.deref(), expect_tx_id); + } + _ => panic!("unexpected parse result: {:?}", res), + } + let res = parser + .parse_one_with_txn( + None, + Some(commit_msg.as_bytes().to_vec()), + builder.row_writer().with_meta(msg_meta), + ) + .await; + match res { + Ok(ParseResult::TransactionControl(TransactionControl::Commit { id })) => { + assert_eq!(id.deref(), expect_tx_id); + } + _ => panic!("unexpected parse result: {:?}", res), + } + + let output = builder.take(10); + assert_eq!(0, output.cardinality()); + } +} diff --git a/src/connector/src/parser/unified/debezium.rs b/src/connector/src/parser/unified/debezium.rs index 6163cfe5c486c..7291b1b359735 100644 --- a/src/connector/src/parser/unified/debezium.rs +++ b/src/connector/src/parser/unified/debezium.rs @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::anyhow; use risingwave_common::types::{DataType, Datum, ScalarImpl}; use super::{Access, AccessError, ChangeEvent, ChangeEventOperation}; use crate::parser::TransactionControl; -use crate::source::SourceColumnDesc; +use crate::source::{ConnectorProperties, SourceColumnDesc}; pub struct DebeziumChangeEvent { value_accessor: Option, @@ -26,8 +27,8 @@ pub struct DebeziumChangeEvent { const BEFORE: &str = "before"; const AFTER: &str = "after"; const OP: &str = "op"; -const TRANSACTION_STATUS: &str = "status"; -const TRANSACTION_ID: &str = "id"; +pub const TRANSACTION_STATUS: &str = "status"; +pub const TRANSACTION_ID: &str = "id"; pub const DEBEZIUM_READ_OP: &str = "r"; pub const DEBEZIUM_CREATE_OP: &str = "c"; @@ -37,6 +38,44 @@ pub const DEBEZIUM_DELETE_OP: &str = "d"; pub const DEBEZIUM_TRANSACTION_STATUS_BEGIN: &str = "BEGIN"; pub const DEBEZIUM_TRANSACTION_STATUS_COMMIT: &str = "END"; +pub fn parse_transaction_meta( + accessor: &impl Access, + connector_props: &ConnectorProperties, +) -> std::result::Result { + if let (Some(ScalarImpl::Utf8(status)), Some(ScalarImpl::Utf8(id))) = ( + accessor.access(&[TRANSACTION_STATUS], Some(&DataType::Varchar))?, + accessor.access(&[TRANSACTION_ID], Some(&DataType::Varchar))?, + ) { + // The id field has different meanings for different databases: + // PG: txID:LSN + // MySQL: source_id:transaction_id (e.g. 3E11FA47-71CA-11E1-9E33-C80AA9429562:23) + match status.as_ref() { + DEBEZIUM_TRANSACTION_STATUS_BEGIN => match *connector_props { + ConnectorProperties::PostgresCdc(_) => { + let (tx_id, _) = id.split_once(':').unwrap(); + return Ok(TransactionControl::Begin { id: tx_id.into() }); + } + ConnectorProperties::MysqlCdc(_) => return Ok(TransactionControl::Begin { id }), + _ => {} + }, + DEBEZIUM_TRANSACTION_STATUS_COMMIT => match *connector_props { + ConnectorProperties::PostgresCdc(_) => { + let (tx_id, _) = id.split_once(':').unwrap(); + return Ok(TransactionControl::Commit { id: tx_id.into() }); + } + ConnectorProperties::MysqlCdc(_) => return Ok(TransactionControl::Commit { id }), + _ => {} + }, + _ => {} + } + } + + Err(AccessError::Undefined { + name: "transaction status".into(), + path: TRANSACTION_STATUS.into(), + }) +} + impl DebeziumChangeEvent where A: Access, @@ -61,28 +100,16 @@ where /// Returns the transaction metadata if exists. /// /// See the [doc](https://debezium.io/documentation/reference/2.3/connectors/postgresql.html#postgresql-transaction-metadata) of Debezium for more details. - pub(crate) fn transaction_control(&self) -> Result { - if let Some(accessor) = &self.value_accessor { - if let (Some(ScalarImpl::Utf8(status)), Some(ScalarImpl::Utf8(id))) = ( - accessor.access(&[TRANSACTION_STATUS], Some(&DataType::Varchar))?, - accessor.access(&[TRANSACTION_ID], Some(&DataType::Varchar))?, - ) { - match status.as_ref() { - DEBEZIUM_TRANSACTION_STATUS_BEGIN => { - return Ok(TransactionControl::Begin { id }) - } - DEBEZIUM_TRANSACTION_STATUS_COMMIT => { - return Ok(TransactionControl::Commit { id }) - } - _ => {} - } - } - } - - Err(AccessError::Undefined { - name: "transaction status".into(), - path: Default::default(), - }) + pub(crate) fn transaction_control( + &self, + connector_props: &ConnectorProperties, + ) -> Result { + let Some(accessor) = &self.value_accessor else { + return Err(AccessError::Other(anyhow!( + "value_accessor must be provided to parse transaction metadata" + ))); + }; + parse_transaction_meta(accessor, connector_props) } } diff --git a/src/connector/src/source/base.rs b/src/connector/src/source/base.rs index b6093a351783b..0374a9484918f 100644 --- a/src/connector/src/source/base.rs +++ b/src/connector/src/source/base.rs @@ -159,6 +159,7 @@ pub struct SourceContext { pub source_info: SourceInfo, pub metrics: Arc, pub source_ctrl_opts: SourceCtrlOpts, + pub connector_props: ConnectorProperties, error_suppressor: Option>>, } impl SourceContext { @@ -169,6 +170,7 @@ impl SourceContext { metrics: Arc, source_ctrl_opts: SourceCtrlOpts, connector_client: Option, + connector_props: ConnectorProperties, ) -> Self { Self { connector_client, @@ -180,6 +182,7 @@ impl SourceContext { metrics, source_ctrl_opts, error_suppressor: None, + connector_props, } } @@ -191,6 +194,7 @@ impl SourceContext { source_ctrl_opts: SourceCtrlOpts, connector_client: Option, error_suppressor: Arc>, + connector_props: ConnectorProperties, ) -> Self { let mut ctx = Self::new( actor_id, @@ -199,6 +203,7 @@ impl SourceContext { metrics, source_ctrl_opts, connector_client, + connector_props, ); ctx.error_suppressor = Some(error_suppressor); ctx @@ -382,6 +387,12 @@ pub trait SplitReader: Sized + Send { for_all_sources!(impl_connector_properties); +impl Default for ConnectorProperties { + fn default() -> Self { + ConnectorProperties::Test(Box::default()) + } +} + impl ConnectorProperties { pub fn is_new_fs_connector_b_tree_map(with_properties: &BTreeMap) -> bool { with_properties diff --git a/src/connector/src/source/cdc/source/message.rs b/src/connector/src/source/cdc/source/message.rs index 04518f7088b4b..28fe52c52cd1e 100644 --- a/src/connector/src/source/cdc/source/message.rs +++ b/src/connector/src/source/cdc/source/message.rs @@ -22,6 +22,8 @@ pub struct DebeziumCdcMeta { pub full_table_name: String, // extracted from `payload.source.ts_ms`, the time that the change event was made in the database pub source_ts_ms: i64, + // Whether the message is a transaction metadata + pub is_transaction_meta: bool, } impl From for SourceMessage { @@ -38,6 +40,7 @@ impl From for SourceMessage { meta: SourceMeta::DebeziumCdc(DebeziumCdcMeta { full_table_name: message.full_table_name, source_ts_ms: message.source_ts_ms, + is_transaction_meta: message.is_transaction_meta, }), } } diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 938ce36223bec..cb9c7dae3d114 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -203,7 +203,7 @@ impl CommonSplitReader for CdcSplitReader { while let Some(result) = rx.recv().await { let GetEventStreamResponse { events, .. } = result?; - tracing::trace!("receive events {:?}", events.len()); + tracing::trace!("receive {} cdc events ", events.len()); metrics .connector_source_rows_received .with_label_values(&[source_type.as_str_name(), &source_id]) diff --git a/src/connector/src/source/test_source.rs b/src/connector/src/source/test_source.rs index 26c51c63540b9..6c10ff9934eef 100644 --- a/src/connector/src/source/test_source.rs +++ b/src/connector/src/source/test_source.rs @@ -115,7 +115,7 @@ pub fn registry_test_source(box_source: BoxSource) -> TestSourceRegistryGuard { pub const TEST_CONNECTOR: &str = "test"; -#[derive(Clone, Debug, WithOptions)] +#[derive(Clone, Debug, Default, WithOptions)] pub struct TestSourceProperties { properties: HashMap, } diff --git a/src/frontend/src/optimizer/plan_node/stream_cdc_table_scan.rs b/src/frontend/src/optimizer/plan_node/stream_cdc_table_scan.rs index c26a1ed41aaf0..24bc2dd5f0b60 100644 --- a/src/frontend/src/optimizer/plan_node/stream_cdc_table_scan.rs +++ b/src/frontend/src/optimizer/plan_node/stream_cdc_table_scan.rs @@ -326,6 +326,7 @@ mod tests { async fn test_cdc_filter_expr() { let t1_json = JsonbVal::from_str(r#"{ "before": null, "after": { "v": 111, "v2": 222.2 }, "source": { "version": "2.2.0.Alpha3", "connector": "mysql", "name": "dbserver1", "ts_ms": 1678428689000, "snapshot": "false", "db": "inventory", "sequence": null, "table": "t1", "server_id": 223344, "gtid": null, "file": "mysql-bin.000003", "pos": 774, "row": 0, "thread": 8, "query": null }, "op": "c", "ts_ms": 1678428689389, "transaction": null }"#).unwrap(); let t2_json = JsonbVal::from_str(r#"{ "before": null, "after": { "v": 333, "v2": 666.6 }, "source": { "version": "2.2.0.Alpha3", "connector": "mysql", "name": "dbserver1", "ts_ms": 1678428689000, "snapshot": "false", "db": "inventory", "sequence": null, "table": "t2", "server_id": 223344, "gtid": null, "file": "mysql-bin.000003", "pos": 884, "row": 0, "thread": 8, "query": null }, "op": "c", "ts_ms": 1678428689389, "transaction": null }"#).unwrap(); + let trx_json = JsonbVal::from_str(r#"{"data_collections": null, "event_count": null, "id": "35319:3962662584", "status": "BEGIN", "ts_ms": 1704263537068}"#).unwrap(); let row1 = OwnedRow::new(vec![ Some(t1_json.into()), Some(r#"{"file": "1.binlog", "pos": 100}"#.into()), @@ -335,6 +336,11 @@ mod tests { Some(r#"{"file": "2.binlog", "pos": 100}"#.into()), ]); + let row3 = OwnedRow::new(vec![ + Some(trx_json.into()), + Some(r#"{"file": "3.binlog", "pos": 100}"#.into()), + ]); + let filter_expr = StreamCdcTableScan::build_cdc_filter_expr("t1"); assert_eq!( filter_expr.eval_row(&row1).await.unwrap(), @@ -344,5 +350,6 @@ mod tests { filter_expr.eval_row(&row2).await.unwrap(), Some(ScalarImpl::Bool(false)) ); + assert_eq!(filter_expr.eval_row(&row3).await.unwrap(), None) } } diff --git a/src/stream/src/executor/source/fetch_executor.rs b/src/stream/src/executor/source/fetch_executor.rs index 633dce3b0b9a8..3aa885cfffe1b 100644 --- a/src/stream/src/executor/source/fetch_executor.rs +++ b/src/stream/src/executor/source/fetch_executor.rs @@ -176,6 +176,7 @@ impl FsFetchExecutor { self.source_ctrl_opts.clone(), self.connector_params.connector_client.clone(), self.actor_ctx.error_suppressor.clone(), + source_desc.source.config.clone(), ) } diff --git a/src/stream/src/executor/source/fs_source_executor.rs b/src/stream/src/executor/source/fs_source_executor.rs index 85967a253ba91..6275ef5d116f6 100644 --- a/src/stream/src/executor/source/fs_source_executor.rs +++ b/src/stream/src/executor/source/fs_source_executor.rs @@ -104,6 +104,7 @@ impl FsSourceExecutor { self.source_ctrl_opts.clone(), None, self.actor_ctx.error_suppressor.clone(), + source_desc.source.config.clone(), ); source_desc .source diff --git a/src/stream/src/executor/source/source_executor.rs b/src/stream/src/executor/source/source_executor.rs index 433d54431ced6..1bb61789f1359 100644 --- a/src/stream/src/executor/source/source_executor.rs +++ b/src/stream/src/executor/source/source_executor.rs @@ -106,6 +106,7 @@ impl SourceExecutor { self.source_ctrl_opts.clone(), self.connector_params.connector_client.clone(), self.actor_ctx.error_suppressor.clone(), + source_desc.source.config.clone(), ); source_desc .source From ba1e2e145092046632deeee7383b7f876a137b95 Mon Sep 17 00:00:00 2001 From: Kexiang Wang Date: Sun, 7 Jan 2024 22:37:11 -0500 Subject: [PATCH 09/20] fix(connector): fix gcs source connector (#14373) --- ci/scripts/notify.py | 1 + .../s3-source-test-for-opendal-fs-engine.sh | 2 +- ci/workflows/main-cron.yml | 26 ++++++++++++++++++- e2e_test/s3/gcs_source.py | 2 +- src/connector/src/source/base.rs | 4 ++- .../filesystem/opendal_source/gcs_source.rs | 8 ++++-- .../source/filesystem/opendal_source/mod.rs | 11 +++++--- src/connector/with_options_source.yaml | 5 ++++ 8 files changed, 50 insertions(+), 9 deletions(-) diff --git a/ci/scripts/notify.py b/ci/scripts/notify.py index 5266998b0045f..9160d01675832 100755 --- a/ci/scripts/notify.py +++ b/ci/scripts/notify.py @@ -98,6 +98,7 @@ def get_mock_test_status(test): "e2e-clickhouse-sink-tests": "hard_failed", "e2e-pulsar-sink-tests": "", "s3-source-test-for-opendal-fs-engine": "", + "s3-source-tests": "", "pulsar-source-tests": "", "connector-node-integration-test": "" } diff --git a/ci/scripts/s3-source-test-for-opendal-fs-engine.sh b/ci/scripts/s3-source-test-for-opendal-fs-engine.sh index 6fbbdb35e0e45..355489acf2512 100755 --- a/ci/scripts/s3-source-test-for-opendal-fs-engine.sh +++ b/ci/scripts/s3-source-test-for-opendal-fs-engine.sh @@ -30,7 +30,7 @@ cargo make ci-start ci-3cn-3fe-opendal-fs-backend echo "--- Run test" python3 -m pip install minio psycopg2-binary -python3 e2e_test/s3/$script.py +python3 e2e_test/s3/$script echo "--- Kill cluster" rm -rf /tmp/rw_ci diff --git a/ci/workflows/main-cron.yml b/ci/workflows/main-cron.yml index 653578e4688e2..d931c3af16660 100644 --- a/ci/workflows/main-cron.yml +++ b/ci/workflows/main-cron.yml @@ -490,6 +490,7 @@ steps: retry: *auto-retry - label: "PosixFs source on OpenDAL fs engine (csv parser)" + key: "s3-source-test-for-opendal-fs-engine" command: "ci/scripts/s3-source-test.sh -p ci-release -s 'posix_fs_source.py csv_without_header'" if: | !(build.pull_request.labels includes "ci/main-cron/skip-ci") && build.env("CI_STEPS") == null @@ -507,7 +508,7 @@ steps: - label: "S3 source on OpenDAL fs engine" key: "s3-source-test-for-opendal-fs-engine" - command: "ci/scripts/s3-source-test-for-opendal-fs-engine.sh -p ci-release -s run" + command: "ci/scripts/s3-source-test-for-opendal-fs-engine.sh -p ci-release -s run.csv" if: | !(build.pull_request.labels includes "ci/main-cron/skip-ci") && build.env("CI_STEPS") == null || build.pull_request.labels includes "ci/run-s3-source-tests" @@ -527,6 +528,29 @@ steps: timeout_in_minutes: 20 retry: *auto-retry + # TODO(Kexiang): Enable this test after we have a GCS_SOURCE_TEST_CONF. + # - label: "GCS source on OpenDAL fs engine" + # key: "s3-source-test-for-opendal-fs-engine" + # command: "ci/scripts/s3-source-test-for-opendal-fs-engine.sh -p ci-release -s gcs.csv" + # if: | + # !(build.pull_request.labels includes "ci/main-cron/skip-ci") && build.env("CI_STEPS") == null + # || build.pull_request.labels includes "ci/run-s3-source-tests" + # || build.env("CI_STEPS") =~ /(^|,)s3-source-tests?(,|$$)/ + # depends_on: build + # plugins: + # - seek-oss/aws-sm#v2.3.1: + # env: + # S3_SOURCE_TEST_CONF: ci_s3_source_test_aws + # - docker-compose#v4.9.0: + # run: rw-build-env + # config: ci/docker-compose.yml + # mount-buildkite-agent: true + # environment: + # - S3_SOURCE_TEST_CONF + # - ./ci/plugins/upload-failure-logs + # timeout_in_minutes: 20 + # retry: *auto-retry + - label: "pulsar source check" key: "pulsar-source-tests" command: "ci/scripts/pulsar-source-test.sh -p ci-release" diff --git a/e2e_test/s3/gcs_source.py b/e2e_test/s3/gcs_source.py index c917f2c2d33fd..5e1144266fb23 100644 --- a/e2e_test/s3/gcs_source.py +++ b/e2e_test/s3/gcs_source.py @@ -57,7 +57,7 @@ def _encode(): connector = 'gcs', match_pattern = '{prefix}*.{fmt}', gcs.bucket_name = '{config['GCS_BUCKET']}', - gcs.credentials = '{credential}', + gcs.credential = '{credential}', ) FORMAT PLAIN ENCODE {_encode()};''') total_rows = file_num * item_num_per_file diff --git a/src/connector/src/source/base.rs b/src/connector/src/source/base.rs index 0374a9484918f..8743972ea8e6e 100644 --- a/src/connector/src/source/base.rs +++ b/src/connector/src/source/base.rs @@ -42,7 +42,7 @@ use super::google_pubsub::GooglePubsubMeta; use super::kafka::KafkaMeta; use super::monitor::SourceMetrics; use super::nexmark::source::message::NexmarkMeta; -use super::{OPENDAL_S3_CONNECTOR, POSIX_FS_CONNECTOR}; +use super::{GCS_CONNECTOR, OPENDAL_S3_CONNECTOR, POSIX_FS_CONNECTOR}; use crate::parser::ParserConfig; pub(crate) use crate::source::common::CommonSplitReader; use crate::source::filesystem::FsPageItem; @@ -400,6 +400,7 @@ impl ConnectorProperties { .map(|s| { s.eq_ignore_ascii_case(OPENDAL_S3_CONNECTOR) || s.eq_ignore_ascii_case(POSIX_FS_CONNECTOR) + || s.eq_ignore_ascii_case(GCS_CONNECTOR) }) .unwrap_or(false) } @@ -410,6 +411,7 @@ impl ConnectorProperties { .map(|s| { s.eq_ignore_ascii_case(OPENDAL_S3_CONNECTOR) || s.eq_ignore_ascii_case(POSIX_FS_CONNECTOR) + || s.eq_ignore_ascii_case(GCS_CONNECTOR) }) .unwrap_or(false) } diff --git a/src/connector/src/source/filesystem/opendal_source/gcs_source.rs b/src/connector/src/source/filesystem/opendal_source/gcs_source.rs index 7d9c2bec4429b..d6f7b44bff591 100644 --- a/src/connector/src/source/filesystem/opendal_source/gcs_source.rs +++ b/src/connector/src/source/filesystem/opendal_source/gcs_source.rs @@ -32,9 +32,13 @@ impl OpendalEnumerator { builder.bucket(&gcs_properties.bucket_name); // if credential env is set, use it. Otherwise, ADC will be used. - let cred = gcs_properties.credential; - if let Some(cred) = cred { + if let Some(cred) = gcs_properties.credential { builder.credential(&cred); + } else { + let cred = std::env::var("GOOGLE_APPLICATION_CREDENTIALS"); + if let Ok(cred) = cred { + builder.credential(&cred); + } } if let Some(service_account) = gcs_properties.service_account { diff --git a/src/connector/src/source/filesystem/opendal_source/mod.rs b/src/connector/src/source/filesystem/opendal_source/mod.rs index d6223c467d08b..e0c5a22f1fd90 100644 --- a/src/connector/src/source/filesystem/opendal_source/mod.rs +++ b/src/connector/src/source/filesystem/opendal_source/mod.rs @@ -38,10 +38,15 @@ pub const POSIX_FS_CONNECTOR: &str = "posix_fs"; pub struct GcsProperties { #[serde(rename = "gcs.bucket_name")] pub bucket_name: String, + + /// The base64 encoded credential key. If not set, ADC will be used. #[serde(rename = "gcs.credential")] pub credential: Option, + + /// If credential/ADC is not set. The service account can be used to provide the credential info. #[serde(rename = "gcs.service_account", default)] pub service_account: Option, + #[serde(rename = "match_pattern", default)] pub match_pattern: Option, @@ -107,7 +112,7 @@ pub struct OpendalS3Properties { #[serde(flatten)] pub s3_properties: S3PropertiesCommon, - // The following are only supported by s3_v2 (opendal) source. + /// The following are only supported by s3_v2 (opendal) source. #[serde(rename = "s3.assume_role", default)] pub assume_role: Option, @@ -131,11 +136,11 @@ impl SourceProperties for OpendalS3Properties { #[derive(Clone, Debug, Deserialize, PartialEq, WithOptions)] pub struct PosixFsProperties { - // The root directly of the files to search. The files will be searched recursively. + /// The root directly of the files to search. The files will be searched recursively. #[serde(rename = "posix_fs.root")] pub root: String, - // The regex pattern to match files under root directory. + /// The regex pattern to match files under root directory. #[serde(rename = "match_pattern", default)] pub match_pattern: Option, diff --git a/src/connector/with_options_source.yaml b/src/connector/with_options_source.yaml index 98a45599b56f2..187780cd23826 100644 --- a/src/connector/with_options_source.yaml +++ b/src/connector/with_options_source.yaml @@ -22,9 +22,11 @@ GcsProperties: required: true - name: gcs.credential field_type: String + comments: The base64 encoded credential key. If not set, ADC will be used. required: false - name: gcs.service_account field_type: String + comments: If credential/ADC is not set. The service account can be used to provide the credential info. required: false default: Default::default - name: match_pattern @@ -453,15 +455,18 @@ OpendalS3Properties: required: false - name: s3.assume_role field_type: String + comments: The following are only supported by s3_v2 (opendal) source. required: false default: Default::default PosixFsProperties: fields: - name: posix_fs.root field_type: String + comments: The root directly of the files to search. The files will be searched recursively. required: true - name: match_pattern field_type: String + comments: The regex pattern to match files under root directory. required: false default: Default::default PubsubProperties: From 975bfac19a5b97eead6c61d16515ccb334e3214b Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Mon, 8 Jan 2024 12:07:43 +0800 Subject: [PATCH 10/20] refactor(dashboard): refine type annotation and renames for layout (#14383) Signed-off-by: Bugen Zhao --- ...yGraph.tsx => FragmentDependencyGraph.tsx} | 10 +- dashboard/components/FragmentGraph.tsx | 279 +++++++++--------- ...mGraph.tsx => RelationDependencyGraph.tsx} | 49 +-- dashboard/lib/layout.ts | 149 +++++----- dashboard/pages/dependency_graph.tsx | 8 +- dashboard/pages/fragment_graph.tsx | 20 +- 6 files changed, 252 insertions(+), 263 deletions(-) rename dashboard/components/{DependencyGraph.tsx => FragmentDependencyGraph.tsx} (96%) rename dashboard/components/{StreamGraph.tsx => RelationDependencyGraph.tsx} (82%) diff --git a/dashboard/components/DependencyGraph.tsx b/dashboard/components/FragmentDependencyGraph.tsx similarity index 96% rename from dashboard/components/DependencyGraph.tsx rename to dashboard/components/FragmentDependencyGraph.tsx index f85204e438fb5..553c40ec53f92 100644 --- a/dashboard/components/DependencyGraph.tsx +++ b/dashboard/components/FragmentDependencyGraph.tsx @@ -3,11 +3,12 @@ import * as d3 from "d3" import { Dag, DagLink, DagNode, zherebko } from "d3-dag" import { cloneDeep } from "lodash" import { useCallback, useEffect, useRef, useState } from "react" +import { Position } from "../lib/layout" const nodeRadius = 5 const edgeRadius = 12 -export default function DependencyGraph({ +export default function FragmentDependencyGraph({ mvDependency, svgWidth, selectedId, @@ -18,7 +19,7 @@ export default function DependencyGraph({ selectedId: string | undefined onSelectedIdChange: (id: string) => void | undefined }) { - const svgRef = useRef() + const svgRef = useRef(null) const [svgHeight, setSvgHeight] = useState("0px") const MARGIN_X = 10 const MARGIN_Y = 2 @@ -47,7 +48,7 @@ export default function DependencyGraph({ // How to draw edges const curveStyle = d3.curveMonotoneY const line = d3 - .line<{ x: number; y: number }>() + .line() .curve(curveStyle) .x(({ x }) => x + MARGIN_X) .y(({ y }) => y) @@ -85,8 +86,7 @@ export default function DependencyGraph({ sel .attr( "transform", - ({ x, y }: { x: number; y: number }) => - `translate(${x + MARGIN_X}, ${y})` + ({ x, y }: Position) => `translate(${x + MARGIN_X}, ${y})` ) .attr("fill", (d: any) => isSelected(d) ? theme.colors.blue["500"] : theme.colors.gray["500"] diff --git a/dashboard/components/FragmentGraph.tsx b/dashboard/components/FragmentGraph.tsx index 891a6d21df4b0..d9ca56fea69e8 100644 --- a/dashboard/components/FragmentGraph.tsx +++ b/dashboard/components/FragmentGraph.tsx @@ -15,9 +15,10 @@ import * as d3 from "d3" import { cloneDeep } from "lodash" import { Fragment, useCallback, useEffect, useRef, useState } from "react" import { - ActorBox, - ActorBoxPosition, - generateBoxLinks, + FragmentBox, + FragmentBoxPosition, + Position, + generateBoxEdges, layout, } from "../lib/layout" import { PlanNodeDatum } from "../pages/fragment_graph" @@ -25,10 +26,17 @@ import BackPressureTable from "./BackPressureTable" const ReactJson = loadable(() => import("react-json-view")) -interface Point { - x: number - y: number -} +type FragmentLayout = { + id: string + layoutRoot: d3.HierarchyPointNode + width: number + height: number + actorIds: string[] +} & Position + +type Enter = Type extends d3.Selection + ? d3.Selection + : never function treeLayoutFlip( root: d3.HierarchyNode, @@ -40,10 +48,10 @@ function treeLayoutFlip( const treeRoot = tree(root) // Flip back x, y - treeRoot.each((d: Point) => ([d.x, d.y] = [d.y, d.x])) + treeRoot.each((d: Position) => ([d.x, d.y] = [d.y, d.x])) // LTR -> RTL - treeRoot.each((d: Point) => (d.x = -d.x)) + treeRoot.each((d: Position) => (d.x = -d.x)) return treeRoot } @@ -78,10 +86,10 @@ function boundBox( const nodeRadius = 12 const nodeMarginX = nodeRadius * 6 const nodeMarginY = nodeRadius * 4 -const actorMarginX = nodeRadius -const actorMarginY = nodeRadius -const actorDistanceX = nodeRadius * 5 -const actorDistanceY = nodeRadius * 5 +const fragmentMarginX = nodeRadius +const fragmentMarginY = nodeRadius +const fragmentDistanceX = nodeRadius * 5 +const fragmentDistanceY = nodeRadius * 5 export default function FragmentGraph({ planNodeDependencies, @@ -89,34 +97,27 @@ export default function FragmentGraph({ selectedFragmentId, }: { planNodeDependencies: Map> - fragmentDependency: ActorBox[] + fragmentDependency: FragmentBox[] selectedFragmentId: string | undefined }) { - const svgRef = useRef() + const svgRef = useRef(null) const { isOpen, onOpen, onClose } = useDisclosure() const [currentStreamNode, setCurrentStreamNode] = useState() const openPlanNodeDetail = useCallback( - () => (node: d3.HierarchyNode) => { - setCurrentStreamNode(node.data) + (node: PlanNodeDatum) => { + setCurrentStreamNode(node) onOpen() }, - [onOpen] - )() + [onOpen, setCurrentStreamNode] + ) const planNodeDependencyDagCallback = useCallback(() => { const deps = cloneDeep(planNodeDependencies) const fragmentDependencyDag = cloneDeep(fragmentDependency) - const layoutActorResult = new Map< - string, - { - layoutRoot: d3.HierarchyPointNode - width: number - height: number - extraInfo: string - } - >() + + const layoutFragmentResult = new Map() const includedFragmentIds = new Set() for (const [fragmentId, fragmentRoot] of deps) { const layoutRoot = treeLayoutFlip(fragmentRoot, { @@ -125,95 +126,86 @@ export default function FragmentGraph({ }) let { width, height } = boundBox(layoutRoot, { margin: { - left: nodeRadius * 4 + actorMarginX, - right: nodeRadius * 4 + actorMarginX, - top: nodeRadius * 3 + actorMarginY, - bottom: nodeRadius * 4 + actorMarginY, + left: nodeRadius * 4 + fragmentMarginX, + right: nodeRadius * 4 + fragmentMarginX, + top: nodeRadius * 3 + fragmentMarginY, + bottom: nodeRadius * 4 + fragmentMarginY, }, }) - layoutActorResult.set(fragmentId, { + layoutFragmentResult.set(fragmentId, { layoutRoot, width, height, - extraInfo: `Actor ${fragmentRoot.data.actor_ids?.join(", ")}` || "", + actorIds: fragmentRoot.data.actorIds ?? [], }) includedFragmentIds.add(fragmentId) } + const fragmentLayout = layout( fragmentDependencyDag.map(({ width: _1, height: _2, id, ...data }) => { - const { width, height } = layoutActorResult.get(id)! + const { width, height } = layoutFragmentResult.get(id)! return { width, height, id, ...data } }), - actorDistanceX, - actorDistanceY + fragmentDistanceX, + fragmentDistanceY ) - const fragmentLayoutPosition = new Map() - fragmentLayout.forEach(({ id, x, y }: ActorBoxPosition) => { + const fragmentLayoutPosition = new Map() + fragmentLayout.forEach(({ id, x, y }: FragmentBoxPosition) => { fragmentLayoutPosition.set(id, { x, y }) }) - const layoutResult = [] - for (const [fragmentId, result] of layoutActorResult) { + + const layoutResult: FragmentLayout[] = [] + for (const [fragmentId, result] of layoutFragmentResult) { const { x, y } = fragmentLayoutPosition.get(fragmentId)! layoutResult.push({ id: fragmentId, x, y, ...result }) } + let svgWidth = 0 let svgHeight = 0 layoutResult.forEach(({ x, y, width, height }) => { svgHeight = Math.max(svgHeight, y + height + 50) svgWidth = Math.max(svgWidth, x + width) }) - const links = generateBoxLinks(fragmentLayout) + const edges = generateBoxEdges(fragmentLayout) + return { layoutResult, - fragmentLayout, svgWidth, svgHeight, - links, + edges, includedFragmentIds, } }, [planNodeDependencies, fragmentDependency]) - type PlanNodeDesc = { - layoutRoot: d3.HierarchyPointNode - width: number - height: number - x: number - y: number - id: string - extraInfo: string - } - const { svgWidth, svgHeight, - links, - fragmentLayout: fragmentDependencyDag, - layoutResult: planNodeDependencyDag, + edges: fragmentEdgeLayout, + layoutResult: fragmentLayout, includedFragmentIds, } = planNodeDependencyDagCallback() useEffect(() => { - if (planNodeDependencyDag) { + if (fragmentLayout) { const svgNode = svgRef.current const svgSelection = d3.select(svgNode) // How to draw edges const treeLink = d3 - .linkHorizontal() - .x((d: Point) => d.x) - .y((d: Point) => d.y) + .linkHorizontal() + .x((d: Position) => d.x) + .y((d: Position) => d.y) - const isSelected = (d: any) => d === selectedFragmentId + const isSelected = (id: string) => id === selectedFragmentId - const applyActor = ( - gSel: d3.Selection - ) => { + // Fragments + const applyFragment = (gSel: FragmentSelection) => { gSel.attr("transform", ({ x, y }) => `translate(${x}, ${y})`) - // Actor text (fragment id) - let text = gSel.select(".actor-text-frag-id") + // Fragment text line 1 (fragment id) + let text = gSel.select(".text-frag-id") if (text.empty()) { - text = gSel.append("text").attr("class", "actor-text-frag-id") + text = gSel.append("text").attr("class", "text-frag-id") } text @@ -221,38 +213,38 @@ export default function FragmentGraph({ .text(({ id }) => `Fragment ${id}`) .attr("font-family", "inherit") .attr("text-anchor", "end") - .attr("dy", ({ height }) => height - actorMarginY + 12) - .attr("dx", ({ width }) => width - actorMarginX) + .attr("dy", ({ height }) => height - fragmentMarginY + 12) + .attr("dx", ({ width }) => width - fragmentMarginX) .attr("fill", "black") .attr("font-size", 12) - // Actor text (actors) - let text2 = gSel.select(".actor-text-actor-id") + // Fragment text line 2 (actor ids) + let text2 = gSel.select(".text-actor-id") if (text2.empty()) { - text2 = gSel.append("text").attr("class", "actor-text-actor-id") + text2 = gSel.append("text").attr("class", "text-actor-id") } text2 .attr("fill", "black") - .text(({ extraInfo }) => extraInfo) + .text(({ actorIds }) => `Actor ${actorIds.join(", ")}`) .attr("font-family", "inherit") .attr("text-anchor", "end") - .attr("dy", ({ height }) => height - actorMarginY + 24) - .attr("dx", ({ width }) => width - actorMarginX) + .attr("dy", ({ height }) => height - fragmentMarginY + 24) + .attr("dx", ({ width }) => width - fragmentMarginX) .attr("fill", "black") .attr("font-size", 12) - // Actor bounding box + // Fragment bounding box let boundingBox = gSel.select(".bounding-box") if (boundingBox.empty()) { boundingBox = gSel.append("rect").attr("class", "bounding-box") } boundingBox - .attr("width", ({ width }) => width - actorMarginX * 2) - .attr("height", ({ height }) => height - actorMarginY * 2) - .attr("x", actorMarginX) - .attr("y", actorMarginY) + .attr("width", ({ width }) => width - fragmentMarginX * 2) + .attr("height", ({ height }) => height - fragmentMarginY * 2) + .attr("x", fragmentMarginX) + .attr("y", fragmentMarginY) .attr("fill", "white") .attr("stroke-width", ({ id }) => (isSelected(id) ? 3 : 1)) .attr("rx", 5) @@ -260,56 +252,43 @@ export default function FragmentGraph({ isSelected(id) ? theme.colors.blue[500] : theme.colors.gray[500] ) - // Actor links - let linkSelection = gSel.select(".links") - if (linkSelection.empty()) { - linkSelection = gSel.append("g").attr("class", "links") + // Stream node edges + let edgeSelection = gSel.select(".edges") + if (edgeSelection.empty()) { + edgeSelection = gSel.append("g").attr("class", "edges") } - const applyLink = ( - sel: d3.Selection< - SVGPathElement, - d3.HierarchyPointLink, - SVGGElement, - PlanNodeDesc - > - ) => sel.attr("d", treeLink) - - const createLink = ( - sel: d3.Selection< - d3.EnterElement, - d3.HierarchyPointLink, - SVGGElement, - PlanNodeDesc - > - ) => { + const applyEdge = (sel: EdgeSelection) => sel.attr("d", treeLink) + + const createEdge = (sel: Enter) => { sel .append("path") .attr("fill", "none") .attr("stroke", theme.colors.gray[700]) .attr("stroke-width", 1.5) - .call(applyLink) + .call(applyEdge) return sel } - const links = linkSelection + const edges = edgeSelection .selectAll("path") .data(({ layoutRoot }) => layoutRoot.links()) + type EdgeSelection = typeof edges - links.enter().call(createLink) - links.call(applyLink) - links.exit().remove() + edges.enter().call(createEdge) + edges.call(applyEdge) + edges.exit().remove() - // Actor nodes + // Stream nodes in fragment let nodes = gSel.select(".nodes") if (nodes.empty()) { nodes = gSel.append("g").attr("class", "nodes") } - const applyActorNode = (g: any) => { - g.attr("transform", (d: any) => `translate(${d.x},${d.y})`) + const applyStreamNode = (g: StreamNodeSelection) => { + g.attr("transform", (d) => `translate(${d.x},${d.y})`) - let circle = g.select("circle") + let circle = g.select("circle") if (circle.empty()) { circle = g.append("circle") } @@ -318,16 +297,16 @@ export default function FragmentGraph({ .attr("fill", theme.colors.blue[500]) .attr("r", nodeRadius) .style("cursor", "pointer") - .on("click", (_d: any, i: any) => openPlanNodeDetail(i)) + .on("click", (_d, i) => openPlanNodeDetail(i.data)) - let text = g.select("text") + let text = g.select("text") if (text.empty()) { text = g.append("text") } text .attr("fill", "black") - .text((d: any) => d.data.name) + .text((d) => d.data.name) .attr("font-family", "inherit") .attr("text-anchor", "middle") .attr("dy", nodeRadius * 1.8) @@ -338,67 +317,77 @@ export default function FragmentGraph({ return g } - const createActorNode = (sel: any) => - sel.append("g").attr("class", "actor-node").call(applyActorNode) + const createStreamNode = (sel: Enter) => + sel.append("g").attr("class", "stream-node").call(applyStreamNode) - const actorNodeSelection = nodes - .selectAll(".actor-node") + const streamNodeSelection = nodes + .selectAll(".stream-node") .data(({ layoutRoot }) => layoutRoot.descendants()) + type StreamNodeSelection = typeof streamNodeSelection - actorNodeSelection.exit().remove() - actorNodeSelection.enter().call(createActorNode) - actorNodeSelection.call(applyActorNode) + streamNodeSelection.exit().remove() + streamNodeSelection.enter().call(createStreamNode) + streamNodeSelection.call(applyStreamNode) } - const createActor = ( - sel: d3.Selection - ) => { - const gSel = sel.append("g").attr("class", "actor").call(applyActor) + const createFragment = (sel: Enter) => { + const gSel = sel + .append("g") + .attr("class", "fragment") + .call(applyFragment) return gSel } - const actorSelection = svgSelection - .select(".actors") - .selectAll(".actor") - .data(planNodeDependencyDag) + const fragmentSelection = svgSelection + .select(".fragments") + .selectAll(".fragment") + .data(fragmentLayout) + type FragmentSelection = typeof fragmentSelection - actorSelection.enter().call(createActor) - actorSelection.call(applyActor) - actorSelection.exit().remove() + fragmentSelection.enter().call(createFragment) + fragmentSelection.call(applyFragment) + fragmentSelection.exit().remove() + // Fragment Edges const edgeSelection = svgSelection - .select(".actor-links") - .selectAll(".actor-link") - .data(links) + .select(".fragment-edges") + .selectAll(".fragment-edge") + .data(fragmentEdgeLayout) + type EdgeSelection = typeof edgeSelection const curveStyle = d3.curveMonotoneX const line = d3 - .line<{ x: number; y: number }>() + .line() .curve(curveStyle) .x(({ x }) => x) .y(({ y }) => y) - const applyEdge = (sel: any) => + const applyEdge = (sel: EdgeSelection) => sel - .attr("d", ({ points }: any) => line(points)) + .attr("d", ({ points }) => line(points)) .attr("fill", "none") - .attr("stroke-width", (d: any) => + .attr("stroke-width", (d) => isSelected(d.source) || isSelected(d.target) ? 2 : 1 ) - .attr("stroke", (d: any) => + .attr("stroke", (d) => isSelected(d.source) || isSelected(d.target) ? theme.colors.blue["500"] : theme.colors.gray["300"] ) - const createEdge = (sel: any) => - sel.append("path").attr("class", "actor-link").call(applyEdge) + const createEdge = (sel: Enter) => + sel.append("path").attr("class", "fragment-edge").call(applyEdge) edgeSelection.enter().call(createEdge) edgeSelection.call(applyEdge) edgeSelection.exit().remove() } - }, [planNodeDependencyDag, links, selectedFragmentId, openPlanNodeDetail]) + }, [ + fragmentLayout, + fragmentEdgeLayout, + selectedFragmentId, + openPlanNodeDetail, + ]) return ( @@ -431,8 +420,8 @@ export default function FragmentGraph({ - - + + diff --git a/dashboard/components/StreamGraph.tsx b/dashboard/components/RelationDependencyGraph.tsx similarity index 82% rename from dashboard/components/StreamGraph.tsx rename to dashboard/components/RelationDependencyGraph.tsx index 890bce6c98ad8..99d40ca2615fd 100644 --- a/dashboard/components/StreamGraph.tsx +++ b/dashboard/components/RelationDependencyGraph.tsx @@ -19,14 +19,15 @@ import { theme } from "@chakra-ui/react" import * as d3 from "d3" import { useCallback, useEffect, useRef } from "react" import { - ActorPoint, - ActorPointPosition, + FragmentPoint, + FragmentPointPosition, + Position, flipLayoutPoint, - generatePointLinks, + generatePointEdges, } from "../lib/layout" function boundBox( - actorPosition: ActorPointPosition[], + fragmentPosition: FragmentPointPosition[], nodeRadius: number ): { width: number @@ -34,7 +35,7 @@ function boundBox( } { let width = 0 let height = 0 - for (const { x, y, data } of actorPosition) { + for (const { x, y } of fragmentPosition) { width = Math.max(width, x + nodeRadius) height = Math.max(height, y + nodeRadius) } @@ -46,11 +47,11 @@ const rowMargin = 200 const nodeRadius = 10 const layoutMargin = 100 -export function StreamGraph({ +export default function RelationDependencyGraph({ nodes, selectedId, }: { - nodes: ActorPoint[] + nodes: FragmentPoint[] selectedId?: string }) { const svgRef = useRef() @@ -61,12 +62,15 @@ export function StreamGraph({ layerMargin, rowMargin, nodeRadius - ).map(({ x, y, ...data }) => ({ - x: x + layoutMargin, - y: y + layoutMargin, - ...data, - })) - const links = generatePointLinks(layoutMap) + ).map( + ({ x, y, ...data }) => + ({ + x: x + layoutMargin, + y: y + layoutMargin, + ...data, + } as FragmentPointPosition) + ) + const links = generatePointEdges(layoutMap) const { width, height } = boundBox(layoutMap, nodeRadius) return { layoutMap, @@ -85,7 +89,7 @@ export function StreamGraph({ const curveStyle = d3.curveMonotoneY const line = d3 - .line<{ x: number; y: number }>() + .line() .curve(curveStyle) .x(({ x }) => x) .y(({ y }) => y) @@ -120,13 +124,10 @@ export function StreamGraph({ edgeSelection.enter().call(createEdge) edgeSelection.call(applyEdge) - const applyNode = (g: any) => { - g.attr( - "transform", - ({ x, y }: ActorPointPosition) => `translate(${x},${y})` - ) + const applyNode = (g: NodeSelection) => { + g.attr("transform", ({ x, y }) => `translate(${x},${y})`) - let circle = g.select("circle") + let circle = g.select("circle") if (circle.empty()) { circle = g.append("circle") } @@ -134,18 +135,18 @@ export function StreamGraph({ circle .attr("r", nodeRadius) .style("cursor", "pointer") - .attr("fill", ({ id }: ActorPointPosition) => + .attr("fill", ({ id }) => isSelected(id) ? theme.colors.blue["500"] : theme.colors.gray["500"] ) - let text = g.select("text") + let text = g.select("text") if (text.empty()) { text = g.append("text") } text .attr("fill", "black") - .text(({ data: { name } }: ActorPointPosition) => name) + .text(({ name }) => name) .attr("font-family", "inherit") .attr("text-anchor", "middle") .attr("dy", nodeRadius * 2) @@ -161,6 +162,8 @@ export function StreamGraph({ const g = svgSelection.select(".boxes") const nodeSelection = g.selectAll(".node").data(layoutMap) + type NodeSelection = typeof nodeSelection + nodeSelection.enter().call(createNode) nodeSelection.call(applyNode) nodeSelection.exit().remove() diff --git a/dashboard/lib/layout.ts b/dashboard/lib/layout.ts index dd2871051e683..c0e61faeddbf1 100644 --- a/dashboard/lib/layout.ts +++ b/dashboard/lib/layout.ts @@ -211,104 +211,103 @@ function dagLayout(nodes: GraphNode[]) { /** * @param fragments - * @returns Layer and row of the actor + * @returns Layer and row of the fragment */ function gridLayout( - fragments: Array -): Map { - // turn ActorBox to GraphNode - let actorBoxIdToActorBox = new Map() + fragments: Array +): Map { + // turn FragmentBox to GraphNode + let idToBox = new Map() for (let fragment of fragments) { - actorBoxIdToActorBox.set(fragment.id, fragment) + idToBox.set(fragment.id, fragment) } - let nodeToActorBoxId = new Map() - let actorBoxIdToNode = new Map() - const getActorBoxNode = (actorboxId: String): GraphNode => { - let rtn = actorBoxIdToNode.get(actorboxId) + let nodeToId = new Map() + let idToNode = new Map() + const getNode = (id: String): GraphNode => { + let rtn = idToNode.get(id) if (rtn !== undefined) { return rtn } let newNode = { nextNodes: new Array(), } - let ab = actorBoxIdToActorBox.get(actorboxId) + let ab = idToBox.get(id) if (ab === undefined) { - throw Error(`no such id ${actorboxId}`) + throw Error(`no such id ${id}`) } for (let id of ab.parentIds) { - // newNode.nextNodes.push(getActorBoxNode(id)) - getActorBoxNode(id).nextNodes.push(newNode) + getNode(id).nextNodes.push(newNode) } - actorBoxIdToNode.set(actorboxId, newNode) - nodeToActorBoxId.set(newNode, actorboxId) + idToNode.set(id, newNode) + nodeToId.set(newNode, id) return newNode } for (let fragment of fragments) { - getActorBoxNode(fragment.id) + getNode(fragment.id) } // run daglayout on GraphNode - let rtn = new Map() + let rtn = new Map() let allNodes = new Array() - for (let _n of nodeToActorBoxId.keys()) { + for (let _n of nodeToId.keys()) { allNodes.push(_n) } let resultMap = dagLayout(allNodes) for (let item of resultMap) { - let abId = nodeToActorBoxId.get(item[0]) - if (!abId) { - throw Error(`no corresponding actorboxid of node ${item[0]}`) + let id = nodeToId.get(item[0]) + if (!id) { + throw Error(`no corresponding fragment id of node ${item[0]}`) } - let ab = actorBoxIdToActorBox.get(abId) - if (!ab) { - throw Error(`actorbox id ${abId} is not present in actorBoxIdToActorBox`) + let fb = idToBox.get(id) + if (!fb) { + throw Error(`fragment id ${id} is not present in idToBox`) } - rtn.set(ab, item[1]) + rtn.set(fb, item[1]) } return rtn } -export interface ActorBox { +export interface FragmentBox { id: string name: string - order: number // preference order, actor box with larger order will be placed at right + order: number // preference order, fragment box with larger order will be placed at right width: number height: number parentIds: string[] fragment?: TableFragments_Fragment } -export interface ActorPoint { +export interface FragmentPoint { id: string name: string - order: number // preference order, actor box with larger order will be placed at right + order: number // preference order, fragment box with larger order will be placed at right parentIds: string[] } -export interface ActorBoxPosition { - id: string +export interface Position { x: number y: number - data: ActorBox } -export interface ActorPointPosition { - id: string - x: number - y: number - data: ActorPoint +export type FragmentBoxPosition = FragmentBox & Position +export type FragmentPointPosition = FragmentPoint & Position + +export interface Edge { + points: Array + source: string + target: string } /** * @param fragments - * @returns the coordination of the top-left corner of the actor box + * @returns the coordination of the top-left corner of the fragment box */ export function layout( - fragments: Array, + fragments: Array, layerMargin: number, rowMargin: number -): ActorBoxPosition[] { +): FragmentBoxPosition[] { let layoutMap = gridLayout(fragments) let layerRequiredWidth = new Map() let rowRequiredHeight = new Map() @@ -316,16 +315,16 @@ export function layout( maxRow = 0 for (let item of layoutMap) { - let ab = item[0], + let fb = item[0], layer = item[1][0], row = item[1][1] let currentWidth = layerRequiredWidth.get(layer) || 0 - if (ab.width > currentWidth) { - layerRequiredWidth.set(layer, ab.width) + if (fb.width > currentWidth) { + layerRequiredWidth.set(layer, fb.width) } let currentHeight = rowRequiredHeight.get(row) || 0 - if (ab.height > currentHeight) { - rowRequiredHeight.set(row, ab.height) + if (fb.height > currentHeight) { + rowRequiredHeight.set(row, fb.height) } maxLayer = max([layer, maxLayer]) || 0 @@ -373,17 +372,16 @@ export function layout( getCumulativeMargin(i, rowMargin, rowCumulativeHeight, rowRequiredHeight) } - let rtn: Array = [] + let rtn: Array = [] for (let [data, [layer, row]] of layoutMap) { let x = layerCumulativeWidth.get(layer) let y = rowCumulativeHeight.get(row) if (x !== undefined && y !== undefined) { rtn.push({ - id: data.id, x, y, - data, + ...data, }) } else { throw Error(`x of layer ${layer}: ${x}, y of row ${row}: ${y} `) @@ -393,30 +391,29 @@ export function layout( } export function flipLayout( - fragments: Array, + fragments: Array, layerMargin: number, rowMargin: number -): ActorBoxPosition[] { +): FragmentBoxPosition[] { const fragments_ = cloneDeep(fragments) for (let fragment of fragments_) { ;[fragment.width, fragment.height] = [fragment.height, fragment.width] } - const actorPosition = layout(fragments_, rowMargin, layerMargin) - return actorPosition.map(({ id, x, y, data }) => ({ - id, - data, + const fragmentPosition = layout(fragments_, rowMargin, layerMargin) + return fragmentPosition.map(({ x, y, ...data }) => ({ x: y, y: x, + ...data, })) } export function layoutPoint( - fragments: Array, + fragments: Array, layerMargin: number, rowMargin: number, nodeRadius: number -): ActorPointPosition[] { - const fragmentBoxes: Array = [] +): FragmentPointPosition[] { + const fragmentBoxes: Array = [] for (let { id, name, order, parentIds, ...others } of fragments) { fragmentBoxes.push({ id, @@ -429,42 +426,40 @@ export function layoutPoint( }) } const result = layout(fragmentBoxes, layerMargin, rowMargin) - return result.map(({ id, x, y, data }) => ({ - id, - data, + return result.map(({ x, y, ...data }) => ({ x: x + nodeRadius, y: y + nodeRadius, + ...data, })) } export function flipLayoutPoint( - fragments: Array, + fragments: Array, layerMargin: number, rowMargin: number, nodeRadius: number -): ActorPointPosition[] { - const actorPosition = layoutPoint( +): FragmentPointPosition[] { + const fragmentPosition = layoutPoint( fragments, rowMargin, layerMargin, nodeRadius ) - return actorPosition.map(({ id, x, y, data }) => ({ - id, - data, + return fragmentPosition.map(({ x, y, ...data }) => ({ x: y, y: x, + ...data, })) } -export function generatePointLinks(layoutMap: ActorPointPosition[]) { +export function generatePointEdges(layoutMap: FragmentPointPosition[]): Edge[] { const links = [] - const fragmentMap = new Map() + const fragmentMap = new Map() for (const x of layoutMap) { fragmentMap.set(x.id, x) } for (const fragment of layoutMap) { - for (const parentId of fragment.data.parentIds) { + for (const parentId of fragment.parentIds) { const parentFragment = fragmentMap.get(parentId)! links.push({ points: [ @@ -479,24 +474,24 @@ export function generatePointLinks(layoutMap: ActorPointPosition[]) { return links } -export function generateBoxLinks(layoutMap: ActorBoxPosition[]) { +export function generateBoxEdges(layoutMap: FragmentBoxPosition[]): Edge[] { const links = [] - const fragmentMap = new Map() + const fragmentMap = new Map() for (const x of layoutMap) { fragmentMap.set(x.id, x) } for (const fragment of layoutMap) { - for (const parentId of fragment.data.parentIds) { + for (const parentId of fragment.parentIds) { const parentFragment = fragmentMap.get(parentId)! links.push({ points: [ { - x: fragment.x + fragment.data.width / 2, - y: fragment.y + fragment.data.height / 2, + x: fragment.x + fragment.width / 2, + y: fragment.y + fragment.height / 2, }, { - x: parentFragment.x + parentFragment.data.width / 2, - y: parentFragment.y + parentFragment.data.height / 2, + x: parentFragment.x + parentFragment.width / 2, + y: parentFragment.y + parentFragment.height / 2, }, ], source: fragment.id, diff --git a/dashboard/pages/dependency_graph.tsx b/dashboard/pages/dependency_graph.tsx index 40a7d4c5c897a..33a3a7f29b029 100644 --- a/dashboard/pages/dependency_graph.tsx +++ b/dashboard/pages/dependency_graph.tsx @@ -21,15 +21,15 @@ import Head from "next/head" import Link from "next/link" import { useRouter } from "next/router" import { Fragment, useCallback, useEffect, useState } from "react" -import { StreamGraph } from "../components/StreamGraph" +import RelationDependencyGraph from "../components/RelationDependencyGraph" import Title from "../components/Title" import useErrorToast from "../hook/useErrorToast" -import { ActorPoint } from "../lib/layout" +import { FragmentPoint } from "../lib/layout" import { Relation, getRelations, relationIsStreamingJob } from "./api/streaming" const SIDEBAR_WIDTH = "200px" -function buildDependencyAsEdges(list: Relation[]): ActorPoint[] { +function buildDependencyAsEdges(list: Relation[]): FragmentPoint[] { const edges = [] const relationSet = new Set(list.map((r) => r.id)) for (const r of reverse(sortBy(list, "id"))) { @@ -122,7 +122,7 @@ export default function StreamingGraph() { > Graph {mvDependency && ( - diff --git a/dashboard/pages/fragment_graph.tsx b/dashboard/pages/fragment_graph.tsx index c5e8accab51be..18042c6450dc3 100644 --- a/dashboard/pages/fragment_graph.tsx +++ b/dashboard/pages/fragment_graph.tsx @@ -33,11 +33,11 @@ import _ from "lodash" import Head from "next/head" import { useRouter } from "next/router" import { Fragment, useCallback, useEffect, useState } from "react" -import DependencyGraph from "../components/DependencyGraph" +import FragmentDependencyGraph from "../components/FragmentDependencyGraph" import FragmentGraph from "../components/FragmentGraph" import Title from "../components/Title" import useErrorToast from "../hook/useErrorToast" -import { ActorBox } from "../lib/layout" +import { FragmentBox } from "../lib/layout" import { TableFragments, TableFragments_Fragment } from "../proto/gen/meta" import { Dispatcher, StreamNode } from "../proto/gen/stream_plan" import useFetch from "./api/fetch" @@ -53,7 +53,7 @@ export interface PlanNodeDatum { children?: PlanNodeDatum[] operatorId: string | number node: StreamNode | DispatcherNode - actor_ids?: string[] + actorIds?: string[] } function buildPlanNodeDependency( @@ -94,15 +94,17 @@ function buildPlanNodeDependency( return d3.hierarchy({ name: dispatcherName, - actor_ids: fragment.actors.map((a) => a.actorId.toString()), + actorIds: fragment.actors.map((a) => a.actorId.toString()), children: firstActor.nodes ? [hierarchyActorNode(firstActor.nodes)] : [], operatorId: "dispatcher", node: dispatcherNode, }) } -function buildFragmentDependencyAsEdges(fragments: TableFragments): ActorBox[] { - const nodes: ActorBox[] = [] +function buildFragmentDependencyAsEdges( + fragments: TableFragments +): FragmentBox[] { + const nodes: FragmentBox[] = [] const actorToFragmentMapping = new Map() for (const fragmentId in fragments.fragments) { const fragment = fragments.fragments[fragmentId] @@ -128,8 +130,8 @@ function buildFragmentDependencyAsEdges(fragments: TableFragments): ActorBox[] { width: 0, height: 0, order: fragment.fragmentId, - fragment: fragment, - } as ActorBox) + fragment, + } as FragmentBox) } return nodes } @@ -326,7 +328,7 @@ export default function Streaming() { Fragments {fragmentDependencyDag && ( - From 4f6295a3da86cfc4ca3e08667e53f76ed6ecff8f Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Mon, 8 Jan 2024 13:01:56 +0800 Subject: [PATCH 11/20] feat: embed trace collector & jaeger ui to dev dashboard (#14220) Signed-off-by: Bugen Zhao Co-authored-by: BugenZhao --- Cargo.lock | 89 +++++++++++++++++++++++++++++++ Cargo.toml | 1 + dashboard/components/Layout.tsx | 3 ++ src/cmd/src/lib.rs | 8 +-- src/cmd_all/src/bin/risingwave.rs | 5 +- src/cmd_all/src/playground.rs | 11 ++++ src/cmd_all/src/standalone.rs | 44 ++++++++++++--- src/common/src/config.rs | 33 ++++++++++-- src/common/src/lib.rs | 4 +- src/common/src/opts.rs | 24 +++++++++ src/common/src/util/meta_addr.rs | 16 ++++++ src/compute/src/lib.rs | 10 ++++ src/config/example.toml | 4 ++ src/frontend/src/lib.rs | 10 ++++ src/meta/Cargo.toml | 1 + src/meta/node/Cargo.toml | 1 + src/meta/node/src/lib.rs | 21 +++++++- src/meta/node/src/server.rs | 10 ++++ src/meta/src/dashboard/mod.rs | 35 ++++++------ src/meta/src/manager/env.rs | 9 ++++ src/storage/compactor/src/lib.rs | 10 ++++ src/utils/runtime/src/logger.rs | 37 ++++++++++++- src/workspace-hack/Cargo.toml | 8 +++ 23 files changed, 354 insertions(+), 40 deletions(-) create mode 100644 src/common/src/opts.rs diff --git a/Cargo.lock b/Cargo.lock index a6f46c1d34fba..2d3142158edcc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2954,6 +2954,24 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "datasize" +version = "0.2.15" +source = "git+https://github.com/BugenZhao/datasize-rs?rev=8192cf2d751119a6a30e2ef67e5eb252f8e5b3e5#8192cf2d751119a6a30e2ef67e5eb252f8e5b3e5" +dependencies = [ + "datasize_derive", +] + +[[package]] +name = "datasize_derive" +version = "0.2.15" +source = "git+https://github.com/BugenZhao/datasize-rs?rev=8192cf2d751119a6a30e2ef67e5eb252f8e5b3e5#8192cf2d751119a6a30e2ef67e5eb252f8e5b3e5" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "debugid" version = "0.8.0" @@ -6460,6 +6478,26 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "otlp-embedded" +version = "0.0.1" +source = "git+https://github.com/risingwavelabs/otlp-embedded?rev=58c1f003484449d7c6dd693b348bf19dd44889cb#58c1f003484449d7c6dd693b348bf19dd44889cb" +dependencies = [ + "axum", + "datasize", + "hex", + "itertools 0.12.0", + "madsim-tonic", + "madsim-tonic-build", + "prost 0.12.1", + "rust-embed", + "schnellru", + "serde", + "serde_json", + "tokio", + "tracing", +] + [[package]] name = "ouroboros" version = "0.17.2" @@ -8908,6 +8946,7 @@ dependencies = [ "mime_guess", "num-integer", "num-traits", + "otlp-embedded", "parking_lot 0.12.1", "prometheus", "prometheus-http-query", @@ -8976,6 +9015,7 @@ dependencies = [ "madsim-etcd-client", "madsim-tokio", "madsim-tonic", + "otlp-embedded", "prometheus-http-query", "redact", "regex", @@ -9558,6 +9598,41 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rust-embed" +version = "8.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "810294a8a4a0853d4118e3b94bb079905f2107c7fe979d8f0faae98765eb6378" +dependencies = [ + "rust-embed-impl", + "rust-embed-utils", + "walkdir", +] + +[[package]] +name = "rust-embed-impl" +version = "8.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfc144a1273124a67b8c1d7cd19f5695d1878b31569c0512f6086f0f4676604e" +dependencies = [ + "proc-macro2", + "quote", + "rust-embed-utils", + "syn 2.0.37", + "walkdir", +] + +[[package]] +name = "rust-embed-utils" +version = "8.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "816ccd4875431253d6bb54b804bcff4369cbde9bae33defde25fdf6c2ef91d40" +dependencies = [ + "mime_guess", + "sha2", + "walkdir", +] + [[package]] name = "rust-ini" version = "0.20.0" @@ -9751,6 +9826,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "schnellru" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "772575a524feeb803e5b0fcbc6dd9f367e579488197c94c6e4023aad2305774d" +dependencies = [ + "ahash 0.8.6", + "cfg-if", + "hashbrown 0.13.2", +] + [[package]] name = "scoped-tls" version = "1.0.1" @@ -12427,6 +12513,7 @@ dependencies = [ "aws-sigv4", "aws-smithy-runtime", "aws-smithy-types", + "axum", "base64 0.21.4", "bit-vec", "bitflags 2.4.0", @@ -12457,6 +12544,7 @@ dependencies = [ "generic-array", "governor", "hashbrown 0.12.3", + "hashbrown 0.13.2", "hashbrown 0.14.0", "hmac", "hyper", @@ -12544,6 +12632,7 @@ dependencies = [ "url", "uuid", "whoami", + "zeroize", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a46874f9e3da6..43d9aa56d2647 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -118,6 +118,7 @@ hashbrown = { version = "0.14.0", features = [ criterion = { version = "0.5", features = ["async_futures"] } tonic = { package = "madsim-tonic", version = "0.4.1" } tonic-build = { package = "madsim-tonic-build", version = "0.4.2" } +otlp-embedded = { git = "https://github.com/risingwavelabs/otlp-embedded", rev = "58c1f003484449d7c6dd693b348bf19dd44889cb" } prost = { version = "0.12" } icelake = { git = "https://github.com/icelake-io/icelake", rev = "3f7b53ba5b563524212c25810345d1314678e7fc", features = [ "prometheus", diff --git a/dashboard/components/Layout.tsx b/dashboard/components/Layout.tsx index a0bc7acd730d4..7be9f727bddec 100644 --- a/dashboard/components/Layout.tsx +++ b/dashboard/components/Layout.tsx @@ -161,6 +161,9 @@ function Layout({ children }: { children: React.ReactNode }) { Diagnose + + Traces +
Settings diff --git a/src/cmd/src/lib.rs b/src/cmd/src/lib.rs index 85691fbb6864e..ce110c9effc17 100644 --- a/src/cmd/src/lib.rs +++ b/src/cmd/src/lib.rs @@ -42,22 +42,22 @@ risingwave_expr_impl::enable!(); // Entry point functions. pub fn compute(opts: ComputeNodeOpts) { - init_risingwave_logger(LoggerSettings::new("compute")); + init_risingwave_logger(LoggerSettings::from_opts(&opts)); main_okk(risingwave_compute::start(opts)); } pub fn meta(opts: MetaNodeOpts) { - init_risingwave_logger(LoggerSettings::new("meta")); + init_risingwave_logger(LoggerSettings::from_opts(&opts)); main_okk(risingwave_meta_node::start(opts)); } pub fn frontend(opts: FrontendOpts) { - init_risingwave_logger(LoggerSettings::new("frontend")); + init_risingwave_logger(LoggerSettings::from_opts(&opts)); main_okk(risingwave_frontend::start(opts)); } pub fn compactor(opts: CompactorOpts) { - init_risingwave_logger(LoggerSettings::new("compactor")); + init_risingwave_logger(LoggerSettings::from_opts(&opts)); main_okk(risingwave_compactor::start(opts)); } diff --git a/src/cmd_all/src/bin/risingwave.rs b/src/cmd_all/src/bin/risingwave.rs index 05c7094a5c08c..a1e3a1b5f7063 100644 --- a/src/cmd_all/src/bin/risingwave.rs +++ b/src/cmd_all/src/bin/risingwave.rs @@ -192,7 +192,7 @@ fn main() -> Result<()> { } fn playground(opts: PlaygroundOpts) { - let settings = risingwave_rt::LoggerSettings::new("playground") + let settings = risingwave_rt::LoggerSettings::from_opts(&opts) .with_target("risingwave_storage", Level::WARN) .with_thread_name(true); risingwave_rt::init_risingwave_logger(settings); @@ -200,7 +200,8 @@ fn playground(opts: PlaygroundOpts) { } fn standalone(opts: StandaloneOpts) { - let settings = risingwave_rt::LoggerSettings::new("standalone") + let opts = risingwave_cmd_all::parse_standalone_opt_args(&opts); + let settings = risingwave_rt::LoggerSettings::from_opts(&opts) .with_target("risingwave_storage", Level::WARN) .with_thread_name(true); risingwave_rt::init_risingwave_logger(settings); diff --git a/src/cmd_all/src/playground.rs b/src/cmd_all/src/playground.rs index e7a018d1db92b..70039264f6ef7 100644 --- a/src/cmd_all/src/playground.rs +++ b/src/cmd_all/src/playground.rs @@ -20,6 +20,7 @@ use std::sync::LazyLock; use anyhow::Result; use clap::Parser; +use risingwave_common::util::meta_addr::MetaAddressStrategy; use tempfile::TempPath; use tokio::signal; @@ -143,6 +144,16 @@ pub struct PlaygroundOpts { profile: String, } +impl risingwave_common::opts::Opts for PlaygroundOpts { + fn name() -> &'static str { + "playground" + } + + fn meta_addr(&self) -> MetaAddressStrategy { + "http://0.0.0.0:5690".parse().unwrap() // hard-coded + } +} + pub async fn playground(opts: PlaygroundOpts) -> Result<()> { let profile = opts.profile; diff --git a/src/cmd_all/src/standalone.rs b/src/cmd_all/src/standalone.rs index 4ea5fe624be2d..a51fb03120313 100644 --- a/src/cmd_all/src/standalone.rs +++ b/src/cmd_all/src/standalone.rs @@ -14,6 +14,7 @@ use anyhow::Result; use clap::Parser; +use risingwave_common::util::meta_addr::MetaAddressStrategy; use risingwave_compactor::CompactorOpts; use risingwave_compute::ComputeNodeOpts; use risingwave_frontend::FrontendOpts; @@ -66,7 +67,27 @@ pub struct ParsedStandaloneOpts { pub compactor_opts: Option, } -fn parse_opt_args(opts: &StandaloneOpts) -> ParsedStandaloneOpts { +impl risingwave_common::opts::Opts for ParsedStandaloneOpts { + fn name() -> &'static str { + "standalone" + } + + fn meta_addr(&self) -> MetaAddressStrategy { + if let Some(opts) = self.meta_opts.as_ref() { + opts.meta_addr() + } else if let Some(opts) = self.compute_opts.as_ref() { + opts.meta_addr() + } else if let Some(opts) = self.frontend_opts.as_ref() { + opts.meta_addr() + } else if let Some(opts) = self.compactor_opts.as_ref() { + opts.meta_addr() + } else { + unreachable!("at least one service should be specified as checked during parsing") + } + } +} + +pub fn parse_standalone_opt_args(opts: &StandaloneOpts) -> ParsedStandaloneOpts { let meta_opts = opts.meta_opts.as_ref().map(|s| { let mut s = split(s).unwrap(); s.insert(0, "meta-node".into()); @@ -123,6 +144,15 @@ fn parse_opt_args(opts: &StandaloneOpts) -> ParsedStandaloneOpts { meta_opts.prometheus_host = Some(prometheus_listener_addr.clone()); } } + + if meta_opts.is_none() + && compute_opts.is_none() + && frontend_opts.is_none() + && compactor_opts.is_none() + { + panic!("No service is specified to start."); + } + ParsedStandaloneOpts { meta_opts, compute_opts, @@ -131,15 +161,15 @@ fn parse_opt_args(opts: &StandaloneOpts) -> ParsedStandaloneOpts { } } -pub async fn standalone(opts: StandaloneOpts) -> Result<()> { - tracing::info!("launching Risingwave in standalone mode"); - - let ParsedStandaloneOpts { +pub async fn standalone( + ParsedStandaloneOpts { meta_opts, compute_opts, frontend_opts, compactor_opts, - } = parse_opt_args(&opts); + }: ParsedStandaloneOpts, +) -> Result<()> { + tracing::info!("launching Risingwave in standalone mode"); if let Some(opts) = meta_opts { tracing::info!("starting meta-node thread with cli args: {:?}", opts); @@ -215,7 +245,7 @@ mod test { assert_eq!(actual, opts); // Test parsing into node-level opts. - let actual = parse_opt_args(&opts); + let actual = parse_standalone_opt_args(&opts); check( actual, expect![[r#" diff --git a/src/common/src/config.rs b/src/common/src/config.rs index 8bad7818a69b2..1c50c07d341aa 100644 --- a/src/common/src/config.rs +++ b/src/common/src/config.rs @@ -157,6 +157,10 @@ pub struct RwConfig { pub unrecognized: Unrecognized, } +serde_with::with_prefix!(meta_prefix "meta_"); +serde_with::with_prefix!(streaming_prefix "stream_"); +serde_with::with_prefix!(batch_prefix "batch_"); + #[derive(Copy, Clone, Debug, Default, ValueEnum, Serialize, Deserialize)] pub enum MetaBackend { #[default] @@ -305,6 +309,9 @@ pub struct MetaConfig { /// Keeps the latest N events per channel. #[serde(default = "default::meta::event_log_channel_max_size")] pub event_log_channel_max_size: u32, + + #[serde(default, with = "meta_prefix")] + pub developer: MetaDeveloperConfig, } #[derive(Clone, Debug, Default)] @@ -371,6 +378,22 @@ impl<'de> Deserialize<'de> for DefaultParallelism { } } +/// The subsections `[meta.developer]`. +/// +/// It is put at [`MetaConfig::developer`]. +#[derive(Clone, Debug, Serialize, Deserialize, DefaultFromSerde)] +pub struct MetaDeveloperConfig { + /// The number of traces to be cached in-memory by the tracing collector + /// embedded in the meta node. + #[serde(default = "default::developer::meta_cached_traces_num")] + pub cached_traces_num: u32, + + /// The maximum memory usage in bytes for the tracing collector embedded + /// in the meta node. + #[serde(default = "default::developer::meta_cached_traces_memory_limit_bytes")] + pub cached_traces_memory_limit_bytes: usize, +} + /// The section `[server]` in `risingwave.toml`. #[derive(Clone, Debug, Serialize, Deserialize, DefaultFromSerde)] pub struct ServerConfig { @@ -747,9 +770,6 @@ pub struct HeapProfilingConfig { pub dir: String, } -serde_with::with_prefix!(streaming_prefix "stream_"); -serde_with::with_prefix!(batch_prefix "batch_"); - /// The subsections `[streaming.developer]`. /// /// It is put at [`StreamingConfig::developer`]. @@ -1304,6 +1324,13 @@ pub mod default { } pub mod developer { + pub fn meta_cached_traces_num() -> u32 { + 256 + } + + pub fn meta_cached_traces_memory_limit_bytes() -> usize { + 1 << 27 // 128 MiB + } pub fn batch_output_channel_size() -> usize { 64 diff --git a/src/common/src/lib.rs b/src/common/src/lib.rs index 7bb359a13c6b9..21b3f393f7c25 100644 --- a/src/common/src/lib.rs +++ b/src/common/src/lib.rs @@ -70,6 +70,8 @@ pub mod log; pub mod memory; pub mod metrics; pub mod monitor; +pub mod opts; +pub mod range; pub mod row; pub mod session_config; pub mod system_param; @@ -79,8 +81,6 @@ pub mod transaction; pub mod types; pub mod vnode_mapping; -pub mod range; - pub mod test_prelude { pub use super::array::{DataChunkTestExt, StreamChunkTestExt}; pub use super::catalog::test_utils::ColumnDescTestExt; diff --git a/src/common/src/opts.rs b/src/common/src/opts.rs new file mode 100644 index 0000000000000..6c04fe7420fc4 --- /dev/null +++ b/src/common/src/opts.rs @@ -0,0 +1,24 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::util::meta_addr::MetaAddressStrategy; + +/// Accessor trait for a component's command-line options. +pub trait Opts { + /// The name of the component. + fn name() -> &'static str; + + /// The address to the meta node. + fn meta_addr(&self) -> MetaAddressStrategy; +} diff --git a/src/common/src/util/meta_addr.rs b/src/common/src/util/meta_addr.rs index 9f32244bd08e6..286d4d45f5ca7 100644 --- a/src/common/src/util/meta_addr.rs +++ b/src/common/src/util/meta_addr.rs @@ -82,6 +82,22 @@ impl fmt::Display for MetaAddressStrategy { } } +impl MetaAddressStrategy { + /// Returns `Some` if there's exactly one address. + pub fn exactly_one(&self) -> Option<&http::Uri> { + match self { + MetaAddressStrategy::LoadBalance(lb) => Some(lb), + MetaAddressStrategy::List(list) => { + if list.len() == 1 { + list.first() + } else { + None + } + } + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index e82a617cc1d7d..342448066dd06 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -135,6 +135,16 @@ pub struct ComputeNodeOpts { pub heap_profiling_dir: Option, } +impl risingwave_common::opts::Opts for ComputeNodeOpts { + fn name() -> &'static str { + "compute" + } + + fn meta_addr(&self) -> MetaAddressStrategy { + self.meta_address.clone() + } +} + #[derive(Copy, Clone, Debug, Default, ValueEnum, Serialize, Deserialize)] pub enum Role { Serving, diff --git a/src/config/example.toml b/src/config/example.toml index 867bb530c4b4c..5f4f91abd32e6 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -65,6 +65,10 @@ level0_max_compact_file_number = 96 tombstone_reclaim_ratio = 40 enable_emergency_picker = true +[meta.developer] +meta_cached_traces_num = 256 +meta_cached_traces_memory_limit_bytes = 134217728 + [batch] enable_barrier_read = false statement_timeout_in_sec = 3600 diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index 0cbcac72e729f..8cf2eb0e331e1 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -139,6 +139,16 @@ pub struct FrontendOpts { pub enable_barrier_read: Option, } +impl risingwave_common::opts::Opts for FrontendOpts { + fn name() -> &'static str { + "frontend" + } + + fn meta_addr(&self) -> MetaAddressStrategy { + self.meta_addr.clone() + } +} + impl Default for FrontendOpts { fn default() -> Self { FrontendOpts::parse_from(iter::empty::()) diff --git a/src/meta/Cargo.toml b/src/meta/Cargo.toml index 3f6e9907c8c4a..c97dfab2d429a 100644 --- a/src/meta/Cargo.toml +++ b/src/meta/Cargo.toml @@ -41,6 +41,7 @@ memcomparable = { version = "0.2" } mime_guess = "2" num-integer = "0.1" num-traits = "0.2" +otlp-embedded = { workspace = true } parking_lot = { version = "0.12", features = ["arc_lock"] } prometheus = "0.13" prometheus-http-query = "0.8" diff --git a/src/meta/node/Cargo.toml b/src/meta/node/Cargo.toml index f0f0bc1874522..4c1237dc16d24 100644 --- a/src/meta/node/Cargo.toml +++ b/src/meta/node/Cargo.toml @@ -20,6 +20,7 @@ either = "1" etcd-client = { workspace = true } futures = { version = "0.3", default-features = false, features = ["alloc"] } itertools = "0.12" +otlp-embedded = { workspace = true } prometheus-http-query = "0.8" redact = "0.1.5" regex = "1" diff --git a/src/meta/node/src/lib.rs b/src/meta/node/src/lib.rs index 1a3baf23053c6..f6aa1be0d08f5 100644 --- a/src/meta/node/src/lib.rs +++ b/src/meta/node/src/lib.rs @@ -24,6 +24,7 @@ use clap::Parser; pub use error::{MetaError, MetaResult}; use redact::Secret; use risingwave_common::config::OverrideConfig; +use risingwave_common::util::meta_addr::MetaAddressStrategy; use risingwave_common::util::resource_util; use risingwave_common::{GIT_SHA, RW_VERSION}; use risingwave_common_heap_profiling::HeapProfiler; @@ -43,8 +44,9 @@ pub struct MetaNodeOpts { #[clap(long, env = "RW_VPC_SECURITY_GROUP_ID")] security_group_id: Option, + // TODO: use `SocketAddr` #[clap(long, env = "RW_LISTEN_ADDR", default_value = "127.0.0.1:5690")] - listen_addr: String, + pub listen_addr: String, /// The address for contacting this instance of the service. /// This would be synonymous with the service's "public address" @@ -164,6 +166,18 @@ pub struct MetaNodeOpts { pub heap_profiling_dir: Option, } +impl risingwave_common::opts::Opts for MetaNodeOpts { + fn name() -> &'static str { + "meta" + } + + fn meta_addr(&self) -> MetaAddressStrategy { + format!("http://{}", self.listen_addr) + .parse() + .expect("invalid listen address") + } +} + use std::future::Future; use std::pin::Pin; @@ -302,6 +316,11 @@ pub fn start(opts: MetaNodeOpts) -> Pin + Send>> { event_log_enabled: config.meta.event_log_enabled, event_log_channel_max_size: config.meta.event_log_channel_max_size, advertise_addr: opts.advertise_addr, + cached_traces_num: config.meta.developer.cached_traces_num, + cached_traces_memory_limit_bytes: config + .meta + .developer + .cached_traces_memory_limit_bytes, }, config.system.into_init_system_params(), ) diff --git a/src/meta/node/src/server.rs b/src/meta/node/src/server.rs index ef14c3689b63c..d7bd0208b3873 100644 --- a/src/meta/node/src/server.rs +++ b/src/meta/node/src/server.rs @@ -19,6 +19,7 @@ use either::Either; use etcd_client::ConnectOptions; use futures::future::join_all; use itertools::Itertools; +use otlp_embedded::TraceServiceServer; use regex::Regex; use risingwave_common::monitor::connection::{RouterExt, TcpConfig}; use risingwave_common::telemetry::manager::TelemetryManager; @@ -489,6 +490,13 @@ pub async fn start_service_as_election_leader( )), MetadataManager::V2(_) => None, }; + + let trace_state = otlp_embedded::State::new(otlp_embedded::Config { + max_length: opts.cached_traces_num, + max_memory_usage: opts.cached_traces_memory_limit_bytes, + }); + let trace_srv = otlp_embedded::TraceServiceImpl::new(trace_state.clone()); + #[cfg(not(madsim))] let dashboard_task = if let Some(ref dashboard_addr) = address_info.dashboard_addr { let dashboard_service = crate::dashboard::DashboardService { @@ -499,6 +507,7 @@ pub async fn start_service_as_election_leader( compute_clients: ComputeClientPool::default(), ui_path: address_info.ui_path, diagnose_command, + trace_state, }; let task = tokio::spawn(dashboard_service.serve()); Some(task) @@ -814,6 +823,7 @@ pub async fn start_service_as_election_leader( .add_service(CloudServiceServer::new(cloud_srv)) .add_service(SinkCoordinationServiceServer::new(sink_coordination_srv)) .add_service(EventLogServiceServer::new(event_log_srv)) + .add_service(TraceServiceServer::new(trace_srv)) .monitored_serve_with_shutdown( address_info.listen_addr, "grpc-meta-leader-service", diff --git a/src/meta/src/dashboard/mod.rs b/src/meta/src/dashboard/mod.rs index 3b806e859f2b1..98b5c11c62f77 100644 --- a/src/meta/src/dashboard/mod.rs +++ b/src/meta/src/dashboard/mod.rs @@ -30,7 +30,7 @@ use axum::Router; use hyper::Request; use parking_lot::Mutex; use risingwave_rpc_client::ComputeClientPool; -use tower::ServiceBuilder; +use tower::{ServiceBuilder, ServiceExt}; use tower_http::add_extension::AddExtensionLayer; use tower_http::cors::{self, CorsLayer}; use tower_http::services::ServeDir; @@ -47,6 +47,7 @@ pub struct DashboardService { pub compute_clients: ComputeClientPool, pub ui_path: Option, pub diagnose_command: Option, + pub trace_state: otlp_embedded::StateRef, } pub type Service = Arc; @@ -403,22 +404,15 @@ impl DashboardService { ) .layer(cors_layer); - let app = if let Some(ui_path) = ui_path { - let static_file_router = Router::new().nest_service( - "/", - get_service(ServeDir::new(ui_path)).handle_error(|e| async move { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Unhandled internal error: {e}",), - ) - }), - ); - Router::new() - .fallback_service(static_file_router) - .nest("/api", api_router) + let trace_ui_router = otlp_embedded::ui_app(srv.trace_state.clone(), "/trace/"); + + let dashboard_router = if let Some(ui_path) = ui_path { + get_service(ServeDir::new(ui_path)) + .handle_error(|e| async move { match e {} }) + .boxed_clone() } else { let cache = Arc::new(Mutex::new(HashMap::new())); - let service = tower::service_fn(move |req: Request| { + tower::service_fn(move |req: Request| { let cache = cache.clone(); async move { proxy::proxy(req, cache).await.or_else(|err| { @@ -429,12 +423,15 @@ impl DashboardService { .into_response()) }) } - }); - Router::new() - .fallback_service(service) - .nest("/api", api_router) + }) + .boxed_clone() }; + let app = Router::new() + .fallback_service(dashboard_router) + .nest("/api", api_router) + .nest("/trace", trace_ui_router); + axum::Server::bind(&srv.dashboard_addr) .serve(app.into_make_service()) .await diff --git a/src/meta/src/manager/env.rs b/src/meta/src/manager/env.rs index 13cfa03bc3a1b..769cbbe198aff 100644 --- a/src/meta/src/manager/env.rs +++ b/src/meta/src/manager/env.rs @@ -195,6 +195,13 @@ pub struct MetaOpts { pub event_log_enabled: bool, pub event_log_channel_max_size: u32, pub advertise_addr: String, + + /// The number of traces to be cached in-memory by the tracing collector + /// embedded in the meta node. + pub cached_traces_num: u32, + /// The maximum memory usage in bytes for the tracing collector embedded + /// in the meta node. + pub cached_traces_memory_limit_bytes: usize, } impl MetaOpts { @@ -242,6 +249,8 @@ impl MetaOpts { event_log_enabled: false, event_log_channel_max_size: 1, advertise_addr: "".to_string(), + cached_traces_num: 1, + cached_traces_memory_limit_bytes: usize::MAX, } } } diff --git a/src/storage/compactor/src/lib.rs b/src/storage/compactor/src/lib.rs index 024c41ecb9620..e95ba1ff6a39a 100644 --- a/src/storage/compactor/src/lib.rs +++ b/src/storage/compactor/src/lib.rs @@ -90,6 +90,16 @@ pub struct CompactorOpts { pub proxy_rpc_endpoint: String, } +impl risingwave_common::opts::Opts for CompactorOpts { + fn name() -> &'static str { + "compactor" + } + + fn meta_addr(&self) -> MetaAddressStrategy { + self.meta_address.clone() + } +} + use std::future::Future; use std::pin::Pin; diff --git a/src/utils/runtime/src/logger.rs b/src/utils/runtime/src/logger.rs index a835d0bc6e4a5..11e82150de4a3 100644 --- a/src/utils/runtime/src/logger.rs +++ b/src/utils/runtime/src/logger.rs @@ -18,6 +18,7 @@ use std::path::PathBuf; use either::Either; use risingwave_common::metrics::MetricsLayer; use risingwave_common::util::deployment::Deployment; +use risingwave_common::util::env_var::env_var_is_true; use risingwave_common::util::query_log::*; use thiserror_ext::AsReport; use tracing::level_filters::LevelFilter as Level; @@ -30,7 +31,7 @@ use tracing_subscriber::prelude::*; use tracing_subscriber::{filter, EnvFilter}; pub struct LoggerSettings { - /// The name of the service. + /// The name of the service. Used to identify the service in distributed tracing. name: String, /// Enable tokio console output. enable_tokio_console: bool, @@ -44,6 +45,8 @@ pub struct LoggerSettings { targets: Vec<(String, tracing::metadata::LevelFilter)>, /// Override the default level. default_level: Option, + /// The endpoint of the tracing collector in OTLP gRPC protocol. + tracing_endpoint: Option, } impl Default for LoggerSettings { @@ -53,6 +56,29 @@ impl Default for LoggerSettings { } impl LoggerSettings { + /// Create a new logger settings from the given command-line options. + /// + /// If env var `RW_TRACING_ENDPOINT` is not set, the meta address will be used + /// as the default tracing endpoint, which means that the embedded tracing + /// collector will be used. This can be disabled by setting env var + /// `RW_DISABLE_EMBEDDED_TRACING` to `true`. + pub fn from_opts(opts: &O) -> Self { + let mut settings = Self::new(O::name()); + if settings.tracing_endpoint.is_none() // no explicit endpoint + && !env_var_is_true("RW_DISABLE_EMBEDDED_TRACING") // not disabled by env var + && let Some(addr) = opts.meta_addr().exactly_one() // meta address is valid + && !Deployment::current().is_ci() + // not in CI + { + // Use embedded collector in the meta service. + // TODO: when there's multiple meta nodes for high availability, we may send + // to a wrong node here. + settings.tracing_endpoint = Some(addr.to_string()); + } + settings + } + + /// Create a new logger settings with the given service name. pub fn new(name: impl Into) -> Self { Self { name: name.into(), @@ -62,6 +88,7 @@ impl LoggerSettings { with_thread_name: false, targets: vec![], default_level: None, + tracing_endpoint: std::env::var("RW_TRACING_ENDPOINT").ok(), } } @@ -98,6 +125,12 @@ impl LoggerSettings { self.default_level = Some(level.into()); self } + + /// Overrides the tracing endpoint. + pub fn with_tracing_endpoint(mut self, endpoint: impl Into) -> Self { + self.tracing_endpoint = Some(endpoint.into()); + self + } } /// Init logger for RisingWave binaries. @@ -353,7 +386,7 @@ pub fn init_risingwave_logger(settings: LoggerSettings) { // Tracing layer #[cfg(not(madsim))] - if let Ok(endpoint) = std::env::var("RW_TRACING_ENDPOINT") { + if let Some(endpoint) = settings.tracing_endpoint { println!("tracing enabled, exported to `{endpoint}`"); use opentelemetry::{sdk, KeyValue}; diff --git a/src/workspace-hack/Cargo.toml b/src/workspace-hack/Cargo.toml index 70f9c75efb9fa..09e5b2e6ce6c7 100644 --- a/src/workspace-hack/Cargo.toml +++ b/src/workspace-hack/Cargo.toml @@ -28,6 +28,7 @@ aws-sdk-s3 = { version = "1" } aws-sigv4 = { version = "1", features = ["http0-compat", "sign-eventstream", "sigv4a"] } aws-smithy-runtime = { version = "1", default-features = false, features = ["client", "rt-tokio", "tls-rustls"] } aws-smithy-types = { version = "1", default-features = false, features = ["byte-stream-poll-next", "http-body-0-4-x", "hyper-0-14-x", "rt-tokio"] } +axum = { version = "0.6" } base64 = { version = "0.21", features = ["alloc"] } bit-vec = { version = "0.6" } bitflags = { version = "2", default-features = false, features = ["serde", "std"] } @@ -57,6 +58,7 @@ futures-util = { version = "0.3", features = ["channel", "io", "sink"] } generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } governor = { version = "0.6", default-features = false, features = ["dashmap", "jitter", "std"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["nightly", "raw"] } +hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] } hashbrown-5ef9efb8ec2df382 = { package = "hashbrown", version = "0.12", features = ["nightly", "raw"] } hmac = { version = "0.12", default-features = false, features = ["reset"] } hyper = { version = "0.14", features = ["full"] } @@ -142,6 +144,7 @@ unicode-normalization = { version = "0.1" } url = { version = "2", features = ["serde"] } uuid = { version = "1", features = ["fast-rng", "serde", "v4"] } whoami = { version = "1" } +zeroize = { version = "1" } [build-dependencies] ahash = { version = "0.8" } @@ -152,9 +155,11 @@ bitflags = { version = "2", default-features = false, features = ["serde", "std" bytes = { version = "1", features = ["serde"] } cc = { version = "1", default-features = false, features = ["parallel"] } deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] } +digest = { version = "0.10", features = ["mac", "oid", "std"] } either = { version = "1", features = ["serde"] } fixedbitset = { version = "0.4" } frunk_core = { version = "0.4", default-features = false, features = ["std"] } +generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["nightly", "raw"] } itertools = { version = "0.11" } lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } @@ -179,11 +184,14 @@ regex-automata = { version = "0.4", default-features = false, features = ["dfa", regex-syntax = { version = "0.8" } serde = { version = "1", features = ["alloc", "derive", "rc"] } serde_json = { version = "1", features = ["alloc", "raw_value"] } +sha2 = { version = "0.10", features = ["oid"] } +subtle = { version = "2" } syn-dff4ba8e3ae991db = { package = "syn", version = "1", features = ["extra-traits", "full", "visit", "visit-mut"] } syn-f595c2ba2a3f28df = { package = "syn", version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] } time-macros = { version = "0.2", default-features = false, features = ["formatting", "parsing", "serde"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } toml_edit = { version = "0.19", features = ["serde"] } +zeroize = { version = "1" } ### END HAKARI SECTION From 1aa9938fdbcbc046b80905822574f4a17aed1fe5 Mon Sep 17 00:00:00 2001 From: Noel Kwan <47273164+kwannoel@users.noreply.github.com> Date: Mon, 8 Jan 2024 13:33:13 +0800 Subject: [PATCH 12/20] feat(meta): remove stream job progress for sink (#14388) --- src/meta/src/barrier/mod.rs | 4 +--- src/meta/src/barrier/progress.rs | 30 ------------------------------ 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/src/meta/src/barrier/mod.rs b/src/meta/src/barrier/mod.rs index 854de53f37708..5f5227e8090d1 100644 --- a/src/meta/src/barrier/mod.rs +++ b/src/meta/src/barrier/mod.rs @@ -1177,9 +1177,7 @@ impl GlobalBarrierManager { // Update the progress of all commands. for progress in resps.iter().flat_map(|r| &r.create_mview_progress) { // Those with actors complete can be finished immediately. - if let Some(command) = tracker.update(progress, &version_stats) - && !command.tracks_sink() - { + if let Some(command) = tracker.update(progress, &version_stats) { tracing::trace!(?progress, "finish progress"); commands.push(command); } else { diff --git a/src/meta/src/barrier/progress.rs b/src/meta/src/barrier/progress.rs index fede60d5eb939..ba1e11c9c6fa3 100644 --- a/src/meta/src/barrier/progress.rs +++ b/src/meta/src/barrier/progress.rs @@ -223,13 +223,6 @@ impl TrackingJob { TrackingJob::Recovered(recovered) => Some(recovered.fragments.table_id()), } } - - pub(crate) fn tracks_sink(&self) -> bool { - match self { - TrackingJob::New(command) => command.tracks_sink(), - TrackingJob::Recovered(_) => false, - } - } } pub struct RecoveredTrackingJob { @@ -247,15 +240,6 @@ pub(super) struct TrackingCommand { pub notifiers: Vec, } -impl TrackingCommand { - pub fn tracks_sink(&self) -> bool { - match &self.context.command { - Command::CreateStreamingJob { ddl_type, .. } => *ddl_type == DdlType::Sink, - _ => false, - } - } -} - /// Track the progress of all creating mviews. When creation is done, `notify_finished` will be /// called on registered notifiers. /// @@ -443,20 +427,6 @@ impl CreateMviewProgressTracker { definition, ); if *ddl_type == DdlType::Sink { - // First we duplicate a separate tracking job for sink. - // This does not need notifiers, it is solely used for - // tracking the backfill progress of sink. - // It will still be removed from progress map when - // backfill completes. - let tracking_job = TrackingJob::New(TrackingCommand { - context: command.context.clone(), - notifiers: vec![], - }); - let old = self - .progress_map - .insert(creating_mv_id, (progress, tracking_job)); - assert!(old.is_none()); - // We return the original tracking job immediately. // This is because sink can be decoupled with backfill progress. // We don't need to wait for sink to finish backfill. From 512e88497b423a05f2eb65821f9e0c800f07ace6 Mon Sep 17 00:00:00 2001 From: Dylan Date: Mon, 8 Jan 2024 13:35:47 +0800 Subject: [PATCH 13/20] feat(dml): sent dml data from the same session to a fixed worker node/channel (#14380) --- e2e_test/batch/transaction/same_session.slt | 27 ++++++ proto/batch_plan.proto | 9 ++ src/batch/src/executor/delete.rs | 7 +- src/batch/src/executor/insert.rs | 7 +- src/batch/src/executor/update.rs | 7 +- src/compute/tests/integration_tests.rs | 2 + .../src/optimizer/plan_node/batch_delete.rs | 1 + .../src/optimizer/plan_node/batch_insert.rs | 1 + .../src/optimizer/plan_node/batch_update.rs | 1 + .../src/scheduler/distributed/stage.rs | 83 +++++++++---------- src/frontend/src/scheduler/local.rs | 6 +- src/frontend/src/scheduler/plan_fragmenter.rs | 9 ++ src/source/src/dml_manager.rs | 17 +++- src/source/src/table.rs | 15 +++- src/stream/src/executor/dml.rs | 46 +++++++--- src/stream/src/executor/stream_reader.rs | 9 +- 16 files changed, 177 insertions(+), 70 deletions(-) create mode 100644 e2e_test/batch/transaction/same_session.slt diff --git a/e2e_test/batch/transaction/same_session.slt b/e2e_test/batch/transaction/same_session.slt new file mode 100644 index 0000000000000..2593c4d338d03 --- /dev/null +++ b/e2e_test/batch/transaction/same_session.slt @@ -0,0 +1,27 @@ +statement ok +create table t (id int primary key); + +statement ok +insert into t select i from generate_series(1, 100, 1) i; + +statement ok +flush + +# we don't use flush between delete and insert to test in the same session whether delete and insert overlap. +statement ok +delete from t; + +statement ok +insert into t select i from generate_series(1, 100, 1) i; + +statement ok +flush + +# Should be no overlap +query I +select count(*) from t; +---- +100 + +statement ok +drop table t; \ No newline at end of file diff --git a/proto/batch_plan.proto b/proto/batch_plan.proto index bf7fab1ae37f8..f6164f12226bf 100644 --- a/proto/batch_plan.proto +++ b/proto/batch_plan.proto @@ -88,6 +88,9 @@ message InsertNode { // be filled in streaming. optional uint32 row_id_index = 3; bool returning = 4; + + // Session id is used to ensure that dml data from the same session should be sent to a fixed worker node and channel. + uint32 session_id = 7; } message DeleteNode { @@ -96,6 +99,9 @@ message DeleteNode { // Version of the table. uint64 table_version_id = 3; bool returning = 2; + + // Session id is used to ensure that dml data from the same session should be sent to a fixed worker node and channel. + uint32 session_id = 4; } message UpdateNode { @@ -107,6 +113,9 @@ message UpdateNode { bool returning = 3; // The columns indices in the input schema, representing the columns need to send to streamDML exeuctor. repeated uint32 update_column_indices = 5; + + // Session id is used to ensure that dml data from the same session should be sent to a fixed worker node and channel. + uint32 session_id = 6; } message ValuesNode { diff --git a/src/batch/src/executor/delete.rs b/src/batch/src/executor/delete.rs index b0a7499ae161a..c5d7d06c42335 100644 --- a/src/batch/src/executor/delete.rs +++ b/src/batch/src/executor/delete.rs @@ -44,6 +44,7 @@ pub struct DeleteExecutor { identity: String, returning: bool, txn_id: TxnId, + session_id: u32, } impl DeleteExecutor { @@ -55,6 +56,7 @@ impl DeleteExecutor { chunk_size: usize, identity: String, returning: bool, + session_id: u32, ) -> Self { let table_schema = child.schema().clone(); let txn_id = dml_manager.gen_txn_id(); @@ -74,6 +76,7 @@ impl DeleteExecutor { identity, returning, txn_id, + session_id, } } } @@ -110,7 +113,7 @@ impl DeleteExecutor { self.child.schema().data_types(), "bad delete schema" ); - let mut write_handle = table_dml_handle.write_handle(self.txn_id)?; + let mut write_handle = table_dml_handle.write_handle(self.session_id, self.txn_id)?; write_handle.begin()?; @@ -182,6 +185,7 @@ impl BoxedExecutorBuilder for DeleteExecutor { source.context.get_config().developer.chunk_size, source.plan_node().get_identity().clone(), delete_node.returning, + delete_node.session_id, ))) } } @@ -247,6 +251,7 @@ mod tests { 1024, "DeleteExecutor".to_string(), false, + 0, )); let handle = tokio::spawn(async move { diff --git a/src/batch/src/executor/insert.rs b/src/batch/src/executor/insert.rs index 7536f160be32f..d236a25561029 100644 --- a/src/batch/src/executor/insert.rs +++ b/src/batch/src/executor/insert.rs @@ -52,6 +52,7 @@ pub struct InsertExecutor { row_id_index: Option, returning: bool, txn_id: TxnId, + session_id: u32, } impl InsertExecutor { @@ -67,6 +68,7 @@ impl InsertExecutor { sorted_default_columns: Vec<(usize, BoxedExpression)>, row_id_index: Option, returning: bool, + session_id: u32, ) -> Self { let table_schema = child.schema().clone(); let txn_id = dml_manager.gen_txn_id(); @@ -89,6 +91,7 @@ impl InsertExecutor { row_id_index, returning, txn_id, + session_id, } } } @@ -116,7 +119,7 @@ impl InsertExecutor { let table_dml_handle = self .dml_manager .table_dml_handle(self.table_id, self.table_version_id)?; - let mut write_handle = table_dml_handle.write_handle(self.txn_id)?; + let mut write_handle = table_dml_handle.write_handle(self.session_id, self.txn_id)?; write_handle.begin()?; @@ -253,6 +256,7 @@ impl BoxedExecutorBuilder for InsertExecutor { sorted_default_columns, insert_node.row_id_index.as_ref().map(|index| *index as _), insert_node.returning, + insert_node.session_id, ))) } } @@ -348,6 +352,7 @@ mod tests { vec![], row_id_index, false, + 0, )); let handle = tokio::spawn(async move { let mut stream = insert_executor.execute(); diff --git a/src/batch/src/executor/update.rs b/src/batch/src/executor/update.rs index b0e1b4750cfcf..1706a5f5cba7a 100644 --- a/src/batch/src/executor/update.rs +++ b/src/batch/src/executor/update.rs @@ -49,6 +49,7 @@ pub struct UpdateExecutor { returning: bool, txn_id: TxnId, update_column_indices: Vec, + session_id: u32, } impl UpdateExecutor { @@ -63,6 +64,7 @@ impl UpdateExecutor { identity: String, returning: bool, update_column_indices: Vec, + session_id: u32, ) -> Self { let chunk_size = chunk_size.next_multiple_of(2); let table_schema = child.schema().clone(); @@ -86,6 +88,7 @@ impl UpdateExecutor { returning, txn_id, update_column_indices, + session_id, } } } @@ -134,7 +137,7 @@ impl UpdateExecutor { let mut builder = DataChunkBuilder::new(data_types, self.chunk_size); let mut write_handle: risingwave_source::WriteHandle = - table_dml_handle.write_handle(self.txn_id)?; + table_dml_handle.write_handle(self.session_id, self.txn_id)?; write_handle.begin()?; // Transform the data chunk to a stream chunk, then write to the source. @@ -246,6 +249,7 @@ impl BoxedExecutorBuilder for UpdateExecutor { source.plan_node().get_identity().clone(), update_node.returning, update_column_indices, + update_node.session_id, ))) } } @@ -321,6 +325,7 @@ mod tests { "UpdateExecutor".to_string(), false, vec![0, 1], + 0, )); let handle = tokio::spawn(async move { diff --git a/src/compute/tests/integration_tests.rs b/src/compute/tests/integration_tests.rs index 60cc689616c60..490e90d174013 100644 --- a/src/compute/tests/integration_tests.rs +++ b/src/compute/tests/integration_tests.rs @@ -245,6 +245,7 @@ async fn test_table_materialize() -> StreamResult<()> { vec![], Some(row_id_index), false, + 0, )); let value_indices = (0..column_descs.len()).collect_vec(); @@ -366,6 +367,7 @@ async fn test_table_materialize() -> StreamResult<()> { 1024, "DeleteExecutor".to_string(), false, + 0, )); curr_epoch += 1; diff --git a/src/frontend/src/optimizer/plan_node/batch_delete.rs b/src/frontend/src/optimizer/plan_node/batch_delete.rs index c960eb1d83c90..d1fc6f1947d19 100644 --- a/src/frontend/src/optimizer/plan_node/batch_delete.rs +++ b/src/frontend/src/optimizer/plan_node/batch_delete.rs @@ -71,6 +71,7 @@ impl ToBatchPb for BatchDelete { table_id: self.core.table_id.table_id(), table_version_id: self.core.table_version_id, returning: self.core.returning, + session_id: self.base.ctx().session_ctx().session_id().0 as u32, }) } } diff --git a/src/frontend/src/optimizer/plan_node/batch_insert.rs b/src/frontend/src/optimizer/plan_node/batch_insert.rs index 4a280471fe199..caf7c449358b5 100644 --- a/src/frontend/src/optimizer/plan_node/batch_insert.rs +++ b/src/frontend/src/optimizer/plan_node/batch_insert.rs @@ -101,6 +101,7 @@ impl ToBatchPb for BatchInsert { }, row_id_index: self.core.row_id_index.map(|index| index as _), returning: self.core.returning, + session_id: self.base.ctx().session_ctx().session_id().0 as u32, }) } } diff --git a/src/frontend/src/optimizer/plan_node/batch_update.rs b/src/frontend/src/optimizer/plan_node/batch_update.rs index 5b3a6a8739fc7..b2e7e1913fb3a 100644 --- a/src/frontend/src/optimizer/plan_node/batch_update.rs +++ b/src/frontend/src/optimizer/plan_node/batch_update.rs @@ -84,6 +84,7 @@ impl ToBatchPb for BatchUpdate { table_version_id: self.core.table_version_id, returning: self.core.returning, update_column_indices, + session_id: self.base.ctx().session_ctx().session_id().0 as u32, }) } } diff --git a/src/frontend/src/scheduler/distributed/stage.rs b/src/frontend/src/scheduler/distributed/stage.rs index cc86d55f2b988..2d0df049da3fa 100644 --- a/src/frontend/src/scheduler/distributed/stage.rs +++ b/src/frontend/src/scheduler/distributed/stage.rs @@ -27,7 +27,6 @@ use futures::stream::Fuse; use futures::{stream, StreamExt, TryStreamExt}; use futures_async_stream::for_await; use itertools::Itertools; -use rand::seq::SliceRandom; use risingwave_batch::executor::ExecutorBuilder; use risingwave_batch::task::{ShutdownMsg, ShutdownSender, ShutdownToken, TaskId as TaskIdBatch}; use risingwave_common::array::DataChunk; @@ -698,51 +697,51 @@ impl StageRunner { dml_table_id: Option, ) -> SchedulerResult> { let plan_node = plan_fragment.root.as_ref().expect("fail to get plan node"); - let vnode_mapping = match dml_table_id { - Some(table_id) => Some(self.get_table_dml_vnode_mapping(&table_id)?), - None => { - if let Some(distributed_lookup_join_node) = - Self::find_distributed_lookup_join_node(plan_node) - { - let fragment_id = self.get_fragment_id( - &distributed_lookup_join_node - .inner_side_table_desc - .as_ref() - .unwrap() - .table_id - .into(), - )?; - let id2pu_vec = self - .worker_node_manager - .fragment_mapping(fragment_id)? - .iter_unique() - .collect_vec(); - - let pu = id2pu_vec[task_id as usize]; - let candidates = self - .worker_node_manager - .manager - .get_workers_by_parallel_unit_ids(&[pu])?; - return Ok(Some(candidates[0].clone())); - } else { - None - } - } - }; - let worker_node = match vnode_mapping { - Some(mapping) => { - let parallel_unit_ids = mapping.iter_unique().collect_vec(); - let candidates = self - .worker_node_manager - .manager - .get_workers_by_parallel_unit_ids(¶llel_unit_ids)?; - Some(candidates.choose(&mut rand::thread_rng()).unwrap().clone()) + if let Some(table_id) = dml_table_id { + let vnode_mapping = self.get_table_dml_vnode_mapping(&table_id)?; + let parallel_unit_ids = vnode_mapping.iter_unique().collect_vec(); + let candidates = self + .worker_node_manager + .manager + .get_workers_by_parallel_unit_ids(¶llel_unit_ids)?; + if candidates.is_empty() { + return Err(SchedulerError::EmptyWorkerNodes); } - None => None, + return Ok(Some( + candidates[self.stage.session_id.0 as usize % candidates.len()].clone(), + )); }; - Ok(worker_node) + if let Some(distributed_lookup_join_node) = + Self::find_distributed_lookup_join_node(plan_node) + { + let fragment_id = self.get_fragment_id( + &distributed_lookup_join_node + .inner_side_table_desc + .as_ref() + .unwrap() + .table_id + .into(), + )?; + let id2pu_vec = self + .worker_node_manager + .fragment_mapping(fragment_id)? + .iter_unique() + .collect_vec(); + + let pu = id2pu_vec[task_id as usize]; + let candidates = self + .worker_node_manager + .manager + .get_workers_by_parallel_unit_ids(&[pu])?; + if candidates.is_empty() { + return Err(SchedulerError::EmptyWorkerNodes); + } + Ok(Some(candidates[0].clone())) + } else { + Ok(None) + } } fn find_distributed_lookup_join_node( diff --git a/src/frontend/src/scheduler/local.rs b/src/frontend/src/scheduler/local.rs index 63ebe1b443bb7..95b36e50b4978 100644 --- a/src/frontend/src/scheduler/local.rs +++ b/src/frontend/src/scheduler/local.rs @@ -24,7 +24,6 @@ use futures::StreamExt; use futures_async_stream::try_stream; use itertools::Itertools; use pgwire::pg_server::BoxedError; -use rand::seq::SliceRandom; use risingwave_batch::executor::ExecutorBuilder; use risingwave_batch::task::{ShutdownToken, TaskId}; use risingwave_common::array::DataChunk; @@ -581,7 +580,10 @@ impl LocalQueryExecution { .worker_node_manager .manager .get_workers_by_parallel_unit_ids(¶llel_unit_ids)?; - candidates.choose(&mut rand::thread_rng()).unwrap().clone() + if candidates.is_empty() { + return Err(SchedulerError::EmptyWorkerNodes); + } + candidates[stage.session_id.0 as usize % candidates.len()].clone() }; Ok(vec![worker_node]) } else { diff --git a/src/frontend/src/scheduler/plan_fragmenter.rs b/src/frontend/src/scheduler/plan_fragmenter.rs index 09453b9cfe446..e40282cbacf86 100644 --- a/src/frontend/src/scheduler/plan_fragmenter.rs +++ b/src/frontend/src/scheduler/plan_fragmenter.rs @@ -22,6 +22,7 @@ use anyhow::anyhow; use async_recursion::async_recursion; use enum_as_inner::EnumAsInner; use itertools::Itertools; +use pgwire::pg_server::SessionId; use risingwave_common::buffer::{Bitmap, BitmapBuilder}; use risingwave_common::catalog::TableDesc; use risingwave_common::error::RwError; @@ -364,6 +365,7 @@ pub struct QueryStage { pub source_info: Option, pub has_lookup_join: bool, pub dml_table_id: Option, + pub session_id: SessionId, /// Used to generate exchange information when complete source scan information. children_exchange_distribution: Option>, @@ -395,6 +397,7 @@ impl QueryStage { source_info: self.source_info.clone(), has_lookup_join: self.has_lookup_join, dml_table_id: self.dml_table_id, + session_id: self.session_id, children_exchange_distribution: self.children_exchange_distribution.clone(), }; } @@ -423,6 +426,7 @@ impl QueryStage { source_info: Some(source_info), has_lookup_join: self.has_lookup_join, dml_table_id: self.dml_table_id, + session_id: self.session_id, children_exchange_distribution: None, } } @@ -467,6 +471,7 @@ struct QueryStageBuilder { source_info: Option, has_lookup_join: bool, dml_table_id: Option, + session_id: SessionId, children_exchange_distribution: HashMap, } @@ -482,6 +487,7 @@ impl QueryStageBuilder { source_info: Option, has_lookup_join: bool, dml_table_id: Option, + session_id: SessionId, ) -> Self { Self { query_id, @@ -494,6 +500,7 @@ impl QueryStageBuilder { source_info, has_lookup_join, dml_table_id, + session_id, children_exchange_distribution: HashMap::new(), } } @@ -514,6 +521,7 @@ impl QueryStageBuilder { source_info: self.source_info, has_lookup_join: self.has_lookup_join, dml_table_id: self.dml_table_id, + session_id: self.session_id, children_exchange_distribution, }); @@ -809,6 +817,7 @@ impl BatchPlanFragmenter { source_info, has_lookup_join, dml_table_id, + root.ctx().session_ctx().session_id(), ); self.visit_node(root, &mut builder, None)?; diff --git a/src/source/src/dml_manager.rs b/src/source/src/dml_manager.rs index af34003f4bde1..b4b03f9798c56 100644 --- a/src/source/src/dml_manager.rs +++ b/src/source/src/dml_manager.rs @@ -180,6 +180,7 @@ mod tests { use super::*; const TEST_TRANSACTION_ID: TxnId = 0; + const TEST_SESSION_ID: u32 = 0; #[tokio::test] async fn test_register_and_drop() { @@ -206,7 +207,9 @@ mod tests { let table_dml_handle = dml_manager .table_dml_handle(table_id, table_version_id) .unwrap(); - let mut write_handle = table_dml_handle.write_handle(TEST_TRANSACTION_ID).unwrap(); + let mut write_handle = table_dml_handle + .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID) + .unwrap(); write_handle.begin().unwrap(); // Should be able to write to the table. @@ -219,7 +222,9 @@ mod tests { write_handle.write_chunk(chunk()).await.unwrap_err(); // Unless we create a new write handle. - let mut write_handle = table_dml_handle.write_handle(TEST_TRANSACTION_ID).unwrap(); + let mut write_handle = table_dml_handle + .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID) + .unwrap(); write_handle.begin().unwrap(); write_handle.write_chunk(chunk()).await.unwrap(); @@ -254,7 +259,9 @@ mod tests { let table_dml_handle = dml_manager .table_dml_handle(table_id, old_version_id) .unwrap(); - let mut write_handle = table_dml_handle.write_handle(TEST_TRANSACTION_ID).unwrap(); + let mut write_handle = table_dml_handle + .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID) + .unwrap(); write_handle.begin().unwrap(); // Should be able to write to the table. @@ -278,7 +285,9 @@ mod tests { let table_dml_handle = dml_manager .table_dml_handle(table_id, new_version_id) .unwrap(); - let mut write_handle = table_dml_handle.write_handle(TEST_TRANSACTION_ID).unwrap(); + let mut write_handle = table_dml_handle + .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID) + .unwrap(); write_handle.begin().unwrap(); write_handle.write_chunk(new_chunk()).await.unwrap(); } diff --git a/src/source/src/table.rs b/src/source/src/table.rs index 503807283f465..ba0292e4f6caf 100644 --- a/src/source/src/table.rs +++ b/src/source/src/table.rs @@ -78,7 +78,7 @@ impl TableDmlHandle { TableStreamReader { rx } } - pub fn write_handle(&self, txn_id: TxnId) -> Result { + pub fn write_handle(&self, session_id: u32, txn_id: TxnId) -> Result { // The `changes_txs` should not be empty normally, since we ensured that the channels // between the `TableDmlHandle` and the `SourceExecutor`s are ready before we making the // table catalog visible to the users. However, when we're recovering, it's possible @@ -94,9 +94,11 @@ impl TableDmlHandle { ))); } let len = guard.changes_txs.len(); + // Use session id instead of txn_id to choose channel so that we can preserve transaction order in the same session. + // PS: only hold if there's no scaling on the table. let sender = guard .changes_txs - .get((txn_id % len as u64) as usize) + .get((session_id % len as u32) as usize) .context("no available table reader in streaming source executors")? .clone(); @@ -298,6 +300,7 @@ mod tests { use super::*; const TEST_TRANSACTION_ID: TxnId = 0; + const TEST_SESSION_ID: u32 = 0; fn new_table_dml_handle() -> TableDmlHandle { TableDmlHandle::new( @@ -310,7 +313,9 @@ mod tests { async fn test_table_dml_handle() -> Result<()> { let table_dml_handle = Arc::new(new_table_dml_handle()); let mut reader = table_dml_handle.stream_reader().into_stream(); - let mut write_handle = table_dml_handle.write_handle(TEST_TRANSACTION_ID).unwrap(); + let mut write_handle = table_dml_handle + .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID) + .unwrap(); write_handle.begin().unwrap(); assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_)); @@ -354,7 +359,9 @@ mod tests { async fn test_write_handle_rollback_on_drop() -> Result<()> { let table_dml_handle = Arc::new(new_table_dml_handle()); let mut reader = table_dml_handle.stream_reader().into_stream(); - let mut write_handle = table_dml_handle.write_handle(TEST_TRANSACTION_ID).unwrap(); + let mut write_handle = table_dml_handle + .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID) + .unwrap(); write_handle.begin().unwrap(); assert_matches!(reader.next().await.unwrap()?, TxnMsg::Begin(_)); diff --git a/src/stream/src/executor/dml.rs b/src/stream/src/executor/dml.rs index 435192974bffc..7e0b50c51a52a 100644 --- a/src/stream/src/executor/dml.rs +++ b/src/stream/src/executor/dml.rs @@ -192,6 +192,21 @@ impl DmlExecutor { batch_group.iter().map(|c| c.cardinality()).sum::(); if txn_buffer_cardinality >= self.chunk_size { + // txn buffer is too large, so yield batch group first to preserve the transaction order in the same session. + if !batch_group.is_empty() { + let vec = mem::take(&mut batch_group); + for chunk in vec { + for (op, row) in chunk.rows() { + if let Some(chunk) = builder.append_row(op, row) { + yield Message::Chunk(chunk); + } + } + } + if let Some(chunk) = builder.take() { + yield Message::Chunk(chunk); + } + } + // txn buffer isn't small, so yield. for chunk in txn_buffer.vec { yield Message::Chunk(chunk); @@ -202,21 +217,23 @@ impl DmlExecutor { // txn buffer is small and batch group has space. batch_group.extend(txn_buffer.vec); } else { - // txn buffer is small and batch group has no space, so yield the large one. - if txn_buffer_cardinality < batch_group_cardinality { - mem::swap(&mut txn_buffer.vec, &mut batch_group); - } - - for chunk in txn_buffer.vec { - for (op, row) in chunk.rows() { - if let Some(chunk) = builder.append_row(op, row) { - yield Message::Chunk(chunk); + // txn buffer is small and batch group has no space, so yield the batch group first to preserve the transaction order in the same session. + if !batch_group.is_empty() { + let vec = mem::take(&mut batch_group); + for chunk in vec { + for (op, row) in chunk.rows() { + if let Some(chunk) = builder.append_row(op, row) { + yield Message::Chunk(chunk); + } } } + if let Some(chunk) = builder.take() { + yield Message::Chunk(chunk); + } } - if let Some(chunk) = builder.take() { - yield Message::Chunk(chunk); - } + + // put txn buffer into the batch group + mem::swap(&mut txn_buffer.vec, &mut batch_group); } } TxnMsg::Rollback(txn_id) => { @@ -288,6 +305,7 @@ mod tests { use crate::executor::test_utils::MockSource; const TEST_TRANSACTION_ID: TxnId = 0; + const TEST_SESSION_ID: u32 = 0; #[tokio::test] async fn test_dml_executor() { @@ -357,7 +375,9 @@ mod tests { let table_dml_handle = dml_manager .table_dml_handle(table_id, INITIAL_TABLE_VERSION_ID) .unwrap(); - let mut write_handle = table_dml_handle.write_handle(TEST_TRANSACTION_ID).unwrap(); + let mut write_handle = table_dml_handle + .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID) + .unwrap(); // Message from batch write_handle.begin().unwrap(); diff --git a/src/stream/src/executor/stream_reader.rs b/src/stream/src/executor/stream_reader.rs index c8d84926bd6ad..de490f730dea8 100644 --- a/src/stream/src/executor/stream_reader.rs +++ b/src/stream/src/executor/stream_reader.rs @@ -152,6 +152,7 @@ mod tests { const TEST_TRANSACTION_ID1: TxnId = 0; const TEST_TRANSACTION_ID2: TxnId = 1; + const TEST_SESSION_ID: u32 = 0; const TEST_DML_CHANNEL_INIT_PERMITS: usize = 32768; #[tokio::test] @@ -162,8 +163,12 @@ mod tests { let source_stream = table_dml_handle.stream_reader().into_data_stream_for_test(); - let mut write_handle1 = table_dml_handle.write_handle(TEST_TRANSACTION_ID1).unwrap(); - let mut write_handle2 = table_dml_handle.write_handle(TEST_TRANSACTION_ID2).unwrap(); + let mut write_handle1 = table_dml_handle + .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID1) + .unwrap(); + let mut write_handle2 = table_dml_handle + .write_handle(TEST_SESSION_ID, TEST_TRANSACTION_ID2) + .unwrap(); let barrier_stream = barrier_to_message_stream(barrier_rx).boxed(); let stream = From 322c24de852911b66a19588a0ab66fd5af76b226 Mon Sep 17 00:00:00 2001 From: Noel Kwan <47273164+kwannoel@users.noreply.github.com> Date: Mon, 8 Jan 2024 14:07:35 +0800 Subject: [PATCH 14/20] chore(ci): fix duplicated key in `main-cron` workflow (#14408) --- ci/workflows/main-cron.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/workflows/main-cron.yml b/ci/workflows/main-cron.yml index d931c3af16660..b5533ae149e0b 100644 --- a/ci/workflows/main-cron.yml +++ b/ci/workflows/main-cron.yml @@ -490,7 +490,7 @@ steps: retry: *auto-retry - label: "PosixFs source on OpenDAL fs engine (csv parser)" - key: "s3-source-test-for-opendal-fs-engine" + key: "s3-source-test-for-opendal-fs-engine-csv-parser" command: "ci/scripts/s3-source-test.sh -p ci-release -s 'posix_fs_source.py csv_without_header'" if: | !(build.pull_request.labels includes "ci/main-cron/skip-ci") && build.env("CI_STEPS") == null From d2f1bafd33918418c8e68727351a153ad5dde45f Mon Sep 17 00:00:00 2001 From: Noel Kwan <47273164+kwannoel@users.noreply.github.com> Date: Mon, 8 Jan 2024 14:28:14 +0800 Subject: [PATCH 15/20] feat(sqlsmith): refactor test runners (#14405) --- src/tests/simulation/src/main.rs | 15 +- src/tests/sqlsmith/src/bin/main.rs | 2 +- src/tests/sqlsmith/src/lib.rs | 2 +- src/tests/sqlsmith/src/runner.rs | 681 ------------------ src/tests/sqlsmith/src/test_runners/README.md | 0 src/tests/sqlsmith/src/test_runners/diff.rs | 206 ++++++ .../sqlsmith/src/test_runners/fuzzing.rs | 180 +++++ src/tests/sqlsmith/src/test_runners/mod.rs | 26 + src/tests/sqlsmith/src/test_runners/utils.rs | 353 +++++++++ 9 files changed, 777 insertions(+), 688 deletions(-) delete mode 100644 src/tests/sqlsmith/src/runner.rs create mode 100644 src/tests/sqlsmith/src/test_runners/README.md create mode 100644 src/tests/sqlsmith/src/test_runners/diff.rs create mode 100644 src/tests/sqlsmith/src/test_runners/fuzzing.rs create mode 100644 src/tests/sqlsmith/src/test_runners/mod.rs create mode 100644 src/tests/sqlsmith/src/test_runners/utils.rs diff --git a/src/tests/simulation/src/main.rs b/src/tests/simulation/src/main.rs index 4c2c1da7fa341..2d198239a7359 100644 --- a/src/tests/simulation/src/main.rs +++ b/src/tests/simulation/src/main.rs @@ -201,7 +201,7 @@ async fn main() { .await .unwrap(); if let Some(outdir) = args.generate_sqlsmith_queries { - risingwave_sqlsmith::runner::generate( + risingwave_sqlsmith::test_runners::generate( rw.pg_client(), &args.files, count, @@ -212,7 +212,7 @@ async fn main() { return; } if args.run_differential_tests { - risingwave_sqlsmith::runner::run_differential_testing( + risingwave_sqlsmith::test_runners::run_differential_testing( rw.pg_client(), &args.files, count, @@ -223,8 +223,13 @@ async fn main() { return; } - risingwave_sqlsmith::runner::run(rw.pg_client(), &args.files, count, Some(seed)) - .await; + risingwave_sqlsmith::test_runners::run( + rw.pg_client(), + &args.files, + count, + Some(seed), + ) + .await; }) .await; return; @@ -237,7 +242,7 @@ async fn main() { let rw = RisingWave::connect("frontend".into(), "dev".into()) .await .unwrap(); - risingwave_sqlsmith::runner::run_pre_generated(rw.pg_client(), &outdir).await; + risingwave_sqlsmith::test_runners::run_pre_generated(rw.pg_client(), &outdir).await; }) .await; return; diff --git a/src/tests/sqlsmith/src/bin/main.rs b/src/tests/sqlsmith/src/bin/main.rs index 79df7f6932a30..6aaa1f60f1500 100644 --- a/src/tests/sqlsmith/src/bin/main.rs +++ b/src/tests/sqlsmith/src/bin/main.rs @@ -21,7 +21,7 @@ use std::time::Duration; use clap::Parser as ClapParser; use risingwave_sqlsmith::print_function_table; -use risingwave_sqlsmith::runner::{generate, run, run_differential_testing}; +use risingwave_sqlsmith::test_runners::{generate, run, run_differential_testing}; use tokio_postgres::NoTls; #[derive(ClapParser, Debug, Clone)] diff --git a/src/tests/sqlsmith/src/lib.rs b/src/tests/sqlsmith/src/lib.rs index 23f6454bc9da5..2d8c23a52b740 100644 --- a/src/tests/sqlsmith/src/lib.rs +++ b/src/tests/sqlsmith/src/lib.rs @@ -37,8 +37,8 @@ use risingwave_sqlparser::parser::Parser; use crate::sql_gen::SqlGenerator; pub mod reducer; -pub mod runner; mod sql_gen; +pub mod test_runners; mod utils; pub mod validation; pub use validation::is_permissible_error; diff --git a/src/tests/sqlsmith/src/runner.rs b/src/tests/sqlsmith/src/runner.rs deleted file mode 100644 index b095cf4e3e964..0000000000000 --- a/src/tests/sqlsmith/src/runner.rs +++ /dev/null @@ -1,681 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Provides E2E Test runner functionality. - -use anyhow::{anyhow, bail}; -use itertools::Itertools; -use rand::rngs::SmallRng; -use rand::{Rng, SeedableRng}; -#[cfg(madsim)] -use rand_chacha::ChaChaRng; -use risingwave_sqlparser::ast::Statement; -use similar::{ChangeTag, TextDiff}; -use tokio::time::{sleep, timeout, Duration}; -use tokio_postgres::error::Error as PgError; -use tokio_postgres::{Client, SimpleQueryMessage}; - -use crate::utils::read_file_contents; -use crate::validation::{is_permissible_error, is_recovery_in_progress_error}; -use crate::{ - differential_sql_gen, generate_update_statements, insert_sql_gen, mview_sql_gen, - parse_create_table_statements, parse_sql, session_sql_gen, sql_gen, Table, -}; - -type PgResult = std::result::Result; -type Result = anyhow::Result; - -/// e2e test runner for pre-generated queries from sqlsmith -pub async fn run_pre_generated(client: &Client, outdir: &str) { - let timeout_duration = 12; // allow for some variance. - let queries_path = format!("{}/queries.sql", outdir); - let queries = read_file_contents(queries_path).unwrap(); - for statement in parse_sql(&queries) { - let sql = statement.to_string(); - tracing::info!("[EXECUTING STATEMENT]: {}", sql); - run_query(timeout_duration, client, &sql).await.unwrap(); - } - tracing::info!("[EXECUTION SUCCESS]"); -} - -/// Query Generator -/// If we encounter an expected error, just skip. -/// If we encounter an unexpected error, -/// Sqlsmith should stop execution, but writeout ddl and queries so far. -/// If query takes too long -> cancel it, **mark it as error**. -/// NOTE(noel): It will still fail if DDL creation fails. -pub async fn generate( - client: &Client, - testdata: &str, - count: usize, - _outdir: &str, - seed: Option, -) { - let timeout_duration = 5; - - set_variable(client, "RW_IMPLICIT_FLUSH", "TRUE").await; - set_variable(client, "QUERY_MODE", "DISTRIBUTED").await; - tracing::info!("Set session variables"); - - let mut rng = generate_rng(seed); - let base_tables = create_base_tables(testdata, client).await.unwrap(); - - let rows_per_table = 50; - let max_rows_inserted = rows_per_table * base_tables.len(); - let inserts = populate_tables(client, &mut rng, base_tables.clone(), rows_per_table).await; - tracing::info!("Populated base tables"); - - let (tables, mviews) = create_mviews(&mut rng, base_tables.clone(), client) - .await - .unwrap(); - - // Generate an update for some inserts, on the corresponding table. - update_base_tables(client, &mut rng, &base_tables, &inserts).await; - - test_sqlsmith( - client, - &mut rng, - tables.clone(), - base_tables.clone(), - max_rows_inserted, - ) - .await; - tracing::info!("Passed sqlsmith tests"); - - tracing::info!("Ran updates"); - - let mut generated_queries = 0; - for _ in 0..count { - test_session_variable(client, &mut rng).await; - let sql = sql_gen(&mut rng, tables.clone()); - tracing::info!("[EXECUTING TEST_BATCH]: {}", sql); - let result = run_query(timeout_duration, client, sql.as_str()).await; - match result { - Err(_e) => { - generated_queries += 1; - tracing::info!("Generated {} batch queries", generated_queries); - tracing::error!("Unrecoverable error encountered."); - return; - } - Ok(0) => { - generated_queries += 1; - } - _ => {} - } - } - tracing::info!("Generated {} batch queries", generated_queries); - - let mut generated_queries = 0; - for _ in 0..count { - test_session_variable(client, &mut rng).await; - let (sql, table) = mview_sql_gen(&mut rng, tables.clone(), "stream_query"); - tracing::info!("[EXECUTING TEST_STREAM]: {}", sql); - let result = run_query(timeout_duration, client, sql.as_str()).await; - match result { - Err(_e) => { - generated_queries += 1; - tracing::info!("Generated {} stream queries", generated_queries); - tracing::error!("Unrecoverable error encountered."); - return; - } - Ok(0) => { - generated_queries += 1; - } - _ => {} - } - tracing::info!("[EXECUTING DROP MVIEW]: {}", &format_drop_mview(&table)); - drop_mview_table(&table, client).await; - } - tracing::info!("Generated {} stream queries", generated_queries); - - drop_tables(&mviews, testdata, client).await; -} - -/// e2e test runner for sqlsmith -pub async fn run(client: &Client, testdata: &str, count: usize, seed: Option) { - let mut rng = generate_rng(seed); - - set_variable(client, "RW_IMPLICIT_FLUSH", "TRUE").await; - set_variable(client, "QUERY_MODE", "DISTRIBUTED").await; - tracing::info!("Set session variables"); - - let base_tables = create_base_tables(testdata, client).await.unwrap(); - - let rows_per_table = 50; - let inserts = populate_tables(client, &mut rng, base_tables.clone(), rows_per_table).await; - tracing::info!("Populated base tables"); - - let (tables, mviews) = create_mviews(&mut rng, base_tables.clone(), client) - .await - .unwrap(); - tracing::info!("Created tables"); - - // Generate an update for some inserts, on the corresponding table. - update_base_tables(client, &mut rng, &base_tables, &inserts).await; - tracing::info!("Ran updates"); - - let max_rows_inserted = rows_per_table * base_tables.len(); - test_sqlsmith( - client, - &mut rng, - tables.clone(), - base_tables.clone(), - max_rows_inserted, - ) - .await; - tracing::info!("Passed sqlsmith tests"); - - test_batch_queries(client, &mut rng, tables.clone(), count) - .await - .unwrap(); - tracing::info!("Passed batch queries"); - test_stream_queries(client, &mut rng, tables.clone(), count) - .await - .unwrap(); - tracing::info!("Passed stream queries"); - - drop_tables(&mviews, testdata, client).await; - tracing::info!("[EXECUTION SUCCESS]"); -} - -/// Differential testing for batch and stream -pub async fn run_differential_testing( - client: &Client, - testdata: &str, - count: usize, - seed: Option, -) -> Result<()> { - let mut rng = generate_rng(seed); - - set_variable(client, "RW_IMPLICIT_FLUSH", "TRUE").await; - set_variable(client, "QUERY_MODE", "DISTRIBUTED").await; - tracing::info!("Set session variables"); - - let base_tables = create_base_tables(testdata, client).await.unwrap(); - - let rows_per_table = 50; - let inserts = populate_tables(client, &mut rng, base_tables.clone(), rows_per_table).await; - tracing::info!("Populated base tables"); - - let (tables, mviews) = create_mviews(&mut rng, base_tables.clone(), client) - .await - .unwrap(); - tracing::info!("Created tables"); - - // Generate an update for some inserts, on the corresponding table. - update_base_tables(client, &mut rng, &base_tables, &inserts).await; - tracing::info!("Ran updates"); - - for i in 0..count { - diff_stream_and_batch(&mut rng, tables.clone(), client, i).await? - } - - drop_tables(&mviews, testdata, client).await; - tracing::info!("[EXECUTION SUCCESS]"); - Ok(()) -} - -fn generate_rng(seed: Option) -> impl Rng { - #[cfg(madsim)] - if let Some(seed) = seed { - ChaChaRng::seed_from_u64(seed) - } else { - ChaChaRng::from_rng(SmallRng::from_entropy()).unwrap() - } - #[cfg(not(madsim))] - if let Some(seed) = seed { - SmallRng::seed_from_u64(seed) - } else { - SmallRng::from_entropy() - } -} - -async fn update_base_tables( - client: &Client, - rng: &mut R, - base_tables: &[Table], - inserts: &[Statement], -) { - let update_statements = generate_update_statements(rng, base_tables, inserts).unwrap(); - for update_statement in update_statements { - let sql = update_statement.to_string(); - tracing::info!("[EXECUTING UPDATES]: {}", &sql); - client.simple_query(&sql).await.unwrap(); - } -} - -async fn populate_tables( - client: &Client, - rng: &mut R, - base_tables: Vec, - row_count: usize, -) -> Vec { - let inserts = insert_sql_gen(rng, base_tables, row_count); - for insert in &inserts { - tracing::info!("[EXECUTING INSERT]: {}", insert); - client.simple_query(insert).await.unwrap(); - } - inserts - .iter() - .map(|s| parse_sql(s).into_iter().next().unwrap()) - .collect_vec() -} - -/// Sanity checks for sqlsmith -async fn test_sqlsmith( - client: &Client, - rng: &mut R, - tables: Vec
, - base_tables: Vec
, - row_count: usize, -) { - // Test inserted rows should be at least 50% population count, - // otherwise we don't have sufficient data in our system. - // ENABLE: https://github.com/risingwavelabs/risingwave/issues/3844 - test_population_count(client, base_tables, row_count).await; - tracing::info!("passed population count test"); - - let threshold = 0.50; // permit at most 50% of queries to be skipped. - let sample_size = 20; - - let skipped_percentage = test_batch_queries(client, rng, tables.clone(), sample_size) - .await - .unwrap(); - tracing::info!( - "percentage of skipped batch queries = {}, threshold: {}", - skipped_percentage, - threshold - ); - if skipped_percentage > threshold { - panic!("skipped batch queries exceeded threshold."); - } - - let skipped_percentage = test_stream_queries(client, rng, tables.clone(), sample_size) - .await - .unwrap(); - tracing::info!( - "percentage of skipped stream queries = {}, threshold: {}", - skipped_percentage, - threshold - ); - if skipped_percentage > threshold { - panic!("skipped stream queries exceeded threshold."); - } -} - -async fn set_variable(client: &Client, variable: &str, value: &str) -> String { - let s = format!("SET {variable} TO {value}"); - tracing::info!("[EXECUTING SET_VAR]: {}", s); - client.simple_query(&s).await.unwrap(); - s -} - -async fn test_session_variable(client: &Client, rng: &mut R) -> String { - let session_sql = session_sql_gen(rng); - tracing::info!("[EXECUTING TEST SESSION_VAR]: {}", session_sql); - client.simple_query(session_sql.as_str()).await.unwrap(); - session_sql -} - -/// Expects at least 50% of inserted rows included. -async fn test_population_count(client: &Client, base_tables: Vec
, expected_count: usize) { - let mut actual_count = 0; - for t in base_tables { - let q = format!("select * from {};", t.name); - let rows = client.simple_query(&q).await.unwrap(); - actual_count += rows.len(); - } - if actual_count < expected_count / 2 { - panic!( - "expected at least 50% rows included.\ - Total {} rows, only had {} rows", - expected_count, actual_count, - ) - } -} - -/// Test batch queries, returns skipped query statistics -/// Runs in distributed mode, since queries can be complex and cause overflow in local execution -/// mode. -async fn test_batch_queries( - client: &Client, - rng: &mut R, - tables: Vec
, - sample_size: usize, -) -> Result { - let mut skipped = 0; - for _ in 0..sample_size { - test_session_variable(client, rng).await; - let sql = sql_gen(rng, tables.clone()); - tracing::info!("[TEST BATCH]: {}", sql); - skipped += run_query(30, client, &sql).await?; - } - Ok(skipped as f64 / sample_size as f64) -} - -/// Test stream queries, returns skipped query statistics -async fn test_stream_queries( - client: &Client, - rng: &mut R, - tables: Vec
, - sample_size: usize, -) -> Result { - let mut skipped = 0; - - for _ in 0..sample_size { - test_session_variable(client, rng).await; - let (sql, table) = mview_sql_gen(rng, tables.clone(), "stream_query"); - tracing::info!("[TEST STREAM]: {}", sql); - skipped += run_query(12, client, &sql).await?; - tracing::info!("[TEST DROP MVIEW]: {}", &format_drop_mview(&table)); - drop_mview_table(&table, client).await; - } - Ok(skipped as f64 / sample_size as f64) -} - -fn get_seed_table_sql(testdata: &str) -> String { - let seed_files = ["tpch.sql", "nexmark.sql", "alltypes.sql"]; - seed_files - .iter() - .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap()) - .collect::() -} - -/// Create the tables defined in testdata, along with some mviews. -/// TODO: Generate indexes and sinks. -async fn create_base_tables(testdata: &str, client: &Client) -> Result> { - tracing::info!("Preparing tables..."); - - let sql = get_seed_table_sql(testdata); - let (base_tables, statements) = parse_create_table_statements(sql); - let mut mvs_and_base_tables = vec![]; - mvs_and_base_tables.extend_from_slice(&base_tables); - - for stmt in &statements { - let create_sql = stmt.to_string(); - tracing::info!("[EXECUTING CREATE TABLE]: {}", &create_sql); - client.simple_query(&create_sql).await.unwrap(); - } - - Ok(base_tables) -} - -/// Create the tables defined in testdata, along with some mviews. -/// TODO: Generate indexes and sinks. -async fn create_mviews( - rng: &mut impl Rng, - mvs_and_base_tables: Vec
, - client: &Client, -) -> Result<(Vec
, Vec
)> { - let mut mvs_and_base_tables = mvs_and_base_tables; - let mut mviews = vec![]; - // Generate some mviews - for i in 0..20 { - let (create_sql, table) = - mview_sql_gen(rng, mvs_and_base_tables.clone(), &format!("m{}", i)); - tracing::info!("[EXECUTING CREATE MVIEW]: {}", &create_sql); - let skip_count = run_query(6, client, &create_sql).await?; - if skip_count == 0 { - mvs_and_base_tables.push(table.clone()); - mviews.push(table); - } - } - Ok((mvs_and_base_tables, mviews)) -} - -fn format_drop_mview(mview: &Table) -> String { - format!("DROP MATERIALIZED VIEW IF EXISTS {}", mview.name) -} - -/// Drops mview tables. -async fn drop_mview_table(mview: &Table, client: &Client) { - client - .simple_query(&format_drop_mview(mview)) - .await - .unwrap(); -} - -/// Drops mview tables and seed tables -async fn drop_tables(mviews: &[Table], testdata: &str, client: &Client) { - tracing::info!("Cleaning tables..."); - - for mview in mviews.iter().rev() { - drop_mview_table(mview, client).await; - } - - let seed_files = ["drop_tpch.sql", "drop_nexmark.sql", "drop_alltypes.sql"]; - let sql = seed_files - .iter() - .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap()) - .collect::(); - - for stmt in sql.lines() { - client.simple_query(stmt).await.unwrap(); - } -} - -/// Validate client responses, returning a count of skipped queries, number of result rows. -fn validate_response( - response: PgResult>, -) -> Result<(i64, Vec)> { - match response { - Ok(rows) => Ok((0, rows)), - Err(e) => { - // Permit runtime errors conservatively. - if let Some(e) = e.as_db_error() - && is_permissible_error(&e.to_string()) - { - tracing::info!("[SKIPPED ERROR]: {:#?}", e); - return Ok((1, vec![])); - } - // consolidate error reason for deterministic test - tracing::info!("[UNEXPECTED ERROR]: {:#?}", e); - Err(anyhow!("Encountered unexpected error: {e}")) - } - } -} - -async fn run_query(timeout_duration: u64, client: &Client, query: &str) -> Result { - let (skipped_count, _) = run_query_inner(timeout_duration, client, query).await?; - Ok(skipped_count) -} -/// Run query, handle permissible errors -/// For recovery error, just do bounded retry. -/// For other errors, validate them accordingly, skipping if they are permitted. -/// Otherwise just return success. -/// If takes too long return the query which timed out + execution time + timeout error -/// Returns: Number of skipped queries, number of rows returned. -async fn run_query_inner( - timeout_duration: u64, - client: &Client, - query: &str, -) -> Result<(i64, Vec)> { - let query_task = client.simple_query(query); - let result = timeout(Duration::from_secs(timeout_duration), query_task).await; - let response = match result { - Ok(r) => r, - Err(_) => bail!( - "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}", - query - ), - }; - if let Err(e) = &response - && let Some(e) = e.as_db_error() - { - if is_recovery_in_progress_error(&e.to_string()) { - let tries = 5; - let interval = 1; - for _ in 0..tries { - // retry 5 times - sleep(Duration::from_secs(interval)).await; - let query_task = client.simple_query(query); - let response = timeout(Duration::from_secs(timeout_duration), query_task).await; - match response { - Ok(Ok(r)) => { - return Ok((0, r)); - } - Err(_) => bail!( - "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}", - query - ), - _ => {} - } - } - bail!("[UNEXPECTED ERROR] Failed to recover after {tries} tries with interval {interval}s") - } else { - return validate_response(response); - } - } - let rows = response?; - Ok((0, rows)) -} - -/// Create the tables defined in testdata, along with some mviews. -/// Just test number of rows for now. -/// TODO(kwannoel): Test row contents as well. That requires us to run a batch query -/// with `select * ORDER BY `. -async fn diff_stream_and_batch( - rng: &mut impl Rng, - mvs_and_base_tables: Vec
, - client: &Client, - i: usize, -) -> Result<()> { - // Generate some mviews - let mview_name = format!("stream_{}", i); - let (batch, stream, table) = differential_sql_gen(rng, mvs_and_base_tables, &mview_name)?; - diff_stream_and_batch_with_sqls(client, i, &batch, &stream, &mview_name, &table).await -} - -async fn diff_stream_and_batch_with_sqls( - client: &Client, - i: usize, - batch: &str, - stream: &str, - mview_name: &str, - table: &Table, -) -> Result<()> { - tracing::info!("[RUN CREATE MVIEW id={}]: {}", i, stream); - let skip_count = run_query(12, client, stream).await?; - if skip_count > 0 { - tracing::info!("[RUN DROP MVIEW id={}]: {}", i, &format_drop_mview(table)); - drop_mview_table(table, client).await; - return Ok(()); - } - - let select = format!("SELECT * FROM {}", &mview_name); - tracing::info!("[RUN SELECT * FROM MVIEW id={}]: {}", i, select); - let (skip_count, stream_result) = run_query_inner(12, client, &select).await?; - if skip_count > 0 { - bail!("SQL should not fail: {:?}", select) - } - - tracing::info!("[RUN - BATCH QUERY id={}]: {}", i, &batch); - let (skip_count, batch_result) = run_query_inner(12, client, batch).await?; - if skip_count > 0 { - tracing::info!( - "[DIFF - DROP MVIEW id={}]: {}", - i, - &format_drop_mview(table) - ); - drop_mview_table(table, client).await; - return Ok(()); - } - let n_stream_rows = stream_result.len(); - let n_batch_rows = batch_result.len(); - let formatted_stream_rows = format_rows(&batch_result); - let formatted_batch_rows = format_rows(&stream_result); - tracing::debug!( - "[COMPARE - STREAM_FORMATTED_ROW id={}]: {formatted_stream_rows}", - i, - ); - tracing::debug!( - "[COMPARE - BATCH_FORMATTED_ROW id={}]: {formatted_batch_rows}", - i, - ); - - let diff = TextDiff::from_lines(&formatted_batch_rows, &formatted_stream_rows); - - let diff: String = diff - .iter_all_changes() - .filter_map(|change| match change.tag() { - ChangeTag::Delete => Some(format!("-{}", change)), - ChangeTag::Insert => Some(format!("+{}", change)), - ChangeTag::Equal => None, - }) - .collect(); - - if diff.is_empty() { - tracing::info!("[RUN DROP MVIEW id={}]: {}", i, format_drop_mview(table)); - tracing::info!("[PASSED DIFF id={}, rows_compared={n_stream_rows}]", i); - - drop_mview_table(table, client).await; - Ok(()) - } else { - bail!( - " -Different results for batch and stream: - -BATCH SQL: -{batch} - -STREAM SQL: -{stream} - -SELECT FROM STREAM SQL: -{select} - -BATCH_ROW_LEN: -{n_batch_rows} - -STREAM_ROW_LEN: -{n_stream_rows} - -BATCH_ROWS: -{formatted_batch_rows} - -STREAM_ROWS: -{formatted_stream_rows} - -ROW DIFF (+/-): -{diff} -", - ) - } -} - -/// Format + sort rows so they can be diffed. -fn format_rows(rows: &[SimpleQueryMessage]) -> String { - rows.iter() - .filter_map(|r| match r { - SimpleQueryMessage::Row(row) => { - let n_cols = row.columns().len(); - let formatted_row: String = (0..n_cols) - .map(|i| { - format!( - "{:#?}", - match row.get(i) { - Some(s) => s, - _ => "NULL", - } - ) - }) - .join(", "); - Some(formatted_row) - } - SimpleQueryMessage::CommandComplete(_n_rows) => None, - _ => unreachable!(), - }) - .sorted() - .join("\n") -} diff --git a/src/tests/sqlsmith/src/test_runners/README.md b/src/tests/sqlsmith/src/test_runners/README.md new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/tests/sqlsmith/src/test_runners/diff.rs b/src/tests/sqlsmith/src/test_runners/diff.rs new file mode 100644 index 0000000000000..f05c9f3521d79 --- /dev/null +++ b/src/tests/sqlsmith/src/test_runners/diff.rs @@ -0,0 +1,206 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Provides E2E Test runner functionality. + +use anyhow::bail; +use itertools::Itertools; +use rand::Rng; +#[cfg(madsim)] +use rand_chacha::ChaChaRng; +use similar::{ChangeTag, TextDiff}; +use tokio_postgres::{Client, SimpleQueryMessage}; + +use crate::test_runners::utils::{ + create_base_tables, create_mviews, drop_mview_table, drop_tables, format_drop_mview, + generate_rng, populate_tables, run_query, run_query_inner, set_variable, update_base_tables, + Result, +}; +use crate::{differential_sql_gen, Table}; + +/// Differential testing for batch and stream +pub async fn run_differential_testing( + client: &Client, + testdata: &str, + count: usize, + seed: Option, +) -> Result<()> { + let mut rng = generate_rng(seed); + + set_variable(client, "RW_IMPLICIT_FLUSH", "TRUE").await; + set_variable(client, "QUERY_MODE", "DISTRIBUTED").await; + tracing::info!("Set session variables"); + + let base_tables = create_base_tables(testdata, client).await.unwrap(); + + let rows_per_table = 50; + let inserts = populate_tables(client, &mut rng, base_tables.clone(), rows_per_table).await; + tracing::info!("Populated base tables"); + + let (tables, mviews) = create_mviews(&mut rng, base_tables.clone(), client) + .await + .unwrap(); + tracing::info!("Created tables"); + + // Generate an update for some inserts, on the corresponding table. + update_base_tables(client, &mut rng, &base_tables, &inserts).await; + tracing::info!("Ran updates"); + + for i in 0..count { + diff_stream_and_batch(&mut rng, tables.clone(), client, i).await? + } + + drop_tables(&mviews, testdata, client).await; + tracing::info!("[EXECUTION SUCCESS]"); + Ok(()) +} + +/// Create the tables defined in testdata, along with some mviews. +/// Just test number of rows for now. +/// TODO(kwannoel): Test row contents as well. That requires us to run a batch query +/// with `select * ORDER BY `. +async fn diff_stream_and_batch( + rng: &mut impl Rng, + mvs_and_base_tables: Vec
, + client: &Client, + i: usize, +) -> Result<()> { + // Generate some mviews + let mview_name = format!("stream_{}", i); + let (batch, stream, table) = differential_sql_gen(rng, mvs_and_base_tables, &mview_name)?; + diff_stream_and_batch_with_sqls(client, i, &batch, &stream, &mview_name, &table).await +} + +async fn diff_stream_and_batch_with_sqls( + client: &Client, + i: usize, + batch: &str, + stream: &str, + mview_name: &str, + table: &Table, +) -> Result<()> { + tracing::info!("[RUN CREATE MVIEW id={}]: {}", i, stream); + let skip_count = run_query(12, client, stream).await?; + if skip_count > 0 { + tracing::info!("[RUN DROP MVIEW id={}]: {}", i, &format_drop_mview(table)); + drop_mview_table(table, client).await; + return Ok(()); + } + + let select = format!("SELECT * FROM {}", &mview_name); + tracing::info!("[RUN SELECT * FROM MVIEW id={}]: {}", i, select); + let (skip_count, stream_result) = run_query_inner(12, client, &select).await?; + if skip_count > 0 { + bail!("SQL should not fail: {:?}", select) + } + + tracing::info!("[RUN - BATCH QUERY id={}]: {}", i, &batch); + let (skip_count, batch_result) = run_query_inner(12, client, batch).await?; + if skip_count > 0 { + tracing::info!( + "[DIFF - DROP MVIEW id={}]: {}", + i, + &format_drop_mview(table) + ); + drop_mview_table(table, client).await; + return Ok(()); + } + let n_stream_rows = stream_result.len(); + let n_batch_rows = batch_result.len(); + let formatted_stream_rows = format_rows(&batch_result); + let formatted_batch_rows = format_rows(&stream_result); + tracing::debug!( + "[COMPARE - STREAM_FORMATTED_ROW id={}]: {formatted_stream_rows}", + i, + ); + tracing::debug!( + "[COMPARE - BATCH_FORMATTED_ROW id={}]: {formatted_batch_rows}", + i, + ); + + let diff = TextDiff::from_lines(&formatted_batch_rows, &formatted_stream_rows); + + let diff: String = diff + .iter_all_changes() + .filter_map(|change| match change.tag() { + ChangeTag::Delete => Some(format!("-{}", change)), + ChangeTag::Insert => Some(format!("+{}", change)), + ChangeTag::Equal => None, + }) + .collect(); + + if diff.is_empty() { + tracing::info!("[RUN DROP MVIEW id={}]: {}", i, format_drop_mview(table)); + tracing::info!("[PASSED DIFF id={}, rows_compared={n_stream_rows}]", i); + + drop_mview_table(table, client).await; + Ok(()) + } else { + bail!( + " +Different results for batch and stream: + +BATCH SQL: +{batch} + +STREAM SQL: +{stream} + +SELECT FROM STREAM SQL: +{select} + +BATCH_ROW_LEN: +{n_batch_rows} + +STREAM_ROW_LEN: +{n_stream_rows} + +BATCH_ROWS: +{formatted_batch_rows} + +STREAM_ROWS: +{formatted_stream_rows} + +ROW DIFF (+/-): +{diff} +", + ) + } +} + +/// Format + sort rows so they can be diffed. +fn format_rows(rows: &[SimpleQueryMessage]) -> String { + rows.iter() + .filter_map(|r| match r { + SimpleQueryMessage::Row(row) => { + let n_cols = row.columns().len(); + let formatted_row: String = (0..n_cols) + .map(|i| { + format!( + "{:#?}", + match row.get(i) { + Some(s) => s, + _ => "NULL", + } + ) + }) + .join(", "); + Some(formatted_row) + } + SimpleQueryMessage::CommandComplete(_n_rows) => None, + _ => unreachable!(), + }) + .sorted() + .join("\n") +} diff --git a/src/tests/sqlsmith/src/test_runners/fuzzing.rs b/src/tests/sqlsmith/src/test_runners/fuzzing.rs new file mode 100644 index 0000000000000..b29a78b4b07d1 --- /dev/null +++ b/src/tests/sqlsmith/src/test_runners/fuzzing.rs @@ -0,0 +1,180 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Provides E2E Test runner functionality. + +#[cfg(madsim)] +use rand_chacha::ChaChaRng; +use tokio_postgres::Client; + +use crate::test_runners::utils::{ + create_base_tables, create_mviews, drop_mview_table, drop_tables, format_drop_mview, + generate_rng, populate_tables, run_query, set_variable, test_batch_queries, + test_session_variable, test_sqlsmith, test_stream_queries, update_base_tables, +}; +use crate::utils::read_file_contents; +use crate::{mview_sql_gen, parse_sql, sql_gen}; + +/// e2e test runner for pre-generated queries from sqlsmith +pub async fn run_pre_generated(client: &Client, outdir: &str) { + let timeout_duration = 12; // allow for some variance. + let queries_path = format!("{}/queries.sql", outdir); + let queries = read_file_contents(queries_path).unwrap(); + for statement in parse_sql(&queries) { + let sql = statement.to_string(); + tracing::info!("[EXECUTING STATEMENT]: {}", sql); + run_query(timeout_duration, client, &sql).await.unwrap(); + } + tracing::info!("[EXECUTION SUCCESS]"); +} + +/// Query Generator +/// If we encounter an expected error, just skip. +/// If we encounter an unexpected error, +/// Sqlsmith should stop execution, but writeout ddl and queries so far. +/// If query takes too long -> cancel it, **mark it as error**. +/// NOTE(noel): It will still fail if DDL creation fails. +pub async fn generate( + client: &Client, + testdata: &str, + count: usize, + _outdir: &str, + seed: Option, +) { + let timeout_duration = 5; + + set_variable(client, "RW_IMPLICIT_FLUSH", "TRUE").await; + set_variable(client, "QUERY_MODE", "DISTRIBUTED").await; + tracing::info!("Set session variables"); + + let mut rng = generate_rng(seed); + let base_tables = create_base_tables(testdata, client).await.unwrap(); + + let rows_per_table = 50; + let max_rows_inserted = rows_per_table * base_tables.len(); + let inserts = populate_tables(client, &mut rng, base_tables.clone(), rows_per_table).await; + tracing::info!("Populated base tables"); + + let (tables, mviews) = create_mviews(&mut rng, base_tables.clone(), client) + .await + .unwrap(); + + // Generate an update for some inserts, on the corresponding table. + update_base_tables(client, &mut rng, &base_tables, &inserts).await; + + test_sqlsmith( + client, + &mut rng, + tables.clone(), + base_tables.clone(), + max_rows_inserted, + ) + .await; + tracing::info!("Passed sqlsmith tests"); + + tracing::info!("Ran updates"); + + let mut generated_queries = 0; + for _ in 0..count { + test_session_variable(client, &mut rng).await; + let sql = sql_gen(&mut rng, tables.clone()); + tracing::info!("[EXECUTING TEST_BATCH]: {}", sql); + let result = run_query(timeout_duration, client, sql.as_str()).await; + match result { + Err(_e) => { + generated_queries += 1; + tracing::info!("Generated {} batch queries", generated_queries); + tracing::error!("Unrecoverable error encountered."); + return; + } + Ok(0) => { + generated_queries += 1; + } + _ => {} + } + } + tracing::info!("Generated {} batch queries", generated_queries); + + let mut generated_queries = 0; + for _ in 0..count { + test_session_variable(client, &mut rng).await; + let (sql, table) = mview_sql_gen(&mut rng, tables.clone(), "stream_query"); + tracing::info!("[EXECUTING TEST_STREAM]: {}", sql); + let result = run_query(timeout_duration, client, sql.as_str()).await; + match result { + Err(_e) => { + generated_queries += 1; + tracing::info!("Generated {} stream queries", generated_queries); + tracing::error!("Unrecoverable error encountered."); + return; + } + Ok(0) => { + generated_queries += 1; + } + _ => {} + } + tracing::info!("[EXECUTING DROP MVIEW]: {}", &format_drop_mview(&table)); + drop_mview_table(&table, client).await; + } + tracing::info!("Generated {} stream queries", generated_queries); + + drop_tables(&mviews, testdata, client).await; +} + +/// e2e test runner for sqlsmith +pub async fn run(client: &Client, testdata: &str, count: usize, seed: Option) { + let mut rng = generate_rng(seed); + + set_variable(client, "RW_IMPLICIT_FLUSH", "TRUE").await; + set_variable(client, "QUERY_MODE", "DISTRIBUTED").await; + tracing::info!("Set session variables"); + + let base_tables = create_base_tables(testdata, client).await.unwrap(); + + let rows_per_table = 50; + let inserts = populate_tables(client, &mut rng, base_tables.clone(), rows_per_table).await; + tracing::info!("Populated base tables"); + + let (tables, mviews) = create_mviews(&mut rng, base_tables.clone(), client) + .await + .unwrap(); + tracing::info!("Created tables"); + + // Generate an update for some inserts, on the corresponding table. + update_base_tables(client, &mut rng, &base_tables, &inserts).await; + tracing::info!("Ran updates"); + + let max_rows_inserted = rows_per_table * base_tables.len(); + test_sqlsmith( + client, + &mut rng, + tables.clone(), + base_tables.clone(), + max_rows_inserted, + ) + .await; + tracing::info!("Passed sqlsmith tests"); + + test_batch_queries(client, &mut rng, tables.clone(), count) + .await + .unwrap(); + tracing::info!("Passed batch queries"); + test_stream_queries(client, &mut rng, tables.clone(), count) + .await + .unwrap(); + tracing::info!("Passed stream queries"); + + drop_tables(&mviews, testdata, client).await; + tracing::info!("[EXECUTION SUCCESS]"); +} diff --git a/src/tests/sqlsmith/src/test_runners/mod.rs b/src/tests/sqlsmith/src/test_runners/mod.rs new file mode 100644 index 0000000000000..9347f070b58e7 --- /dev/null +++ b/src/tests/sqlsmith/src/test_runners/mod.rs @@ -0,0 +1,26 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Contains test runners: +//! - fuzzing: For crashing testing the database with generated batch, stream queries. +//! - differential testing: For testing the database with generated batch, +//! stream queries and comparing their results. + +mod diff; +mod fuzzing; + +mod utils; + +pub use diff::run_differential_testing; +pub use fuzzing::{generate, run, run_pre_generated}; diff --git a/src/tests/sqlsmith/src/test_runners/utils.rs b/src/tests/sqlsmith/src/test_runners/utils.rs new file mode 100644 index 0000000000000..98f29df490446 --- /dev/null +++ b/src/tests/sqlsmith/src/test_runners/utils.rs @@ -0,0 +1,353 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::{anyhow, bail}; +use itertools::Itertools; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; +#[cfg(madsim)] +use rand_chacha::ChaChaRng; +use risingwave_sqlparser::ast::Statement; +use tokio::time::{sleep, timeout, Duration}; +use tokio_postgres::error::Error as PgError; +use tokio_postgres::{Client, SimpleQueryMessage}; + +use crate::utils::read_file_contents; +use crate::validation::{is_permissible_error, is_recovery_in_progress_error}; +use crate::{ + generate_update_statements, insert_sql_gen, mview_sql_gen, parse_create_table_statements, + parse_sql, session_sql_gen, sql_gen, Table, +}; + +pub(super) type PgResult = std::result::Result; +pub(super) type Result = anyhow::Result; + +pub(super) async fn update_base_tables( + client: &Client, + rng: &mut R, + base_tables: &[Table], + inserts: &[Statement], +) { + let update_statements = generate_update_statements(rng, base_tables, inserts).unwrap(); + for update_statement in update_statements { + let sql = update_statement.to_string(); + tracing::info!("[EXECUTING UPDATES]: {}", &sql); + client.simple_query(&sql).await.unwrap(); + } +} + +pub(super) async fn populate_tables( + client: &Client, + rng: &mut R, + base_tables: Vec
, + row_count: usize, +) -> Vec { + let inserts = insert_sql_gen(rng, base_tables, row_count); + for insert in &inserts { + tracing::info!("[EXECUTING INSERT]: {}", insert); + client.simple_query(insert).await.unwrap(); + } + inserts + .iter() + .map(|s| parse_sql(s).into_iter().next().unwrap()) + .collect_vec() +} + +pub(super) async fn set_variable(client: &Client, variable: &str, value: &str) -> String { + let s = format!("SET {variable} TO {value}"); + tracing::info!("[EXECUTING SET_VAR]: {}", s); + client.simple_query(&s).await.unwrap(); + s +} + +/// Sanity checks for sqlsmith +pub(super) async fn test_sqlsmith( + client: &Client, + rng: &mut R, + tables: Vec
, + base_tables: Vec
, + row_count: usize, +) { + // Test inserted rows should be at least 50% population count, + // otherwise we don't have sufficient data in our system. + // ENABLE: https://github.com/risingwavelabs/risingwave/issues/3844 + test_population_count(client, base_tables, row_count).await; + tracing::info!("passed population count test"); + + let threshold = 0.50; // permit at most 50% of queries to be skipped. + let sample_size = 20; + + let skipped_percentage = test_batch_queries(client, rng, tables.clone(), sample_size) + .await + .unwrap(); + tracing::info!( + "percentage of skipped batch queries = {}, threshold: {}", + skipped_percentage, + threshold + ); + if skipped_percentage > threshold { + panic!("skipped batch queries exceeded threshold."); + } + + let skipped_percentage = test_stream_queries(client, rng, tables.clone(), sample_size) + .await + .unwrap(); + tracing::info!( + "percentage of skipped stream queries = {}, threshold: {}", + skipped_percentage, + threshold + ); + if skipped_percentage > threshold { + panic!("skipped stream queries exceeded threshold."); + } +} + +pub(super) async fn test_session_variable(client: &Client, rng: &mut R) -> String { + let session_sql = session_sql_gen(rng); + tracing::info!("[EXECUTING TEST SESSION_VAR]: {}", session_sql); + client.simple_query(session_sql.as_str()).await.unwrap(); + session_sql +} + +/// Expects at least 50% of inserted rows included. +pub(super) async fn test_population_count( + client: &Client, + base_tables: Vec
, + expected_count: usize, +) { + let mut actual_count = 0; + for t in base_tables { + let q = format!("select * from {};", t.name); + let rows = client.simple_query(&q).await.unwrap(); + actual_count += rows.len(); + } + if actual_count < expected_count / 2 { + panic!( + "expected at least 50% rows included.\ + Total {} rows, only had {} rows", + expected_count, actual_count, + ) + } +} + +/// Test batch queries, returns skipped query statistics +/// Runs in distributed mode, since queries can be complex and cause overflow in local execution +/// mode. +pub(super) async fn test_batch_queries( + client: &Client, + rng: &mut R, + tables: Vec
, + sample_size: usize, +) -> Result { + let mut skipped = 0; + for _ in 0..sample_size { + test_session_variable(client, rng).await; + let sql = sql_gen(rng, tables.clone()); + tracing::info!("[TEST BATCH]: {}", sql); + skipped += run_query(30, client, &sql).await?; + } + Ok(skipped as f64 / sample_size as f64) +} + +/// Test stream queries, returns skipped query statistics +pub(super) async fn test_stream_queries( + client: &Client, + rng: &mut R, + tables: Vec
, + sample_size: usize, +) -> Result { + let mut skipped = 0; + + for _ in 0..sample_size { + test_session_variable(client, rng).await; + let (sql, table) = mview_sql_gen(rng, tables.clone(), "stream_query"); + tracing::info!("[TEST STREAM]: {}", sql); + skipped += run_query(12, client, &sql).await?; + tracing::info!("[TEST DROP MVIEW]: {}", &format_drop_mview(&table)); + drop_mview_table(&table, client).await; + } + Ok(skipped as f64 / sample_size as f64) +} + +pub(super) fn get_seed_table_sql(testdata: &str) -> String { + let seed_files = ["tpch.sql", "nexmark.sql", "alltypes.sql"]; + seed_files + .iter() + .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap()) + .collect::() +} + +/// Create the tables defined in testdata, along with some mviews. +/// TODO: Generate indexes and sinks. +pub(super) async fn create_base_tables(testdata: &str, client: &Client) -> Result> { + tracing::info!("Preparing tables..."); + + let sql = get_seed_table_sql(testdata); + let (base_tables, statements) = parse_create_table_statements(sql); + let mut mvs_and_base_tables = vec![]; + mvs_and_base_tables.extend_from_slice(&base_tables); + + for stmt in &statements { + let create_sql = stmt.to_string(); + tracing::info!("[EXECUTING CREATE TABLE]: {}", &create_sql); + client.simple_query(&create_sql).await.unwrap(); + } + + Ok(base_tables) +} + +/// Create the tables defined in testdata, along with some mviews. +/// TODO: Generate indexes and sinks. +pub(super) async fn create_mviews( + rng: &mut impl Rng, + mvs_and_base_tables: Vec
, + client: &Client, +) -> Result<(Vec
, Vec
)> { + let mut mvs_and_base_tables = mvs_and_base_tables; + let mut mviews = vec![]; + // Generate some mviews + for i in 0..20 { + let (create_sql, table) = + mview_sql_gen(rng, mvs_and_base_tables.clone(), &format!("m{}", i)); + tracing::info!("[EXECUTING CREATE MVIEW]: {}", &create_sql); + let skip_count = run_query(6, client, &create_sql).await?; + if skip_count == 0 { + mvs_and_base_tables.push(table.clone()); + mviews.push(table); + } + } + Ok((mvs_and_base_tables, mviews)) +} + +pub(super) fn format_drop_mview(mview: &Table) -> String { + format!("DROP MATERIALIZED VIEW IF EXISTS {}", mview.name) +} + +/// Drops mview tables. +pub(super) async fn drop_mview_table(mview: &Table, client: &Client) { + client + .simple_query(&format_drop_mview(mview)) + .await + .unwrap(); +} + +/// Drops mview tables and seed tables +pub(super) async fn drop_tables(mviews: &[Table], testdata: &str, client: &Client) { + tracing::info!("Cleaning tables..."); + + for mview in mviews.iter().rev() { + drop_mview_table(mview, client).await; + } + + let seed_files = ["drop_tpch.sql", "drop_nexmark.sql", "drop_alltypes.sql"]; + let sql = seed_files + .iter() + .map(|filename| read_file_contents(format!("{}/{}", testdata, filename)).unwrap()) + .collect::(); + + for stmt in sql.lines() { + client.simple_query(stmt).await.unwrap(); + } +} + +/// Validate client responses, returning a count of skipped queries, number of result rows. +pub(super) fn validate_response( + response: PgResult>, +) -> Result<(i64, Vec)> { + match response { + Ok(rows) => Ok((0, rows)), + Err(e) => { + // Permit runtime errors conservatively. + if let Some(e) = e.as_db_error() + && is_permissible_error(&e.to_string()) + { + tracing::info!("[SKIPPED ERROR]: {:#?}", e); + return Ok((1, vec![])); + } + // consolidate error reason for deterministic test + tracing::info!("[UNEXPECTED ERROR]: {:#?}", e); + Err(anyhow!("Encountered unexpected error: {e}")) + } + } +} + +pub(super) async fn run_query(timeout_duration: u64, client: &Client, query: &str) -> Result { + let (skipped_count, _) = run_query_inner(timeout_duration, client, query).await?; + Ok(skipped_count) +} +/// Run query, handle permissible errors +/// For recovery error, just do bounded retry. +/// For other errors, validate them accordingly, skipping if they are permitted. +/// Otherwise just return success. +/// If takes too long return the query which timed out + execution time + timeout error +/// Returns: Number of skipped queries, number of rows returned. +pub(super) async fn run_query_inner( + timeout_duration: u64, + client: &Client, + query: &str, +) -> Result<(i64, Vec)> { + let query_task = client.simple_query(query); + let result = timeout(Duration::from_secs(timeout_duration), query_task).await; + let response = match result { + Ok(r) => r, + Err(_) => bail!( + "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}", + query + ), + }; + if let Err(e) = &response + && let Some(e) = e.as_db_error() + { + if is_recovery_in_progress_error(&e.to_string()) { + let tries = 5; + let interval = 1; + for _ in 0..tries { + // retry 5 times + sleep(Duration::from_secs(interval)).await; + let query_task = client.simple_query(query); + let response = timeout(Duration::from_secs(timeout_duration), query_task).await; + match response { + Ok(Ok(r)) => { + return Ok((0, r)); + } + Err(_) => bail!( + "[UNEXPECTED ERROR] Query timeout after {timeout_duration}s:\n{:?}", + query + ), + _ => {} + } + } + bail!("[UNEXPECTED ERROR] Failed to recover after {tries} tries with interval {interval}s") + } else { + return validate_response(response); + } + } + let rows = response?; + Ok((0, rows)) +} + +pub(super) fn generate_rng(seed: Option) -> impl Rng { + #[cfg(madsim)] + if let Some(seed) = seed { + ChaChaRng::seed_from_u64(seed) + } else { + ChaChaRng::from_rng(SmallRng::from_entropy()).unwrap() + } + #[cfg(not(madsim))] + if let Some(seed) = seed { + SmallRng::seed_from_u64(seed) + } else { + SmallRng::from_entropy() + } +} From f33dac4b2c97eb90bd1e55ddaa8c37c3aa0ad3c4 Mon Sep 17 00:00:00 2001 From: Kevin Axel Date: Mon, 8 Jan 2024 14:39:24 +0800 Subject: [PATCH 16/20] feat(udf): support implicit cast for UDF arguments (#14338) Signed-off-by: Kevin Axel Co-authored-by: Zihao Xu --- e2e_test/udf/udf.slt | 61 +++++++++--------- src/expr/core/src/sig/mod.rs | 32 ++++++++-- src/expr/impl/tests/sig.rs | 2 +- src/frontend/src/binder/expr/function.rs | 8 +-- src/frontend/src/catalog/root_catalog.rs | 29 +++++++++ src/frontend/src/catalog/schema_catalog.rs | 66 +++++++++++++++++++- src/frontend/src/expr/mod.rs | 4 +- src/frontend/src/expr/type_inference/func.rs | 16 +++-- src/frontend/src/expr/type_inference/mod.rs | 2 +- 9 files changed, 168 insertions(+), 52 deletions(-) diff --git a/e2e_test/udf/udf.slt b/e2e_test/udf/udf.slt index d3f88e8f5b5d8..096a605709d67 100644 --- a/e2e_test/udf/udf.slt +++ b/e2e_test/udf/udf.slt @@ -115,7 +115,7 @@ select hex_to_dec('000000000000000000000000000000000000000000c0f6346334241a61f90 233276425899864771438119478 query I -select float_to_decimal('-1e-10'::float8); +select float_to_decimal('-1e-10'); ---- -0.0000000001000000000000000036 @@ -138,17 +138,17 @@ NULL false query T -select jsonb_concat(ARRAY['null'::jsonb, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]); +select jsonb_concat(ARRAY['null', '1', '"str"', '{}'::jsonb]); ---- [null, 1, "str", {}] query T -select jsonb_array_identity(ARRAY[null, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]); +select jsonb_array_identity(ARRAY[null, '1', '"str"', '{}'::jsonb]); ---- {NULL,1,"\"str\"","{}"} query T -select jsonb_array_struct_identity(ROW(ARRAY[null, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb], 4)::struct); +select jsonb_array_struct_identity(ROW(ARRAY[null, '1', '"str"', '{}'::jsonb], 4)::struct); ---- ("{NULL,1,""\\""str\\"""",""{}""}",4) @@ -156,18 +156,18 @@ query T select (return_all( true, 1 ::smallint, - 1 ::int, - 1 ::bigint, - 1 ::float4, - 1 ::float8, - 12345678901234567890.12345678 ::decimal, + 1, + 1, + 1, + 1, + 12345678901234567890.12345678, date '2023-06-01', time '01:02:03.456789', timestamp '2023-06-01 01:02:03.456789', interval '1 month 2 days 3 seconds', 'string', - 'bytes'::bytea, - '{"key":1}'::jsonb, + 'bytes', + '{"key":1}', row(1, 2)::struct )).*; ---- @@ -177,11 +177,11 @@ query T select (return_all_arrays( array[null, true], array[null, 1 ::smallint], - array[null, 1 ::int], + array[null, 1], array[null, 1 ::bigint], array[null, 1 ::float4], array[null, 1 ::float8], - array[null, 12345678901234567890.12345678 ::decimal], + array[null, 12345678901234567890.12345678], array[null, date '2023-06-01'], array[null, time '01:02:03.456789'], array[null, timestamp '2023-06-01 01:02:03.456789'], @@ -197,21 +197,21 @@ select (return_all_arrays( # test large string output query I select length((return_all( - null::boolean, - null::smallint, - null::int, - null::bigint, - null::float4, - null::float8, - null::decimal, - null::date, - null::time, - null::timestamp, - null::interval, - repeat('a', 100000)::varchar, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + repeat('a', 100000), repeat('a', 100000)::bytea, - null::jsonb, - null::struct + null, + null )).varchar); ---- 100000 @@ -253,16 +253,13 @@ select count(*) from series(1000000); ---- 1000000 -# TODO: support argument implicit cast for UDF -# e.g. extract_tcp_info(E'\\x45'); - query T -select extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: bytea); +select extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4'); ---- (192.168.0.14,192.168.0.1,861,8374) query TTII -select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: BYTEA)).*; +select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4')).*; ---- 192.168.0.14 192.168.0.1 861 8374 diff --git a/src/expr/core/src/sig/mod.rs b/src/expr/core/src/sig/mod.rs index e509e1bd6322b..4366b90230cd9 100644 --- a/src/expr/core/src/sig/mod.rs +++ b/src/expr/core/src/sig/mod.rs @@ -14,6 +14,7 @@ //! Metadata of expressions. +use std::borrow::Cow; use std::collections::HashMap; use std::fmt; use std::sync::LazyLock; @@ -47,7 +48,7 @@ pub struct FunctionRegistry(HashMap>); impl FunctionRegistry { /// Inserts a function signature. pub fn insert(&mut self, sig: FuncSign) { - let list = self.0.entry(sig.name).or_default(); + let list = self.0.entry(sig.name.clone()).or_default(); if sig.is_aggregate() { // merge retractable and append-only aggregate if let Some(existing) = list @@ -85,6 +86,22 @@ impl FunctionRegistry { list.push(sig); } + /// Remove a function signature from registry. + pub fn remove(&mut self, sig: FuncSign) -> Option { + let pos = self + .0 + .get_mut(&sig.name)? + .iter() + .positions(|s| s.inputs_type == sig.inputs_type && s.ret_type == sig.ret_type) + .rev() + .collect_vec(); + let mut ret = None; + for p in pos { + ret = Some(self.0.get_mut(&sig.name)?.swap_remove(p)); + } + ret + } + /// Returns a function signature with the same type, argument types and return type. /// Deprecated functions are included. pub fn get( @@ -300,11 +317,12 @@ impl FuncSign { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum FuncName { Scalar(ScalarFunctionType), Table(TableFunctionType), Aggregate(AggregateFunctionType), + Udf(String), } impl From for FuncName { @@ -333,11 +351,12 @@ impl fmt::Display for FuncName { impl FuncName { /// Returns the name of the function in `UPPER_CASE` style. - pub fn as_str_name(&self) -> &'static str { + pub fn as_str_name(&self) -> Cow<'static, str> { match self { - Self::Scalar(ty) => ty.as_str_name(), - Self::Table(ty) => ty.as_str_name(), - Self::Aggregate(ty) => ty.to_protobuf().as_str_name(), + Self::Scalar(ty) => ty.as_str_name().into(), + Self::Table(ty) => ty.as_str_name().into(), + Self::Aggregate(ty) => ty.to_protobuf().as_str_name().into(), + Self::Udf(name) => name.clone().into(), } } @@ -437,6 +456,7 @@ pub enum FuncBuilder { /// `None` means equal to the return type. append_only_state_type: Option, }, + Udf, } /// Register a function into global registry. diff --git a/src/expr/impl/tests/sig.rs b/src/expr/impl/tests/sig.rs index c021d3363fc5a..2dc8aacdb203f 100644 --- a/src/expr/impl/tests/sig.rs +++ b/src/expr/impl/tests/sig.rs @@ -29,7 +29,7 @@ fn test_func_sig_map() { } new_map - .entry(sig.name) + .entry(sig.name.clone()) .or_default() .entry(sig.inputs_type.to_vec()) .or_default() diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 7244a8527f857..341e254b221fb 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -122,7 +122,7 @@ impl Binder { // Used later in sql udf expression evaluation let args = f.args.clone(); - let inputs = f + let mut inputs = f .args .into_iter() .map(|arg| self.bind_function_arg(arg)) @@ -224,12 +224,10 @@ impl Binder { // user defined function // TODO: resolve schema name https://github.com/risingwavelabs/risingwave/issues/12422 if let Ok(schema) = self.first_valid_schema() - && let Some(func) = schema.get_function_by_name_args( - &function_name, - &inputs.iter().map(|arg| arg.return_type()).collect_vec(), - ) + && let Some(func) = schema.get_function_by_name_inputs(&function_name, &mut inputs) { use crate::catalog::function_catalog::FunctionKind::*; + if func.language == "sql" { if func.body.is_none() { return Err(ErrorCode::InvalidInputSyntax( diff --git a/src/frontend/src/catalog/root_catalog.rs b/src/frontend/src/catalog/root_catalog.rs index 9d2045f5dc61f..3b9a9722a1f8f 100644 --- a/src/frontend/src/catalog/root_catalog.rs +++ b/src/frontend/src/catalog/root_catalog.rs @@ -37,6 +37,7 @@ use crate::catalog::system_catalog::{ }; use crate::catalog::table_catalog::TableCatalog; use crate::catalog::{DatabaseId, IndexCatalog, SchemaId}; +use crate::expr::{Expr, ExprImpl}; #[derive(Copy, Clone)] pub enum SchemaPath<'a> { @@ -753,6 +754,34 @@ impl Catalog { .ok_or_else(|| CatalogError::NotFound("connection", connection_name.to_string())) } + pub fn get_function_by_name_inputs<'a>( + &self, + db_name: &str, + schema_path: SchemaPath<'a>, + function_name: &str, + inputs: &mut [ExprImpl], + ) -> CatalogResult<(&Arc, &'a str)> { + schema_path + .try_find(|schema_name| { + Ok(self + .get_schema_by_name(db_name, schema_name)? + .get_function_by_name_inputs(function_name, inputs)) + })? + .ok_or_else(|| { + CatalogError::NotFound( + "function", + format!( + "{}({})", + function_name, + inputs + .iter() + .map(|a| a.return_type().to_string()) + .join(", ") + ), + ) + }) + } + pub fn get_function_by_name_args<'a>( &self, db_name: &str, diff --git a/src/frontend/src/catalog/schema_catalog.rs b/src/frontend/src/catalog/schema_catalog.rs index ab47d4f708edc..340b77f7aa066 100644 --- a/src/frontend/src/catalog/schema_catalog.rs +++ b/src/frontend/src/catalog/schema_catalog.rs @@ -16,9 +16,11 @@ use std::collections::hash_map::Entry::{Occupied, Vacant}; use std::collections::HashMap; use std::sync::Arc; +use itertools::Itertools; use risingwave_common::catalog::{valid_table_name, FunctionId, IndexId, TableId}; use risingwave_common::types::DataType; use risingwave_connector::sink::catalog::SinkCatalog; +pub use risingwave_expr::sig::*; use risingwave_pb::catalog::{ PbConnection, PbFunction, PbIndex, PbSchema, PbSink, PbSource, PbTable, PbView, }; @@ -32,6 +34,7 @@ use crate::catalog::system_catalog::SystemTableCatalog; use crate::catalog::table_catalog::TableCatalog; use crate::catalog::view_catalog::ViewCatalog; use crate::catalog::{ConnectionId, DatabaseId, SchemaId, SinkId, SourceId, ViewId}; +use crate::expr::{infer_type_name, infer_type_with_sigmap, Expr, ExprImpl}; use crate::user::UserId; #[derive(Clone, Debug)] @@ -50,6 +53,7 @@ pub struct SchemaCatalog { indexes_by_table_id: HashMap>>, view_by_name: HashMap>, view_by_id: HashMap>, + function_registry: FunctionRegistry, function_by_name: HashMap, Arc>>, function_by_id: HashMap>, connection_by_name: HashMap>, @@ -320,6 +324,23 @@ impl SchemaCatalog { self.view_by_id.insert(id, view_ref); } + pub fn get_func_sign(func: &FunctionCatalog) -> FuncSign { + FuncSign { + name: FuncName::Udf(func.name.clone()), + inputs_type: func + .arg_types + .iter() + .map(|t| t.clone().into()) + .collect_vec(), + variadic: false, + ret_type: func.return_type.clone().into(), + build: FuncBuilder::Udf, + // dummy type infer, will not use this result + type_infer: |_| Ok(DataType::Boolean), + deprecated: false, + } + } + pub fn create_function(&mut self, prost: &PbFunction) { let name = prost.name.clone(); let id = prost.id; @@ -327,6 +348,8 @@ impl SchemaCatalog { let args = function.arg_types.clone(); let function_ref = Arc::new(function); + self.function_registry + .insert(Self::get_func_sign(&function_ref)); self.function_by_name .entry(name) .or_default() @@ -342,6 +365,11 @@ impl SchemaCatalog { .function_by_id .remove(&id) .expect("function not found by id"); + + self.function_registry + .remove(Self::get_func_sign(&function_ref)) + .expect("function not found in registry"); + self.function_by_name .get_mut(&function_ref.name) .expect("function not found by name") @@ -537,12 +565,47 @@ impl SchemaCatalog { self.function_by_id.get(&function_id) } + pub fn get_function_by_name_inputs( + &self, + name: &str, + inputs: &mut [ExprImpl], + ) -> Option<&Arc> { + infer_type_with_sigmap( + FuncName::Udf(name.to_string()), + inputs, + &self.function_registry, + ) + .ok()?; + let args = inputs.iter().map(|x| x.return_type()).collect_vec(); + self.function_by_name.get(name)?.get(&args) + } + pub fn get_function_by_name_args( &self, name: &str, args: &[DataType], ) -> Option<&Arc> { - self.function_by_name.get(name)?.get(args) + let args = args.iter().map(|x| Some(x.clone())).collect_vec(); + let func = infer_type_name( + &self.function_registry, + FuncName::Udf(name.to_string()), + &args, + ) + .ok()?; + + let args = func + .inputs_type + .iter() + .filter_map(|x| { + if let SigDataType::Exact(t) = x { + Some(t.clone()) + } else { + None + } + }) + .collect_vec(); + + self.function_by_name.get(name)?.get(&args) } pub fn get_functions_by_name(&self, name: &str) -> Option>> { @@ -619,6 +682,7 @@ impl From<&PbSchema> for SchemaCatalog { system_table_by_name: HashMap::new(), view_by_name: HashMap::new(), view_by_id: HashMap::new(), + function_registry: FunctionRegistry::default(), function_by_name: HashMap::new(), function_by_id: HashMap::new(), connection_by_name: HashMap::new(), diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index 2a83bf63063ca..b38aaf735c2bd 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -66,8 +66,8 @@ pub use session_timezone::{SessionTimezone, TimestamptzExprFinder}; pub use subquery::{Subquery, SubqueryKind}; pub use table_function::{TableFunction, TableFunctionType}; pub use type_inference::{ - align_types, cast_map_array, cast_ok, cast_sigs, infer_some_all, infer_type, least_restrictive, - CastContext, CastSig, FuncSign, + align_types, cast_map_array, cast_ok, cast_sigs, infer_some_all, infer_type, infer_type_name, + infer_type_with_sigmap, least_restrictive, CastContext, CastSig, FuncSign, }; pub use user_defined_function::UserDefinedFunction; pub use utils::*; diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index d4637cbf7178f..cccb1f0905a42 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -28,7 +28,11 @@ use crate::expr::{cast_ok, is_row_function, Expr as _, ExprImpl, ExprType, Funct /// is not supported on backend. /// /// It also mutates the `inputs` by adding necessary casts. -pub fn infer_type(func_name: FuncName, inputs: &mut [ExprImpl]) -> Result { +pub fn infer_type_with_sigmap( + func_name: FuncName, + inputs: &mut [ExprImpl], + sig_map: &FunctionRegistry, +) -> Result { // special cases if let FuncName::Scalar(func_type) = func_name && let Some(res) = infer_type_for_special(func_type, inputs).transpose() @@ -46,7 +50,7 @@ pub fn infer_type(func_name: FuncName, inputs: &mut [ExprImpl]) -> Result Some(e.return_type()), }) .collect_vec(); - let sig = infer_type_name(&FUNCTION_REGISTRY, func_name, &actuals)?; + let sig = infer_type_name(sig_map, func_name, &actuals)?; // add implicit casts to inputs for (expr, t) in inputs.iter_mut().zip_eq_fast(&sig.inputs_type) { @@ -67,6 +71,10 @@ pub fn infer_type(func_name: FuncName, inputs: &mut [ExprImpl]) -> Result Result { + infer_type_with_sigmap(func_name, inputs, &FUNCTION_REGISTRY) +} + pub fn infer_some_all( mut func_types: Vec, inputs: &mut Vec, @@ -608,12 +616,12 @@ fn infer_type_for_special( /// 4e in `PostgreSQL`. See [`narrow_category`] for details. /// 5. Attempt to narrow down candidates by assuming all arguments are same type. This covers Rule /// 4f in `PostgreSQL`. See [`narrow_same_type`] for details. -fn infer_type_name<'a>( +pub fn infer_type_name<'a>( sig_map: &'a FunctionRegistry, func_name: FuncName, inputs: &[Option], ) -> Result<&'a FuncSign> { - let candidates = sig_map.get_with_arg_nums(func_name, inputs.len()); + let candidates = sig_map.get_with_arg_nums(func_name.clone(), inputs.len()); // Binary operators have a special `unknown` handling rule for exact match. We do not // distinguish operators from functions as of now. diff --git a/src/frontend/src/expr/type_inference/mod.rs b/src/frontend/src/expr/type_inference/mod.rs index 9f496f8c3d750..5f191a898614c 100644 --- a/src/frontend/src/expr/type_inference/mod.rs +++ b/src/frontend/src/expr/type_inference/mod.rs @@ -21,4 +21,4 @@ pub use cast::{ align_types, cast_map_array, cast_ok, cast_ok_base, cast_sigs, least_restrictive, CastContext, CastSig, }; -pub use func::{infer_some_all, infer_type, FuncSign}; +pub use func::{infer_some_all, infer_type, infer_type_name, infer_type_with_sigmap, FuncSign}; From 569b5897891ce50ce88809343f2e57db0f8516cf Mon Sep 17 00:00:00 2001 From: Yufan Song <33971064+yufansong@users.noreply.github.com> Date: Sun, 7 Jan 2024 22:56:48 -0800 Subject: [PATCH 17/20] fix: column index mapping bug of stream_delta_join (#14398) --- ...in_upstream_with_index_different_types.slt | 47 +++++++++++++++++++ proto/stream_plan.proto | 2 + .../optimizer/plan_node/stream_delta_join.rs | 10 ++++ src/stream/src/from_proto/lookup.rs | 8 +++- 4 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 e2e_test/streaming/delta_join/delta_join_upstream_with_index_different_types.slt diff --git a/e2e_test/streaming/delta_join/delta_join_upstream_with_index_different_types.slt b/e2e_test/streaming/delta_join/delta_join_upstream_with_index_different_types.slt new file mode 100644 index 0000000000000..76cb0314e3a48 --- /dev/null +++ b/e2e_test/streaming/delta_join/delta_join_upstream_with_index_different_types.slt @@ -0,0 +1,47 @@ +statement ok +set rw_implicit_flush = true; + +statement ok +set rw_streaming_enable_delta_join = true; + +statement ok +create table A (k1 numeric, k2 smallint, v int); + +statement ok +create index Ak1 on A(k1) include(k1,k2,v); + +statement ok +create table B (k1 numeric, k2 smallint, v int); + +statement ok +create index Bk1 on B(k1) include(k1,k2,v); + +statement ok +insert into A values(1, 2, 4); + +statement ok +insert into B values(1, 2, 4); + +statement ok +create MATERIALIZED VIEW m1 as select A.v, B.v as Bv from A join B using(k1); + + +query I +SELECT * from m1; +---- +4 4 + +statement ok +drop MATERIALIZED VIEW m1; + +statement ok +drop index Ak1; + +statement ok +drop index Bk1; + +statement ok +drop table A; + +statement ok +drop table B; \ No newline at end of file diff --git a/proto/stream_plan.proto b/proto/stream_plan.proto index 04ba2246bb859..a168ea163f5b5 100644 --- a/proto/stream_plan.proto +++ b/proto/stream_plan.proto @@ -576,6 +576,8 @@ message ArrangementInfo { repeated plan_common.ColumnDesc column_descs = 2; // Used to build storage table by stream lookup join of delta join. plan_common.StorageTableDesc table_desc = 4; + // Output index columns + repeated uint32 output_col_idx = 5; } // Special node for shared state, which will only be produced in fragmenter. ArrangeNode will diff --git a/src/frontend/src/optimizer/plan_node/stream_delta_join.rs b/src/frontend/src/optimizer/plan_node/stream_delta_join.rs index 6cadd8a31b9e3..25b45ac24c73a 100644 --- a/src/frontend/src/optimizer/plan_node/stream_delta_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_delta_join.rs @@ -182,6 +182,11 @@ impl StreamNode for StreamDeltaJoin { .map(ColumnDesc::to_protobuf) .collect(), table_desc: Some(left_table_desc.to_protobuf()), + output_col_idx: left_table + .output_col_idx + .iter() + .map(|&v| v as u32) + .collect(), }), right_info: Some(ArrangementInfo { // TODO: remove it @@ -193,6 +198,11 @@ impl StreamNode for StreamDeltaJoin { .map(ColumnDesc::to_protobuf) .collect(), table_desc: Some(right_table_desc.to_protobuf()), + output_col_idx: right_table + .output_col_idx + .iter() + .map(|&v| v as u32) + .collect(), }), output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(), }) diff --git a/src/stream/src/from_proto/lookup.rs b/src/stream/src/from_proto/lookup.rs index a35f0a4390b34..1c1733ae7e4ba 100644 --- a/src/stream/src/from_proto/lookup.rs +++ b/src/stream/src/from_proto/lookup.rs @@ -72,7 +72,13 @@ impl ExecutorBuilder for LookupExecutorBuilder { .iter() .map(ColumnDesc::from) .collect_vec(); - let column_ids = column_descs.iter().map(|x| x.column_id).collect_vec(); + + let column_ids = lookup + .get_arrangement_table_info()? + .get_output_col_idx() + .iter() + .map(|&idx| column_descs[idx as usize].column_id) + .collect_vec(); // Use indices based on full table instead of streaming executor output. let pk_indices = table_desc From 886371aa40dd09a1f6954ee250ae7ed8d045f987 Mon Sep 17 00:00:00 2001 From: Noel Kwan <47273164+kwannoel@users.noreply.github.com> Date: Mon, 8 Jan 2024 15:54:50 +0800 Subject: [PATCH 18/20] chore(ci): enable differential tests (#14411) --- ci/scripts/notify.py | 1 + ci/workflows/main-cron.yml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/scripts/notify.py b/ci/scripts/notify.py index 9160d01675832..4fbaf16799fa6 100755 --- a/ci/scripts/notify.py +++ b/ci/scripts/notify.py @@ -13,6 +13,7 @@ "test-notify-2": ["noelkwan", "noelkwan"], "backfill-tests": ["noelkwan"], "backwards-compat-tests": ["noelkwan"], + "sqlsmith-differential-tests": ["noelkwan"], "fuzz-test": ["noelkwan"], "e2e-test-release": ["zhi"], "e2e-iceberg-sink-tests": ["renjie"], diff --git a/ci/workflows/main-cron.yml b/ci/workflows/main-cron.yml index b5533ae149e0b..56d4695beafbc 100644 --- a/ci/workflows/main-cron.yml +++ b/ci/workflows/main-cron.yml @@ -650,6 +650,7 @@ steps: # Sqlsmith differential testing - label: "Sqlsmith Differential Testing" + key: "sqlsmith-differential-tests" command: "ci/scripts/sqlsmith-differential-test.sh -p ci-release" if: | !(build.pull_request.labels includes "ci/main-cron/skip-ci") && build.env("CI_STEPS") == null @@ -663,7 +664,6 @@ steps: config: ci/docker-compose.yml mount-buildkite-agent: true timeout_in_minutes: 40 - soft_fail: true - label: "Backfill tests" key: "backfill-tests" From f52c04619f8bac0b29cb3c12541fd1b5fdc2b6db Mon Sep 17 00:00:00 2001 From: congyi wang <58715567+wcy-fdu@users.noreply.github.com> Date: Mon, 8 Jan 2024 17:36:24 +0800 Subject: [PATCH 19/20] chore: update rw version with hdfs (#14412) --- docker/docker-compose-with-hdfs.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docker/docker-compose-with-hdfs.yml b/docker/docker-compose-with-hdfs.yml index 9088d09626a5c..ccf0f433c450b 100644 --- a/docker/docker-compose-with-hdfs.yml +++ b/docker/docker-compose-with-hdfs.yml @@ -2,7 +2,7 @@ version: "3" services: compactor-0: - image: ghcr.io/risingwavelabs/risingwave:RW_1.1_HADOOP2-x86_64 + image: ghcr.io/risingwavelabs/risingwave:RisingWave_v1.5.4_HDFS_2.7-x86_64 command: - compactor-node - "--listen-addr" @@ -42,7 +42,7 @@ services: reservations: memory: 1G compute-node-0: - image: "ghcr.io/risingwavelabs/risingwave:RW_1.1_HADOOP2-x86_64" + image: "ghcr.io/risingwavelabs/risingwave:RisingWave_v1.5.4_HDFS_2.7-x86_64" command: - compute-node - "--listen-addr" @@ -132,7 +132,7 @@ services: retries: 5 restart: always frontend-node-0: - image: "ghcr.io/risingwavelabs/risingwave:RW_1.1_HADOOP2-x86_64" + image: "ghcr.io/risingwavelabs/risingwave:RisingWave_v1.5.4_HDFS_2.7-x86_64" command: - frontend-node - "--listen-addr" @@ -195,7 +195,7 @@ services: retries: 5 restart: always meta-node-0: - image: "ghcr.io/risingwavelabs/risingwave:RW_1.1_HADOOP2-x86_64" + image: "ghcr.io/risingwavelabs/risingwave:RisingWave_v1.5.4_HDFS_2.7-x86_64" command: - meta-node - "--listen-addr" From 414c6ec1c05e3d13aa4842d30624c13b83b9692c Mon Sep 17 00:00:00 2001 From: William Wen <44139337+wenym1@users.noreply.github.com> Date: Mon, 8 Jan 2024 18:42:14 +0800 Subject: [PATCH 20/20] refactor: own global barrier manager in worker loop (#14410) --- src/meta/node/src/server.rs | 8 +- src/meta/src/barrier/command.rs | 78 ++-- src/meta/src/barrier/mod.rs | 598 +++++++++++++------------- src/meta/src/barrier/progress.rs | 2 +- src/meta/src/barrier/recovery.rs | 27 +- src/meta/src/stream/stream_manager.rs | 4 +- 6 files changed, 369 insertions(+), 348 deletions(-) diff --git a/src/meta/node/src/server.rs b/src/meta/node/src/server.rs index d7bd0208b3873..d28c4e8e5b93e 100644 --- a/src/meta/node/src/server.rs +++ b/src/meta/node/src/server.rs @@ -535,7 +535,7 @@ pub async fn start_service_as_election_leader( let (sink_manager, shutdown_handle) = SinkCoordinatorManager::start_worker(); let mut sub_tasks = vec![shutdown_handle]; - let barrier_manager = Arc::new(GlobalBarrierManager::new( + let barrier_manager = GlobalBarrierManager::new( scheduled_barriers, env.clone(), metadata_manager.clone(), @@ -543,7 +543,7 @@ pub async fn start_service_as_election_leader( source_manager.clone(), sink_manager.clone(), meta_metrics.clone(), - )); + ); { let source_manager = source_manager.clone(); @@ -611,7 +611,7 @@ pub async fn start_service_as_election_leader( metadata_manager.clone(), stream_manager.clone(), source_manager.clone(), - barrier_manager.clone(), + barrier_manager.context().clone(), sink_manager.clone(), ) .await; @@ -622,7 +622,7 @@ pub async fn start_service_as_election_leader( metadata_manager.clone(), source_manager, stream_manager.clone(), - barrier_manager.clone(), + barrier_manager.context().clone(), ); let cluster_srv = ClusterServiceImpl::new(metadata_manager.clone()); diff --git a/src/meta/src/barrier/command.rs b/src/meta/src/barrier/command.rs index 96eed8cba6846..39bc3ced0023a 100644 --- a/src/meta/src/barrier/command.rs +++ b/src/meta/src/barrier/command.rs @@ -33,19 +33,14 @@ use risingwave_pb::stream_plan::{ UpdateMutation, }; use risingwave_pb::stream_service::{DropActorsRequest, WaitEpochCommitRequest}; -use risingwave_rpc_client::StreamClientPoolRef; use uuid::Uuid; use super::info::BarrierActorInfo; use super::trace::TracedEpoch; -use crate::barrier::CommandChanges; -use crate::hummock::HummockManagerRef; +use crate::barrier::{CommandChanges, GlobalBarrierManagerContext}; use crate::manager::{DdlType, MetadataManager, WorkerId}; use crate::model::{ActorId, DispatcherId, FragmentId, TableFragments, TableParallelism}; -use crate::stream::{ - build_actor_connector_splits, ScaleControllerRef, SourceManagerRef, SplitAssignment, - ThrottleConfig, -}; +use crate::stream::{build_actor_connector_splits, SplitAssignment, ThrottleConfig}; use crate::MetaResult; /// [`Reschedule`] is for the [`Command::RescheduleFragment`], which is used for rescheduling actors @@ -266,12 +261,6 @@ impl Command { /// [`CommandContext`] is used for generating barrier and doing post stuffs according to the given /// [`Command`]. pub struct CommandContext { - pub metadata_manager: MetadataManager, - - hummock_manager: HummockManagerRef, - - client_pool: StreamClientPoolRef, - /// Resolved info in this barrier loop. // TODO: this could be stale when we are calling `post_collect`, check if it matters pub info: Arc, @@ -285,9 +274,7 @@ pub struct CommandContext { pub kind: BarrierKind, - source_manager: SourceManagerRef, - - scale_controller: Option, + barrier_manager_context: GlobalBarrierManagerContext, /// The tracing span of this command. /// @@ -300,34 +287,30 @@ pub struct CommandContext { impl CommandContext { #[allow(clippy::too_many_arguments)] pub(super) fn new( - metadata_manager: MetadataManager, - hummock_manager: HummockManagerRef, - client_pool: StreamClientPoolRef, info: BarrierActorInfo, prev_epoch: TracedEpoch, curr_epoch: TracedEpoch, current_paused_reason: Option, command: Command, kind: BarrierKind, - source_manager: SourceManagerRef, - scale_controller: Option, + barrier_manager_context: GlobalBarrierManagerContext, span: tracing::Span, ) -> Self { Self { - metadata_manager, - hummock_manager, - client_pool, info: Arc::new(info), prev_epoch, curr_epoch, current_paused_reason, command, kind, - source_manager, - scale_controller, + barrier_manager_context, span, } } + + pub fn metadata_manager(&self) -> &MetadataManager { + &self.barrier_manager_context.metadata_manager + } } impl CommandContext { @@ -382,7 +365,8 @@ impl CommandContext { } Command::DropStreamingJobs(table_ids) => { - let MetadataManager::V1(mgr) = &self.metadata_manager else { + let MetadataManager::V1(mgr) = &self.barrier_manager_context.metadata_manager + else { unreachable!("only available in v1"); }; @@ -477,7 +461,8 @@ impl CommandContext { ), Command::RescheduleFragment { reschedules, .. } => { - let MetadataManager::V1(mgr) = &self.metadata_manager else { + let MetadataManager::V1(mgr) = &self.barrier_manager_context.metadata_manager + else { unimplemented!("implement scale functions in v2"); }; let mut dispatcher_update = HashMap::new(); @@ -736,7 +721,12 @@ impl CommandContext { let request_id = Uuid::new_v4().to_string(); async move { - let client = self.client_pool.get(node).await?; + let client = self + .barrier_manager_context + .env + .stream_client_pool() + .get(node) + .await?; let request = DropActorsRequest { request_id, actor_ids: actors.to_owned(), @@ -751,7 +741,12 @@ impl CommandContext { pub async fn wait_epoch_commit(&self, epoch: HummockEpoch) -> MetaResult<()> { let futures = self.info.node_map.values().map(|worker_node| async { - let client = self.client_pool.get(worker_node).await?; + let client = self + .barrier_manager_context + .env + .stream_client_pool() + .get(worker_node) + .await?; let request = WaitEpochCommitRequest { epoch }; client.wait_epoch_commit(request).await }); @@ -782,19 +777,22 @@ impl CommandContext { Command::Resume(_) => {} Command::SourceSplitAssignment(split_assignment) => { - let MetadataManager::V1(mgr) = &self.metadata_manager else { + let MetadataManager::V1(mgr) = &self.barrier_manager_context.metadata_manager + else { unimplemented!("implement config change funcs in v2"); }; mgr.fragment_manager .update_actor_splits_by_split_assignment(split_assignment) .await?; - self.source_manager + self.barrier_manager_context + .source_manager .apply_source_change(None, Some(split_assignment.clone()), None) .await; } Command::DropStreamingJobs(table_ids) => { - let MetadataManager::V1(mgr) = &self.metadata_manager else { + let MetadataManager::V1(mgr) = &self.barrier_manager_context.metadata_manager + else { unreachable!("only available in v1"); }; // Tell compute nodes to drop actors. @@ -834,11 +832,12 @@ impl CommandContext { let table_id = table_fragments.table_id().table_id; let mut table_ids = table_fragments.internal_table_ids(); table_ids.push(table_id); - self.hummock_manager + self.barrier_manager_context + .hummock_manager .unregister_table_ids_fail_fast(&table_ids) .await; - match &self.metadata_manager { + match &self.barrier_manager_context.metadata_manager { MetadataManager::V1(mgr) => { // NOTE(kwannoel): At this point, catalog manager has persisted the tables already. // We need to cleanup the table state. So we can do it here. @@ -889,7 +888,7 @@ impl CommandContext { replace_table, .. } => { - match &self.metadata_manager { + match &self.barrier_manager_context.metadata_manager { MetadataManager::V1(mgr) => { let mut dependent_table_actors = Vec::with_capacity(upstream_mview_actors.len()); @@ -944,7 +943,8 @@ impl CommandContext { // Extract the fragments that include source operators. let source_fragments = table_fragments.stream_source_fragments(); - self.source_manager + self.barrier_manager_context + .source_manager .apply_source_change( Some(source_fragments), Some(init_split_assignment.clone()), @@ -958,6 +958,7 @@ impl CommandContext { table_parallelism, } => { let node_dropped_actors = self + .barrier_manager_context .scale_controller .as_ref() .unwrap() @@ -973,7 +974,8 @@ impl CommandContext { dispatchers, init_split_assignment, }) => { - let MetadataManager::V1(mgr) = &self.metadata_manager else { + let MetadataManager::V1(mgr) = &self.barrier_manager_context.metadata_manager + else { unimplemented!("implement replace funcs in v2"); }; let table_ids = HashSet::from_iter(std::iter::once(old_table_fragments.table_id())); diff --git a/src/meta/src/barrier/mod.rs b/src/meta/src/barrier/mod.rs index 5f5227e8090d1..dc7e6f8d8fb01 100644 --- a/src/meta/src/barrier/mod.rs +++ b/src/meta/src/barrier/mod.rs @@ -27,6 +27,7 @@ use prometheus::HistogramTimer; use risingwave_common::bail; use risingwave_common::catalog::TableId; use risingwave_common::system_param::PAUSE_ON_NEXT_BOOTSTRAP_KEY; +use risingwave_common::util::epoch::{Epoch, INVALID_EPOCH}; use risingwave_common::util::tracing::TracingContext; use risingwave_hummock_sdk::table_watermark::{ merge_multiple_new_table_watermarks, TableWatermarks, @@ -162,6 +163,28 @@ pub enum CommandChanges { /// No changes. None, } + +#[derive(Clone)] +pub struct GlobalBarrierManagerContext { + status: Arc>, + + tracker: Arc>, + + metadata_manager: MetadataManager, + + hummock_manager: HummockManagerRef, + + source_manager: SourceManagerRef, + + scale_controller: Option, + + sink_manager: SinkCoordinatorManager, + + metrics: Arc, + + env: MetaSrvEnv, +} + /// [`crate::barrier::GlobalBarrierManager`] sends barriers to all registered compute nodes and /// collect them, with monotonic increasing epoch numbers. On compute nodes, `LocalBarrierManager` /// in `risingwave_stream` crate will serve these requests and dispatch them to source actors. @@ -175,29 +198,19 @@ pub struct GlobalBarrierManager { /// Enable recovery or not when failover. enable_recovery: bool, - status: Mutex, - /// The queue of scheduled barriers. scheduled_barriers: schedule::ScheduledBarriers, /// The max barrier nums in flight in_flight_barrier_nums: usize, - metadata_manager: MetadataManager, + context: GlobalBarrierManagerContext, - hummock_manager: HummockManagerRef, + env: MetaSrvEnv, - source_manager: SourceManagerRef, - - scale_controller: Option, - - sink_manager: SinkCoordinatorManager, - - metrics: Arc, + state: BarrierManagerState, - pub env: MetaSrvEnv, - - tracker: Mutex, + checkpoint_control: CheckpointControl, } /// Controls the concurrent execution of commands. @@ -566,6 +579,10 @@ impl GlobalBarrierManager { let enable_recovery = env.opts.enable_recovery; let in_flight_barrier_nums = env.opts.in_flight_barrier_nums; + let initial_invalid_state = + BarrierManagerState::new(TracedEpoch::new(Epoch(INVALID_EPOCH)), None); + let checkpoint_control = CheckpointControl::new(metrics.clone()); + let tracker = CreateMviewProgressTracker::new(); let scale_controller = match &metadata_manager { @@ -576,23 +593,34 @@ impl GlobalBarrierManager { ))), MetadataManager::V2(_) => None, }; - Self { - enable_recovery, - status: Mutex::new(BarrierManagerStatus::Starting), - scheduled_barriers, - in_flight_barrier_nums, + let context = GlobalBarrierManagerContext { + status: Arc::new(Mutex::new(BarrierManagerStatus::Starting)), metadata_manager, hummock_manager, source_manager, scale_controller, sink_manager, metrics, + tracker: Arc::new(Mutex::new(tracker)), + env: env.clone(), + }; + + Self { + enable_recovery, + scheduled_barriers, + in_flight_barrier_nums, + context, env, - tracker: Mutex::new(tracker), + state: initial_invalid_state, + checkpoint_control, } } - pub fn start(barrier_manager: BarrierManagerRef) -> (JoinHandle<()>, Sender<()>) { + pub fn context(&self) -> &GlobalBarrierManagerContext { + &self.context + } + + pub fn start(barrier_manager: GlobalBarrierManager) -> (JoinHandle<()>, Sender<()>) { let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); let join_handle = tokio::spawn(async move { barrier_manager.run(shutdown_rx).await; @@ -601,27 +629,6 @@ impl GlobalBarrierManager { (join_handle, shutdown_tx) } - /// Check the status of barrier manager, return error if it is not `Running`. - pub async fn check_status_running(&self) -> MetaResult<()> { - let status = self.status.lock().await; - match &*status { - BarrierManagerStatus::Starting - | BarrierManagerStatus::Recovering(RecoveryReason::Bootstrap) => { - bail!("The cluster is bootstrapping") - } - BarrierManagerStatus::Recovering(RecoveryReason::Failover(e)) => { - Err(anyhow::anyhow!(e.clone()).context("The cluster is recovering"))? - } - BarrierManagerStatus::Running => Ok(()), - } - } - - /// Set barrier manager status. - async fn set_status(&self, new_status: BarrierManagerStatus) { - let mut status = self.status.lock().await; - *status = new_status; - } - /// Check whether we should pause on bootstrap from the system parameter and reset it. async fn take_pause_on_bootstrap(&self) -> MetaResult { let paused = self @@ -651,7 +658,7 @@ impl GlobalBarrierManager { } /// Start an infinite loop to take scheduled barriers and send them. - async fn run(&self, mut shutdown_rx: Receiver<()>) { + async fn run(mut self, mut shutdown_rx: Receiver<()>) { // Initialize the barrier manager. let interval = Duration::from_millis( self.env.system_params_reader().await.barrier_interval_ms() as u64, @@ -664,7 +671,7 @@ impl GlobalBarrierManager { ); if !self.enable_recovery { - let job_exist = match &self.metadata_manager { + let job_exist = match &self.context.metadata_manager { MetadataManager::V1(mgr) => mgr.fragment_manager.has_any_table_fragments().await, MetadataManager::V2(mgr) => mgr .catalog_controller @@ -680,8 +687,8 @@ impl GlobalBarrierManager { } } - let mut state = { - let latest_snapshot = self.hummock_manager.latest_snapshot(); + self.state = { + let latest_snapshot = self.context.hummock_manager.latest_snapshot(); assert_eq!( latest_snapshot.committed_epoch, latest_snapshot.current_epoch, "persisted snapshot must be from a checkpoint barrier" @@ -692,24 +699,25 @@ impl GlobalBarrierManager { // consistency. // Even if there's no actor to recover, we still go through the recovery process to // inject the first `Initial` barrier. - self.set_status(BarrierManagerStatus::Recovering(RecoveryReason::Bootstrap)) + self.context + .set_status(BarrierManagerStatus::Recovering(RecoveryReason::Bootstrap)) .await; let span = tracing::info_span!("bootstrap_recovery", prev_epoch = prev_epoch.value().0); let paused = self.take_pause_on_bootstrap().await.unwrap_or(false); let paused_reason = paused.then_some(PausedReason::Manual); - self.recovery(prev_epoch, paused_reason) + self.context + .recovery(prev_epoch, paused_reason, &self.scheduled_barriers) .instrument(span) .await }; - self.set_status(BarrierManagerStatus::Running).await; + self.context.set_status(BarrierManagerStatus::Running).await; let mut min_interval = tokio::time::interval(interval); min_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); let (barrier_complete_tx, mut barrier_complete_rx) = tokio::sync::mpsc::unbounded_channel(); - let mut checkpoint_control = CheckpointControl::new(self.metrics.clone()); let (local_notification_tx, mut local_notification_rx) = tokio::sync::mpsc::unbounded_channel(); self.env @@ -745,34 +753,32 @@ impl GlobalBarrierManager { completion = barrier_complete_rx.recv() => { self.handle_barrier_complete( completion.unwrap(), - &mut state, - &mut checkpoint_control, ) .await; } // There's barrier scheduled. - _ = self.scheduled_barriers.wait_one(), if checkpoint_control.can_inject_barrier(self.in_flight_barrier_nums) => { + _ = self.scheduled_barriers.wait_one(), if self.checkpoint_control.can_inject_barrier(self.in_flight_barrier_nums) => { min_interval.reset(); // Reset the interval as we have a new barrier. - self.handle_new_barrier(&barrier_complete_tx, &mut state, &mut checkpoint_control).await; + self.handle_new_barrier(&barrier_complete_tx).await; } // Minimum interval reached. - _ = min_interval.tick(), if checkpoint_control.can_inject_barrier(self.in_flight_barrier_nums) => { - self.handle_new_barrier(&barrier_complete_tx, &mut state, &mut checkpoint_control).await; + _ = min_interval.tick(), if self.checkpoint_control.can_inject_barrier(self.in_flight_barrier_nums) => { + self.handle_new_barrier(&barrier_complete_tx).await; } } - checkpoint_control.update_barrier_nums_metrics(); + self.checkpoint_control.update_barrier_nums_metrics(); } } /// Handle the new barrier from the scheduled queue and inject it. async fn handle_new_barrier( - &self, + &mut self, barrier_complete_tx: &UnboundedSender, - state: &mut BarrierManagerState, - checkpoint_control: &mut CheckpointControl, ) { - assert!(checkpoint_control.can_inject_barrier(self.in_flight_barrier_nums)); + assert!(self + .checkpoint_control + .can_inject_barrier(self.in_flight_barrier_nums)); let Scheduled { command, @@ -781,9 +787,17 @@ impl GlobalBarrierManager { checkpoint, span, } = self.scheduled_barriers.pop_or_default().await; - let info = self.resolve_actor_info(checkpoint_control, &command).await; + self.checkpoint_control.pre_resolve(&command); + let info = self + .context + .resolve_actor_info(|s: ActorState, table_id: TableId, actor_id: ActorId| { + self.checkpoint_control + .can_actor_send_or_collect(s, table_id, actor_id) + }) + .await; + self.checkpoint_control.post_resolve(&command); - let (prev_epoch, curr_epoch) = state.next_epoch_pair(); + let (prev_epoch, curr_epoch) = self.state.next_epoch_pair(); let kind = if checkpoint { BarrierKind::Checkpoint } else { @@ -795,28 +809,25 @@ impl GlobalBarrierManager { span.record("epoch", curr_epoch.value().0); let command_ctx = Arc::new(CommandContext::new( - self.metadata_manager.clone(), - self.hummock_manager.clone(), - self.env.stream_client_pool_ref(), info, prev_epoch.clone(), curr_epoch.clone(), - state.paused_reason(), + self.state.paused_reason(), command, kind, - self.source_manager.clone(), - self.scale_controller.clone(), + self.context.clone(), span.clone(), )); send_latency_timer.observe_duration(); - self.inject_barrier(command_ctx.clone(), barrier_complete_tx) + self.context + .inject_barrier(command_ctx.clone(), barrier_complete_tx) .instrument(span) .await; // Notify about the injection. - let prev_paused_reason = state.paused_reason(); + let prev_paused_reason = self.state.paused_reason(); let curr_paused_reason = command_ctx.next_paused_reason(); let info = BarrierInfo { @@ -828,182 +839,22 @@ impl GlobalBarrierManager { notifiers.iter_mut().for_each(|n| n.notify_injected(info)); // Update the paused state after the barrier is injected. - state.set_paused_reason(curr_paused_reason); + self.state.set_paused_reason(curr_paused_reason); // Record the in-flight barrier. - checkpoint_control.enqueue_command(command_ctx.clone(), notifiers); - } - - /// Inject a barrier to all CNs and spawn a task to collect it - async fn inject_barrier( - &self, - command_context: Arc, - barrier_complete_tx: &UnboundedSender, - ) { - let prev_epoch = command_context.prev_epoch.value().0; - let result = self.inject_barrier_inner(command_context.clone()).await; - match result { - Ok(node_need_collect) => { - // todo: the collect handler should be abort when recovery. - tokio::spawn(Self::collect_barrier( - self.env.clone(), - node_need_collect, - self.env.stream_client_pool_ref(), - command_context, - barrier_complete_tx.clone(), - )); - } - Err(e) => { - let _ = barrier_complete_tx.send(BarrierCompletion { - prev_epoch, - result: Err(e), - }); - } - } - } - - /// Send inject-barrier-rpc to stream service and wait for its response before returns. - async fn inject_barrier_inner( - &self, - command_context: Arc, - ) -> MetaResult> { - fail_point!("inject_barrier_err", |_| bail!("inject_barrier_err")); - let mutation = command_context.to_mutation().await?; - let info = command_context.info.clone(); - let mut node_need_collect = HashMap::new(); - let inject_futures = info.node_map.iter().filter_map(|(node_id, node)| { - let actor_ids_to_send = info.actor_ids_to_send(node_id).collect_vec(); - let actor_ids_to_collect = info.actor_ids_to_collect(node_id).collect_vec(); - if actor_ids_to_collect.is_empty() { - // No need to send or collect barrier for this node. - assert!(actor_ids_to_send.is_empty()); - node_need_collect.insert(*node_id, false); - None - } else { - node_need_collect.insert(*node_id, true); - let mutation = mutation.clone(); - let request_id = Uuid::new_v4().to_string(); - let barrier = Barrier { - epoch: Some(risingwave_pb::data::Epoch { - curr: command_context.curr_epoch.value().0, - prev: command_context.prev_epoch.value().0, - }), - mutation: mutation.clone().map(|_| BarrierMutation { mutation }), - tracing_context: TracingContext::from_span(command_context.curr_epoch.span()) - .to_protobuf(), - kind: command_context.kind as i32, - passed_actors: vec![], - }; - async move { - let client = self.env.stream_client_pool().get(node).await?; - - let request = InjectBarrierRequest { - request_id, - barrier: Some(barrier), - actor_ids_to_send, - actor_ids_to_collect, - }; - tracing::debug!( - target: "events::meta::barrier::inject_barrier", - ?request, "inject barrier request" - ); - - // This RPC returns only if this worker node has injected this barrier. - client.inject_barrier(request).await - } - .into() - } - }); - try_join_all(inject_futures).await.inspect_err(|e| { - // Record failure in event log. - use risingwave_pb::meta::event_log; - use thiserror_ext::AsReport; - let event = event_log::EventInjectBarrierFail { - prev_epoch: command_context.prev_epoch.value().0, - cur_epoch: command_context.curr_epoch.value().0, - error: e.to_report_string(), - }; - self.env - .event_log_manager_ref() - .add_event_logs(vec![event_log::Event::InjectBarrierFail(event)]); - })?; - Ok(node_need_collect) - } - - /// Send barrier-complete-rpc and wait for responses from all CNs - async fn collect_barrier( - env: MetaSrvEnv, - node_need_collect: HashMap, - client_pool_ref: StreamClientPoolRef, - command_context: Arc, - barrier_complete_tx: UnboundedSender, - ) { - let prev_epoch = command_context.prev_epoch.value().0; - let tracing_context = - TracingContext::from_span(command_context.prev_epoch.span()).to_protobuf(); - - let info = command_context.info.clone(); - let client_pool = client_pool_ref.deref(); - let collect_futures = info.node_map.iter().filter_map(|(node_id, node)| { - if !*node_need_collect.get(node_id).unwrap() { - // No need to send or collect barrier for this node. - None - } else { - let request_id = Uuid::new_v4().to_string(); - let tracing_context = tracing_context.clone(); - async move { - let client = client_pool.get(node).await?; - let request = BarrierCompleteRequest { - request_id, - prev_epoch, - tracing_context, - }; - tracing::debug!( - target: "events::meta::barrier::barrier_complete", - ?request, "barrier complete" - ); - - // This RPC returns only if this worker node has collected this barrier. - client.barrier_complete(request).await - } - .into() - } - }); - - let result = try_join_all(collect_futures) - .await - .inspect_err(|e| { - // Record failure in event log. - use risingwave_pb::meta::event_log; - use thiserror_ext::AsReport; - let event = event_log::EventCollectBarrierFail { - prev_epoch: command_context.prev_epoch.value().0, - cur_epoch: command_context.curr_epoch.value().0, - error: e.to_report_string(), - }; - env.event_log_manager_ref() - .add_event_logs(vec![event_log::Event::CollectBarrierFail(event)]); - }) - .map_err(Into::into); - let _ = barrier_complete_tx - .send(BarrierCompletion { prev_epoch, result }) - .inspect_err(|_| tracing::warn!(prev_epoch, "failed to notify barrier completion")); + self.checkpoint_control + .enqueue_command(command_ctx.clone(), notifiers); } /// Changes the state to `Complete`, and try to commit all epoch that state is `Complete` in /// order. If commit is err, all nodes will be handled. - async fn handle_barrier_complete( - &self, - completion: BarrierCompletion, - state: &mut BarrierManagerState, - checkpoint_control: &mut CheckpointControl, - ) { + async fn handle_barrier_complete(&mut self, completion: BarrierCompletion) { let BarrierCompletion { prev_epoch, result } = completion; // Received barrier complete responses with an epoch that is not managed by checkpoint // control, which means a recovery has been triggered. We should ignore it because // trying to complete and commit the epoch is not necessary and could cause // meaningless recovery again. - if !checkpoint_control.contains_epoch(prev_epoch) { + if !self.checkpoint_control.contains_epoch(prev_epoch) { tracing::warn!( "received barrier complete response for an unknown epoch: {}", prev_epoch @@ -1015,24 +866,21 @@ impl GlobalBarrierManager { // FIXME: If it is a connector source error occurred in the init barrier, we should pass // back to frontend fail_point!("inject_barrier_err_success"); - let fail_node = checkpoint_control.barrier_failed(); + let fail_node = self.checkpoint_control.barrier_failed(); tracing::warn!("Failed to complete epoch {}: {:?}", prev_epoch, err); - self.failure_recovery(err, fail_node, state, checkpoint_control) - .await; + self.failure_recovery(err, fail_node).await; return; } // change the state to Complete - let mut complete_nodes = checkpoint_control.barrier_completed(prev_epoch, result.unwrap()); + let mut complete_nodes = self + .checkpoint_control + .barrier_completed(prev_epoch, result.unwrap()); // try commit complete nodes let (mut index, mut err_msg) = (0, None); for (i, node) in complete_nodes.iter_mut().enumerate() { assert!(matches!(node.state, Completed(_))); let span = node.command_ctx.span.clone(); - if let Err(err) = self - .complete_barrier(node, checkpoint_control) - .instrument(span) - .await - { + if let Err(err) = self.complete_barrier(node).instrument(span).await { index = i; err_msg = Some(err); break; @@ -1042,21 +890,19 @@ impl GlobalBarrierManager { if let Some(err) = err_msg { let fail_nodes = complete_nodes .drain(index..) - .chain(checkpoint_control.barrier_failed().into_iter()); + .chain(self.checkpoint_control.barrier_failed().into_iter()) + .collect_vec(); tracing::warn!("Failed to commit epoch {}: {:?}", prev_epoch, err); - self.failure_recovery(err, fail_nodes, state, checkpoint_control) - .await; + self.failure_recovery(err, fail_nodes).await; } } async fn failure_recovery( - &self, + &mut self, err: MetaError, fail_nodes: impl IntoIterator, - state: &mut BarrierManagerState, - checkpoint_control: &mut CheckpointControl, ) { - checkpoint_control.clear_changes(); + self.checkpoint_control.clear_changes(); for node in fail_nodes { if let Some(timer) = node.timer { @@ -1071,11 +917,12 @@ impl GlobalBarrierManager { } if self.enable_recovery { - self.set_status(BarrierManagerStatus::Recovering(RecoveryReason::Failover( - err.clone(), - ))) - .await; - let latest_snapshot = self.hummock_manager.latest_snapshot(); + self.context + .set_status(BarrierManagerStatus::Recovering(RecoveryReason::Failover( + err.clone(), + ))) + .await; + let latest_snapshot = self.context.hummock_manager.latest_snapshot(); let prev_epoch = TracedEpoch::new(latest_snapshot.committed_epoch.into()); // we can only recovery from the committed epoch let span = tracing::info_span!( "failure_recovery", @@ -1085,19 +932,19 @@ impl GlobalBarrierManager { // No need to clean dirty tables for barrier recovery, // The foreground stream job should cleanup their own tables. - *state = self.recovery(prev_epoch, None).instrument(span).await; - self.set_status(BarrierManagerStatus::Running).await; + self.state = self + .context + .recovery(prev_epoch, None, &self.scheduled_barriers) + .instrument(span) + .await; + self.context.set_status(BarrierManagerStatus::Running).await; } else { panic!("failed to execute barrier: {:?}", err); } } /// Try to commit this node. If err, returns - async fn complete_barrier( - &self, - node: &mut EpochNode, - checkpoint_control: &mut CheckpointControl, - ) -> MetaResult<()> { + async fn complete_barrier(&mut self, node: &mut EpochNode) -> MetaResult<()> { let prev_epoch = node.command_ctx.prev_epoch.value().0; match &mut node.state { Completed(resps) => { @@ -1118,12 +965,17 @@ impl GlobalBarrierManager { ), BarrierKind::Checkpoint => { new_snapshot = self + .context .hummock_manager .commit_epoch(node.command_ctx.prev_epoch.value().0, commit_info) .await?; } BarrierKind::Barrier => { - new_snapshot = Some(self.hummock_manager.update_current_epoch(prev_epoch)); + new_snapshot = Some( + self.context + .hummock_manager + .update_current_epoch(prev_epoch), + ); // if we collect a barrier(checkpoint = false), // we need to ensure that command is Plain and the notifier's checkpoint is // false @@ -1152,7 +1004,7 @@ impl GlobalBarrierManager { // Save `cancelled_command` for Create MVs. let actors_to_cancel = node.command_ctx.actors_to_cancel(); let cancelled_command = if !actors_to_cancel.is_empty() { - let mut tracker = self.tracker.lock().await; + let mut tracker = self.context.tracker.lock().await; tracker.find_cancelled_command(actors_to_cancel) } else { None @@ -1161,8 +1013,8 @@ impl GlobalBarrierManager { // Save `finished_commands` for Create MVs. let finished_commands = { let mut commands = vec![]; - let version_stats = self.hummock_manager.get_version_stats().await; - let mut tracker = self.tracker.lock().await; + let version_stats = self.context.hummock_manager.get_version_stats().await; + let mut tracker = self.context.tracker.lock().await; // Add the command to tracker. if let Some(command) = tracker.add( TrackingCommand { @@ -1188,18 +1040,21 @@ impl GlobalBarrierManager { }; for command in finished_commands { - checkpoint_control.stash_command_to_finish(command); + self.checkpoint_control.stash_command_to_finish(command); } if let Some(command) = cancelled_command { - checkpoint_control.cancel_command(command); + self.checkpoint_control.cancel_command(command); } else if let Some(table_id) = node.command_ctx.table_to_cancel() { // the cancelled command is possibly stashed in `finished_commands` and waiting // for checkpoint, we should also clear it. - checkpoint_control.cancel_stashed_command(table_id); + self.checkpoint_control.cancel_stashed_command(table_id); } - let remaining = checkpoint_control.finish_jobs(kind.is_checkpoint()).await?; + let remaining = self + .checkpoint_control + .finish_jobs(kind.is_checkpoint()) + .await?; // If there are remaining commands (that requires checkpoint to finish), we force // the next barrier to be a checkpoint. if remaining { @@ -1230,37 +1085,206 @@ impl GlobalBarrierManager { InFlight => unreachable!(), } } +} + +impl GlobalBarrierManagerContext { + /// Check the status of barrier manager, return error if it is not `Running`. + pub async fn check_status_running(&self) -> MetaResult<()> { + let status = self.status.lock().await; + match &*status { + BarrierManagerStatus::Starting + | BarrierManagerStatus::Recovering(RecoveryReason::Bootstrap) => { + bail!("The cluster is bootstrapping") + } + BarrierManagerStatus::Recovering(RecoveryReason::Failover(e)) => { + Err(anyhow::anyhow!(e.clone()).context("The cluster is recovering"))? + } + BarrierManagerStatus::Running => Ok(()), + } + } + + /// Set barrier manager status. + async fn set_status(&self, new_status: BarrierManagerStatus) { + let mut status = self.status.lock().await; + *status = new_status; + } + + /// Inject a barrier to all CNs and spawn a task to collect it + async fn inject_barrier( + &self, + command_context: Arc, + barrier_complete_tx: &UnboundedSender, + ) { + let prev_epoch = command_context.prev_epoch.value().0; + let result = self.inject_barrier_inner(command_context.clone()).await; + match result { + Ok(node_need_collect) => { + // todo: the collect handler should be abort when recovery. + tokio::spawn(Self::collect_barrier( + self.env.clone(), + node_need_collect, + self.env.stream_client_pool_ref(), + command_context, + barrier_complete_tx.clone(), + )); + } + Err(e) => { + let _ = barrier_complete_tx.send(BarrierCompletion { + prev_epoch, + result: Err(e), + }); + } + } + } + + /// Send inject-barrier-rpc to stream service and wait for its response before returns. + async fn inject_barrier_inner( + &self, + command_context: Arc, + ) -> MetaResult> { + fail_point!("inject_barrier_err", |_| bail!("inject_barrier_err")); + let mutation = command_context.to_mutation().await?; + let info = command_context.info.clone(); + let mut node_need_collect = HashMap::new(); + let inject_futures = info.node_map.iter().filter_map(|(node_id, node)| { + let actor_ids_to_send = info.actor_ids_to_send(node_id).collect_vec(); + let actor_ids_to_collect = info.actor_ids_to_collect(node_id).collect_vec(); + if actor_ids_to_collect.is_empty() { + // No need to send or collect barrier for this node. + assert!(actor_ids_to_send.is_empty()); + node_need_collect.insert(*node_id, false); + None + } else { + node_need_collect.insert(*node_id, true); + let mutation = mutation.clone(); + let request_id = Uuid::new_v4().to_string(); + let barrier = Barrier { + epoch: Some(risingwave_pb::data::Epoch { + curr: command_context.curr_epoch.value().0, + prev: command_context.prev_epoch.value().0, + }), + mutation: mutation.clone().map(|_| BarrierMutation { mutation }), + tracing_context: TracingContext::from_span(command_context.curr_epoch.span()) + .to_protobuf(), + kind: command_context.kind as i32, + passed_actors: vec![], + }; + async move { + let client = self.env.stream_client_pool().get(node).await?; + + let request = InjectBarrierRequest { + request_id, + barrier: Some(barrier), + actor_ids_to_send, + actor_ids_to_collect, + }; + tracing::debug!( + target: "events::meta::barrier::inject_barrier", + ?request, "inject barrier request" + ); + + // This RPC returns only if this worker node has injected this barrier. + client.inject_barrier(request).await + } + .into() + } + }); + try_join_all(inject_futures).await.inspect_err(|e| { + // Record failure in event log. + use risingwave_pb::meta::event_log; + use thiserror_ext::AsReport; + let event = event_log::EventInjectBarrierFail { + prev_epoch: command_context.prev_epoch.value().0, + cur_epoch: command_context.curr_epoch.value().0, + error: e.to_report_string(), + }; + self.env + .event_log_manager_ref() + .add_event_logs(vec![event_log::Event::InjectBarrierFail(event)]); + })?; + Ok(node_need_collect) + } + + /// Send barrier-complete-rpc and wait for responses from all CNs + async fn collect_barrier( + env: MetaSrvEnv, + node_need_collect: HashMap, + client_pool_ref: StreamClientPoolRef, + command_context: Arc, + barrier_complete_tx: UnboundedSender, + ) { + let prev_epoch = command_context.prev_epoch.value().0; + let tracing_context = + TracingContext::from_span(command_context.prev_epoch.span()).to_protobuf(); + + let info = command_context.info.clone(); + let client_pool = client_pool_ref.deref(); + let collect_futures = info.node_map.iter().filter_map(|(node_id, node)| { + if !*node_need_collect.get(node_id).unwrap() { + // No need to send or collect barrier for this node. + None + } else { + let request_id = Uuid::new_v4().to_string(); + let tracing_context = tracing_context.clone(); + async move { + let client = client_pool.get(node).await?; + let request = BarrierCompleteRequest { + request_id, + prev_epoch, + tracing_context, + }; + tracing::debug!( + target: "events::meta::barrier::barrier_complete", + ?request, "barrier complete" + ); + + // This RPC returns only if this worker node has collected this barrier. + client.barrier_complete(request).await + } + .into() + } + }); + + let result = try_join_all(collect_futures) + .await + .inspect_err(|e| { + // Record failure in event log. + use risingwave_pb::meta::event_log; + use thiserror_ext::AsReport; + let event = event_log::EventCollectBarrierFail { + prev_epoch: command_context.prev_epoch.value().0, + cur_epoch: command_context.curr_epoch.value().0, + error: e.to_report_string(), + }; + env.event_log_manager_ref() + .add_event_logs(vec![event_log::Event::CollectBarrierFail(event)]); + }) + .map_err(Into::into); + let _ = barrier_complete_tx + .send(BarrierCompletion { prev_epoch, result }) + .inspect_err(|_| tracing::warn!(prev_epoch, "failed to notify barrier completion")); + } /// Resolve actor information from cluster, fragment manager and `ChangedTableId`. /// We use `changed_table_id` to modify the actors to be sent or collected. Because these actor /// will create or drop before this barrier flow through them. async fn resolve_actor_info( &self, - checkpoint_control: &mut CheckpointControl, - command: &Command, + check_state: impl Fn(ActorState, TableId, ActorId) -> bool, ) -> BarrierActorInfo { - checkpoint_control.pre_resolve(command); - let info = match &self.metadata_manager { MetadataManager::V1(mgr) => { - let check_state = |s: ActorState, table_id: TableId, actor_id: ActorId| { - checkpoint_control.can_actor_send_or_collect(s, table_id, actor_id) - }; let all_nodes = mgr .cluster_manager .list_active_streaming_compute_nodes() .await; - let all_actor_infos = mgr.fragment_manager.load_all_actors(check_state).await; + let all_actor_infos = mgr.fragment_manager.load_all_actors(&check_state).await; BarrierActorInfo::resolve(all_nodes, all_actor_infos) } MetadataManager::V2(mgr) => { let check_state = |s: ActorState, table_id: ObjectId, actor_id: i32| { - checkpoint_control.can_actor_send_or_collect( - s, - TableId::new(table_id as _), - actor_id as _, - ) + check_state(s, TableId::new(table_id as _), actor_id as _) }; let all_nodes = mgr .cluster_controller @@ -1281,8 +1305,6 @@ impl GlobalBarrierManager { } }; - checkpoint_control.post_resolve(command); - info } @@ -1327,7 +1349,7 @@ impl GlobalBarrierManager { } } -pub type BarrierManagerRef = Arc; +pub type BarrierManagerRef = GlobalBarrierManagerContext; fn collect_commit_epoch_info(resps: &mut [BarrierCompleteResponse]) -> CommitEpochInfo { let mut sst_to_worker: HashMap = HashMap::new(); diff --git a/src/meta/src/barrier/progress.rs b/src/meta/src/barrier/progress.rs index ba1e11c9c6fa3..0c753a3c3f025 100644 --- a/src/meta/src/barrier/progress.rs +++ b/src/meta/src/barrier/progress.rs @@ -161,7 +161,7 @@ pub enum TrackingJob { impl TrackingJob { fn metadata_manager(&self) -> &MetadataManager { match self { - TrackingJob::New(command) => &command.context.metadata_manager, + TrackingJob::New(command) => command.context.metadata_manager(), TrackingJob::Recovered(recovered) => &recovered.metadata_manager, } } diff --git a/src/meta/src/barrier/recovery.rs b/src/meta/src/barrier/recovery.rs index 4a5b13c03caee..be9658d57b74e 100644 --- a/src/meta/src/barrier/recovery.rs +++ b/src/meta/src/barrier/recovery.rs @@ -41,14 +41,15 @@ use crate::barrier::command::CommandContext; use crate::barrier::info::BarrierActorInfo; use crate::barrier::notifier::Notifier; use crate::barrier::progress::CreateMviewProgressTracker; -use crate::barrier::{CheckpointControl, Command, GlobalBarrierManager}; +use crate::barrier::schedule::ScheduledBarriers; +use crate::barrier::{CheckpointControl, Command, GlobalBarrierManagerContext}; use crate::controller::catalog::ReleaseContext; use crate::manager::{MetadataManager, WorkerId}; use crate::model::{BarrierManagerState, MetadataModel, MigrationPlan, TableFragments}; use crate::stream::{build_actor_connector_splits, RescheduleOptions, TableResizePolicy}; use crate::MetaResult; -impl GlobalBarrierManager { +impl GlobalBarrierManagerContext { // Retry base interval in milliseconds. const RECOVERY_RETRY_BASE_INTERVAL: u64 = 20; // Retry max interval. @@ -63,10 +64,10 @@ impl GlobalBarrierManager { } async fn resolve_actor_info_for_recovery(&self) -> BarrierActorInfo { - self.resolve_actor_info( - &mut CheckpointControl::new(self.metrics.clone()), - &Command::barrier(), - ) + let default_checkpoint_control = CheckpointControl::new(self.metrics.clone()); + self.resolve_actor_info(|s, table_id, actor_id| { + default_checkpoint_control.can_actor_send_or_collect(s, table_id, actor_id) + }) .await } @@ -326,9 +327,10 @@ impl GlobalBarrierManager { &self, prev_epoch: TracedEpoch, paused_reason: Option, + scheduled_barriers: &ScheduledBarriers, ) -> BarrierManagerState { // Mark blocked and abort buffered schedules, they might be dirty already. - self.scheduled_barriers + scheduled_barriers .abort_and_mark_blocked("cluster is under recovering") .await; @@ -358,8 +360,7 @@ impl GlobalBarrierManager { // some table fragments might have been cleaned as dirty, but it's fine since the drop // interface is idempotent. if let MetadataManager::V1(mgr) = &self.metadata_manager { - let to_drop_tables = - self.scheduled_barriers.pre_apply_drop_scheduled().await; + let to_drop_tables = scheduled_barriers.pre_apply_drop_scheduled().await; mgr.fragment_manager .drop_table_fragments_vec(&to_drop_tables) .await?; @@ -412,17 +413,13 @@ impl GlobalBarrierManager { // Inject the `Initial` barrier to initialize all executors. let command_ctx = Arc::new(CommandContext::new( - self.metadata_manager.clone(), - self.hummock_manager.clone(), - self.env.stream_client_pool_ref(), info, prev_epoch.clone(), new_epoch.clone(), paused_reason, command, BarrierKind::Initial, - self.source_manager.clone(), - self.scale_controller.clone(), + self.clone(), tracing::Span::current(), // recovery span )); @@ -470,7 +467,7 @@ impl GlobalBarrierManager { .expect("Retry until recovery success."); recovery_timer.observe_duration(); - self.scheduled_barriers.mark_ready().await; + scheduled_barriers.mark_ready().await; tracing::info!( epoch = state.in_flight_prev_epoch().value().0, diff --git a/src/meta/src/stream/stream_manager.rs b/src/meta/src/stream/stream_manager.rs index ad3da73d1115f..f10738d3f5cab 100644 --- a/src/meta/src/stream/stream_manager.rs +++ b/src/meta/src/stream/stream_manager.rs @@ -996,7 +996,7 @@ mod tests { let (sink_manager, _) = SinkCoordinatorManager::start_worker(); - let barrier_manager = Arc::new(GlobalBarrierManager::new( + let barrier_manager = GlobalBarrierManager::new( scheduled_barriers, env.clone(), metadata_manager.clone(), @@ -1004,7 +1004,7 @@ mod tests { source_manager.clone(), sink_manager, meta_metrics.clone(), - )); + ); let stream_manager = GlobalStreamManager::new( env.clone(),