diff --git a/src/frontend/src/catalog/table_catalog.rs b/src/frontend/src/catalog/table_catalog.rs index 10ab05a951757..2254d4a96a077 100644 --- a/src/frontend/src/catalog/table_catalog.rs +++ b/src/frontend/src/catalog/table_catalog.rs @@ -464,8 +464,8 @@ impl TableCatalog { } } - pub fn default_column_exprs(&self) -> Vec { - self.columns + pub fn default_column_exprs(columns: &[ColumnCatalog]) -> Vec { + columns .iter() .map(|c| { if let Some(GeneratedOrDefaultColumn::DefaultColumn(DefaultColumnDesc { diff --git a/src/frontend/src/handler/alter_table_column.rs b/src/frontend/src/handler/alter_table_column.rs index 5cdad63ffbc00..7893060f944a9 100644 --- a/src/frontend/src/handler/alter_table_column.rs +++ b/src/frontend/src/handler/alter_table_column.rs @@ -17,12 +17,16 @@ use std::rc::Rc; use std::sync::Arc; use anyhow::Context; +use create_sink::derive_default_column_project_for_sink; use itertools::Itertools; use pgwire::pg_response::{PgResponse, StatementType}; use risingwave_common::bail_not_implemented; +use risingwave_common::catalog::ColumnCatalog; use risingwave_common::util::column_index_mapping::ColIndexMapping; +use risingwave_connector::sink::catalog::SinkCatalog; use risingwave_pb::plan_common::column_desc::GeneratedOrDefaultColumn; use risingwave_pb::plan_common::DefaultColumnDesc; +use risingwave_pb::stream_plan::StreamFragmentGraph; use risingwave_sqlparser::ast::{ AlterTableOperation, ColumnOption, ConnectorSchema, Encode, ObjectName, Statement, }; @@ -31,14 +35,16 @@ use risingwave_sqlparser::parser::Parser; use super::create_source::get_json_schema_location; use super::create_table::{bind_sql_columns, generate_stream_graph_for_table, ColumnIdGenerator}; use super::util::SourceSchemaCompatExt; -use super::{HandlerArgs, RwPgResponse}; +use super::{create_sink, HandlerArgs, RwPgResponse}; use crate::catalog::root_catalog::SchemaPath; use crate::catalog::table_catalog::TableType; use crate::error::{ErrorCode, Result, RwError}; use crate::expr::ExprImpl; -use crate::handler::create_sink::insert_merger_to_union_with_project; +use crate::handler::create_sink::{fetch_incoming_sinks, insert_merger_to_union_with_project}; use crate::optimizer::plan_node::generic::SourceNodeKind; -use crate::optimizer::plan_node::{LogicalSource, StreamProject, ToStream, ToStreamContext}; +use crate::optimizer::plan_node::{ + generic, LogicalSource, StreamProject, ToStream, ToStreamContext, +}; use crate::session::SessionImpl; use crate::{Binder, OptimizerContext, TableCatalog, WithOptions}; @@ -98,51 +104,51 @@ pub async fn replace_table_with_definition( ); let incoming_sink_ids: HashSet<_> = original_catalog.incoming_sinks.iter().copied().collect(); - - let incoming_sinks = { - let reader = session.env().catalog_reader().read_guard(); - let mut sinks = HashMap::new(); - let db_name = session.database(); - for schema in reader.iter_schemas(db_name)? { - for sink in schema.iter_sink() { - if incoming_sink_ids.contains(&sink.id.sink_id) { - sinks.insert(sink.id.sink_id, sink.clone()); - } - } - } - - sinks - }; - + let incoming_sinks = fetch_incoming_sinks(session, &incoming_sink_ids)?; let target_columns = bind_sql_columns(&columns)?; + let default_columns: Vec = TableCatalog::default_column_exprs(&target_columns); - let default_x: Vec = target_columns - .iter() - .map(|c| { - if let Some(GeneratedOrDefaultColumn::DefaultColumn(DefaultColumnDesc { - expr, .. - })) = c.column_desc.generated_or_default_column.as_ref() - { - ExprImpl::from_expr_proto(expr.as_ref().unwrap()) - .expect("expr in default columns corrupted") - } else { - ExprImpl::literal_null(c.data_type().clone()) - } - }) - .collect(); - - for (_, sink) in incoming_sinks { - let exprs = crate::handler::create_sink::derive_default_column_project_for_sink( - sink.as_ref(), - &sink.full_schema(), + for sink in incoming_sinks { + let context = Rc::new(OptimizerContext::from_handler_args(handler_args.clone())); + hijack_merger_for_target_table( + &mut graph, &target_columns, - &default_x, - false, // todo + &default_columns, + &sink, + context, )?; + } - let context = Rc::new(OptimizerContext::from_handler_args(handler_args.clone())); + table.incoming_sinks = incoming_sink_ids.iter().copied().collect(); + + println!("fe table incoming {:?}", table.incoming_sinks); + + let catalog_writer = session.catalog_writer()?; - let dummy_source_node = LogicalSource::new( + catalog_writer + .replace_table(source, table, graph, col_index_mapping) + .await?; + Ok(()) +} + +pub(crate) fn hijack_merger_for_target_table( + graph: &mut StreamFragmentGraph, + target_columns: &[ColumnCatalog], + default_columns: &[ExprImpl], + sink: &SinkCatalog, + context: Rc, +) -> Result<()> { + let exprs = derive_default_column_project_for_sink( + sink, + &sink.full_schema(), + &target_columns, + &default_columns, + false, // todo + )?; + + let pb_project = StreamProject::new(generic::Project::new( + exprs, + LogicalSource::new( None, sink.full_columns().to_vec(), None, @@ -150,33 +156,16 @@ pub async fn replace_table_with_definition( context, None, ) - .and_then(|s| s.to_stream(&mut ToStreamContext::new(false)))?; - - let logical_project = - crate::optimizer::plan_node::generic::Project::new(exprs, dummy_source_node); - - let input: crate::PlanRef = StreamProject::new(logical_project).into(); - - let x = input.as_stream_project().unwrap(); - - let pb_project = x.to_stream_prost_body_inner(); + .and_then(|s| s.to_stream(&mut ToStreamContext::new(false)))?, + )) + .to_stream_prost_body_inner(); - for fragment in graph.fragments.values_mut() { - if let Some(node) = &mut fragment.node { - insert_merger_to_union_with_project(node, &pb_project); - } + for fragment in graph.fragments.values_mut() { + if let Some(node) = &mut fragment.node { + insert_merger_to_union_with_project(node, &pb_project); } } - table.incoming_sinks = incoming_sink_ids.iter().copied().collect(); - - println!("fe table incoming {:?}", table.incoming_sinks); - - let catalog_writer = session.catalog_writer()?; - - catalog_writer - .replace_table(source, table, graph, col_index_mapping) - .await?; Ok(()) } diff --git a/src/frontend/src/handler/create_sink.rs b/src/frontend/src/handler/create_sink.rs index 2f2a42653369c..1337482b76e03 100644 --- a/src/frontend/src/handler/create_sink.rs +++ b/src/frontend/src/handler/create_sink.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::rc::Rc; use std::sync::{Arc, LazyLock}; @@ -51,11 +51,14 @@ use crate::binder::Binder; use crate::catalog::catalog_service::CatalogReadGuard; use crate::catalog::source_catalog::SourceCatalog; use crate::catalog::view_catalog::ViewCatalog; +use crate::catalog::SinkId; use crate::error::{ErrorCode, Result, RwError}; use crate::expr::{ExprImpl, InputRef}; use crate::handler::alter_table_column::fetch_table_catalog_for_alter; use crate::handler::create_mv::parse_column_names; -use crate::handler::create_table::{generate_stream_graph_for_table, ColumnIdGenerator}; +use crate::handler::create_table::{ + bind_sql_columns, generate_stream_graph_for_table, ColumnIdGenerator, +}; use crate::handler::privilege::resolve_query_privileges; use crate::handler::util::SourceSchemaCompatExt; use crate::handler::HandlerArgs; @@ -284,7 +287,7 @@ pub fn gen_sink_plan( &sink_catalog, sink_plan.schema(), table_catalog.columns(), - table_catalog.default_column_exprs().as_ref(), + TableCatalog::default_column_exprs(table_catalog.columns()).as_ref(), user_specified_columns, )?; @@ -416,7 +419,7 @@ pub async fn handle_create_sink( let partition_info = get_partition_compute_info(&handle_args.with_options).await?; let (sink, graph, target_table_catalog) = { - let context = Rc::new(OptimizerContext::from_handler_args(handle_args)); + let context = Rc::new(OptimizerContext::from_handler_args(handle_args.clone())); let SinkPlanContext { query, @@ -457,12 +460,27 @@ pub async fn handle_create_sink( .incoming_sinks .clone_from(&table_catalog.incoming_sinks); - for _ in 0..(table_catalog.incoming_sinks.len() + 1) { - for fragment in graph.fragments.values_mut() { - if let Some(node) = &mut fragment.node { - insert_merger_to_union(node); - } - } + // let target_columns = bind_sql_columns(&columns)?; + + let default_columns: Vec = + TableCatalog::default_column_exprs(table_catalog.columns()); + // let default_columns: Vec = TableCatalog::default_column_exprs(&target_columns); + + let incoming_sink_ids: HashSet<_> = table_catalog.incoming_sinks.iter().copied().collect(); + + let mut incoming_sinks = fetch_incoming_sinks(&session, &incoming_sink_ids)?; + + incoming_sinks.push(Arc::new(sink.clone())); + + let context = Rc::new(OptimizerContext::from_handler_args(handle_args.clone())); + for sink in incoming_sinks { + crate::handler::alter_table_column::hijack_merger_for_target_table( + &mut graph, + table_catalog.columns(), + &default_columns, + &sink, + context.clone(), + )?; } target_table_replace_plan = Some(ReplaceTablePlan { @@ -492,6 +510,24 @@ pub async fn handle_create_sink( Ok(PgResponse::empty_result(StatementType::CREATE_SINK)) } +pub fn fetch_incoming_sinks( + session: &Arc, + incoming_sink_ids: &HashSet, +) -> Result>> { + let reader = session.env().catalog_reader().read_guard(); + let mut sinks = Vec::with_capacity(incoming_sink_ids.len()); + let db_name = session.database(); + for schema in reader.iter_schemas(db_name)? { + for sink in schema.iter_sink() { + if incoming_sink_ids.contains(&sink.id.sink_id) { + sinks.push(sink.clone()); + } + } + } + + Ok(sinks) +} + fn check_cycle_for_sink( session: &SessionImpl, sink_catalog: SinkCatalog, diff --git a/src/frontend/src/handler/drop_sink.rs b/src/frontend/src/handler/drop_sink.rs index a42605ad1c856..f072adddbd0e9 100644 --- a/src/frontend/src/handler/drop_sink.rs +++ b/src/frontend/src/handler/drop_sink.rs @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; +use std::rc::Rc; +use std::sync::Arc; + use pgwire::pg_response::{PgResponse, StatementType}; use risingwave_pb::ddl_service::ReplaceTablePlan; use risingwave_sqlparser::ast::ObjectName; @@ -20,8 +24,12 @@ use super::RwPgResponse; use crate::binder::Binder; use crate::catalog::root_catalog::SchemaPath; use crate::error::Result; -use crate::handler::create_sink::{insert_merger_to_union, reparse_table_for_sink}; +use crate::expr::ExprImpl; +use crate::handler::create_sink::{ + fetch_incoming_sinks, insert_merger_to_union, reparse_table_for_sink, +}; use crate::handler::HandlerArgs; +use crate::{OptimizerContext, TableCatalog}; pub async fn handle_drop_sink( handler_args: HandlerArgs, @@ -29,7 +37,7 @@ pub async fn handle_drop_sink( if_exists: bool, cascade: bool, ) -> Result { - let session = handler_args.session; + let session = handler_args.session.clone(); let db_name = session.database(); let (schema_name, sink_name) = Binder::resolve_schema_qualified_name(db_name, sink_name)?; let search_path = session.config().search_path(); @@ -76,12 +84,25 @@ pub async fn handle_drop_sink( .incoming_sinks .clone_from(&table_catalog.incoming_sinks); - for _ in 0..(table_catalog.incoming_sinks.len() - 1) { - for fragment in graph.fragments.values_mut() { - if let Some(node) = &mut fragment.node { - insert_merger_to_union(node); - } - } + let default_columns: Vec = + TableCatalog::default_column_exprs(table_catalog.columns()); + + let mut incoming_sink_ids: HashSet<_> = + table_catalog.incoming_sinks.iter().copied().collect(); + + assert!(incoming_sink_ids.remove(&sink_id.sink_id)); + + let mut incoming_sinks = fetch_incoming_sinks(&session, &incoming_sink_ids)?; + + for sink in incoming_sinks { + let context = Rc::new(OptimizerContext::from_handler_args(handler_args.clone())); + crate::handler::alter_table_column::hijack_merger_for_target_table( + &mut graph, + table_catalog.columns(), + &default_columns, + &sink, + context, + )?; } affected_table_change = Some(ReplaceTablePlan {