Skip to content

Commit

Permalink
refactor(frontend): extract filling fields in fragment graph (#18466)
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao authored Sep 10, 2024
1 parent 29d2e1e commit 8c7a364
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 78 deletions.
11 changes: 2 additions & 9 deletions src/frontend/src/handler/create_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use risingwave_common::acl::AclMode;
use risingwave_common::catalog::{IndexId, TableDesc, TableId};
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_pb::catalog::{PbIndex, PbIndexColumnProperties, PbStreamJobStatus, PbTable};
use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism;
use risingwave_pb::user::grant_privilege::Object;
use risingwave_sqlparser::ast;
use risingwave_sqlparser::ast::{Ident, ObjectName, OrderByExpr};
Expand Down Expand Up @@ -448,14 +447,8 @@ pub async fn handle_create_index(
include,
distributed_by,
)?;
let mut graph = build_graph(plan)?;
graph.parallelism =
session
.config()
.streaming_parallelism()
.map(|parallelism| Parallelism {
parallelism: parallelism.get(),
});
let graph = build_graph(plan)?;

(graph, index_table, index)
};

Expand Down
14 changes: 1 addition & 13 deletions src/frontend/src/handler/create_mv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use pgwire::pg_response::{PgResponse, StatementType};
use risingwave_common::acl::AclMode;
use risingwave_common::catalog::TableId;
use risingwave_pb::catalog::PbTable;
use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism;
use risingwave_sqlparser::ast::{EmitMode, Ident, ObjectName, Query};

use super::privilege::resolve_relation_privileges;
Expand Down Expand Up @@ -243,18 +242,7 @@ It only indicates the physical clustering of the data, which may improve the per
emit_mode,
)?;

let context = plan.plan_base().ctx().clone();
let mut graph = build_graph(plan)?;
graph.parallelism =
session
.config()
.streaming_parallelism()
.map(|parallelism| Parallelism {
parallelism: parallelism.get(),
});
// Set the timezone for the stream context
let ctx = graph.ctx.as_mut().unwrap();
ctx.timezone = context.get_session_timezone();
let graph = build_graph(plan)?;

(table, graph)
};
Expand Down
11 changes: 1 addition & 10 deletions src/frontend/src/handler/create_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use risingwave_connector::sink::{
};
use risingwave_pb::catalog::{PbSink, PbSource, Table};
use risingwave_pb::ddl_service::{ReplaceTablePlan, TableJobType};
use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism;
use risingwave_pb::stream_plan::stream_node::{NodeBody, PbNodeBody};
use risingwave_pb::stream_plan::{MergeNode, StreamFragmentGraph, StreamNode};
use risingwave_sqlparser::ast::{
Expand Down Expand Up @@ -445,15 +444,7 @@ pub async fn handle_create_sink(
);
}

let mut graph = build_graph(plan)?;

graph.parallelism =
session
.config()
.streaming_parallelism()
.map(|parallelism| Parallelism {
parallelism: parallelism.get(),
});
let graph = build_graph(plan)?;

(sink, graph, target_table_catalog)
};
Expand Down
11 changes: 1 addition & 10 deletions src/frontend/src/handler/create_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ use risingwave_connector::WithPropertiesExt;
use risingwave_pb::catalog::{PbSchemaRegistryNameStrategy, StreamSourceInfo, WatermarkDesc};
use risingwave_pb::plan_common::additional_column::ColumnType as AdditionalColumnType;
use risingwave_pb::plan_common::{EncodeType, FormatType};
use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism;
use risingwave_sqlparser::ast::{
get_delimiter, AstString, ColumnDef, ConnectorSchema, CreateSourceStatement, Encode, Format,
ObjectName, ProtobufSchema, SourceWatermark, TableConstraint,
Expand Down Expand Up @@ -1697,15 +1696,7 @@ pub async fn handle_create_source(
)?;

let stream_plan = source_node.to_stream(&mut ToStreamContext::new(false))?;
let mut graph = build_graph(stream_plan)?;
graph.parallelism =
session
.config()
.streaming_parallelism()
.map(|parallelism| Parallelism {
parallelism: parallelism.get(),
});
graph
build_graph(stream_plan)?
};
catalog_writer
.create_source_with_graph(source, graph)
Expand Down
23 changes: 4 additions & 19 deletions src/frontend/src/handler/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ use risingwave_pb::plan_common::column_desc::GeneratedOrDefaultColumn;
use risingwave_pb::plan_common::{
AdditionalColumn, ColumnDescVersion, DefaultColumnDesc, GeneratedColumnDesc,
};
use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism;
use risingwave_pb::stream_plan::StreamFragmentGraph;
use risingwave_sqlparser::ast::{
CdcTableInfo, ColumnDef, ColumnOption, ConnectorSchema, DataType as AstDataType,
Expand Down Expand Up @@ -1263,14 +1262,8 @@ pub async fn handle_create_table(
)
.await?;

let mut graph = build_graph(plan)?;
graph.parallelism =
session
.config()
.streaming_parallelism()
.map(|parallelism| Parallelism {
parallelism: parallelism.get(),
});
let graph = build_graph(plan)?;

(graph, source, table, job_type)
};

Expand Down Expand Up @@ -1315,7 +1308,7 @@ pub fn check_create_table_with_source(

#[allow(clippy::too_many_arguments)]
pub async fn generate_stream_graph_for_table(
session: &Arc<SessionImpl>,
_session: &Arc<SessionImpl>,
table_name: ObjectName,
original_catalog: &Arc<TableCatalog>,
source_schema: Option<ConnectorSchema>,
Expand Down Expand Up @@ -1430,15 +1423,7 @@ pub async fn generate_stream_graph_for_table(
))?
}

let graph = StreamFragmentGraph {
parallelism: session
.config()
.streaming_parallelism()
.map(|parallelism| Parallelism {
parallelism: parallelism.get(),
}),
..build_graph(plan)?
};
let graph = build_graph(plan)?;

// Fill the original table ID.
let table = Table {
Expand Down
11 changes: 2 additions & 9 deletions src/frontend/src/handler/create_table_as.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use either::Either;
use pgwire::pg_response::StatementType;
use risingwave_common::catalog::{ColumnCatalog, ColumnDesc};
use risingwave_pb::ddl_service::TableJobType;
use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism;
use risingwave_sqlparser::ast::{ColumnDef, ObjectName, OnConflict, Query, Statement};

use super::{HandlerArgs, RwPgResponse};
Expand Down Expand Up @@ -110,14 +109,8 @@ pub async fn handle_create_as(
with_version_column,
Some(col_id_gen.into_version()),
)?;
let mut graph = build_graph(plan)?;
graph.parallelism =
session
.config()
.streaming_parallelism()
.map(|parallelism| Parallelism {
parallelism: parallelism.get(),
});
let graph = build_graph(plan)?;

(graph, None, table)
};

Expand Down
11 changes: 4 additions & 7 deletions src/frontend/src/stream_fragmenter/graph/fragment_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ use risingwave_pb::stream_plan::stream_fragment_graph::{
StreamFragment as StreamFragmentProto, StreamFragmentEdge as StreamFragmentEdgeProto,
};
use risingwave_pb::stream_plan::{
DispatchStrategy, FragmentTypeFlag, StreamContext,
StreamFragmentGraph as StreamFragmentGraphProto, StreamNode,
DispatchStrategy, FragmentTypeFlag, StreamFragmentGraph as StreamFragmentGraphProto, StreamNode,
};
use thiserror_ext::AsReport;

Expand Down Expand Up @@ -92,9 +91,6 @@ pub struct StreamFragmentGraph {

/// stores edges between fragments: (upstream, downstream) => edge.
edges: HashMap<(LocalFragmentId, LocalFragmentId), StreamFragmentEdgeProto>,

/// Stores the streaming context for the streaming plan
ctx: StreamContext,
}

impl StreamFragmentGraph {
Expand All @@ -106,8 +102,9 @@ impl StreamFragmentGraph {
.map(|(k, v)| (*k, v.to_protobuf()))
.collect(),
edges: self.edges.values().cloned().collect(),
ctx: Some(self.ctx.clone()),
// To be filled later

// Following fields will be filled later in `build_graph` based on session context.
ctx: None,
dependent_table_ids: vec![],
table_ids_cnt: 0,
parallelism: None,
Expand Down
24 changes: 23 additions & 1 deletion src/frontend/src/stream_fragmenter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod graph;
use graph::*;
use risingwave_common::util::recursive::{self, Recurse as _};
use risingwave_connector::WithPropertiesExt;
use risingwave_pb::stream_plan::stream_fragment_graph::Parallelism;
use risingwave_pb::stream_plan::stream_node::NodeBody;
mod rewrite;

Expand All @@ -26,12 +27,13 @@ use educe::Educe;
use risingwave_common::catalog::TableId;
use risingwave_pb::plan_common::JoinType;
use risingwave_pb::stream_plan::{
DispatchStrategy, DispatcherType, ExchangeNode, FragmentTypeFlag, NoOpNode,
DispatchStrategy, DispatcherType, ExchangeNode, FragmentTypeFlag, NoOpNode, StreamContext,
StreamFragmentGraph as StreamFragmentGraphProto, StreamNode, StreamScanType,
};

use self::rewrite::build_delta_join_without_arrange;
use crate::error::Result;
use crate::optimizer::plan_node::generic::GenericPlanRef;
use crate::optimizer::plan_node::reorganize_elements_id;
use crate::optimizer::PlanRef;
use crate::scheduler::SchedulerResult;
Expand Down Expand Up @@ -116,18 +118,38 @@ impl BuildFragmentGraphState {
}

pub fn build_graph(plan_node: PlanRef) -> SchedulerResult<StreamFragmentGraphProto> {
let ctx = plan_node.plan_base().ctx();
let plan_node = reorganize_elements_id(plan_node);

let mut state = BuildFragmentGraphState::default();
let stream_node = plan_node.to_stream_prost(&mut state)?;
generate_fragment_graph(&mut state, stream_node).unwrap();
let mut fragment_graph = state.fragment_graph.to_protobuf();

// Set table ids.
fragment_graph.dependent_table_ids = state
.dependent_table_ids
.into_iter()
.map(|id| id.table_id)
.collect();
fragment_graph.table_ids_cnt = state.next_table_id;

// Set parallelism.
{
let config = ctx.session_ctx().config();
fragment_graph.parallelism =
config
.streaming_parallelism()
.map(|parallelism| Parallelism {
parallelism: parallelism.get(),
});
}

// Set timezone.
fragment_graph.ctx = Some(StreamContext {
timezone: ctx.get_session_timezone(),
});

Ok(fragment_graph)
}

Expand Down

0 comments on commit 8c7a364

Please sign in to comment.