From b63d48024a5bbde6d85149420f379bb7474ad354 Mon Sep 17 00:00:00 2001 From: Yufan Song <33971064+yufansong@users.noreply.github.com> Date: Thu, 11 Jan 2024 23:00:04 -0800 Subject: [PATCH 01/71] refactor(frontend): Move some pg wire value into common and make they const (#14519) --- src/common/src/lib.rs | 15 ++++++++++++++- src/common/src/session_config/mod.rs | 11 ++++++----- src/utils/pgwire/src/pg_protocol.rs | 7 ++++--- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/common/src/lib.rs b/src/common/src/lib.rs index 21b3f393f7c25..20428599b1039 100644 --- a/src/common/src/lib.rs +++ b/src/common/src/lib.rs @@ -91,6 +91,16 @@ pub const RW_VERSION: &str = env!("CARGO_PKG_VERSION"); /// Placeholder for unknown git sha. pub const UNKNOWN_GIT_SHA: &str = "unknown"; +// The single source of truth of the pg parameters, Used in ConfigMap and current_cluster_version. +// The version of PostgreSQL that Risingwave claims to be. +pub const PG_VERSION: &str = "9.5.0"; +/// The version of PostgreSQL that Risingwave claims to be. +pub const SERVER_VERSION_NUM: i32 = 90500; +/// Shows the server-side character set encoding. At present, this parameter can be shown but not set, because the encoding is determined at database creation time. It is also the default value of `client_encoding`. +pub const SERVER_ENCODING: &str = "UTF8"; +/// see +pub const STANDARD_CONFORMING_STRINGS: &str = "on"; + #[macro_export] macro_rules! git_sha { ($env:literal) => { @@ -107,5 +117,8 @@ macro_rules! git_sha { pub const GIT_SHA: &str = git_sha!("GIT_SHA"); pub fn current_cluster_version() -> String { - format!("PostgreSQL 9.5-RisingWave-{} ({})", RW_VERSION, GIT_SHA) + format!( + "PostgreSQL {}-RisingWave-{} ({})", + PG_VERSION, RW_VERSION, GIT_SHA + ) } diff --git a/src/common/src/session_config/mod.rs b/src/common/src/session_config/mod.rs index 40315da85252f..daef5faf1e240 100644 --- a/src/common/src/session_config/mod.rs +++ b/src/common/src/session_config/mod.rs @@ -31,6 +31,7 @@ use self::non_zero64::ConfigNonZeroU64; use crate::session_config::sink_decouple::SinkDecouple; use crate::session_config::transaction_isolation_level::IsolationLevel; pub use crate::session_config::visibility_mode::VisibilityMode; +use crate::{PG_VERSION, SERVER_ENCODING, SERVER_VERSION_NUM, STANDARD_CONFORMING_STRINGS}; pub const SESSION_CONFIG_LIST_SEP: &str = ", "; @@ -175,11 +176,11 @@ pub struct ConfigMap { batch_parallelism: ConfigNonZeroU64, /// The version of PostgreSQL that Risingwave claims to be. - #[parameter(default = "9.5.0")] + #[parameter(default = PG_VERSION)] server_version: String, /// The version of PostgreSQL that Risingwave claims to be. - #[parameter(default = 90500)] + #[parameter(default = SERVER_VERSION_NUM)] server_version_num: i32, /// see @@ -187,7 +188,7 @@ pub struct ConfigMap { client_min_messages: String, /// see - #[parameter(default = "UTF8", check_hook = check_client_encoding)] + #[parameter(default = SERVER_ENCODING, check_hook = check_client_encoding)] client_encoding: String, /// Enable decoupling sink and internal streaming graph or not @@ -217,7 +218,7 @@ pub struct ConfigMap { row_security: bool, /// see - #[parameter(default = "on")] + #[parameter(default = STANDARD_CONFORMING_STRINGS)] standard_conforming_strings: String, /// Set streaming rate limit (rows per second) for each parallelism for mv backfilling @@ -234,7 +235,7 @@ pub struct ConfigMap { background_ddl: bool, /// Shows the server-side character set encoding. At present, this parameter can be shown but not set, because the encoding is determined at database creation time. - #[parameter(default = "UTF8")] + #[parameter(default = SERVER_ENCODING)] server_encoding: String, #[parameter(default = "hex", check_hook = check_bytea_output)] diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 8f0d1fc5a5498..2f7c3572ee80a 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -28,6 +28,7 @@ use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod}; use risingwave_common::types::DataType; use risingwave_common::util::panic::FutureCatchUnwindExt; use risingwave_common::util::query_log::*; +use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS}; use risingwave_sqlparser::ast::Statement; use risingwave_sqlparser::parser::Parser; use thiserror_ext::AsReport; @@ -973,13 +974,13 @@ where fn write_parameter_status_msg_no_flush(&mut self, status: &ParameterStatus) -> io::Result<()> { self.write_no_flush(&BeMessage::ParameterStatus( - BeParameterStatusMessage::ClientEncoding("UTF8"), + BeParameterStatusMessage::ClientEncoding(SERVER_ENCODING), ))?; self.write_no_flush(&BeMessage::ParameterStatus( - BeParameterStatusMessage::StandardConformingString("on"), + BeParameterStatusMessage::StandardConformingString(STANDARD_CONFORMING_STRINGS), ))?; self.write_no_flush(&BeMessage::ParameterStatus( - BeParameterStatusMessage::ServerVersion("9.5.0"), + BeParameterStatusMessage::ServerVersion(PG_VERSION), ))?; if let Some(application_name) = &status.application_name { self.write_no_flush(&BeMessage::ParameterStatus( From 9cd7f6486e659af72857aef7ac27247305e880c0 Mon Sep 17 00:00:00 2001 From: Yuhao Su <31772373+yuhao-su@users.noreply.github.com> Date: Fri, 12 Jan 2024 15:21:57 +0800 Subject: [PATCH 02/71] refactor(streaming): improve hash join error message (#14515) --- src/storage/src/table/mod.rs | 4 ++ src/stream/src/executor/hash_join.rs | 45 +++++++------ .../managed_state/join/join_entry_state.rs | 26 ++++++-- .../src/executor/managed_state/join/mod.rs | 66 +++++++++++++++---- 4 files changed, 104 insertions(+), 37 deletions(-) diff --git a/src/storage/src/table/mod.rs b/src/storage/src/table/mod.rs index 884915bdbbb3f..c43256503fd33 100644 --- a/src/storage/src/table/mod.rs +++ b/src/storage/src/table/mod.rs @@ -237,6 +237,10 @@ impl> KeyedRow { self.vnode_prefixed_key.key_part() } + pub fn row(&self) -> &OwnedRow { + &self.row + } + pub fn into_parts(self) -> (TableKey, OwnedRow) { (self.vnode_prefixed_key, self.row) } diff --git a/src/stream/src/executor/hash_join.rs b/src/stream/src/executor/hash_join.rs index 838d4cf5e1fda..b3ad2f2471cd6 100644 --- a/src/stream/src/executor/hash_join.rs +++ b/src/stream/src/executor/hash_join.rs @@ -480,29 +480,32 @@ impl HashJoinExecutor HashJoinExecutor HashJoinExecutor HashJoinExecutor HashJoinExecutor HashJoinExecutor HashJoinExecutor Result<&mut StateValueType, JoinEntryError> { self.kv_heap_size.add(&key, &value); - self.cached.try_insert(key, value).unwrap(); + self.cached + .try_insert(key, value) + .map_err(|_| JoinEntryError::OccupiedError) } /// Delete from the cache. - pub fn remove(&mut self, pk: PkType) { + pub fn remove(&mut self, pk: PkType) -> Result<(), JoinEntryError> { if let Some(value) = self.cached.remove(&pk) { self.kv_heap_size.sub(&pk, &value); + Ok(()) } else { - panic!("pk {:?} should be in the cache", pk); + Err(JoinEntryError::RemoveError) } } @@ -98,7 +114,7 @@ mod tests { // Pk is only a `i64` here, so encoding method does not matter. let pk = OwnedRow::new(pk).project(&value_indices).value_serialize(); let join_row = JoinRow { row, degree: 0 }; - managed_state.insert(pk, join_row.encode()); + managed_state.insert(pk, join_row.encode()).unwrap(); } } diff --git a/src/stream/src/executor/managed_state/join/mod.rs b/src/stream/src/executor/managed_state/join/mod.rs index 90839df2e3792..ff20e5346fc34 100644 --- a/src/stream/src/executor/managed_state/join/mod.rs +++ b/src/stream/src/executor/managed_state/join/mod.rs @@ -19,6 +19,7 @@ use std::alloc::Global; use std::ops::{Bound, Deref, DerefMut}; use std::sync::Arc; +use anyhow::Context; use futures::future::try_join; use futures::StreamExt; use futures_async_stream::for_await; @@ -267,7 +268,10 @@ pub struct JoinHashMap { } struct TableInner { + /// Indices of the (cache) pk in a state row pk_indices: Vec, + /// Indices of the join key in a state row + join_key_indices: Vec, // This should be identical to the pk in state table. order_key_indices: Vec, // This should be identical to the data types in table schema. @@ -276,15 +280,31 @@ struct TableInner { pub(crate) table: StateTable, } +impl TableInner { + fn error_context(&self, row: &impl Row) -> String { + let pk = row.project(&self.pk_indices); + let jk = row.project(&self.join_key_indices); + format!( + "join key: {}, pk: {}, row: {}, state_table_id: {}", + jk.display(), + pk.display(), + row.display(), + self.table.table_id() + ) + } +} + impl JoinHashMap { /// Create a [`JoinHashMap`] with the given LRU capacity. #[allow(clippy::too_many_arguments)] pub fn new( watermark_epoch: AtomicU64Ref, join_key_data_types: Vec, + state_join_key_indices: Vec, state_all_data_types: Vec, state_table: StateTable, state_pk_indices: Vec, + degree_join_key_indices: Vec, degree_all_data_types: Vec, degree_table: StateTable, degree_pk_indices: Vec, @@ -311,6 +331,7 @@ impl JoinHashMap { let degree_table_id = degree_table.table_id(); let state = TableInner { pk_indices: state_pk_indices, + join_key_indices: state_join_key_indices, order_key_indices: state_table.pk_indices().to_vec(), all_data_types: state_all_data_types, table: state_table, @@ -318,6 +339,7 @@ impl JoinHashMap { let degree_state = TableInner { pk_indices: degree_pk_indices, + join_key_indices: degree_join_key_indices, order_key_indices: degree_table.pk_indices().to_vec(), all_data_types: degree_all_data_types, table: degree_table, @@ -445,10 +467,12 @@ impl JoinHashMap { let degree_i64 = degree_row .datum_at(degree_row.len() - 1) .expect("degree should not be NULL"); - entry_state.insert( - pk, - JoinRow::new(row.into_owned_row(), degree_i64.into_int64() as u64).encode(), - ); + entry_state + .insert( + pk, + JoinRow::new(row.row(), degree_i64.into_int64() as u64).encode(), + ) + .with_context(|| self.state.error_context(row.row()))?; } } else { let sub_range: &(Bound, Bound) = @@ -466,7 +490,9 @@ impl JoinHashMap { .as_ref() .project(&self.state.pk_indices) .memcmp_serialize(&self.pk_serializer); - entry_state.insert(pk, JoinRow::new(row.into_owned_row(), 0).encode()); + entry_state + .insert(pk, JoinRow::new(row.row(), 0).encode()) + .with_context(|| self.state.error_context(row.row()))?; } }; @@ -498,12 +524,16 @@ impl JoinHashMap { if self.inner.contains(key) { // Update cache let mut entry = self.inner.get_mut(key).unwrap(); - entry.insert(pk, value.encode()); + entry + .insert(pk, value.encode()) + .with_context(|| self.state.error_context(&value.row))?; } else if self.pk_contained_in_jk { // Refill cache when the join key exist in neither cache or storage. self.metrics.insert_cache_miss_count += 1; let mut state = JoinEntryState::default(); - state.insert(pk, value.encode()); + state + .insert(pk, value.encode()) + .with_context(|| self.state.error_context(&value.row))?; self.update_state(key, state.into()); } @@ -528,12 +558,16 @@ impl JoinHashMap { if self.inner.contains(key) { // Update cache let mut entry = self.inner.get_mut(key).unwrap(); - entry.insert(pk, join_row.encode()); + entry + .insert(pk, join_row.encode()) + .with_context(|| self.state.error_context(&value))?; } else if self.pk_contained_in_jk { // Refill cache when the join key exist in neither cache or storage. self.metrics.insert_cache_miss_count += 1; let mut state = JoinEntryState::default(); - state.insert(pk, join_row.encode()); + state + .insert(pk, join_row.encode()) + .with_context(|| self.state.error_context(&value))?; self.update_state(key, state.into()); } @@ -543,32 +577,38 @@ impl JoinHashMap { } /// Delete a join row - pub fn delete(&mut self, key: &K, value: JoinRow) { + pub fn delete(&mut self, key: &K, value: JoinRow) -> StreamExecutorResult<()> { if let Some(mut entry) = self.inner.get_mut(key) { let pk = (&value.row) .project(&self.state.pk_indices) .memcmp_serialize(&self.pk_serializer); - entry.remove(pk); + entry + .remove(pk) + .with_context(|| self.state.error_context(&value.row))?; } // If no cache maintained, only update the state table. let (row, degree) = value.to_table_rows(&self.state.order_key_indices); self.state.table.delete(row); self.degree_state.table.delete(degree); + Ok(()) } /// Delete a row /// Used when the side does not need to update degree. - pub fn delete_row(&mut self, key: &K, value: impl Row) { + pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> { if let Some(mut entry) = self.inner.get_mut(key) { let pk = (&value) .project(&self.state.pk_indices) .memcmp_serialize(&self.pk_serializer); - entry.remove(pk); + entry + .remove(pk) + .with_context(|| self.state.error_context(&value))?; } // If no cache maintained, only update the state table. self.state.table.delete(value); + Ok(()) } /// Update a [`JoinEntryState`] into the hash table. From 681c22686fabb06e7597c95402afcafcc25ba18e Mon Sep 17 00:00:00 2001 From: xxchan Date: Fri, 12 Jan 2024 15:44:10 +0800 Subject: [PATCH 03/71] refactor(meta): refactor how upstream fragment is handled when creating stream job (#14510) --- proto/ddl_service.proto | 9 ++++- proto/stream_plan.proto | 3 +- src/meta/src/controller/fragment.rs | 24 ++++++------- src/meta/src/manager/catalog/fragment.rs | 21 ++++------- src/meta/src/manager/metadata.rs | 19 +++++++--- src/meta/src/manager/streaming_job.rs | 33 +++++++++++------ src/meta/src/rpc/ddl_controller.rs | 12 +++---- src/meta/src/stream/stream_graph/fragment.rs | 38 ++++++++++++-------- 8 files changed, 94 insertions(+), 65 deletions(-) diff --git a/proto/ddl_service.proto b/proto/ddl_service.proto index db910930b5bee..1b584a7df78e1 100644 --- a/proto/ddl_service.proto +++ b/proto/ddl_service.proto @@ -137,10 +137,17 @@ message DropViewResponse { uint64 version = 2; } -// An enum to distinguish different types of the Table streaming job. +// An enum to distinguish different types of the `Table` streaming job. // - GENERAL: Table streaming jobs w/ or w/o a connector // - SHARED_CDC_SOURCE: The table streaming job is created based on a shared CDC source job (risingwavelabs/rfcs#73). +// // And one may add other types to support Table jobs that based on other backfill-able sources (risingwavelabs/rfcs#72). +// +// Currently, it's usages include: +// - When creating the streaming actor graph, different table jobs may need different treatment. +// - Some adhoc validation when creating the streaming job. e.g., `validate_cdc_table`. +// +// It's not included in `catalog.Table`, and thus not persisted. It's only used in the `CreateTableRequest`. enum TableJobType { TABLE_JOB_TYPE_UNSPECIFIED = 0; // table streaming jobs excepts the `SHARED_CDC_SOURCE` type diff --git a/proto/stream_plan.proto b/proto/stream_plan.proto index a168ea163f5b5..e69a712c9e3d8 100644 --- a/proto/stream_plan.proto +++ b/proto/stream_plan.proto @@ -839,6 +839,7 @@ message StreamActor { plan_common.ExprContext expr_context = 10; } +// Indicates whether the fragment contains some special kind of nodes. enum FragmentTypeFlag { FRAGMENT_TYPE_FLAG_FRAGMENT_UNSPECIFIED = 0; FRAGMENT_TYPE_FLAG_SOURCE = 1; @@ -864,7 +865,7 @@ message StreamFragmentGraph { uint32 fragment_id = 1; // root stream node in this fragment. StreamNode node = 2; - // Bitwise-OR of FragmentTypeFlags + // Bitwise-OR of `FragmentTypeFlag`s uint32 fragment_type_mask = 3; // Mark whether this fragment requires exactly one actor. // Note: if this is `false`, the fragment may still be a singleton according to the scheduler. diff --git a/src/meta/src/controller/fragment.rs b/src/meta/src/controller/fragment.rs index d0c39694692b1..bec61ff0f2166 100644 --- a/src/meta/src/controller/fragment.rs +++ b/src/meta/src/controller/fragment.rs @@ -27,7 +27,6 @@ use risingwave_meta_model_v2::{ StreamNode, TableId, VnodeBitmap, WorkerId, }; use risingwave_pb::common::PbParallelUnit; -use risingwave_pb::ddl_service::PbTableJobType; use risingwave_pb::meta::subscribe_response::{ Info as NotificationInfo, Operation as NotificationOperation, }; @@ -1037,31 +1036,30 @@ impl CatalogController { Ok(actors) } - /// Get and filter the upstream `Materialize` or `Source` fragments of the specified relations. pub async fn get_upstream_root_fragments( &self, upstream_job_ids: Vec, - job_type: Option, ) -> MetaResult> { let inner = self.inner.read().await; - let mut fragments = Fragment::find() + let all_upstream_fragments = Fragment::find() .filter(fragment::Column::JobId.is_in(upstream_job_ids)) .all(&inner.db) .await?; - fragments.retain(|f| match job_type { - Some(PbTableJobType::SharedCdcSource) => { - f.fragment_type_mask & PbFragmentTypeFlag::Source as i32 != 0 - } - // MV on MV, and other kinds of table job - None | Some(PbTableJobType::General) | Some(PbTableJobType::Unspecified) => { - f.fragment_type_mask & PbFragmentTypeFlag::Mview as i32 != 0 + // job_id -> fragment + let mut fragments = HashMap::::new(); + for fragment in all_upstream_fragments { + if fragment.fragment_type_mask & PbFragmentTypeFlag::Mview as i32 != 0 { + _ = fragments.insert(fragment.job_id, fragment); + } else if fragment.fragment_type_mask & PbFragmentTypeFlag::Source as i32 != 0 { + // look for Source fragment if there's no MView fragment + _ = fragments.try_insert(fragment.job_id, fragment); } - }); + } let parallel_units_map = get_parallel_unit_mapping(&inner.db).await?; let mut root_fragments = HashMap::new(); - for fragment in fragments { + for (_, fragment) in fragments { let actors = fragment.find_related(Actor).all(&inner.db).await?; let actor_dispatchers = get_actor_dispatchers( &inner.db, diff --git a/src/meta/src/manager/catalog/fragment.rs b/src/meta/src/manager/catalog/fragment.rs index d359c3fa453c9..89ea2de7148a7 100644 --- a/src/meta/src/manager/catalog/fragment.rs +++ b/src/meta/src/manager/catalog/fragment.rs @@ -24,7 +24,6 @@ use risingwave_common::hash::{ActorMapping, ParallelUnitId, ParallelUnitMapping} use risingwave_common::util::stream_graph_visitor::{visit_stream_node, visit_stream_node_cont}; use risingwave_connector::source::SplitImpl; use risingwave_meta_model_v2::SourceId; -use risingwave_pb::ddl_service::TableJobType; use risingwave_pb::meta::subscribe_response::{Info, Operation}; use risingwave_pb::meta::table_fragments::actor_status::ActorState; use risingwave_pb::meta::table_fragments::{ActorStatus, Fragment, State}; @@ -1398,11 +1397,9 @@ impl FragmentManager { .mview_actor_ids()) } - /// Get and filter the upstream `Materialize` or `Source` fragments of the specified relations. pub async fn get_upstream_root_fragments( &self, upstream_table_ids: &HashSet, - table_job_type: Option, ) -> MetaResult> { let map = &self.core.read().await.table_fragments; let mut fragments = HashMap::new(); @@ -1411,18 +1408,12 @@ impl FragmentManager { let table_fragments = map .get(&table_id) .with_context(|| format!("table_fragment not exist: id={}", table_id))?; - match table_job_type.as_ref() { - Some(TableJobType::SharedCdcSource) => { - if let Some(fragment) = table_fragments.source_fragment() { - fragments.insert(table_id, fragment); - } - } - // MV on MV, and other kinds of table job - None | Some(TableJobType::General) | Some(TableJobType::Unspecified) => { - if let Some(fragment) = table_fragments.mview_fragment() { - fragments.insert(table_id, fragment); - } - } + + if let Some(fragment) = table_fragments.mview_fragment() { + fragments.insert(table_id, fragment); + } else if let Some(fragment) = table_fragments.source_fragment() { + // look for Source fragment if there's no MView fragment + fragments.insert(table_id, fragment); } } diff --git a/src/meta/src/manager/metadata.rs b/src/meta/src/manager/metadata.rs index 0d50f7e1dc8c4..450a920c379d0 100644 --- a/src/meta/src/manager/metadata.rs +++ b/src/meta/src/manager/metadata.rs @@ -19,7 +19,6 @@ use risingwave_meta_model_v2::SourceId; use risingwave_pb::catalog::PbSource; use risingwave_pb::common::worker_node::{PbResource, State}; use risingwave_pb::common::{HostAddress, PbWorkerNode, PbWorkerType, WorkerType}; -use risingwave_pb::ddl_service::TableJobType; use risingwave_pb::meta::add_worker_node_request::Property as AddNodeProperty; use risingwave_pb::meta::table_fragments::Fragment; use risingwave_pb::stream_plan::PbStreamActor; @@ -175,15 +174,28 @@ impl MetadataManager { } } + /// Get and filter the "**root**" fragments of the specified relations. + /// The root fragment is the bottom-most fragment of its fragment graph, and can be a `MView` or a `Source`. + /// + /// ## What can be the root fragment + /// - For MV, it should have one `MView` fragment. + /// - For table, it should have one `MView` fragment and one or two `Source` fragments. `MView` should be the root. + /// - For source, it should have one `Source` fragment. + /// + /// In other words, it's the `MView` fragment if it exists, otherwise it's the `Source` fragment. + /// + /// ## What do we expect to get for different creating streaming job + /// - MV/Sink/Index should have MV upstream fragments for upstream MV/Tables, and Source upstream fragments for upstream backfill-able sources. + /// - CDC Table has a Source upstream fragment. + /// - Sources and other Tables shouldn't have an upstream fragment. pub async fn get_upstream_root_fragments( &self, upstream_table_ids: &HashSet, - table_job_type: Option, ) -> MetaResult> { match self { MetadataManager::V1(mgr) => { mgr.fragment_manager - .get_upstream_root_fragments(upstream_table_ids, table_job_type) + .get_upstream_root_fragments(upstream_table_ids) .await } MetadataManager::V2(mgr) => { @@ -194,7 +206,6 @@ impl MetadataManager { .iter() .map(|id| id.table_id as _) .collect(), - table_job_type, ) .await?; Ok(upstream_root_fragments diff --git a/src/meta/src/manager/streaming_job.rs b/src/meta/src/manager/streaming_job.rs index b5d63256ccb6e..54dfdb9a22abb 100644 --- a/src/meta/src/manager/streaming_job.rs +++ b/src/meta/src/manager/streaming_job.rs @@ -26,8 +26,6 @@ use crate::model::FragmentId; // This enum is used in order to re-use code in `DdlServiceImpl` for creating MaterializedView and // Sink. #[derive(Debug, Clone, EnumDiscriminants)] -#[strum_discriminants(name(DdlType))] -#[strum_discriminants(vis(pub))] pub enum StreamingJob { MaterializedView(Table), Sink(Sink, Option<(Table, Option)>), @@ -36,13 +34,34 @@ pub enum StreamingJob { Source(PbSource), } +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum DdlType { + MaterializedView, + Sink, + Table(TableJobType), + Index, + Source, +} + +impl From<&StreamingJob> for DdlType { + fn from(job: &StreamingJob) -> Self { + match job { + StreamingJob::MaterializedView(_) => DdlType::MaterializedView, + StreamingJob::Sink(_, _) => DdlType::Sink, + StreamingJob::Table(_, _, ty) => DdlType::Table(*ty), + StreamingJob::Index(_, _) => DdlType::Index, + StreamingJob::Source(_) => DdlType::Source, + } + } +} + #[cfg(test)] #[allow(clippy::derivable_impls)] impl Default for DdlType { fn default() -> Self { // This should not be used by mock services, // so we can just pick an arbitrary default variant. - DdlType::Table + DdlType::MaterializedView } } @@ -259,14 +278,6 @@ impl StreamingJob { } } - pub fn table_job_type(&self) -> Option { - if let Self::Table(.., sub_type) = self { - Some(*sub_type) - } else { - None - } - } - // TODO: record all objects instead. pub fn dependent_relations(&self) -> Vec { match self { diff --git a/src/meta/src/rpc/ddl_controller.rs b/src/meta/src/rpc/ddl_controller.rs index 3937944e66d9b..1ab3aad2ddf3d 100644 --- a/src/meta/src/rpc/ddl_controller.rs +++ b/src/meta/src/rpc/ddl_controller.rs @@ -1259,6 +1259,7 @@ impl DdlController { } /// Builds the actor graph: + /// - Add the upstream fragments to the fragment graph /// - Schedule the fragments based on their distribution /// - Expand each fragment into one or several actors pub(crate) async fn build_stream_job( @@ -1278,10 +1279,7 @@ impl DdlController { let upstream_root_fragments = self .metadata_manager - .get_upstream_root_fragments( - fragment_graph.dependent_table_ids(), - stream_job.table_job_type(), - ) + .get_upstream_root_fragments(fragment_graph.dependent_table_ids()) .await?; let upstream_actors: HashMap<_, _> = upstream_root_fragments @@ -1297,7 +1295,7 @@ impl DdlController { let complete_graph = CompleteStreamFragmentGraph::with_upstreams( fragment_graph, upstream_root_fragments, - stream_job.table_job_type(), + stream_job.into(), )?; // 2. Build the actor graph. @@ -1717,6 +1715,7 @@ impl DdlController { fragment_graph, original_table_fragment.fragment_id, downstream_fragments, + stream_job.into(), )?; // 2. Build the actor graph. @@ -1979,7 +1978,8 @@ impl DdlController { } } -/// Fill in necessary information for table stream graph. +/// Fill in necessary information for `Table` stream graph. +/// e.g., fill source id for table with connector, fill external table id for CDC table. pub fn fill_table_stream_graph_info( source: &mut Option, table: &mut PbTable, diff --git a/src/meta/src/stream/stream_graph/fragment.rs b/src/meta/src/stream/stream_graph/fragment.rs index 4edd3743a1cee..925b01c8cbdcf 100644 --- a/src/meta/src/stream/stream_graph/fragment.rs +++ b/src/meta/src/stream/stream_graph/fragment.rs @@ -38,7 +38,7 @@ use risingwave_pb::stream_plan::{ StreamFragmentGraph as StreamFragmentGraphProto, StreamNode, StreamScanType, }; -use crate::manager::{MetaSrvEnv, StreamingJob}; +use crate::manager::{DdlType, MetaSrvEnv, StreamingJob}; use crate::model::FragmentId; use crate::stream::stream_graph::id::{GlobalFragmentId, GlobalFragmentIdGen, GlobalTableIdGen}; use crate::stream::stream_graph::schedule::Distribution; @@ -283,6 +283,10 @@ impl StreamFragmentEdge { /// In-memory representation of a **Fragment** Graph, built from the [`StreamFragmentGraphProto`] /// from the frontend. +/// +/// This only includes nodes and edges of the current job itself. It will be converted to [`CompleteStreamFragmentGraph`] later, +/// that contains the additional information of pre-existing +/// fragments, which are connected to the graph's top-most or bottom-most fragments. #[derive(Default)] pub struct StreamFragmentGraph { /// stores all the fragments in the graph. @@ -514,8 +518,8 @@ pub(super) enum EitherFragment { Existing(Fragment), } -/// A wrapper of [`StreamFragmentGraph`] that contains the additional information of existing -/// fragments, which is connected to the graph's top-most or bottom-most fragments. +/// A wrapper of [`StreamFragmentGraph`] that contains the additional information of pre-existing +/// fragments, which are connected to the graph's top-most or bottom-most fragments. /// /// For example, /// - if we're going to build a mview on an existing mview, the upstream fragment containing the @@ -560,12 +564,12 @@ impl CompleteStreamFragmentGraph { } } - /// Create a new [`CompleteStreamFragmentGraph`] for MV on MV or Table on CDC Source, with the upstream existing + /// Create a new [`CompleteStreamFragmentGraph`] for MV on MV and CDC/Source Table with the upstream existing /// `Materialize` or `Source` fragments. pub fn with_upstreams( graph: StreamFragmentGraph, upstream_root_fragments: HashMap, - table_job_type: Option, + ddl_type: DdlType, ) -> MetaResult { Self::build_helper( graph, @@ -573,7 +577,7 @@ impl CompleteStreamFragmentGraph { upstream_root_fragments, }), None, - table_job_type, + ddl_type, ) } @@ -583,6 +587,7 @@ impl CompleteStreamFragmentGraph { graph: StreamFragmentGraph, original_table_fragment_id: FragmentId, downstream_fragments: Vec<(DispatchStrategy, Fragment)>, + ddl_type: DdlType, ) -> MetaResult { Self::build_helper( graph, @@ -591,15 +596,16 @@ impl CompleteStreamFragmentGraph { original_table_fragment_id, downstream_fragments, }), - None, + ddl_type, ) } + /// The core logic of building a [`CompleteStreamFragmentGraph`], i.e., adding extra upstream/downstream fragments. fn build_helper( mut graph: StreamFragmentGraph, upstream_ctx: Option, downstream_ctx: Option, - table_job_type: Option, + ddl_type: DdlType, ) -> MetaResult { let mut extra_downstreams = HashMap::new(); let mut extra_upstreams = HashMap::new(); @@ -609,13 +615,11 @@ impl CompleteStreamFragmentGraph { upstream_root_fragments, }) = upstream_ctx { - // Build the extra edges between the upstream `Materialize` and the downstream `StreamScan` - // of the new materialized view. for (&id, fragment) in &mut graph.fragments { let uses_arrangement_backfill = fragment.has_arrangement_backfill(); for (&upstream_table_id, output_columns) in &fragment.upstream_table_columns { - let (up_fragment_id, edge) = match table_job_type.as_ref() { - Some(TableJobType::SharedCdcSource) => { + let (up_fragment_id, edge) = match ddl_type { + DdlType::Table(TableJobType::SharedCdcSource) => { let source_fragment = upstream_root_fragments .get(&upstream_table_id) .context("upstream source fragment not found")?; @@ -651,8 +655,11 @@ impl CompleteStreamFragmentGraph { (source_job_id, edge) } - _ => { - // handle other kinds of streaming graph, normally MV on MV + DdlType::MaterializedView | DdlType::Sink | DdlType::Index => { + // handle MV on MV + + // Build the extra edges between the upstream `Materialize` and the downstream `StreamScan` + // of the new materialized view. let mview_fragment = upstream_root_fragments .get(&upstream_table_id) .context("upstream materialized view fragment not found")?; @@ -724,6 +731,9 @@ impl CompleteStreamFragmentGraph { (mview_id, edge) } + DdlType::Source | DdlType::Table(_) => { + bail!("the streaming job shouldn't have an upstream fragment, ddl_type: {:?}", ddl_type) + } }; // put the edge into the extra edges From 3dc4d08892c34dc2352486a7975cfda936a45c6c Mon Sep 17 00:00:00 2001 From: lmatz Date: Fri, 12 Jan 2024 15:57:28 +0800 Subject: [PATCH 04/71] chore: upgrade localstack in integration tests to `3.0` (#14529) --- integration_tests/kinesis-s3-source/docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/kinesis-s3-source/docker-compose.yml b/integration_tests/kinesis-s3-source/docker-compose.yml index 9deae94041992..9108537309bd7 100644 --- a/integration_tests/kinesis-s3-source/docker-compose.yml +++ b/integration_tests/kinesis-s3-source/docker-compose.yml @@ -23,7 +23,7 @@ services: service: prometheus-0 localstack: container_name: localstack - image: localstack/localstack:2.2 + image: localstack/localstack:3.0 networks: default: aliases: From 69d8b6ecf1b98133361c88946f3687c7c95bdfda Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Fri, 12 Jan 2024 16:03:21 +0800 Subject: [PATCH 05/71] refactor(over window): add some short-circuit logic in `find_affected_ranges` and some other refactoring (#14385) Signed-off-by: Richard Chien --- src/expr/core/src/window_function/call.rs | 48 +++--- .../core/src/window_function/state/buffer.rs | 15 +- src/frontend/src/binder/expr/function.rs | 4 +- .../src/executor/over_window/general.rs | 7 +- .../executor/over_window/over_partition.rs | 146 ++++++++++-------- 5 files changed, 127 insertions(+), 93 deletions(-) diff --git a/src/expr/core/src/window_function/call.rs b/src/expr/core/src/window_function/call.rs index 663479584ddd1..1bb4dfa85f2bb 100644 --- a/src/expr/core/src/window_function/call.rs +++ b/src/expr/core/src/window_function/call.rs @@ -63,7 +63,7 @@ impl Display for Frame { impl Frame { pub fn rows(start: FrameBound, end: FrameBound) -> Self { Self { - bounds: FrameBounds::Rows(start, end), + bounds: FrameBounds::Rows(RowsFrameBounds { start, end }), exclusion: FrameExclusion::default(), } } @@ -74,14 +74,10 @@ impl Frame { exclusion: FrameExclusion, ) -> Self { Self { - bounds: FrameBounds::Rows(start, end), + bounds: FrameBounds::Rows(RowsFrameBounds { start, end }), exclusion, } } - - pub fn is_unbounded(&self) -> bool { - self.bounds.is_unbounded() - } } impl Frame { @@ -92,7 +88,7 @@ impl Frame { PbType::Rows => { let start = FrameBound::from_protobuf(frame.get_start()?)?; let end = FrameBound::from_protobuf(frame.get_end()?)?; - FrameBounds::Rows(start, end) + FrameBounds::Rows(RowsFrameBounds { start, end }) } }; let exclusion = FrameExclusion::from_protobuf(frame.get_exclusion()?)?; @@ -103,7 +99,7 @@ impl Frame { use risingwave_pb::expr::window_frame::PbType; let exclusion = self.exclusion.to_protobuf() as _; match &self.bounds { - FrameBounds::Rows(start, end) => PbWindowFrame { + FrameBounds::Rows(RowsFrameBounds { start, end }) => PbWindowFrame { r#type: PbType::Rows as _, start: Some(start.to_protobuf()), end: Some(end.to_protobuf()), @@ -116,19 +112,19 @@ impl Frame { impl FrameBounds { pub fn validate(&self) -> Result<()> { match self { - Self::Rows(start, end) => FrameBound::validate_bounds(start, end), + Self::Rows(bounds) => bounds.validate(), } } pub fn start_is_unbounded(&self) -> bool { match self { - Self::Rows(start, _) => matches!(start, FrameBound::UnboundedPreceding), + Self::Rows(RowsFrameBounds { start, .. }) => start.is_unbounded_preceding(), } } pub fn end_is_unbounded(&self) -> bool { match self { - Self::Rows(_, end) => matches!(end, FrameBound::UnboundedFollowing), + Self::Rows(RowsFrameBounds { end, .. }) => end.is_unbounded_following(), } } @@ -140,22 +136,38 @@ impl FrameBounds { impl Display for FrameBounds { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Rows(start, end) => { - write!(f, "ROWS BETWEEN {} AND {}", start, end)?; - } + Self::Rows(bounds) => bounds.fmt(f), } - Ok(()) } } #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub enum FrameBounds { - Rows(FrameBound, FrameBound), - // Groups(FrameBound, FrameBound), - // Range(FrameBound, FrameBound), + Rows(RowsFrameBounds), + // Groups(GroupsFrameBounds), + // Range(RangeFrameBounds), } #[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct RowsFrameBounds { + pub start: FrameBound, + pub end: FrameBound, +} + +impl Display for RowsFrameBounds { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ROWS BETWEEN {} AND {}", self.start, self.end)?; + Ok(()) + } +} + +impl RowsFrameBounds { + fn validate(&self) -> Result<()> { + FrameBound::validate_bounds(&self.start, &self.end) + } +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash, EnumAsInner)] pub enum FrameBound { UnboundedPreceding, Preceding(T), diff --git a/src/expr/core/src/window_function/state/buffer.rs b/src/expr/core/src/window_function/state/buffer.rs index 54227084f20f5..3edb6d7adc164 100644 --- a/src/expr/core/src/window_function/state/buffer.rs +++ b/src/expr/core/src/window_function/state/buffer.rs @@ -26,7 +26,6 @@ struct Entry { value: V, } -// TODO(rc): May be a good idea to extract this into a separate crate. /// A common sliding window buffer. pub struct WindowBuffer { frame: Frame, @@ -68,8 +67,8 @@ impl WindowBuffer { fn preceding_saturated(&self) -> bool { self.curr_key().is_some() && match &self.frame.bounds { - FrameBounds::Rows(start, _) => { - let start_off = start.to_offset(); + FrameBounds::Rows(bounds) => { + let start_off = bounds.start.to_offset(); if let Some(start_off) = start_off { if start_off >= 0 { true // pure following frame, always preceding-saturated @@ -92,8 +91,8 @@ impl WindowBuffer { fn following_saturated(&self) -> bool { self.curr_key().is_some() && match &self.frame.bounds { - FrameBounds::Rows(_, end) => { - let end_off = end.to_offset(); + FrameBounds::Rows(bounds) => { + let end_off = bounds.end.to_offset(); if let Some(end_off) = end_off { if end_off <= 0 { true // pure preceding frame, always following-saturated @@ -178,9 +177,9 @@ impl WindowBuffer { } match &self.frame.bounds { - FrameBounds::Rows(start, end) => { - let start_off = start.to_offset(); - let end_off = end.to_offset(); + FrameBounds::Rows(bounds) => { + let start_off = bounds.start.to_offset(); + let end_off = bounds.end.to_offset(); if let Some(start_off) = start_off { let logical_left_idx = self.curr_idx as isize + start_off; if logical_left_idx >= 0 { diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 8c9532f50a9a7..b92b7e832f81e 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -27,7 +27,7 @@ use risingwave_common::types::{DataType, ScalarImpl, Timestamptz}; use risingwave_common::{bail_not_implemented, current_cluster_version, no_function}; use risingwave_expr::aggregate::{agg_kinds, AggKind}; use risingwave_expr::window_function::{ - Frame, FrameBound, FrameBounds, FrameExclusion, WindowFuncKind, + Frame, FrameBound, FrameBounds, FrameExclusion, RowsFrameBounds, WindowFuncKind, }; use risingwave_sqlparser::ast::{ self, Expr as AstExpr, Function, FunctionArg, FunctionArgExpr, Ident, SelectItem, SetExpr, @@ -663,7 +663,7 @@ impl Binder { } else { FrameBound::CurrentRow }; - FrameBounds::Rows(start, end) + FrameBounds::Rows(RowsFrameBounds { start, end }) } WindowFrameUnits::Range | WindowFrameUnits::Groups => { bail_not_implemented!( diff --git a/src/stream/src/executor/over_window/general.rs b/src/stream/src/executor/over_window/general.rs index 4c37e31f4e1d7..0ba4808b93624 100644 --- a/src/stream/src/executor/over_window/general.rs +++ b/src/stream/src/executor/over_window/general.rs @@ -165,7 +165,10 @@ impl OverWindowExecutor { let input_info = args.input.info(); let input_schema = &input_info.schema; - let has_unbounded_frame = args.calls.iter().any(|call| call.frame.is_unbounded()); + let has_unbounded_frame = args + .calls + .iter() + .any(|call| call.frame.bounds.is_unbounded()); let cache_policy = if has_unbounded_frame { // For unbounded frames, we finally need all entries of the partition in the cache, // so for simplicity we just use full cache policy for these cases. @@ -454,6 +457,8 @@ impl OverWindowExecutor { // Find affected ranges, this also ensures that all rows in the affected ranges are loaded // into the cache. + // TODO(rc): maybe we can find affected ranges for each window function call (each frame) to simplify + // the implementation of `find_affected_ranges` let (part_with_delta, affected_ranges) = partition .find_affected_ranges(&this.state_table, &delta) .await?; diff --git a/src/stream/src/executor/over_window/over_partition.rs b/src/stream/src/executor/over_window/over_partition.rs index 3a1b91380f78c..7a395821f6030 100644 --- a/src/stream/src/executor/over_window/over_partition.rs +++ b/src/stream/src/executor/over_window/over_partition.rs @@ -795,20 +795,21 @@ fn find_affected_ranges<'cache>( // support `RANGE` and `GROUPS` frames later. May introduce a return value variant to clearly // tell the caller that there exists at least one affected range that touches the sentinel. - let delta = part_with_delta.delta(); - if part_with_delta.first_key().is_none() { // all keys are deleted in the delta return vec![]; } + let delta_first_key = part_with_delta.delta().first_key_value().unwrap().0; + let delta_last_key = part_with_delta.delta().last_key_value().unwrap().0; + if part_with_delta.snapshot().is_empty() { // all existing keys are inserted in the delta return vec![( - delta.first_key_value().unwrap().0, - delta.first_key_value().unwrap().0, - delta.last_key_value().unwrap().0, - delta.last_key_value().unwrap().0, + delta_first_key, + delta_first_key, + delta_last_key, + delta_last_key, )]; } @@ -822,18 +823,23 @@ fn find_affected_ranges<'cache>( .iter() .any(|call| call.frame.bounds.end_is_unbounded()); - let first_curr_key = if end_is_unbounded { - // If the frame end is unbounded, the frame corresponding to the first key is always - // affected. + // NOTE: Don't be too clever! Here we must calculate `first_frame_start` after calculating + // `first_curr_key`, because the correct calculation of `first_frame_start` depends on + // `first_curr_key` which is the MINIMUM of all `first_curr_key`s of all frames of all window + // function calls. + + let first_curr_key = if end_is_unbounded || delta_first_key == first_key { + // If the frame end is unbounded, or, the first key is in delta, then the frame corresponding + // to the first key is always affected. first_key } else { - calls - .iter() - .map(|call| match &call.frame.bounds { - FrameBounds::Rows(_start, end) => { - let mut cursor = part_with_delta - .lower_bound(Bound::Included(delta.first_key_value().unwrap().0)); - for _ in 0..end.n_following_rows().unwrap() { + let mut min_first_curr_key = &Sentinelled::Largest; + + for call in calls { + let key = match &call.frame.bounds { + FrameBounds::Rows(bounds) => { + let mut cursor = part_with_delta.lower_bound(Bound::Included(delta_first_key)); + for _ in 0..bounds.end.n_following_rows().unwrap() { // Note that we have to move before check, to handle situation where the // cursor is at ghost position at first. cursor.move_prev(); @@ -843,22 +849,29 @@ fn find_affected_ranges<'cache>( } cursor.key().unwrap_or(first_key) } - }) - .min() - .expect("# of window function calls > 0") + }; + min_first_curr_key = min_first_curr_key.min(key); + if min_first_curr_key == first_key { + // if we already pushed the affected curr key to the first key, no more pushing is needed + break; + } + } + + min_first_curr_key }; - let first_frame_start = if start_is_unbounded { - // If the frame start is unbounded, the first key always need to be included in the affected - // range. + let first_frame_start = if start_is_unbounded || first_curr_key == first_key { + // If the frame start is unbounded, or, the first curr key is the first key, then the first key + // always need to be included in the affected range. first_key } else { - calls - .iter() - .map(|call| match &call.frame.bounds { - FrameBounds::Rows(start, _end) => { + let mut min_frame_start = &Sentinelled::Largest; + + for call in calls { + let key = match &call.frame.bounds { + FrameBounds::Rows(bounds) => { let mut cursor = part_with_delta.find(first_curr_key).unwrap(); - for _ in 0..start.n_preceding_rows().unwrap() { + for _ in 0..bounds.start.n_preceding_rows().unwrap() { cursor.move_prev(); if cursor.position().is_ghost() { break; @@ -866,21 +879,27 @@ fn find_affected_ranges<'cache>( } cursor.key().unwrap_or(first_key) } - }) - .min() - .expect("# of window function calls > 0") + }; + min_frame_start = min_frame_start.min(key); + if min_frame_start == first_key { + // if we already pushed the affected frame start to the first key, no more pushing is needed + break; + } + } + + min_frame_start }; - let last_curr_key = if start_is_unbounded { + let last_curr_key = if start_is_unbounded || delta_last_key == last_key { last_key } else { - calls - .iter() - .map(|call| match &call.frame.bounds { - FrameBounds::Rows(start, _end) => { - let mut cursor = part_with_delta - .upper_bound(Bound::Included(delta.last_key_value().unwrap().0)); - for _ in 0..start.n_preceding_rows().unwrap() { + let mut max_last_curr_key = &Sentinelled::Smallest; + + for call in calls { + let key = match &call.frame.bounds { + FrameBounds::Rows(bounds) => { + let mut cursor = part_with_delta.upper_bound(Bound::Included(delta_last_key)); + for _ in 0..bounds.start.n_preceding_rows().unwrap() { cursor.move_next(); if cursor.position().is_ghost() { break; @@ -888,20 +907,27 @@ fn find_affected_ranges<'cache>( } cursor.key().unwrap_or(last_key) } - }) - .max() - .expect("# of window function calls > 0") + }; + max_last_curr_key = max_last_curr_key.max(key); + if max_last_curr_key == last_key { + // if we already pushed the affected curr key to the last key, no more pushing is needed + break; + } + } + + max_last_curr_key }; - let last_frame_end = if end_is_unbounded { + let last_frame_end = if end_is_unbounded || last_curr_key == last_key { last_key } else { - calls - .iter() - .map(|call| match &call.frame.bounds { - FrameBounds::Rows(_start, end) => { + let mut max_frame_end = &Sentinelled::Smallest; + + for call in calls { + let key = match &call.frame.bounds { + FrameBounds::Rows(bounds) => { let mut cursor = part_with_delta.find(last_curr_key).unwrap(); - for _ in 0..end.n_following_rows().unwrap() { + for _ in 0..bounds.end.n_following_rows().unwrap() { cursor.move_next(); if cursor.position().is_ghost() { break; @@ -909,9 +935,15 @@ fn find_affected_ranges<'cache>( } cursor.key().unwrap_or(last_key) } - }) - .max() - .expect("# of window function calls > 0") + }; + max_frame_end = max_frame_end.max(key); + if max_frame_end == last_key { + // if we already pushed the affected frame end to the last key, no more pushing is needed + break; + } + } + + max_frame_end }; if first_curr_key > last_curr_key { @@ -1052,20 +1084,6 @@ mod find_affected_ranges_tests { }) } - #[test] - fn test_all_empty() { - let cache = create_cache!(); - let delta = create_delta!(); - let calls = vec![create_call(Frame::rows( - FrameBound::Preceding(2), - FrameBound::Preceding(1), - ))]; - assert_ranges_eq( - find_affected_ranges(&calls, DeltaBTreeMap::new(&cache, &delta)), - [], - ); - } - #[test] fn test_insert_delta_only() { let cache = create_cache!(); From 608c89e7d648baead660f5569841f7f9fafbc3c7 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Fri, 12 Jan 2024 17:05:49 +0800 Subject: [PATCH 06/71] refactor(over window): use `parse_display` to impl `Display` for window frame types (#14452) Signed-off-by: Richard Chien --- src/expr/core/src/window_function/call.rs | 75 +++++++---------------- 1 file changed, 22 insertions(+), 53 deletions(-) diff --git a/src/expr/core/src/window_function/call.rs b/src/expr/core/src/window_function/call.rs index 1bb4dfa85f2bb..43545cc2a107a 100644 --- a/src/expr/core/src/window_function/call.rs +++ b/src/expr/core/src/window_function/call.rs @@ -15,10 +15,12 @@ use std::fmt::Display; use enum_as_inner::EnumAsInner; +use parse_display::Display; use risingwave_common::bail; use risingwave_common::types::DataType; use risingwave_pb::expr::window_frame::{PbBound, PbExclusion}; use risingwave_pb::expr::{PbWindowFrame, PbWindowFunction}; +use FrameBound::{CurrentRow, Following, Preceding, UnboundedFollowing, UnboundedPreceding}; use super::WindowFuncKind; use crate::aggregate::AggArgs; @@ -109,6 +111,14 @@ impl Frame { } } +#[derive(Display, Debug, Clone, Eq, PartialEq, Hash)] +#[display("{0}")] +pub enum FrameBounds { + Rows(RowsFrameBounds), + // Groups(GroupsFrameBounds), + // Range(RangeFrameBounds), +} + impl FrameBounds { pub fn validate(&self) -> Result<()> { match self { @@ -133,52 +143,33 @@ impl FrameBounds { } } -impl Display for FrameBounds { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Rows(bounds) => bounds.fmt(f), - } - } -} - -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub enum FrameBounds { - Rows(RowsFrameBounds), - // Groups(GroupsFrameBounds), - // Range(RangeFrameBounds), -} - -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Display, Debug, Clone, Eq, PartialEq, Hash)] +#[display("ROWS BETWEEN {start} AND {end}")] pub struct RowsFrameBounds { pub start: FrameBound, pub end: FrameBound, } -impl Display for RowsFrameBounds { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "ROWS BETWEEN {} AND {}", self.start, self.end)?; - Ok(()) - } -} - impl RowsFrameBounds { fn validate(&self) -> Result<()> { FrameBound::validate_bounds(&self.start, &self.end) } } -#[derive(Debug, Clone, Eq, PartialEq, Hash, EnumAsInner)] +#[derive(Display, Debug, Clone, Eq, PartialEq, Hash, EnumAsInner)] +#[display(style = "TITLE CASE")] pub enum FrameBound { UnboundedPreceding, + #[display("{0} PRECEDING")] Preceding(T), CurrentRow, + #[display("{0} FOLLOWING")] Following(T), UnboundedFollowing, } impl FrameBound { fn validate_bounds(start: &Self, end: &Self) -> Result<()> { - use FrameBound::*; match (start, end) { (_, UnboundedPreceding) => bail!("frame end cannot be UNBOUNDED PRECEDING"), (UnboundedFollowing, _) => bail!("frame start cannot be UNBOUNDED FOLLOWING"), @@ -232,27 +223,14 @@ impl FrameBound { } } -impl Display for FrameBound { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FrameBound::UnboundedPreceding => write!(f, "UNBOUNDED PRECEDING")?, - FrameBound::Preceding(n) => write!(f, "{} PRECEDING", n)?, - FrameBound::CurrentRow => write!(f, "CURRENT ROW")?, - FrameBound::Following(n) => write!(f, "{} FOLLOWING", n)?, - FrameBound::UnboundedFollowing => write!(f, "UNBOUNDED FOLLOWING")?, - } - Ok(()) - } -} - impl FrameBound { /// Convert the bound to sized offset from current row. `None` if the bound is unbounded. pub fn to_offset(&self) -> Option { match self { - FrameBound::UnboundedPreceding | FrameBound::UnboundedFollowing => None, - FrameBound::CurrentRow => Some(0), - FrameBound::Preceding(n) => Some(-(*n as isize)), - FrameBound::Following(n) => Some(*n as isize), + UnboundedPreceding | UnboundedFollowing => None, + CurrentRow => Some(0), + Preceding(n) => Some(-(*n as isize)), + Following(n) => Some(*n as isize), } } @@ -267,7 +245,8 @@ impl FrameBound { } } -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Default, EnumAsInner)] +#[derive(Display, Debug, Copy, Clone, Eq, PartialEq, Hash, Default, EnumAsInner)] +#[display("EXCLUDE {}", style = "TITLE CASE")] pub enum FrameExclusion { CurrentRow, // Group, @@ -276,16 +255,6 @@ pub enum FrameExclusion { NoOthers, } -impl Display for FrameExclusion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FrameExclusion::CurrentRow => write!(f, "EXCLUDE CURRENT ROW")?, - FrameExclusion::NoOthers => write!(f, "EXCLUDE NO OTHERS")?, - } - Ok(()) - } -} - impl FrameExclusion { pub fn from_protobuf(exclusion: PbExclusion) -> Result { let excl = match exclusion { From 11c1c6b038df50df5b7b90f8e4cd2caede8e8205 Mon Sep 17 00:00:00 2001 From: xxchan Date: Fri, 12 Jan 2024 17:46:21 +0800 Subject: [PATCH 07/71] chore: increase log size limit in ci (#14532) --- Makefile.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile.toml b/Makefile.toml index 347f2234e5fda..c0135dc84c048 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -245,7 +245,7 @@ do fi done -if (( "$(du -sk ${PREFIX_LOG} | cut -f1)" > 2000 )) ; then +if (( "$(du -sk ${PREFIX_LOG} | cut -f1)" > 3000 )) ; then echo "$(tput setaf 1)ERROR: log size is significantly large ($(du -sh ${PREFIX_LOG} | cut -f1)).$(tput sgr0) Please disable unnecessary logs." exit 1 fi From 47af81e4c75e9e35fce8dad545e005936ec25abb Mon Sep 17 00:00:00 2001 From: lmatz Date: Fri, 12 Jan 2024 18:37:18 +0800 Subject: [PATCH 08/71] chore: upgrade docker image version (#14535) Co-authored-by: lmatz --- Cargo.lock | 102 +++++++++++++------------- Cargo.toml | 2 +- docker/docker-compose-distributed.yml | 2 +- docker/docker-compose-with-azblob.yml | 2 +- docker/docker-compose-with-gcs.yml | 2 +- docker/docker-compose-with-obs.yml | 2 +- docker/docker-compose-with-oss.yml | 2 +- docker/docker-compose-with-s3.yml | 2 +- docker/docker-compose.yml | 2 +- 9 files changed, 59 insertions(+), 59 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ff20f999220dc..8950c51de77c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3214,7 +3214,7 @@ dependencies = [ [[package]] name = "delta_btree_map" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "enum-as-inner", ] @@ -5690,7 +5690,7 @@ checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "local_stats_alloc" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "workspace-hack", ] @@ -7251,7 +7251,7 @@ dependencies = [ [[package]] name = "pgwire" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "auto_enums", @@ -8469,7 +8469,7 @@ dependencies = [ [[package]] name = "risedev" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "chrono", @@ -8498,7 +8498,7 @@ dependencies = [ [[package]] name = "risedev-config" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "clap", @@ -8511,7 +8511,7 @@ dependencies = [ [[package]] name = "risingwave-fields-derive" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "expect-test", "indoc", @@ -8523,7 +8523,7 @@ dependencies = [ [[package]] name = "risingwave_backup" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "async-trait", @@ -8545,7 +8545,7 @@ dependencies = [ [[package]] name = "risingwave_batch" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "assert_matches", @@ -8591,7 +8591,7 @@ dependencies = [ [[package]] name = "risingwave_bench" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "async-trait", "aws-config", @@ -8625,7 +8625,7 @@ dependencies = [ [[package]] name = "risingwave_cmd" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "clap", "madsim-tokio", @@ -8646,7 +8646,7 @@ dependencies = [ [[package]] name = "risingwave_cmd_all" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "clap", @@ -8678,7 +8678,7 @@ dependencies = [ [[package]] name = "risingwave_common" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "arc-swap", @@ -8784,7 +8784,7 @@ dependencies = [ [[package]] name = "risingwave_common_heap_profiling" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "chrono", @@ -8799,7 +8799,7 @@ dependencies = [ [[package]] name = "risingwave_common_proc_macro" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "bae", "proc-macro-error 1.0.4", @@ -8810,7 +8810,7 @@ dependencies = [ [[package]] name = "risingwave_common_service" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "async-trait", "futures", @@ -8831,7 +8831,7 @@ dependencies = [ [[package]] name = "risingwave_compaction_test" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "async-trait", @@ -8858,7 +8858,7 @@ dependencies = [ [[package]] name = "risingwave_compactor" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "async-trait", "await-tree", @@ -8880,7 +8880,7 @@ dependencies = [ [[package]] name = "risingwave_compute" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "async-trait", @@ -8923,7 +8923,7 @@ dependencies = [ [[package]] name = "risingwave_connector" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "apache-avro 0.16.0", @@ -9027,7 +9027,7 @@ dependencies = [ [[package]] name = "risingwave_ctl" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "bytes", @@ -9062,7 +9062,7 @@ dependencies = [ [[package]] name = "risingwave_e2e_extended_mode_test" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "chrono", @@ -9077,7 +9077,7 @@ dependencies = [ [[package]] name = "risingwave_error" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "bincode 1.3.3", "bytes", @@ -9092,7 +9092,7 @@ dependencies = [ [[package]] name = "risingwave_expr" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "arrow-array 49.0.0", @@ -9132,7 +9132,7 @@ dependencies = [ [[package]] name = "risingwave_expr_impl" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "aho-corasick", "anyhow", @@ -9179,7 +9179,7 @@ dependencies = [ [[package]] name = "risingwave_frontend" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "arc-swap", @@ -9252,7 +9252,7 @@ dependencies = [ [[package]] name = "risingwave_hummock_sdk" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "bytes", "easy-ext", @@ -9268,7 +9268,7 @@ dependencies = [ [[package]] name = "risingwave_hummock_test" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "async-trait", "bytes", @@ -9300,7 +9300,7 @@ dependencies = [ [[package]] name = "risingwave_hummock_trace" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "async-trait", "bincode 2.0.0-rc.3", @@ -9364,7 +9364,7 @@ dependencies = [ [[package]] name = "risingwave_mem_table_spill_test" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "async-trait", "bytes", @@ -9380,7 +9380,7 @@ dependencies = [ [[package]] name = "risingwave_meta" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "arc-swap", @@ -9452,7 +9452,7 @@ dependencies = [ [[package]] name = "risingwave_meta_model_migration" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "async-std", "sea-orm-migration", @@ -9461,7 +9461,7 @@ dependencies = [ [[package]] name = "risingwave_meta_model_v2" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "risingwave_hummock_sdk", "risingwave_pb", @@ -9472,7 +9472,7 @@ dependencies = [ [[package]] name = "risingwave_meta_node" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "clap", @@ -9503,7 +9503,7 @@ dependencies = [ [[package]] name = "risingwave_meta_service" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "async-trait", @@ -9529,7 +9529,7 @@ dependencies = [ [[package]] name = "risingwave_object_store" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "async-trait", "await-tree", @@ -9562,7 +9562,7 @@ dependencies = [ [[package]] name = "risingwave_pb" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "enum-as-inner", "fs-err", @@ -9582,7 +9582,7 @@ dependencies = [ [[package]] name = "risingwave_planner_test" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "expect-test", @@ -9604,7 +9604,7 @@ dependencies = [ [[package]] name = "risingwave_regress_test" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "clap", @@ -9618,7 +9618,7 @@ dependencies = [ [[package]] name = "risingwave_rpc_client" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "async-trait", @@ -9650,7 +9650,7 @@ dependencies = [ [[package]] name = "risingwave_rt" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "await-tree", "console", @@ -9729,7 +9729,7 @@ dependencies = [ [[package]] name = "risingwave_source" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "assert_matches", @@ -9751,7 +9751,7 @@ dependencies = [ [[package]] name = "risingwave_sqlparser" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "itertools 0.12.0", "matches", @@ -9778,7 +9778,7 @@ dependencies = [ [[package]] name = "risingwave_sqlsmith" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "chrono", @@ -9805,7 +9805,7 @@ dependencies = [ [[package]] name = "risingwave_state_cleaning_test" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "clap", @@ -9825,7 +9825,7 @@ dependencies = [ [[package]] name = "risingwave_storage" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "arc-swap", @@ -9891,7 +9891,7 @@ dependencies = [ [[package]] name = "risingwave_stream" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "anyhow", "assert_matches", @@ -9953,7 +9953,7 @@ dependencies = [ [[package]] name = "risingwave_test_runner" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "fail", "sync-point", @@ -9982,7 +9982,7 @@ dependencies = [ [[package]] name = "risingwave_variables" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "chrono", "workspace-hack", @@ -13546,7 +13546,7 @@ dependencies = [ [[package]] name = "with_options" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "proc-macro2", "quote", @@ -13567,7 +13567,7 @@ dependencies = [ [[package]] name = "workspace-config" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "log", "openssl-sys", @@ -13578,7 +13578,7 @@ dependencies = [ [[package]] name = "workspace-hack" -version = "1.5.0-alpha" +version = "1.7.0-alpha" dependencies = [ "ahash 0.8.6", "allocator-api2", diff --git a/Cargo.toml b/Cargo.toml index b09dd5def90b2..945091f384371 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,7 +63,7 @@ exclude = ["e2e_test/udf/wasm", "lints"] resolver = "2" [workspace.package] -version = "1.5.0-alpha" +version = "1.7.0-alpha" edition = "2021" homepage = "https://github.com/risingwavelabs/risingwave" keywords = ["sql", "database", "streaming"] diff --git a/docker/docker-compose-distributed.yml b/docker/docker-compose-distributed.yml index ae6412d51dbf8..8cb1e87651325 100644 --- a/docker/docker-compose-distributed.yml +++ b/docker/docker-compose-distributed.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.5.0} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.0} services: compactor-0: <<: *image diff --git a/docker/docker-compose-with-azblob.yml b/docker/docker-compose-with-azblob.yml index 1188110ee4f22..e7149664ad23b 100644 --- a/docker/docker-compose-with-azblob.yml +++ b/docker/docker-compose-with-azblob.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.5.0} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.0} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-gcs.yml b/docker/docker-compose-with-gcs.yml index 773ae7d01b066..45a5b3d17dce1 100644 --- a/docker/docker-compose-with-gcs.yml +++ b/docker/docker-compose-with-gcs.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.5.0} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.0} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-obs.yml b/docker/docker-compose-with-obs.yml index 9842d929de612..c7c397c8b1234 100644 --- a/docker/docker-compose-with-obs.yml +++ b/docker/docker-compose-with-obs.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.5.0} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.0} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-oss.yml b/docker/docker-compose-with-oss.yml index f7d506dcab801..fc05c05dec207 100644 --- a/docker/docker-compose-with-oss.yml +++ b/docker/docker-compose-with-oss.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.5.0} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.0} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-s3.yml b/docker/docker-compose-with-s3.yml index 1d4c9600bf1b1..e62955455bd88 100644 --- a/docker/docker-compose-with-s3.yml +++ b/docker/docker-compose-with-s3.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.5.0} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.0} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index e0e6e6fa0b6ce..f7beb587e387e 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.5.0} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.0} services: risingwave-standalone: <<: *image From 7aabd3b82760a2ca7f08f0de546bf3051e1ad047 Mon Sep 17 00:00:00 2001 From: Shanicky Chen Date: Fri, 12 Jan 2024 19:30:06 +0800 Subject: [PATCH 09/71] fix: only check the parallelism of the target table in e2e alter parallelism test (#14542) Signed-off-by: Shanicky Chen --- e2e_test/ddl/alter_parallelism.slt | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/e2e_test/ddl/alter_parallelism.slt b/e2e_test/ddl/alter_parallelism.slt index 723dd8e08b75d..e030b44273575 100644 --- a/e2e_test/ddl/alter_parallelism.slt +++ b/e2e_test/ddl/alter_parallelism.slt @@ -13,6 +13,9 @@ create view mview_parallelism as select m.name, tf.parallelism from rw_materiali statement ok create view sink_parallelism as select s.name, tf.parallelism from rw_sinks s, rw_table_fragments tf where s.id = tf.table_id; +statement ok +create view fragment_parallelism as select t.name as table_name, f.fragment_id, f.parallelism from rw_fragments f, rw_tables t where f.table_id = t.id; + statement ok create table t (v int); @@ -30,7 +33,7 @@ select parallelism from table_parallelism where name = 't'; FIXED(2) query I -select parallelism from rw_fragments; +select parallelism from fragment_parallelism where table_name = 't'; ---- 2 2 @@ -152,4 +155,7 @@ statement ok drop view mview_parallelism; statement ok -drop view sink_parallelism; \ No newline at end of file +drop view sink_parallelism; + +statement ok +drop view fragment_parallelism; \ No newline at end of file From 1afb0eccaff05b8a03dcb15693e464d12a0b9359 Mon Sep 17 00:00:00 2001 From: Zihao Xu Date: Fri, 12 Jan 2024 12:14:48 -0500 Subject: [PATCH 10/71] feat(sql-udf): deep calling stack (recursion) prevention for sql udf (#14392) --- e2e_test/udf/sql_udf.slt | 65 ++++++++++++++++++- src/frontend/src/binder/expr/column.rs | 2 +- src/frontend/src/binder/expr/function.rs | 24 ++++++- src/frontend/src/binder/expr/mod.rs | 2 +- src/frontend/src/binder/mod.rs | 50 +++++++++++++- .../src/handler/create_sql_function.rs | 9 --- 6 files changed, 134 insertions(+), 18 deletions(-) diff --git a/e2e_test/udf/sql_udf.slt b/e2e_test/udf/sql_udf.slt index e1100834c9bbd..02fc23b2d7f02 100644 --- a/e2e_test/udf/sql_udf.slt +++ b/e2e_test/udf/sql_udf.slt @@ -28,14 +28,42 @@ create function add_return(INT, INT) returns int language sql return $1 + $2; statement ok create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1); -# Recursive definition is forbidden -statement error recursive definition is forbidden, please recheck your function syntax +# Recursive definition can be accepted, but will be eventually rejected during runtime +statement ok create function recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + recursive($1, $2)'; +# Complex but error-prone definition, recursive & normal sql udfs interleaving +statement ok +create function recursive_non_recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + sub($1, $2)'; + +# Recursive corner case +statement ok +create function foo(INT) returns varchar language sql as $$select 'foo(INT)'$$; + # Create a wrapper function for `add` & `sub` statement ok create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512'; +# Create a valid recursive function +# Please note we do NOT support actual running the recursive sql udf at present +statement ok +create function fib(INT) returns int + language sql as 'select case + when $1 = 0 then 0 + when $1 = 1 then 1 + when $1 = 2 then 1 + when $1 = 3 then 2 + else fib($1 - 1) + fib($1 - 2) + end;'; + +# The execution will eventually exceed the pre-defined max stack depth +statement error function fib calling stack depth limit exceeded +select fib(100); + +# Currently create a materialized view with a recursive sql udf will be rejected +statement error function fib calling stack depth limit exceeded +create materialized view foo_mv as select fib(100); + # Call the defined sql udf query I select add(1, -1); @@ -72,6 +100,19 @@ select call_regexp_replace(); ---- 💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥 +query T +select foo(114514); +---- +foo(INT) + +# Rejected deep calling stack +statement error function recursive calling stack depth limit exceeded +select recursive(1, 1); + +# Same as above +statement error function recursive calling stack depth limit exceeded +select recursive_non_recursive(1, 1); + query I select add_sub_wrapper(1, 1); ---- @@ -103,6 +144,14 @@ select c1, c2, add_return(c1, c2) from t1 order by c1 asc; 4 4 8 5 5 10 +# Recursive sql udf with normal table +statement error function fib calling stack depth limit exceeded +select fib(c1) from t1; + +# Recursive sql udf with materialized view +statement error function fib calling stack depth limit exceeded +create materialized view bar_mv as select fib(c1) from t1; + # Invalid function body syntax statement error Expected an expression:, found: EOF at the end create function add_error(INT, INT) returns int language sql as $$select $1 + $2 +$$; @@ -187,9 +236,21 @@ drop function call_regexp_replace; statement ok drop function add_sub_wrapper; +statement ok +drop function recursive; + +statement ok +drop function foo; + +statement ok +drop function recursive_non_recursive; + statement ok drop function add_sub_types; +statement ok +drop function fib; + # Drop the mock table statement ok drop table t1; diff --git a/src/frontend/src/binder/expr/column.rs b/src/frontend/src/binder/expr/column.rs index 2f2a8d9335256..cac4f7eccd62e 100644 --- a/src/frontend/src/binder/expr/column.rs +++ b/src/frontend/src/binder/expr/column.rs @@ -45,7 +45,7 @@ impl Binder { // to the name of the defined sql udf parameters stored in `udf_context`. // If so, we will treat this bind as an special bind, the actual expression // stored in `udf_context` will then be bound instead of binding the non-existing column. - if let Some(expr) = self.udf_context.get(&column_name) { + if let Some(expr) = self.udf_context.get_expr(&column_name) { return self.bind_expr(expr.clone()); } diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index b92b7e832f81e..de4f0f4aa6331 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -55,6 +55,12 @@ pub const SYS_FUNCTION_WITHOUT_ARGS: &[&str] = &[ "current_timestamp", ]; +/// The global max calling depth for the global counter in `udf_context` +/// To reduce the chance that the current running rw thread +/// be killed by os, the current allowance depth of calling +/// stack is set to `16`. +const SQL_UDF_MAX_CALLING_DEPTH: u32 = 16; + impl Binder { pub(in crate::binder) fn bind_function(&mut self, f: Function) -> Result { let function_name = match f.name.0.as_slice() { @@ -235,6 +241,7 @@ impl Binder { ) .into()); } + // This represents the current user defined function is `language sql` let parse_result = risingwave_sqlparser::parser::Parser::parse_sql( func.body.as_ref().unwrap().as_str(), @@ -245,6 +252,7 @@ impl Binder { // Here we just return the original parse error message return Err(ErrorCode::InvalidInputSyntax(err).into()); } + debug_assert!(parse_result.is_ok()); // We can safely unwrap here @@ -263,7 +271,7 @@ impl Binder { if self.udf_context.is_empty() { // The actual inline logic for sql udf if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) { - self.udf_context = context; + self.udf_context.update_context(context); } else { return Err(ErrorCode::InvalidInputSyntax( "failed to create the `udf_context`, please recheck your function definition and syntax".to_string() @@ -277,9 +285,21 @@ impl Binder { clean_flag = false; } + // Check for potential recursive calling + if self.udf_context.global_count() >= SQL_UDF_MAX_CALLING_DEPTH { + return Err(ErrorCode::BindError(format!( + "function {} calling stack depth limit exceeded", + &function_name + )) + .into()); + } else { + // Update the status for the global counter + self.udf_context.incr_global_count(); + } + if let Ok(expr) = extract_udf_expression(ast) { let bind_result = self.bind_expr(expr); - // Clean the `udf_context` after inlining, + // Clean the `udf_context` & `udf_recursive_context` after inlining, // which makes sure the subsequent binding will not be affected if clean_flag { self.udf_context.clear(); diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index cacd2d80dcfe4..1b3dcd5dd051c 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -382,7 +382,7 @@ impl Binder { // Note: This is specific to anonymous sql udf, since the // parameters will be parsed and treated as `Parameter`. // For detailed explanation, consider checking `bind_column`. - if let Some(expr) = self.udf_context.get(&format!("${index}")) { + if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) { return self.bind_expr(expr.clone()); } diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 6ba891aa6b513..51b53d23a2e3f 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -116,9 +116,53 @@ pub struct Binder { param_types: ParameterTypes, - /// The mapping from sql udf parameters to ast expressions + /// The sql udf context that will be used during binding phase + udf_context: UdfContext, +} + +#[derive(Clone, Debug, Default)] +pub struct UdfContext { + /// The mapping from `sql udf parameters` to `ast expressions` /// Note: The expressions are constructed during runtime, correspond to the actual users' input - udf_context: HashMap, + udf_param_context: HashMap, + + /// The global counter that records the calling stack depth + /// of the current binding sql udf chain + udf_global_counter: u32, +} + +impl UdfContext { + pub fn new() -> Self { + Self { + udf_param_context: HashMap::new(), + udf_global_counter: 0, + } + } + + pub fn global_count(&self) -> u32 { + self.udf_global_counter + } + + pub fn incr_global_count(&mut self) { + self.udf_global_counter += 1; + } + + pub fn is_empty(&self) -> bool { + self.udf_param_context.is_empty() + } + + pub fn update_context(&mut self, context: HashMap) { + self.udf_param_context = context; + } + + pub fn clear(&mut self) { + self.udf_global_counter = 0; + self.udf_param_context.clear(); + } + + pub fn get_expr(&self, name: &str) -> Option<&AstExpr> { + self.udf_param_context.get(name) + } } /// `ParameterTypes` is used to record the types of the parameters during binding. It works @@ -220,7 +264,7 @@ impl Binder { shared_views: HashMap::new(), included_relations: HashSet::new(), param_types: ParameterTypes::new(param_types), - udf_context: HashMap::new(), + udf_context: UdfContext::new(), } } diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index 834e0bec3135d..bbe504d779bfd 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -72,15 +72,6 @@ pub async fn handle_create_sql_function( } }; - // We do NOT allow recursive calling inside sql udf - // Since there does not exist the base case for this definition - if body.contains(format!("{}(", name.real_value()).as_str()) { - return Err(ErrorCode::InvalidInputSyntax( - "recursive definition is forbidden, please recheck your function syntax".to_string(), - ) - .into()); - } - // Sanity check for link, this must be none with sql udf function if let Some(CreateFunctionUsing::Link(_)) = params.using { return Err(ErrorCode::InvalidParameterValue( From 2bec2eda37df0a92d0a46a737168657b2523d020 Mon Sep 17 00:00:00 2001 From: Yufan Song <33971064+yufansong@users.noreply.github.com> Date: Fri, 12 Jan 2024 12:12:45 -0800 Subject: [PATCH 11/71] fix(connector): add additional check in nats list_splits (#14546) --- src/connector/src/source/nats/enumerator/mod.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/connector/src/source/nats/enumerator/mod.rs b/src/connector/src/source/nats/enumerator/mod.rs index e1d4f96197716..c5059fdc8186c 100644 --- a/src/connector/src/source/nats/enumerator/mod.rs +++ b/src/connector/src/source/nats/enumerator/mod.rs @@ -21,10 +21,11 @@ use super::source::{NatsOffset, NatsSplit}; use super::NatsProperties; use crate::source::{SourceEnumeratorContextRef, SplitEnumerator, SplitId}; -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone)] pub struct NatsSplitEnumerator { subject: String, split_id: SplitId, + client: async_nats::Client, } #[async_trait] @@ -36,13 +37,23 @@ impl SplitEnumerator for NatsSplitEnumerator { properties: Self::Properties, _context: SourceEnumeratorContextRef, ) -> anyhow::Result { + let client = properties.common.build_client().await?; Ok(Self { subject: properties.common.subject, split_id: Arc::from("0"), + client, }) } async fn list_splits(&mut self) -> anyhow::Result> { + // Nats currently does not support list_splits API, if we simple return the default 0 without checking the client status, will result executor crash + let state = self.client.connection_state(); + if state != async_nats::connection::State::Connected { + return Err(anyhow::anyhow!( + "Nats connection status is not connected, current status is {:?}", + state + )); + } // TODO: to simplify the logic, return 1 split for first version let nats_split = NatsSplit { subject: self.subject.clone(), From 7c3edb1ff2d64f9c0e50eacfca09ea5ce112ca2d Mon Sep 17 00:00:00 2001 From: Tao Wu Date: Sat, 13 Jan 2024 16:24:31 +0200 Subject: [PATCH 12/71] test: include sqlalchemy 1.4 test in python client testing (#14491) --- .../client-library/python/client.py | 7 ++- .../client-library/python/crud.py | 46 +++++--------- .../client-library/python/init.py | 0 .../client-library/python/materializeview.py | 39 +++++------- .../client-library/python/requirements.txt | 4 +- .../client-library/python/test_database.py | 41 +++++++----- .../client-library/python/test_sqlalchemy.py | 62 +++++++++++++++++++ 7 files changed, 129 insertions(+), 70 deletions(-) delete mode 100644 integration_tests/client-library/python/init.py create mode 100644 integration_tests/client-library/python/test_sqlalchemy.py diff --git a/integration_tests/client-library/python/client.py b/integration_tests/client-library/python/client.py index dbd2ad9b4581c..ff98bfb1638f0 100644 --- a/integration_tests/client-library/python/client.py +++ b/integration_tests/client-library/python/client.py @@ -1,9 +1,10 @@ import psycopg2 -class client: - def __init__(self, host, port,database, user, password): + +class Client: + def __init__(self, host, port, database, user, password): self.host = host - self.port=port + self.port = port self.database = database self.user = user self.password = password diff --git a/integration_tests/client-library/python/crud.py b/integration_tests/client-library/python/crud.py index f90233c24e336..a4f182e488856 100644 --- a/integration_tests/client-library/python/crud.py +++ b/integration_tests/client-library/python/crud.py @@ -1,14 +1,12 @@ import psycopg2 -from client import client +from client import Client + + +class SampleTableCrud: + # Represents the table `sample_table_py`. -class crud: def __init__(self, host, port, database, user, password): - self.host = host - self.database = database - self.user = user - self.password = password - self.connection = None - self.port=port + self.client = Client(host, port, database, user, password) def create_table(self): create_table_query = """ @@ -19,10 +17,9 @@ def create_table(self): ); """ try: - databaseconnection = client(self.host, self.port,self.database, self.user, self.password) - cursor=databaseconnection.connect() + cursor = self.client.connect() cursor.execute(create_table_query) - databaseconnection.connection.commit() + self.client.connection.commit() print("Table created successfully.") except psycopg2.Error as e: print("Table creation failed: ", str(e)) @@ -33,10 +30,9 @@ def insert_data(self, name, age, salary): VALUES (%s, %s,%s); """ try: - databaseconnection = client(self.host, self.port,self.database, self.user, self.password) - cursor=databaseconnection.connect() + cursor = self.client.connect() cursor.execute(insert_data_query, (name, age, salary)) - databaseconnection.connection.commit() + self.client.connection.commit() print("Data inserted successfully.") except psycopg2.Error as e: print("Data insertion failed: ", str(e)) @@ -48,10 +44,9 @@ def update_data(self, name, salary): WHERE name=%s; """ try: - databaseconnection = client(self.host, self.port,self.database, self.user, self.password) - cursor=databaseconnection.connect() + cursor = self.client.connect() cursor.execute(update_data_query, (salary, name)) - databaseconnection.connection.commit() + self.client.connection.commit() print("Data updated successfully.") except psycopg2.Error as e: print("Data updation failed: ", str(e)) @@ -61,10 +56,9 @@ def delete_data(self, name): DELETE FROM sample_table_py WHERE name='%s'; """ try: - databaseconnection = client(self.host, self.port,self.database, self.user, self.password) - cursor=databaseconnection.connect() + cursor = self.client.connect() cursor.execute(insert_data_query, (name,)) - databaseconnection.connection.commit() + self.client.connection.commit() print("Data deletion successfully.") except psycopg2.Error as e: print("Data deletion failed: ", str(e)) @@ -74,17 +68,9 @@ def table_drop(self): DROP TABLE sample_table_py; """ try: - databaseconnection = client(self.host, self.port,self.database, self.user, self.password) - cursor=databaseconnection.connect() + cursor = self.client.connect() cursor.execute(reset_query) - databaseconnection.connection.commit() + self.client.connection.commit() print("Table Dropped successfully") except psycopg2.Error as e: print("Table Drop Failed: ", str(e)) - -crud_ins=crud(host="risingwave-standalone", - port="4566", - database="dev", - user="root", - password="") -crud_ins.create_table() diff --git a/integration_tests/client-library/python/init.py b/integration_tests/client-library/python/init.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/integration_tests/client-library/python/materializeview.py b/integration_tests/client-library/python/materializeview.py index a5cf7fbac5f68..fd92a4d850392 100644 --- a/integration_tests/client-library/python/materializeview.py +++ b/integration_tests/client-library/python/materializeview.py @@ -1,35 +1,31 @@ import psycopg2 -from client import client -from crud import crud +from client import Client +from crud import SampleTableCrud + +# Represents the materialized view `average_salary_view_py`. class MaterializeView: def __init__(self, host, port, database, user, password): - self.host = host - self.database = database - self.user = user - self.password = password - self.connection = None - self.port=port + self.client = Client(host, port, database, user, password) + self.crud = SampleTableCrud(host, port, database, user, password) def create_mv(self): - crud_ins = crud(self.host, self.port, self.database, self.user, self.password) - crud_ins.create_table() - crud_ins.insert_data("John",25,10000) - crud_ins.insert_data("Shaun",25,11000) - crud_ins.insert_data("Caul",25,14000) - crud_ins.insert_data("Mantis",28,18000) - crud_ins.insert_data("Tony",28,19000) - mv_query=""" + self.crud.create_table() + self.crud.insert_data("John", 25, 10000) + self.crud.insert_data("Shaun", 25, 11000) + self.crud.insert_data("Caul", 25, 14000) + self.crud.insert_data("Mantis", 28, 18000) + self.crud.insert_data("Tony", 28, 19000) + mv_query = """ CREATE MATERIALIZED VIEW average_salary_view_py AS SELECT age, AVG(salary) AS average_salary FROM sample_table_py GROUP BY age; """ try: - databaseconnection = client(self.host, self.port,self.database, self.user, self.password) - cursor=databaseconnection.connect() + cursor = self.client.connect() cursor.execute(mv_query) - databaseconnection.connection.commit() + self.client.connection.commit() print("MV created successfully.") except psycopg2.Error as e: print("MV creation failed: ", str(e)) @@ -37,10 +33,9 @@ def create_mv(self): def drop_mv(self): mv_drop_query = "DROP materialized view average_salary_view_py;" try: - databaseconnection = client(self.host, self.port,self.database, self.user, self.password) - cursor=databaseconnection.connect() + cursor = self.client.connect() cursor.execute(mv_drop_query) - databaseconnection.connection.commit() + self.client.connection.commit() print("MV dropped successfully.") except psycopg2.Error as e: print("MV drop failed: ", str(e)) diff --git a/integration_tests/client-library/python/requirements.txt b/integration_tests/client-library/python/requirements.txt index 8391ca9c83332..96ffe07408f9b 100644 --- a/integration_tests/client-library/python/requirements.txt +++ b/integration_tests/client-library/python/requirements.txt @@ -1,2 +1,4 @@ psycopg2-binary -pytest \ No newline at end of file +pytest +sqlalchemy-risingwave +SQLAlchemy==1.4.51 diff --git a/integration_tests/client-library/python/test_database.py b/integration_tests/client-library/python/test_database.py index 78ceea33f8373..e874bf39ebeac 100644 --- a/integration_tests/client-library/python/test_database.py +++ b/integration_tests/client-library/python/test_database.py @@ -1,12 +1,12 @@ import pytest -from client import client -from crud import crud +from client import Client +from crud import SampleTableCrud from materializeview import MaterializeView @pytest.fixture def db_connection(): - db = client( + db = Client( host="risingwave-standalone", port="4566", database="dev", @@ -19,7 +19,7 @@ def db_connection(): @pytest.fixture def crud_instance(): - return crud( + return SampleTableCrud( host="risingwave-standalone", port="4566", database="dev", @@ -48,19 +48,22 @@ def test_disconnect(db_connection): db_connection.disconnect() assert db_connection.connection is None + def test_table_creation(crud_instance, db_connection): cursor = db_connection.connect() cursor.execute("SET TRANSACTION READ WRITE;") crud_instance.create_table() cursor.execute("FLUSH;") - cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_name = 'sample_table_py';") + cursor.execute( + "SELECT table_name FROM information_schema.tables WHERE table_name = 'sample_table_py';") result = cursor.fetchone()[0] cursor.close() assert result == 'sample_table_py' + def test_data_insertion(crud_instance, db_connection): - crud_instance.insert_data("John Doe", 25,10000) + crud_instance.insert_data("John Doe", 25, 10000) cursor = db_connection.connect() cursor.execute("FLUSH;") @@ -71,6 +74,7 @@ def test_data_insertion(crud_instance, db_connection): assert result == 1 + def test_data_updation(crud_instance, db_connection): crud_instance.update_data("John Doe", 12000) @@ -82,56 +86,65 @@ def test_data_updation(crud_instance, db_connection): cursor.close() assert result == 12000 + def test_data_deletion(crud_instance, db_connection): crud_instance.delete_data("John Doe") cursor = db_connection.connect() cursor.execute("FLUSH;") - cursor.execute("SELECT EXISTS (SELECT 1 FROM sample_table_py WHERE name = 'John Doe');") + cursor.execute( + "SELECT EXISTS (SELECT 1 FROM sample_table_py WHERE name = 'John Doe');") result = cursor.fetchone() result = result[0] cursor.close() assert result == True + def test_table_drop(crud_instance, db_connection): crud_instance.table_drop() cursor = db_connection.connect() cursor.execute("FLUSH;") - cursor.execute("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'sample_table_py');") + cursor.execute( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'sample_table_py');") result = cursor.fetchone() result = result[0] cursor.close() assert result is False -def test_mv_creation(mv_instance,db_connection): + +def test_mv_creation(mv_instance, db_connection): mv_instance.create_mv() cursor = db_connection.connect() cursor.execute("FLUSH;") - cursor.execute("SELECT EXISTS (SELECT 1 FROM pg_matviews WHERE matviewname = 'average_salary_view_py');") + cursor.execute( + "SELECT EXISTS (SELECT 1 FROM pg_matviews WHERE matviewname = 'average_salary_view_py');") result = cursor.fetchone()[0] cursor.close() assert result is True -def test_mv_updation(db_connection,crud_instance): + +def test_mv_updation(db_connection, crud_instance): crud_instance.insert_data("Stark", 25, 13000) cursor = db_connection.connect() cursor.execute("FLUSH;") - cursor.execute("SELECT average_salary FROM average_salary_view_py WHERE age=25;") + cursor.execute( + "SELECT average_salary FROM average_salary_view_py WHERE age=25;") result = cursor.fetchone()[0] cursor.close() # assert result == 11250 assert result == 12000 -def test_mv_drop(crud_instance,mv_instance,db_connection): +def test_mv_drop(crud_instance, mv_instance, db_connection): mv_instance.drop_mv() crud_instance.table_drop() cursor = db_connection.connect() cursor.execute("FLUSH;") - cursor.execute("SELECT EXISTS (SELECT 1 FROM pg_matviews WHERE matviewname = 'average_salary_view_py');") + cursor.execute( + "SELECT EXISTS (SELECT 1 FROM pg_matviews WHERE matviewname = 'average_salary_view_py');") result = cursor.fetchone() result = result[0] cursor.close() diff --git a/integration_tests/client-library/python/test_sqlalchemy.py b/integration_tests/client-library/python/test_sqlalchemy.py new file mode 100644 index 0000000000000..d71c8530f54b0 --- /dev/null +++ b/integration_tests/client-library/python/test_sqlalchemy.py @@ -0,0 +1,62 @@ +from sqlalchemy import Column, BigInteger, Integer, String, create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +import pytest +# Create a base class for declarative class definitions +Base = declarative_base() + + +class User(Base): + # Define a simple User class as an example + __tablename__ = 'users' + + id = Column('id', BigInteger, primary_key=True) + name = Column('name', String) + age = Column('age', Integer) + + +# Pytest fixture to create and destroy the database session +@pytest.fixture +def db_session(): + DB_URI = 'risingwave+psycopg2://root@risingwave-standalone:4566/dev' + # Create an SQLAlchemy engine to manage connections to the database + engine = create_engine(DB_URI) + + # The automatically created table is incorrect. The BigInteger will be translated into BIGSERIAL somehow, which is not supported. + create_table = """ + CREATE TABLE IF NOT EXISTS users ( + id BIGINT PRIMARY KEY, + name VARCHAR, + age INTEGER + ) + """ + with engine.connect() as conn: + conn.execute(create_table) + conn.execute('SET RW_IMPLICIT_FLUSH=true') + + Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) + session = Session() + yield session + session.close() + + Base.metadata.drop_all(engine) + + +# Pytest test functions to perform CRUD operations +def test_create_user(db_session): + new_user = User(id=1, name='John Doe', age=30) + db_session.add(new_user) + db_session.commit() + assert new_user.id is not None + + all_users = db_session.query(User).all() + assert len(all_users) > 0 + + +def test_delete_user(db_session): + user_to_delete = db_session.query(User).filter_by(name='John Doe').first() + if user_to_delete: + db_session.delete(user_to_delete) + db_session.commit() + deleted_user = db_session.query(User).get(user_to_delete.id) + assert deleted_user is None From 240416f43bae35c11d1d8a6d9e081fdd42fcc6bb Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Mon, 15 Jan 2024 12:00:17 +0800 Subject: [PATCH 13/71] feat: enable or disable tracing with system params (#14528) Signed-off-by: Bugen Zhao Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- proto/meta.proto | 1 + src/common/src/config.rs | 7 ++- src/common/src/system_param/common.rs | 48 +++++++++++++++++++ src/common/src/system_param/local_manager.rs | 19 +++++++- src/common/src/system_param/mod.rs | 3 ++ src/common/src/system_param/reader.rs | 12 ++++- src/common/src/util/tracing.rs | 2 + src/common/src/util/tracing/layer.rs | 32 +++++++++++++ src/config/example.toml | 1 + src/meta/src/controller/system_param.rs | 11 ++++- src/meta/src/manager/system_param/mod.rs | 11 ++++- src/utils/runtime/src/logger.rs | 50 ++++++++++++++++---- 12 files changed, 180 insertions(+), 17 deletions(-) create mode 100644 src/common/src/system_param/common.rs create mode 100644 src/common/src/util/tracing/layer.rs diff --git a/proto/meta.proto b/proto/meta.proto index 0ce5fc887fe9d..3e41032b33bfa 100644 --- a/proto/meta.proto +++ b/proto/meta.proto @@ -552,6 +552,7 @@ message SystemParams { optional uint32 max_concurrent_creating_streaming_jobs = 12; optional bool pause_on_next_bootstrap = 13; optional string wasm_storage_url = 14; + optional bool enable_tracing = 15; } message GetSystemParamsRequest {} diff --git a/src/common/src/config.rs b/src/common/src/config.rs index 8655c5fcbdf6d..78cb7146370de 100644 --- a/src/common/src/config.rs +++ b/src/common/src/config.rs @@ -902,6 +902,10 @@ pub struct SystemConfig { #[serde(default = "default::system::wasm_storage_url")] pub wasm_storage_url: Option, + + /// Whether to enable distributed tracing. + #[serde(default = "default::system::enable_tracing")] + pub enable_tracing: Option, } /// The subsections `[storage.object_store]`. @@ -955,8 +959,9 @@ impl SystemConfig { backup_storage_directory: self.backup_storage_directory, max_concurrent_creating_streaming_jobs: self.max_concurrent_creating_streaming_jobs, pause_on_next_bootstrap: self.pause_on_next_bootstrap, - telemetry_enabled: None, // deprecated wasm_storage_url: self.wasm_storage_url, + enable_tracing: self.enable_tracing, + telemetry_enabled: None, // deprecated } } } diff --git a/src/common/src/system_param/common.rs b/src/common/src/system_param/common.rs new file mode 100644 index 0000000000000..b8dcf825e2dda --- /dev/null +++ b/src/common/src/system_param/common.rs @@ -0,0 +1,48 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Mutex; + +use super::reader::SystemParamsReader; +use crate::util::tracing::layer::toggle_otel_layer; + +/// Node-independent handler for system parameter changes. +/// +/// Currently, it is only used to enable or disable the distributed tracing layer. +pub struct CommonHandler { + last_params: Mutex>, +} + +impl CommonHandler { + /// Create a new handler with the initial parameters. + pub fn new(initial: SystemParamsReader) -> Self { + let this = Self { + last_params: None.into(), + }; + this.handle_change(initial); + this + } + + /// Handle the change of system parameters. + // TODO: directly call this method with the difference of old and new params. + pub fn handle_change(&self, new_params: SystemParamsReader) { + let mut last_params = self.last_params.lock().unwrap(); + + if last_params.as_ref().map(|p| p.enable_tracing()) != Some(new_params.enable_tracing()) { + toggle_otel_layer(new_params.enable_tracing()); + } + + last_params.replace(new_params); + } +} diff --git a/src/common/src/system_param/local_manager.rs b/src/common/src/system_param/local_manager.rs index 7103ed6737104..312c5577a0f81 100644 --- a/src/common/src/system_param/local_manager.rs +++ b/src/common/src/system_param/local_manager.rs @@ -19,6 +19,7 @@ use arc_swap::ArcSwap; use risingwave_pb::meta::SystemParams; use tokio::sync::watch::{channel, Receiver, Sender}; +use super::common::CommonHandler; use super::reader::SystemParamsReader; use super::system_params_for_test; @@ -40,9 +41,23 @@ pub struct LocalSystemParamsManager { } impl LocalSystemParamsManager { - pub fn new(params: SystemParamsReader) -> Self { - let params = Arc::new(ArcSwap::from_pointee(params)); + pub fn new(initial_params: SystemParamsReader) -> Self { + let params = Arc::new(ArcSwap::from_pointee(initial_params.clone())); let (tx, _) = channel(params.clone()); + + // Spawn a task to run the common handler. + tokio::spawn({ + let mut rx = tx.subscribe(); + async move { + let handler = CommonHandler::new(initial_params); + + while rx.changed().await.is_ok() { + let new_params = (**rx.borrow_and_update().load()).clone(); + handler.handle_change(new_params); + } + } + }); + Self { params, tx } } diff --git a/src/common/src/system_param/mod.rs b/src/common/src/system_param/mod.rs index 366cc61d2dd53..cffa7a4564a5f 100644 --- a/src/common/src/system_param/mod.rs +++ b/src/common/src/system_param/mod.rs @@ -20,6 +20,7 @@ //! - Add a new entry to `for_all_undeprecated_params` in this file. //! - Add a new method to [`reader::SystemParamsReader`]. +pub mod common; pub mod local_manager; pub mod reader; @@ -56,6 +57,7 @@ macro_rules! for_all_params { { max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true }, { pause_on_next_bootstrap, bool, Some(false), true }, { wasm_storage_url, String, Some("fs://.risingwave/data".to_string()), false }, + { enable_tracing, bool, Some(false), true }, } }; } @@ -359,6 +361,7 @@ mod tests { (MAX_CONCURRENT_CREATING_STREAMING_JOBS_KEY, "1"), (PAUSE_ON_NEXT_BOOTSTRAP_KEY, "false"), (WASM_STORAGE_URL_KEY, "a"), + (ENABLE_TRACING_KEY, "true"), ("a_deprecated_param", "foo"), ]; diff --git a/src/common/src/system_param/reader.rs b/src/common/src/system_param/reader.rs index 0059974203c6d..24ecd83f5f061 100644 --- a/src/common/src/system_param/reader.rs +++ b/src/common/src/system_param/reader.rs @@ -14,7 +14,7 @@ use risingwave_pb::meta::PbSystemParams; -use super::system_params_to_kv; +use super::{default, system_params_to_kv}; /// A wrapper for [`risingwave_pb::meta::SystemParams`] for 2 purposes: /// - Avoid misuse of deprecated fields by hiding their getters. @@ -77,7 +77,15 @@ impl SystemParamsReader { } pub fn pause_on_next_bootstrap(&self) -> bool { - self.prost.pause_on_next_bootstrap.unwrap_or(false) + self.prost + .pause_on_next_bootstrap + .unwrap_or_else(|| default::pause_on_next_bootstrap().unwrap()) + } + + pub fn enable_tracing(&self) -> bool { + self.prost + .enable_tracing + .unwrap_or_else(|| default::enable_tracing().unwrap()) } pub fn wasm_storage_url(&self) -> &str { diff --git a/src/common/src/util/tracing.rs b/src/common/src/util/tracing.rs index f87a5efd4baef..e7da6e8e7d580 100644 --- a/src/common/src/util/tracing.rs +++ b/src/common/src/util/tracing.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod layer; + use std::collections::HashMap; use std::pin::Pin; use std::task::{Context, Poll}; diff --git a/src/common/src/util/tracing/layer.rs b/src/common/src/util/tracing/layer.rs new file mode 100644 index 0000000000000..a5268a55dc90e --- /dev/null +++ b/src/common/src/util/tracing/layer.rs @@ -0,0 +1,32 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::OnceLock; + +static TOGGLE_OTEL_LAYER: OnceLock> = OnceLock::new(); + +/// Set the function to toggle the opentelemetry tracing layer. Panics if called twice. +pub fn set_toggle_otel_layer_fn(f: impl Fn(bool) + Sync + Send + 'static) { + TOGGLE_OTEL_LAYER + .set(Box::new(f)) + .ok() + .expect("toggle otel layer fn set twice"); +} + +/// Toggle the opentelemetry tracing layer. +pub fn toggle_otel_layer(enabled: bool) { + if let Some(f) = TOGGLE_OTEL_LAYER.get() { + f(enabled); + } +} diff --git a/src/config/example.toml b/src/config/example.toml index 15738219ae08c..b2eef323c2d00 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -189,3 +189,4 @@ backup_storage_directory = "backup" max_concurrent_creating_streaming_jobs = 1 pause_on_next_bootstrap = false wasm_storage_url = "fs://.risingwave/data" +enable_tracing = false diff --git a/src/meta/src/controller/system_param.rs b/src/meta/src/controller/system_param.rs index c37ce5a626f13..fbcce97a97d9b 100644 --- a/src/meta/src/controller/system_param.rs +++ b/src/meta/src/controller/system_param.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use std::time::Duration; use anyhow::anyhow; +use risingwave_common::system_param::common::CommonHandler; use risingwave_common::system_param::reader::SystemParamsReader; use risingwave_common::system_param::{ check_missing_params, derive_missing_fields, set_system_param, @@ -44,6 +45,8 @@ pub struct SystemParamsController { notification_manager: NotificationManagerRef, // Cached parameters. params: RwLock, + /// Common handler for system params. + common_handler: CommonHandler, } /// Derive system params from db models. @@ -146,7 +149,8 @@ impl SystemParamsController { let ctl = Self { db, notification_manager, - params: RwLock::new(params), + params: RwLock::new(params.clone()), + common_handler: CommonHandler::new(params.into()), }; // flush to db. ctl.flush_params().await?; @@ -196,6 +200,11 @@ impl SystemParamsController { param.update(&self.db).await?; *params_guard = params.clone(); + // TODO: check if the parameter is actually changed. + + // Run common handler. + self.common_handler.handle_change(params.clone().into()); + // Sync params to other managers on the meta node only once, since it's infallible. self.notification_manager .notify_local_subscribers(LocalNotification::SystemParamsChange(params.clone().into())) diff --git a/src/meta/src/manager/system_param/mod.rs b/src/meta/src/manager/system_param/mod.rs index 9d5e574efa8b6..14d0e311a2d89 100644 --- a/src/meta/src/manager/system_param/mod.rs +++ b/src/meta/src/manager/system_param/mod.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use std::time::Duration; use anyhow::anyhow; +use risingwave_common::system_param::common::CommonHandler; use risingwave_common::system_param::reader::SystemParamsReader; use risingwave_common::system_param::{check_missing_params, set_system_param}; use risingwave_common::{for_all_params, key_of}; @@ -43,6 +44,8 @@ pub struct SystemParamsManager { notification_manager: NotificationManagerRef, // Cached parameters. params: RwLock, + /// Common handler for system params. + common_handler: CommonHandler, } impl SystemParamsManager { @@ -69,7 +72,8 @@ impl SystemParamsManager { Ok(Self { meta_store, notification_manager, - params: RwLock::new(params), + params: RwLock::new(params.clone()), + common_handler: CommonHandler::new(params.into()), }) } @@ -94,6 +98,11 @@ impl SystemParamsManager { mem_txn.commit(); + // TODO: check if the parameter is actually changed. + + // Run common handler. + self.common_handler.handle_change(params.clone().into()); + // Sync params to other managers on the meta node only once, since it's infallible. self.notification_manager .notify_local_subscribers(super::LocalNotification::SystemParamsChange( diff --git a/src/utils/runtime/src/logger.rs b/src/utils/runtime/src/logger.rs index e636c3a72de51..cb27840d7530f 100644 --- a/src/utils/runtime/src/logger.rs +++ b/src/utils/runtime/src/logger.rs @@ -18,8 +18,8 @@ use std::path::PathBuf; use either::Either; use risingwave_common::metrics::MetricsLayer; use risingwave_common::util::deployment::Deployment; -use risingwave_common::util::env_var::env_var_is_true; use risingwave_common::util::query_log::*; +use risingwave_common::util::tracing::layer::set_toggle_otel_layer_fn; use thiserror_ext::AsReport; use tracing::level_filters::LevelFilter as Level; use tracing_subscriber::filter::{FilterFn, Targets}; @@ -28,7 +28,7 @@ use tracing_subscriber::fmt::time::OffsetTime; use tracing_subscriber::fmt::FormatFields; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::prelude::*; -use tracing_subscriber::{filter, EnvFilter}; +use tracing_subscriber::{filter, reload, EnvFilter}; pub struct LoggerSettings { /// The name of the service. Used to identify the service in distributed tracing. @@ -60,15 +60,12 @@ impl LoggerSettings { /// /// If env var `RW_TRACING_ENDPOINT` is not set, the meta address will be used /// as the default tracing endpoint, which means that the embedded tracing - /// collector will be used. This can be disabled by setting env var - /// `RW_DISABLE_EMBEDDED_TRACING` to `true`. + /// collector will be used. pub fn from_opts(opts: &O) -> Self { let mut settings = Self::new(O::name()); if settings.tracing_endpoint.is_none() // no explicit endpoint - && !env_var_is_true("RW_DISABLE_EMBEDDED_TRACING") // not disabled by env var - && let Some(addr) = opts.meta_addr().exactly_one() // meta address is valid - && !Deployment::current().is_ci() - // not in CI + && let Some(addr) = opts.meta_addr().exactly_one() + // meta address is valid { // Use embedded collector in the meta service. // TODO: when there's multiple meta nodes for high availability, we may send @@ -133,6 +130,11 @@ impl LoggerSettings { } } +/// Create a filter that disables all events or spans. +fn disabled_filter() -> filter::Targets { + filter::Targets::new() +} + /// Init logger for RisingWave binaries. /// /// ## Environment variables to configure logger dynamically @@ -388,7 +390,7 @@ pub fn init_risingwave_logger(settings: LoggerSettings) { // Tracing layer #[cfg(not(madsim))] if let Some(endpoint) = settings.tracing_endpoint { - println!("tracing enabled, exported to `{endpoint}`"); + println!("opentelemetry tracing will be exported to `{endpoint}` if enabled"); use opentelemetry::{sdk, KeyValue}; use opentelemetry_otlp::WithExportConfig; @@ -437,9 +439,37 @@ pub fn init_risingwave_logger(settings: LoggerSettings) { .unwrap() }; + // Disable by filtering out all events or spans by default. + // + // It'll be enabled with `toggle_otel_layer` based on the system parameter `enable_tracing` later. + let (reload_filter, reload_handle) = reload::Layer::new(disabled_filter()); + + set_toggle_otel_layer_fn(move |enabled: bool| { + let result = reload_handle.modify(|f| { + *f = if enabled { + default_filter.clone() + } else { + disabled_filter() + } + }); + + match result { + Ok(_) => tracing::info!( + "opentelemetry tracing {}", + if enabled { "enabled" } else { "disabled" }, + ), + + Err(error) => tracing::error!( + error = %error.as_report(), + "failed to {} opentelemetry tracing", + if enabled { "enable" } else { "disable" }, + ), + } + }); + let layer = tracing_opentelemetry::layer() .with_tracer(otel_tracer) - .with_filter(default_filter); + .with_filter(reload_filter); layers.push(layer.boxed()); } From d65c151c50e88220b6bf2d9ed01411d8da6175ec Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Mon, 15 Jan 2024 12:11:34 +0800 Subject: [PATCH 14/71] feat(dashboard): improvements on the relation dep graph (#14505) Signed-off-by: Bugen Zhao --- dashboard/components/CatalogModal.tsx | 87 +++++++++++ .../components/FragmentDependencyGraph.tsx | 84 +++++----- dashboard/components/FragmentGraph.tsx | 13 +- .../components/RelationDependencyGraph.tsx | 128 +++++++++++----- dashboard/components/Relations.tsx | 49 +----- dashboard/lib/layout.ts | 145 ++++++++---------- dashboard/pages/api/streaming.ts | 20 ++- dashboard/pages/dependency_graph.tsx | 32 ++-- dashboard/pages/fragment_graph.tsx | 4 +- 9 files changed, 341 insertions(+), 221 deletions(-) create mode 100644 dashboard/components/CatalogModal.tsx diff --git a/dashboard/components/CatalogModal.tsx b/dashboard/components/CatalogModal.tsx new file mode 100644 index 0000000000000..cf6a2f8cc9e0d --- /dev/null +++ b/dashboard/components/CatalogModal.tsx @@ -0,0 +1,87 @@ +/* + * Copyright 2024 RisingWave Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +import { + Button, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, +} from "@chakra-ui/react" + +import Link from "next/link" +import { parseAsInteger, useQueryState } from "nuqs" +import { + Relation, + relationIsStreamingJob, + relationTypeTitleCase, +} from "../pages/api/streaming" +import { ReactJson } from "./Relations" + +export function useCatalogModal(relationList: Relation[] | undefined) { + const [modalId, setModalId] = useQueryState("modalId", parseAsInteger) + const modalData = relationList?.find((r) => r.id === modalId) + + return [modalData, setModalId] as const +} + +export function CatalogModal({ + modalData, + onClose, +}: { + modalData: Relation | undefined + onClose: () => void +}) { + return ( + + + + + Catalog of {modalData && relationTypeTitleCase(modalData)}{" "} + {modalData?.id} - {modalData?.name} + + + + {modalData && ( + + )} + + + + {modalData && relationIsStreamingJob(modalData) && ( + + )} + + + + + ) +} diff --git a/dashboard/components/FragmentDependencyGraph.tsx b/dashboard/components/FragmentDependencyGraph.tsx index 553c40ec53f92..7b3417507efcf 100644 --- a/dashboard/components/FragmentDependencyGraph.tsx +++ b/dashboard/components/FragmentDependencyGraph.tsx @@ -3,18 +3,18 @@ import * as d3 from "d3" import { Dag, DagLink, DagNode, zherebko } from "d3-dag" import { cloneDeep } from "lodash" import { useCallback, useEffect, useRef, useState } from "react" -import { Position } from "../lib/layout" +import { Enter, FragmentBox, Position } from "../lib/layout" const nodeRadius = 5 const edgeRadius = 12 export default function FragmentDependencyGraph({ - mvDependency, + fragmentDependency, svgWidth, selectedId, onSelectedIdChange, }: { - mvDependency: Dag + fragmentDependency: Dag svgWidth: number selectedId: string | undefined onSelectedIdChange: (id: string) => void | undefined @@ -24,21 +24,21 @@ export default function FragmentDependencyGraph({ const MARGIN_X = 10 const MARGIN_Y = 2 - const mvDependencyDagCallback = useCallback(() => { + const fragmentDependencyDagCallback = useCallback(() => { const layout = zherebko().nodeSize([ nodeRadius * 2, (nodeRadius + edgeRadius) * 2, nodeRadius, ]) - const dag = cloneDeep(mvDependency) + const dag = cloneDeep(fragmentDependency) const { width, height } = layout(dag) return { width, height, dag } - }, [mvDependency]) + }, [fragmentDependency]) - const mvDependencyDag = mvDependencyDagCallback() + const fragmentDependencyDag = fragmentDependencyDagCallback() useEffect(() => { - const { width, height, dag } = mvDependencyDag + const { width, height, dag } = fragmentDependencyDag // This code only handles rendering @@ -53,25 +53,27 @@ export default function FragmentDependencyGraph({ .x(({ x }) => x + MARGIN_X) .y(({ y }) => y) - const isSelected = (d: any) => d.data.id === selectedId + const isSelected = (d: DagNode) => d.data.id === selectedId const edgeSelection = svgSelection .select(".edges") - .selectAll(".edge") + .selectAll(".edge") .data(dag.links()) - const applyEdge = (sel: any) => + type EdgeSelection = typeof edgeSelection + + const applyEdge = (sel: EdgeSelection) => sel .attr("d", ({ points }: DagLink) => line(points)) .attr("fill", "none") - .attr("stroke-width", (d: any) => + .attr("stroke-width", (d) => isSelected(d.source) || isSelected(d.target) ? 2 : 1 ) - .attr("stroke", (d: any) => + .attr("stroke", (d) => isSelected(d.source) || isSelected(d.target) ? theme.colors.blue["500"] : theme.colors.gray["300"] ) - const createEdge = (sel: any) => + const createEdge = (sel: Enter) => sel.append("path").attr("class", "edge").call(applyEdge) edgeSelection.exit().remove() edgeSelection.enter().call(createEdge) @@ -80,19 +82,18 @@ export default function FragmentDependencyGraph({ // Select nodes const nodeSelection = svgSelection .select(".nodes") - .selectAll(".node") + .selectAll(".node") .data(dag.descendants()) - const applyNode = (sel: any) => + type NodeSelection = typeof nodeSelection + + const applyNode = (sel: NodeSelection) => sel - .attr( - "transform", - ({ x, y }: Position) => `translate(${x + MARGIN_X}, ${y})` - ) - .attr("fill", (d: any) => + .attr("transform", (d) => `translate(${d.x! + MARGIN_X}, ${d.y})`) + .attr("fill", (d) => isSelected(d) ? theme.colors.blue["500"] : theme.colors.gray["500"] ) - const createNode = (sel: any) => + const createNode = (sel: Enter) => sel .append("circle") .attr("class", "node") @@ -105,22 +106,23 @@ export default function FragmentDependencyGraph({ // Add text to nodes const labelSelection = svgSelection .select(".labels") - .selectAll(".label") + .selectAll(".label") .data(dag.descendants()) + type LabelSelection = typeof labelSelection - const applyLabel = (sel: any) => + const applyLabel = (sel: LabelSelection) => sel - .text((d: any) => d.data.name) + .text((d) => d.data.name) .attr("x", svgWidth - MARGIN_X) .attr("font-family", "inherit") .attr("text-anchor", "end") .attr("alignment-baseline", "middle") - .attr("y", (d: any) => d.y) - .attr("fill", (d: any) => + .attr("y", (d) => d.y!) + .attr("fill", (d) => isSelected(d) ? theme.colors.black["500"] : theme.colors.gray["500"] ) .attr("font-weight", "600") - const createLabel = (sel: any) => + const createLabel = (sel: Enter) => sel.append("text").attr("class", "label").call(applyLabel) labelSelection.exit().remove() labelSelection.enter().call(createLabel) @@ -129,11 +131,12 @@ export default function FragmentDependencyGraph({ // Add overlays const overlaySelection = svgSelection .select(".overlays") - .selectAll(".overlay") + .selectAll(".overlay") .data(dag.descendants()) + type OverlaySelection = typeof overlaySelection const STROKE_WIDTH = 3 - const applyOverlay = (sel: any) => + const applyOverlay = (sel: OverlaySelection) => sel .attr("x", STROKE_WIDTH) .attr( @@ -143,20 +146,13 @@ export default function FragmentDependencyGraph({ .attr("width", svgWidth - STROKE_WIDTH * 2) .attr( "y", - (d: any) => d.y - nodeRadius - edgeRadius + MARGIN_Y + STROKE_WIDTH + (d) => d.y! - nodeRadius - edgeRadius + MARGIN_Y + STROKE_WIDTH ) .attr("rx", 5) .attr("fill", theme.colors.gray["500"]) .attr("opacity", 0) .style("cursor", "pointer") - const createOverlay = ( - sel: d3.Selection< - d3.EnterElement, - DagNode, - d3.BaseType, - unknown - > - ) => + const createOverlay = (sel: Enter) => sel .append("rect") .attr("class", "overlay") @@ -187,7 +183,7 @@ export default function FragmentDependencyGraph({ }) .on("click", function (d, i) { if (onSelectedIdChange) { - onSelectedIdChange((i.data as any).id) + onSelectedIdChange(i.data.id) } }) @@ -196,7 +192,13 @@ export default function FragmentDependencyGraph({ overlaySelection.call(applyOverlay) setSvgHeight(`${height}px`) - }, [mvDependency, selectedId, svgWidth, onSelectedIdChange, mvDependencyDag]) + }, [ + fragmentDependency, + selectedId, + svgWidth, + onSelectedIdChange, + fragmentDependencyDag, + ]) return ( diff --git a/dashboard/components/FragmentGraph.tsx b/dashboard/components/FragmentGraph.tsx index 875d92baa2b6b..72184d1b2a8bc 100644 --- a/dashboard/components/FragmentGraph.tsx +++ b/dashboard/components/FragmentGraph.tsx @@ -17,11 +17,12 @@ import { cloneDeep } from "lodash" import { Fragment, useCallback, useEffect, useRef, useState } from "react" import { Edge, + Enter, FragmentBox, FragmentBoxPosition, Position, - generateBoxEdges, - layout, + generateFragmentEdges, + layoutItem, } from "../lib/layout" import { PlanNodeDatum } from "../pages/fragment_graph" import { StreamNode } from "../proto/gen/stream_plan" @@ -36,10 +37,6 @@ type FragmentLayout = { actorIds: string[] } & Position -type Enter = Type extends d3.Selection - ? d3.Selection - : never - function treeLayoutFlip( root: d3.HierarchyNode, { dx, dy }: { dx: number; dy: number } @@ -145,7 +142,7 @@ export default function FragmentGraph({ includedFragmentIds.add(fragmentId) } - const fragmentLayout = layout( + const fragmentLayout = layoutItem( fragmentDependencyDag.map(({ width: _1, height: _2, id, ...data }) => { const { width, height } = layoutFragmentResult.get(id)! return { width, height, id, ...data } @@ -170,7 +167,7 @@ export default function FragmentGraph({ svgHeight = Math.max(svgHeight, y + height + 50) svgWidth = Math.max(svgWidth, x + width) }) - const edges = generateBoxEdges(fragmentLayout) + const edges = generateFragmentEdges(fragmentLayout) return { layoutResult, diff --git a/dashboard/components/RelationDependencyGraph.tsx b/dashboard/components/RelationDependencyGraph.tsx index 99d40ca2615fd..0f677101cce17 100644 --- a/dashboard/components/RelationDependencyGraph.tsx +++ b/dashboard/components/RelationDependencyGraph.tsx @@ -19,15 +19,23 @@ import { theme } from "@chakra-ui/react" import * as d3 from "d3" import { useCallback, useEffect, useRef } from "react" import { - FragmentPoint, - FragmentPointPosition, + Enter, Position, - flipLayoutPoint, - generatePointEdges, + RelationPoint, + RelationPointPosition, + flipLayoutRelation, + generateRelationEdges, } from "../lib/layout" +import { + Relation, + relationIsStreamingJob, + relationType, + relationTypeTitleCase, +} from "../pages/api/streaming" +import { CatalogModal, useCatalogModal } from "./CatalogModal" function boundBox( - fragmentPosition: FragmentPointPosition[], + relationPosition: RelationPointPosition[], nodeRadius: number ): { width: number @@ -35,7 +43,7 @@ function boundBox( } { let width = 0 let height = 0 - for (const { x, y } of fragmentPosition) { + for (const { x, y } of relationPosition) { width = Math.max(width, x + nodeRadius) height = Math.max(height, y + nodeRadius) } @@ -43,21 +51,25 @@ function boundBox( } const layerMargin = 50 -const rowMargin = 200 -const nodeRadius = 10 -const layoutMargin = 100 +const rowMargin = 50 +export const nodeRadius = 12 +const layoutMargin = 50 export default function RelationDependencyGraph({ nodes, selectedId, + setSelectedId, }: { - nodes: FragmentPoint[] - selectedId?: string + nodes: RelationPoint[] + selectedId: string | undefined + setSelectedId: (id: string) => void }) { - const svgRef = useRef() + const [modalData, setModalId] = useCatalogModal(nodes.map((n) => n.relation)) + + const svgRef = useRef(null) const layoutMapCallback = useCallback(() => { - const layoutMap = flipLayoutPoint( + const layoutMap = flipLayoutRelation( nodes, layerMargin, rowMargin, @@ -68,9 +80,9 @@ export default function RelationDependencyGraph({ x: x + layoutMargin, y: y + layoutMargin, ...data, - } as FragmentPointPosition) + } as RelationPointPosition) ) - const links = generatePointEdges(layoutMap) + const links = generateRelationEdges(layoutMap) const { width, height } = boundBox(layoutMap, nodeRadius) return { layoutMap, @@ -96,29 +108,30 @@ export default function RelationDependencyGraph({ const edgeSelection = svgSelection .select(".edges") - .selectAll(".edge") + .selectAll(".edge") .data(links) + type EdgeSelection = typeof edgeSelection const isSelected = (id: string) => id === selectedId - const applyEdge = (sel: any) => + const applyEdge = (sel: EdgeSelection) => sel - .attr("d", ({ points }: any) => line(points)) + .attr("d", ({ points }) => line(points)) .attr("fill", "none") .attr("stroke-width", 1) - .attr("stroke-width", (d: any) => - isSelected(d.source) || isSelected(d.target) ? 2 : 1 + .attr("stroke-width", (d) => + isSelected(d.source) || isSelected(d.target) ? 4 : 2 ) - .attr("opacity", (d: any) => + .attr("opacity", (d) => isSelected(d.source) || isSelected(d.target) ? 1 : 0.5 ) - .attr("stroke", (d: any) => + .attr("stroke", (d) => isSelected(d.source) || isSelected(d.target) ? theme.colors.blue["500"] : theme.colors.gray["300"] ) - const createEdge = (sel: any) => + const createEdge = (sel: Enter) => sel.append("path").attr("class", "edge").call(applyEdge) edgeSelection.exit().remove() edgeSelection.enter().call(createEdge) @@ -127,21 +140,23 @@ export default function RelationDependencyGraph({ const applyNode = (g: NodeSelection) => { g.attr("transform", ({ x, y }) => `translate(${x},${y})`) + // Circle let circle = g.select("circle") if (circle.empty()) { circle = g.append("circle") } - circle - .attr("r", nodeRadius) - .style("cursor", "pointer") - .attr("fill", ({ id }) => - isSelected(id) ? theme.colors.blue["500"] : theme.colors.gray["500"] - ) + circle.attr("r", nodeRadius).attr("fill", ({ id, relation }) => { + const weight = relationIsStreamingJob(relation) ? "500" : "400" + return isSelected(id) + ? theme.colors.blue[weight] + : theme.colors.gray[weight] + }) - let text = g.select("text") + // Relation name + let text = g.select(".text") if (text.empty()) { - text = g.append("text") + text = g.append("text").attr("class", "text") } text @@ -150,24 +165,66 @@ export default function RelationDependencyGraph({ .attr("font-family", "inherit") .attr("text-anchor", "middle") .attr("dy", nodeRadius * 2) - .attr("fill", "black") .attr("font-size", 12) .attr("transform", "rotate(-8)") + // Relation type + let typeText = g.select(".type") + if (typeText.empty()) { + typeText = g.append("text").attr("class", "type") + } + + const relationTypeAbbr = (relation: Relation) => { + const type = relationType(relation) + if (type === "SINK") { + return "K" + } else { + return type.charAt(0) + } + } + + typeText + .attr("fill", "white") + .text(({ relation }) => `${relationTypeAbbr(relation)}`) + .attr("font-family", "inherit") + .attr("text-anchor", "middle") + .attr("dy", nodeRadius * 0.5) + .attr("font-size", 16) + .attr("font-weight", "bold") + + // Relation type tooltip + let typeTooltip = g.select("title") + if (typeTooltip.empty()) { + typeTooltip = g.append("title") + } + + typeTooltip.text( + ({ relation }) => + `${relation.name} (${relationTypeTitleCase(relation)})` + ) + + // Relation modal + g.style("cursor", "pointer").on("click", (_, { relation, id }) => { + setSelectedId(id) + setModalId(relation.id) + }) + return g } - const createNode = (sel: any) => + const createNode = (sel: Enter) => sel.append("g").attr("class", "node").call(applyNode) const g = svgSelection.select(".boxes") - const nodeSelection = g.selectAll(".node").data(layoutMap) + const nodeSelection = g + .selectAll(".node") + .data(layoutMap) type NodeSelection = typeof nodeSelection nodeSelection.enter().call(createNode) nodeSelection.call(applyNode) nodeSelection.exit().remove() - }, [layoutMap, links, selectedId]) + }, [layoutMap, links, selectedId, setModalId, setSelectedId]) return ( <> @@ -175,6 +232,7 @@ export default function RelationDependencyGraph({ + setModalId(null)} /> ) } diff --git a/dashboard/components/Relations.tsx b/dashboard/components/Relations.tsx index c16a70e8c6fa2..0422eaa2531fa 100644 --- a/dashboard/components/Relations.tsx +++ b/dashboard/components/Relations.tsx @@ -18,13 +18,6 @@ import { Box, Button, - Modal, - ModalBody, - ModalCloseButton, - ModalContent, - ModalFooter, - ModalHeader, - ModalOverlay, Table, TableContainer, Tbody, @@ -37,7 +30,6 @@ import loadable from "@loadable/component" import Head from "next/head" import Link from "next/link" -import { parseAsInteger, useQueryState } from "nuqs" import { Fragment } from "react" import Title from "../components/Title" import extractColumnInfo from "../lib/extractInfo" @@ -48,8 +40,9 @@ import { Source as RwSource, Table as RwTable, } from "../proto/gen/catalog" +import { CatalogModal, useCatalogModal } from "./CatalogModal" -const ReactJson = loadable(() => import("react-json-view")) +export const ReactJson = loadable(() => import("react-json-view")) export type Column = { name: string @@ -122,40 +115,10 @@ export function Relations( extraColumns: Column[] ) { const { response: relationList } = useFetch(getRelations) + const [modalData, setModalId] = useCatalogModal(relationList) - const [modalId, setModalId] = useQueryState("id", parseAsInteger) - const modalData = relationList?.find((r) => r.id === modalId) - - const catalogModal = ( - setModalId(null)} - size="3xl" - > - - - - Catalog of {modalData?.id} - {modalData?.name} - - - - {modalData && ( - - )} - - - - - - - + const modal = ( + setModalId(null)} /> ) const table = ( @@ -214,7 +177,7 @@ export function Relations( {title} - {catalogModal} + {modal} {table} ) diff --git a/dashboard/lib/layout.ts b/dashboard/lib/layout.ts index 1182976dfe8cb..924374341daa8 100644 --- a/dashboard/lib/layout.ts +++ b/dashboard/lib/layout.ts @@ -15,10 +15,20 @@ * */ -import { cloneDeep, max } from "lodash" +import { max } from "lodash" +import { Relation } from "../pages/api/streaming" import { TableFragments_Fragment } from "../proto/gen/meta" import { GraphNode } from "./algo" +export type Enter = Type extends d3.Selection< + any, + infer B, + infer C, + infer D +> + ? d3.Selection + : never + interface DagNode { node: GraphNode temp: boolean @@ -210,16 +220,16 @@ function dagLayout(nodes: GraphNode[]) { } /** - * @param fragments - * @returns Layer and row of the fragment + * @param items + * @returns Layer and row of the item */ -function gridLayout( - fragments: Array -): Map { - // turn FragmentBox to GraphNode - let idToBox = new Map() - for (let fragment of fragments) { - idToBox.set(fragment.id, fragment) +function gridLayout( + items: Array +): Map { + // turn item to GraphNode + let idToItem = new Map() + for (let item of items) { + idToItem.set(item.id, item) } let nodeToId = new Map() @@ -232,23 +242,23 @@ function gridLayout( let newNode = { nextNodes: new Array(), } - let ab = idToBox.get(id) - if (ab === undefined) { + let item = idToItem.get(id) + if (item === undefined) { throw Error(`no such id ${id}`) } - for (let id of ab.parentIds) { + for (let id of item.parentIds) { getNode(id).nextNodes.push(newNode) } idToNode.set(id, newNode) nodeToId.set(newNode, id) return newNode } - for (let fragment of fragments) { - getNode(fragment.id) + for (let item of items) { + getNode(item.id) } // run daglayout on GraphNode - let rtn = new Map() + let rtn = new Map() let allNodes = new Array() for (let _n of nodeToId.keys()) { allNodes.push(_n) @@ -257,33 +267,34 @@ function gridLayout( for (let item of resultMap) { let id = nodeToId.get(item[0]) if (!id) { - throw Error(`no corresponding fragment id of node ${item[0]}`) + throw Error(`no corresponding item of node ${item[0]}`) } - let fb = idToBox.get(id) + let fb = idToItem.get(id) if (!fb) { - throw Error(`fragment id ${id} is not present in idToBox`) + throw Error(`item id ${id} is not present in idToBox`) } rtn.set(fb, item[1]) } return rtn } -export interface FragmentBox { +export interface LayoutItemBase { id: string - name: string - order: number // preference order, fragment box with larger order will be placed at right + order: number // preference order, item with larger order will be placed at right or down width: number height: number parentIds: string[] +} + +export type FragmentBox = LayoutItemBase & { + name: string externalParentIds: string[] fragment?: TableFragments_Fragment } -export interface FragmentPoint { - id: string +export type RelationPoint = LayoutItemBase & { name: string - order: number // preference order, fragment box with larger order will be placed at right - parentIds: string[] + relation: Relation } export interface Position { @@ -292,7 +303,7 @@ export interface Position { } export type FragmentBoxPosition = FragmentBox & Position -export type FragmentPointPosition = FragmentPoint & Position +export type RelationPointPosition = RelationPoint & Position export interface Edge { points: Array @@ -301,15 +312,15 @@ export interface Edge { } /** - * @param fragments + * @param items * @returns the coordination of the top-left corner of the fragment box */ -export function layout( - fragments: Array, +export function layoutItem( + items: Array, layerMargin: number, rowMargin: number -): FragmentBoxPosition[] { - let layoutMap = gridLayout(fragments) +): (I & Position)[] { + let layoutMap = gridLayout(items) let layerRequiredWidth = new Map() let rowRequiredHeight = new Map() let maxLayer = 0, @@ -373,7 +384,7 @@ export function layout( getCumulativeMargin(i, rowMargin, rowCumulativeHeight, rowRequiredHeight) } - let rtn: Array = [] + let rtn: Array = [] for (let [data, [layer, row]] of layoutMap) { let x = layerCumulativeWidth.get(layer) @@ -391,39 +402,13 @@ export function layout( return rtn } -export function flipLayout( - fragments: Array, - layerMargin: number, - rowMargin: number -): FragmentBoxPosition[] { - const fragments_ = cloneDeep(fragments) - for (let fragment of fragments_) { - ;[fragment.width, fragment.height] = [fragment.height, fragment.width] - } - const fragmentPosition = layout(fragments_, rowMargin, layerMargin) - return fragmentPosition.map(({ x, y, ...data }) => ({ - x: y, - y: x, - ...data, - })) -} - -export function layoutPoint( - fragments: Array, +function layoutRelation( + relations: Array, layerMargin: number, rowMargin: number, nodeRadius: number -): FragmentPointPosition[] { - const fragmentBoxes: Array = [] - for (let { ...others } of fragments) { - fragmentBoxes.push({ - width: nodeRadius * 2, - height: nodeRadius * 2, - externalParentIds: [], // we don't care about external parent for point layout - ...others, - }) - } - const result = layout(fragmentBoxes, layerMargin, rowMargin) +): RelationPointPosition[] { + const result = layoutItem(relations, layerMargin, rowMargin) return result.map(({ x, y, ...data }) => ({ x: x + nodeRadius, y: y + nodeRadius, @@ -431,14 +416,14 @@ export function layoutPoint( })) } -export function flipLayoutPoint( - fragments: Array, +export function flipLayoutRelation( + relations: Array, layerMargin: number, rowMargin: number, nodeRadius: number -): FragmentPointPosition[] { - const fragmentPosition = layoutPoint( - fragments, +): RelationPointPosition[] { + const fragmentPosition = layoutRelation( + relations, rowMargin, layerMargin, nodeRadius @@ -450,21 +435,23 @@ export function flipLayoutPoint( })) } -export function generatePointEdges(layoutMap: FragmentPointPosition[]): Edge[] { +export function generateRelationEdges( + layoutMap: RelationPointPosition[] +): Edge[] { const links = [] - const fragmentMap = new Map() + const relationMap = new Map() for (const x of layoutMap) { - fragmentMap.set(x.id, x) + relationMap.set(x.id, x) } - for (const fragment of layoutMap) { - for (const parentId of fragment.parentIds) { - const parentFragment = fragmentMap.get(parentId)! + for (const relation of layoutMap) { + for (const parentId of relation.parentIds) { + const parentRelation = relationMap.get(parentId)! links.push({ points: [ - { x: fragment.x, y: fragment.y }, - { x: parentFragment.x, y: parentFragment.y }, + { x: relation.x, y: relation.y }, + { x: parentRelation.x, y: parentRelation.y }, ], - source: fragment.id, + source: relation.id, target: parentId, }) } @@ -472,7 +459,9 @@ export function generatePointEdges(layoutMap: FragmentPointPosition[]): Edge[] { return links } -export function generateBoxEdges(layoutMap: FragmentBoxPosition[]): Edge[] { +export function generateFragmentEdges( + layoutMap: FragmentBoxPosition[] +): Edge[] { const links = [] const fragmentMap = new Map() for (const x of layoutMap) { diff --git a/dashboard/pages/api/streaming.ts b/dashboard/pages/api/streaming.ts index a77a165357b9f..13fa8716f821a 100644 --- a/dashboard/pages/api/streaming.ts +++ b/dashboard/pages/api/streaming.ts @@ -45,8 +45,26 @@ export interface StreamingJob extends Relation { dependentRelations: number[] } +export function relationType(x: Relation) { + if ((x as Table).tableType !== undefined) { + return (x as Table).tableType + } else if ((x as Sink).sinkFromName !== undefined) { + return "SINK" + } else if ((x as Source).info !== undefined) { + return "SOURCE" + } else { + return "UNKNOWN" + } +} +export type RelationType = ReturnType + +export function relationTypeTitleCase(x: Relation) { + return _.startCase(_.toLower(relationType(x))) +} + export function relationIsStreamingJob(x: Relation): x is StreamingJob { - return (x as StreamingJob).dependentRelations !== undefined + const type = relationType(x) + return type !== "UNKNOWN" && type !== "SOURCE" && type !== "INTERNAL" } export async function getStreamingJobs() { diff --git a/dashboard/pages/dependency_graph.tsx b/dashboard/pages/dependency_graph.tsx index fb29f57b11bb5..a4c13a94df169 100644 --- a/dashboard/pages/dependency_graph.tsx +++ b/dashboard/pages/dependency_graph.tsx @@ -20,15 +20,17 @@ import { reverse, sortBy } from "lodash" import Head from "next/head" import { parseAsInteger, useQueryState } from "nuqs" import { Fragment, useCallback } from "react" -import RelationDependencyGraph from "../components/RelationDependencyGraph" +import RelationDependencyGraph, { + nodeRadius, +} from "../components/RelationDependencyGraph" import Title from "../components/Title" -import { FragmentPoint } from "../lib/layout" +import { RelationPoint } from "../lib/layout" import useFetch from "./api/fetch" import { Relation, getRelations, relationIsStreamingJob } from "./api/streaming" const SIDEBAR_WIDTH = "200px" -function buildDependencyAsEdges(list: Relation[]): FragmentPoint[] { +function buildDependencyAsEdges(list: Relation[]): RelationPoint[] { const edges = [] const relationSet = new Set(list.map((r) => r.id)) for (const r of reverse(sortBy(list, "id"))) { @@ -41,24 +43,27 @@ function buildDependencyAsEdges(list: Relation[]): FragmentPoint[] { .map((r) => r.toString()) : [], order: r.id, + width: nodeRadius * 2, + height: nodeRadius * 2, + relation: r, }) } return edges } export default function StreamingGraph() { - const { response: streamingJobList } = useFetch(getRelations) + const { response: relationList } = useFetch(getRelations) const [selectedId, setSelectedId] = useQueryState("id", parseAsInteger) - const mvDependencyCallback = useCallback(() => { - if (streamingJobList) { - return buildDependencyAsEdges(streamingJobList) + const relationDependencyCallback = useCallback(() => { + if (relationList) { + return buildDependencyAsEdges(relationList) } else { return undefined } - }, [streamingJobList]) + }, [relationList]) - const mvDependency = mvDependencyCallback() + const relationDependency = relationDependencyCallback() const retVal = ( @@ -77,7 +82,7 @@ export default function StreamingGraph() { - {streamingJobList?.map((r) => { + {relationList?.map((r) => { const match = selectedId === r.id return (