From 471aa2b623877ce9ec5cfabc43d737baf0b0aee8 Mon Sep 17 00:00:00 2001 From: William Wen <44139337+wenym1@users.noreply.github.com> Date: Thu, 7 Sep 2023 22:31:44 +0800 Subject: [PATCH] refactor(source): specify cdc generic type parameter for different cdc source (#12109) --- Cargo.lock | 1 + src/connector/Cargo.toml | 1 + src/connector/src/macros.rs | 110 ++++++++++++++- src/connector/src/source/base.rs | 104 ++++---------- .../src/source/cdc/enumerator/mod.rs | 132 +++++++++++------- src/connector/src/source/cdc/mod.rs | 40 ++++-- src/connector/src/source/cdc/source/reader.rs | 34 ++--- src/connector/src/source/cdc/split.rs | 13 +- .../src/executor/backfill/cdc_backfill.rs | 2 +- 9 files changed, 272 insertions(+), 165 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c666315abe84..8bfd22508c5e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6907,6 +6907,7 @@ dependencies = [ "num-bigint", "opendal", "parking_lot 0.12.1", + "paste", "prometheus", "prost", "prost-reflect", diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 418829161431..de0f680453b1 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -62,6 +62,7 @@ nexmark = { version = "0.2", features = ["serde"] } num-bigint = "0.4" opendal = "0.39" parking_lot = "0.12" +paste = "1" prometheus = { version = "0.13", features = ["process"] } prost = { version = "0.11.9", features = ["no-recursion-limit"] } prost-reflect = "0.11.5" diff --git a/src/connector/src/macros.rs b/src/connector/src/macros.rs index ea42c63aa6c5..a30ad58ce07c 100644 --- a/src/connector/src/macros.rs +++ b/src/connector/src/macros.rs @@ -37,7 +37,7 @@ macro_rules! impl_split_enumerator { .map(SplitImpl::$variant_name) .collect_vec() }) - .map_err(|e| ErrorCode::ConnectorError(e.into()).into()), + .map_err(|e| risingwave_common::error::ErrorCode::ConnectorError(e.into()).into()), )* } } @@ -55,6 +55,25 @@ macro_rules! impl_split { } } } + $( + impl TryFrom for $split { + type Error = anyhow::Error; + + fn try_from(split: SplitImpl) -> std::result::Result { + match split { + SplitImpl::$variant_name(inner) => Ok(inner), + other => Err(anyhow::anyhow!("expect {} but get {:?}", stringify!($split), other)) + } + } + } + + impl From<$split> for SplitImpl { + fn from(split: $split) -> SplitImpl { + SplitImpl::$variant_name(split) + } + } + + )* impl TryFrom<&ConnectorSplit> for SplitImpl { type Error = anyhow::Error; @@ -161,7 +180,8 @@ macro_rules! impl_connector_properties { pub fn extract(mut props: HashMap) -> Result { const UPSTREAM_SOURCE_KEY: &str = "connector"; let connector = props.remove(UPSTREAM_SOURCE_KEY).ok_or_else(|| anyhow!("Must specify 'connector' in WITH clause"))?; - if connector.ends_with("cdc") { + use $crate::source::cdc::CDC_CONNECTOR_NAME_SUFFIX; + if connector.ends_with(CDC_CONNECTOR_NAME_SUFFIX) { ConnectorProperties::new_cdc_properties(&connector, props) } else { let json_value = serde_json::to_value(props).map_err(|e| anyhow!(e))?; @@ -180,3 +200,89 @@ macro_rules! impl_connector_properties { } } } + +#[macro_export] +macro_rules! impl_cdc_source_type { + ($({$source_type:ident, $name:expr }),*) => { + $( + paste!{ + #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] + pub struct $source_type; + impl CdcSourceTypeTrait for $source_type { + const CDC_CONNECTOR_NAME: &'static str = concat!($name, "-cdc"); + fn source_type() -> CdcSourceType { + CdcSourceType::$source_type + } + } + + pub type [< $source_type DebeziumSplitEnumerator >] = DebeziumSplitEnumerator<$source_type>; + } + )* + + pub enum CdcSourceType { + $( + $source_type, + )* + } + + impl From for CdcSourceType { + fn from(value: PbSourceType) -> Self { + match value { + PbSourceType::Unspecified => unreachable!(), + $( + PbSourceType::$source_type => CdcSourceType::$source_type, + )* + } + } + } + + impl From for PbSourceType { + fn from(this: CdcSourceType) -> PbSourceType { + match this { + $( + CdcSourceType::$source_type => PbSourceType::$source_type, + )* + } + } + } + + impl ConnectorProperties { + pub(crate) fn new_cdc_properties( + connector_name: &str, + properties: HashMap, + ) -> std::result::Result { + match connector_name { + $( + $source_type::CDC_CONNECTOR_NAME => paste! { + Ok(Self::[< $source_type Cdc >](Box::new(CdcProperties::<$source_type> { + props: properties, + ..Default::default() + }))) + }, + )* + _ => Err(anyhow::anyhow!("unexpected cdc connector '{}'", connector_name,)), + } + } + + pub fn init_cdc_properties(&mut self, table_schema: PbTableSchema) { + match self { + $( + paste! {ConnectorProperties:: [< $source_type Cdc >](c)} => { + c.table_schema = table_schema; + } + )* + _ => {} + } + } + + pub fn is_cdc_connector(&self) -> bool { + match self { + $( + paste! {ConnectorProperties:: [< $source_type Cdc >](_)} => true, + )* + _ => false, + } + } + } + } +} diff --git a/src/connector/src/source/base.rs b/src/connector/src/source/base.rs index 1a856ef38fa3..5a94b36a6b8e 100644 --- a/src/connector/src/source/base.rs +++ b/src/connector/src/source/base.rs @@ -25,15 +25,13 @@ use itertools::Itertools; use parking_lot::Mutex; use risingwave_common::array::StreamChunk; use risingwave_common::catalog::TableId; -use risingwave_common::error::{ErrorCode, ErrorSuppressor, RwError}; +use risingwave_common::error::{ErrorSuppressor, RwError}; use risingwave_common::types::{JsonbVal, Scalar}; -use risingwave_pb::connector_service::PbTableSchema; use risingwave_pb::source::ConnectorSplit; use risingwave_rpc_client::ConnectorClient; -use serde::{Deserialize, Serialize}; use super::datagen::DatagenMeta; -use super::filesystem::{FsSplit, S3FileReader, S3Properties, S3SplitEnumerator, S3_CONNECTOR}; +use super::filesystem::{FsSplit, S3Properties, S3_CONNECTOR}; use super::google_pubsub::GooglePubsubMeta; use super::kafka::KafkaMeta; use super::monitor::SourceMetrics; @@ -42,13 +40,16 @@ use super::nats::source::NatsSplitReader; use super::nexmark::source::message::NexmarkMeta; use crate::parser::ParserConfig; use crate::source::cdc::{ - CdcProperties, CdcSplitReader, DebeziumCdcSplit, DebeziumSplitEnumerator, CITUS_CDC_CONNECTOR, - MYSQL_CDC_CONNECTOR, POSTGRES_CDC_CONNECTOR, + CdcProperties, CdcSplitReader, Citus, CitusDebeziumSplitEnumerator, DebeziumCdcSplit, + DebeziumSplitEnumerator, Mysql, MysqlDebeziumSplitEnumerator, Postgres, + PostgresDebeziumSplitEnumerator, CITUS_CDC_CONNECTOR, MYSQL_CDC_CONNECTOR, + POSTGRES_CDC_CONNECTOR, }; use crate::source::datagen::{ DatagenProperties, DatagenSplit, DatagenSplitEnumerator, DatagenSplitReader, DATAGEN_CONNECTOR, }; use crate::source::dummy_connector::DummySplitReader; +use crate::source::filesystem::{S3FileReader, S3SplitEnumerator}; use crate::source::google_pubsub::{ PubsubProperties, PubsubSplit, PubsubSplitEnumerator, PubsubSplitReader, GOOGLE_PUBSUB_CONNECTOR, @@ -289,7 +290,7 @@ pub trait SplitReader: Sized { fn into_stream(self) -> BoxSourceWithStateStream; } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug)] pub enum ConnectorProperties { Kafka(Box), Pulsar(Box), @@ -297,65 +298,21 @@ pub enum ConnectorProperties { Nexmark(Box), Datagen(Box), S3(Box), - MySqlCdc(Box), - PostgresCdc(Box), - CitusCdc(Box), + MysqlCdc(Box>), + PostgresCdc(Box>), + CitusCdc(Box>), GooglePubsub(Box), Nats(Box), Dummy(Box<()>), } impl ConnectorProperties { - fn new_cdc_properties( - connector_name: &str, - properties: HashMap, - ) -> Result { - match connector_name { - MYSQL_CDC_CONNECTOR => Ok(Self::MySqlCdc(Box::new(CdcProperties { - props: properties, - source_type: "mysql".to_string(), - ..Default::default() - }))), - POSTGRES_CDC_CONNECTOR => Ok(Self::PostgresCdc(Box::new(CdcProperties { - props: properties, - source_type: "postgres".to_string(), - ..Default::default() - }))), - CITUS_CDC_CONNECTOR => Ok(Self::CitusCdc(Box::new(CdcProperties { - props: properties, - source_type: "citus".to_string(), - ..Default::default() - }))), - _ => Err(anyhow!("unexpected cdc connector '{}'", connector_name,)), - } - } - - pub fn init_cdc_properties(&mut self, table_schema: PbTableSchema) { - match self { - ConnectorProperties::MySqlCdc(c) - | ConnectorProperties::PostgresCdc(c) - | ConnectorProperties::CitusCdc(c) => { - c.table_schema = table_schema; - } - _ => {} - } - } - - pub fn is_cdc_connector(&self) -> bool { - matches!( - self, - ConnectorProperties::MySqlCdc(_) - | ConnectorProperties::PostgresCdc(_) - | ConnectorProperties::CitusCdc(_) - ) - } - pub fn support_multiple_splits(&self) -> bool { matches!(self, ConnectorProperties::Kafka(_)) } } -#[derive(Debug, Clone, Serialize, Deserialize, EnumAsInner, PartialEq, Hash)] +#[derive(Debug, Clone, EnumAsInner, PartialEq, Hash)] pub enum SplitImpl { Kafka(KafkaSplit), Pulsar(PulsarSplit), @@ -363,9 +320,9 @@ pub enum SplitImpl { Nexmark(NexmarkSplit), Datagen(DatagenSplit), GooglePubsub(PubsubSplit), - MySqlCdc(DebeziumCdcSplit), - PostgresCdc(DebeziumCdcSplit), - CitusCdc(DebeziumCdcSplit), + MysqlCdc(DebeziumCdcSplit), + PostgresCdc(DebeziumCdcSplit), + CitusCdc(DebeziumCdcSplit), Nats(NatsSplit), S3(FsSplit), } @@ -396,9 +353,9 @@ pub enum SplitReaderImpl { Nexmark(Box), Pulsar(Box), Datagen(Box), - MySqlCdc(Box), - PostgresCdc(Box), - CitusCdc(Box), + MysqlCdc(Box>), + PostgresCdc(Box>), + CitusCdc(Box>), GooglePubsub(Box), Nats(Box), } @@ -409,9 +366,9 @@ pub enum SplitEnumeratorImpl { Kinesis(KinesisSplitEnumerator), Nexmark(NexmarkSplitEnumerator), Datagen(DatagenSplitEnumerator), - MySqlCdc(DebeziumSplitEnumerator), - PostgresCdc(DebeziumSplitEnumerator), - CitusCdc(DebeziumSplitEnumerator), + MysqlCdc(MysqlDebeziumSplitEnumerator), + PostgresCdc(PostgresDebeziumSplitEnumerator), + CitusCdc(CitusDebeziumSplitEnumerator), GooglePubsub(PubsubSplitEnumerator), S3(S3SplitEnumerator), Nats(NatsSplitEnumerator), @@ -424,9 +381,6 @@ impl_connector_properties! { { Nexmark, NEXMARK_CONNECTOR }, { Datagen, DATAGEN_CONNECTOR }, { S3, S3_CONNECTOR }, - { MySqlCdc, MYSQL_CDC_CONNECTOR }, - { PostgresCdc, POSTGRES_CDC_CONNECTOR }, - { CitusCdc, CITUS_CDC_CONNECTOR }, { GooglePubsub, GOOGLE_PUBSUB_CONNECTOR}, { Nats, NATS_CONNECTOR } } @@ -437,7 +391,7 @@ impl_split_enumerator! { { Kinesis, KinesisSplitEnumerator }, { Nexmark, NexmarkSplitEnumerator }, { Datagen, DatagenSplitEnumerator }, - { MySqlCdc, DebeziumSplitEnumerator }, + { MysqlCdc, DebeziumSplitEnumerator }, { PostgresCdc, DebeziumSplitEnumerator }, { CitusCdc, DebeziumSplitEnumerator }, { GooglePubsub, PubsubSplitEnumerator}, @@ -452,9 +406,9 @@ impl_split! { { Nexmark, NEXMARK_CONNECTOR, NexmarkSplit }, { Datagen, DATAGEN_CONNECTOR, DatagenSplit }, { GooglePubsub, GOOGLE_PUBSUB_CONNECTOR, PubsubSplit }, - { MySqlCdc, MYSQL_CDC_CONNECTOR, DebeziumCdcSplit }, - { PostgresCdc, POSTGRES_CDC_CONNECTOR, DebeziumCdcSplit }, - { CitusCdc, CITUS_CDC_CONNECTOR, DebeziumCdcSplit }, + { MysqlCdc, MYSQL_CDC_CONNECTOR, DebeziumCdcSplit }, + { PostgresCdc, POSTGRES_CDC_CONNECTOR, DebeziumCdcSplit }, + { CitusCdc, CITUS_CDC_CONNECTOR, DebeziumCdcSplit }, { S3, S3_CONNECTOR, FsSplit }, { Nats, NATS_CONNECTOR, NatsSplit } } @@ -466,7 +420,7 @@ impl_split_reader! { { Kinesis, KinesisSplitReader }, { Nexmark, NexmarkSplitReader }, { Datagen, DatagenSplitReader }, - { MySqlCdc, CdcSplitReader}, + { MysqlCdc, CdcSplitReader}, { PostgresCdc, CdcSplitReader}, { CitusCdc, CdcSplitReader }, { GooglePubsub, PubsubSplitReader }, @@ -565,7 +519,7 @@ mod tests { let offset_str = "{\"sourcePartition\":{\"server\":\"RW_CDC_mydb.products\"},\"sourceOffset\":{\"transaction_id\":null,\"ts_sec\":1670407377,\"file\":\"binlog.000001\",\"pos\":98587,\"row\":2,\"server_id\":1,\"event\":2}}"; let mysql_split = MySqlCdcSplit::new(1001, offset_str.to_string()); let split = DebeziumCdcSplit::new(Some(mysql_split), None); - let split_impl = SplitImpl::MySqlCdc(split); + let split_impl = SplitImpl::MysqlCdc(split); let encoded_split = split_impl.encode_to_bytes(); let restored_split_impl = SplitImpl::restore_from_bytes(encoded_split.as_ref())?; assert_eq!( @@ -653,8 +607,7 @@ mod tests { )); let conn_props = ConnectorProperties::extract(user_props_mysql).unwrap(); - if let ConnectorProperties::MySqlCdc(c) = conn_props { - assert_eq!(c.source_type, "mysql"); + if let ConnectorProperties::MysqlCdc(c) = conn_props { assert_eq!(c.props.get("connector_node_addr").unwrap(), "localhost"); assert_eq!(c.props.get("database.hostname").unwrap(), "127.0.0.1"); assert_eq!(c.props.get("database.port").unwrap(), "3306"); @@ -668,7 +621,6 @@ mod tests { let conn_props = ConnectorProperties::extract(user_props_postgres).unwrap(); if let ConnectorProperties::PostgresCdc(c) = conn_props { - assert_eq!(c.source_type, "postgres"); assert_eq!(c.props.get("connector_node_addr").unwrap(), "localhost"); assert_eq!(c.props.get("database.hostname").unwrap(), "127.0.0.1"); assert_eq!(c.props.get("database.port").unwrap(), "5432"); diff --git a/src/connector/src/source/cdc/enumerator/mod.rs b/src/connector/src/source/cdc/enumerator/mod.rs index 75173e413a98..1c689026e568 100644 --- a/src/connector/src/source/cdc/enumerator/mod.rs +++ b/src/connector/src/source/cdc/enumerator/mod.rs @@ -12,38 +12,43 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::marker::PhantomData; use std::str::FromStr; use anyhow::anyhow; use async_trait::async_trait; use itertools::Itertools; use risingwave_common::util::addr::HostAddr; -use risingwave_pb::connector_service::SourceType as PbSourceType; +use risingwave_pb::connector_service::SourceType; use crate::source::cdc::{ - CdcProperties, CdcSplitBase, DebeziumCdcSplit, MySqlCdcSplit, PostgresCdcSplit, + CdcProperties, CdcSourceTypeTrait, CdcSplitBase, Citus, DebeziumCdcSplit, MySqlCdcSplit, Mysql, + Postgres, PostgresCdcSplit, }; use crate::source::{SourceEnumeratorContextRef, SplitEnumerator}; pub const DATABASE_SERVERS_KEY: &str = "database.servers"; #[derive(Debug)] -pub struct DebeziumSplitEnumerator { +pub struct DebeziumSplitEnumerator { /// The source_id in the catalog source_id: u32, - source_type: PbSourceType, worker_node_addrs: Vec, + _phantom: PhantomData, } #[async_trait] -impl SplitEnumerator for DebeziumSplitEnumerator { - type Properties = CdcProperties; - type Split = DebeziumCdcSplit; +impl SplitEnumerator for DebeziumSplitEnumerator +where + Self: ListCdcSplits, +{ + type Properties = CdcProperties; + type Split = DebeziumCdcSplit; async fn new( - props: CdcProperties, + props: CdcProperties, context: SourceEnumeratorContextRef, - ) -> anyhow::Result { + ) -> anyhow::Result { let connector_client = context.connector_client.clone().ok_or_else(|| { anyhow!("connector node endpoint not specified or unable to connect to connector node") })?; @@ -59,12 +64,16 @@ impl SplitEnumerator for DebeziumSplitEnumerator { .transpose()? .unwrap_or_default(); - let source_type = props.get_source_type_pb()?; + assert_eq!( + props.get_source_type_pb(), + SourceType::from(T::source_type()) + ); + // validate connector properties connector_client .validate_source_properties( context.info.source_id as u64, - props.get_source_type_pb()?, + props.get_source_type_pb(), props.props, Some(props.table_schema), ) @@ -73,54 +82,73 @@ impl SplitEnumerator for DebeziumSplitEnumerator { tracing::debug!("validate cdc source properties success"); Ok(Self { source_id: context.info.source_id, - source_type, worker_node_addrs: server_addrs, + _phantom: PhantomData, }) } - async fn list_splits(&mut self) -> anyhow::Result> { - match self.source_type { - PbSourceType::Mysql => { - // CDC source only supports single split - let split = MySqlCdcSplit { - inner: CdcSplitBase::new(self.source_id, None), - }; - let dbz_split = DebeziumCdcSplit { - mysql_split: Some(split), - pg_split: None, - }; - Ok(vec![dbz_split]) - } - PbSourceType::Postgres => { + async fn list_splits(&mut self) -> anyhow::Result>> { + Ok(self.list_cdc_splits()) + } +} + +pub trait ListCdcSplits { + type CdcSourceType: CdcSourceTypeTrait; + fn list_cdc_splits(&mut self) -> Vec>; +} + +impl ListCdcSplits for DebeziumSplitEnumerator { + type CdcSourceType = Mysql; + + fn list_cdc_splits(&mut self) -> Vec> { + // CDC source only supports single split + let split = MySqlCdcSplit { + inner: CdcSplitBase::new(self.source_id, None), + }; + let dbz_split = DebeziumCdcSplit { + mysql_split: Some(split), + pg_split: None, + _phantom: PhantomData, + }; + vec![dbz_split] + } +} + +impl ListCdcSplits for DebeziumSplitEnumerator { + type CdcSourceType = Postgres; + + fn list_cdc_splits(&mut self) -> Vec> { + let split = PostgresCdcSplit { + inner: CdcSplitBase::new(self.source_id, None), + server_addr: None, + }; + let dbz_split = DebeziumCdcSplit { + mysql_split: None, + pg_split: Some(split), + _phantom: Default::default(), + }; + vec![dbz_split] + } +} + +impl ListCdcSplits for DebeziumSplitEnumerator { + type CdcSourceType = Citus; + + fn list_cdc_splits(&mut self) -> Vec> { + self.worker_node_addrs + .iter() + .enumerate() + .map(|(id, addr)| { let split = PostgresCdcSplit { - inner: CdcSplitBase::new(self.source_id, None), - server_addr: None, + inner: CdcSplitBase::new(id as u32, None), + server_addr: Some(addr.to_string()), }; - let dbz_split = DebeziumCdcSplit { + DebeziumCdcSplit { mysql_split: None, pg_split: Some(split), - }; - Ok(vec![dbz_split]) - } - PbSourceType::Citus => { - let splits = self - .worker_node_addrs - .iter() - .enumerate() - .map(|(id, addr)| { - let split = PostgresCdcSplit { - inner: CdcSplitBase::new(id as u32, None), - server_addr: Some(addr.to_string()), - }; - DebeziumCdcSplit { - mysql_split: None, - pg_split: Some(split), - } - }) - .collect_vec(); - Ok(splits) - } - _ => Err(anyhow!("unexpected source type")), - } + _phantom: Default::default(), + } + }) + .collect_vec() } } diff --git a/src/connector/src/source/cdc/mod.rs b/src/connector/src/source/cdc/mod.rs index 52fdd2d38e59..86a8b16adec0 100644 --- a/src/connector/src/source/cdc/mod.rs +++ b/src/connector/src/source/cdc/mod.rs @@ -15,36 +15,46 @@ pub mod enumerator; pub mod source; pub mod split; - use std::collections::HashMap; +use std::marker::PhantomData; -use anyhow::anyhow; pub use enumerator::*; +use paste::paste; use risingwave_common::catalog::{ColumnDesc, Field, Schema}; -use risingwave_pb::connector_service::{SourceType, TableSchema}; -use serde::Deserialize; +use risingwave_pb::connector_service::{PbSourceType, PbTableSchema, SourceType, TableSchema}; pub use source::*; pub use split::*; -pub const MYSQL_CDC_CONNECTOR: &str = "mysql-cdc"; -pub const POSTGRES_CDC_CONNECTOR: &str = "postgres-cdc"; -pub const CITUS_CDC_CONNECTOR: &str = "citus-cdc"; +use crate::impl_cdc_source_type; +use crate::source::ConnectorProperties; + +pub const CDC_CONNECTOR_NAME_SUFFIX: &str = "-cdc"; -#[derive(Clone, Debug, Deserialize, Default)] -pub struct CdcProperties { - /// Type of the cdc source, e.g. mysql, postgres - pub source_type: String, +pub const MYSQL_CDC_CONNECTOR: &str = Mysql::CDC_CONNECTOR_NAME; +pub const POSTGRES_CDC_CONNECTOR: &str = Postgres::CDC_CONNECTOR_NAME; +pub const CITUS_CDC_CONNECTOR: &str = Citus::CDC_CONNECTOR_NAME; + +pub trait CdcSourceTypeTrait: Send + Sync + Clone + 'static { + const CDC_CONNECTOR_NAME: &'static str; + fn source_type() -> CdcSourceType; +} + +impl_cdc_source_type!({ Mysql, "mysql" }, { Postgres, "postgres" }, { Citus, "citus" }); + +#[derive(Clone, Debug, Default)] +pub struct CdcProperties { /// Properties specified in the WITH clause by user pub props: HashMap, /// Schema of the source specified by users pub table_schema: TableSchema, + + pub _phantom: PhantomData, } -impl CdcProperties { - pub fn get_source_type_pb(&self) -> anyhow::Result { - SourceType::from_str_name(&self.source_type.to_ascii_uppercase()) - .ok_or_else(|| anyhow!("unknown source type: {}", self.source_type)) +impl CdcProperties { + pub fn get_source_type_pb(&self) -> SourceType { + SourceType::from(T::source_type()) } pub fn schema(&self) -> Schema { diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 32e0c27ca285..974ec8877d2f 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -23,19 +23,19 @@ use risingwave_pb::connector_service::GetEventStreamResponse; use crate::parser::ParserConfig; use crate::source::base::SourceMessage; -use crate::source::cdc::CdcProperties; +use crate::source::cdc::{CdcProperties, CdcSourceType, CdcSourceTypeTrait, DebeziumCdcSplit}; use crate::source::common::{into_chunk_stream, CommonSplitReader}; use crate::source::{ BoxSourceWithStateStream, Column, SourceContextRef, SplitId, SplitImpl, SplitMetaData, SplitReader, }; -pub struct CdcSplitReader { +pub struct CdcSplitReader { source_id: u64, start_offset: Option, // host address of worker node for a Citus cluster server_addr: Option, - conn_props: CdcProperties, + conn_props: CdcProperties, split_id: SplitId, // whether the full snapshot phase is done @@ -45,22 +45,25 @@ pub struct CdcSplitReader { } #[async_trait] -impl SplitReader for CdcSplitReader { - type Properties = CdcProperties; +impl SplitReader for CdcSplitReader +where + DebeziumCdcSplit: TryFrom, +{ + type Properties = CdcProperties; #[allow(clippy::unused_async)] async fn new( - conn_props: CdcProperties, + conn_props: CdcProperties, splits: Vec, parser_config: ParserConfig, source_ctx: SourceContextRef, _columns: Option>, ) -> Result { assert_eq!(splits.len(), 1); - let split = splits.into_iter().next().unwrap(); + let split = DebeziumCdcSplit::::try_from(splits.into_iter().next().unwrap())?; let split_id = split.id(); - match split { - SplitImpl::MySqlCdc(split) | SplitImpl::PostgresCdc(split) => Ok(Self { + match T::source_type() { + CdcSourceType::Mysql | CdcSourceType::Postgres => Ok(Self { source_id: split.split_id() as u64, start_offset: split.start_offset().clone(), server_addr: None, @@ -70,7 +73,7 @@ impl SplitReader for CdcSplitReader { parser_config, source_ctx, }), - SplitImpl::CitusCdc(split) => Ok(Self { + CdcSourceType::Citus => Ok(Self { source_id: split.split_id() as u64, start_offset: split.start_offset().clone(), server_addr: split.server_addr().clone(), @@ -80,10 +83,6 @@ impl SplitReader for CdcSplitReader { parser_config, source_ctx, }), - - _ => Err(anyhow!( - "failed to create cdc split reader: invalid splis info" - )), } } @@ -94,7 +93,10 @@ impl SplitReader for CdcSplitReader { } } -impl CommonSplitReader for CdcSplitReader { +impl CommonSplitReader for CdcSplitReader +where + Self: SplitReader, +{ #[try_stream(ok = Vec, error = anyhow::Error)] async fn into_data_stream(self) { let cdc_client = self.source_ctx.connector_client.clone().ok_or_else(|| { @@ -122,7 +124,7 @@ impl CommonSplitReader for CdcSplitReader { let cdc_stream = cdc_client .start_source_stream( self.source_id, - self.conn_props.get_source_type_pb()?, + self.conn_props.get_source_type_pb(), self.start_offset, properties, self.snapshot_done, diff --git a/src/connector/src/source/cdc/split.rs b/src/connector/src/source/cdc/split.rs index bee23f97857d..9d13a3c5a6ea 100644 --- a/src/connector/src/source/cdc/split.rs +++ b/src/connector/src/source/cdc/split.rs @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::marker::PhantomData; + use anyhow::anyhow; use risingwave_common::types::JsonbVal; use serde::{Deserialize, Serialize}; +use crate::source::cdc::CdcSourceTypeTrait; use crate::source::external::DebeziumOffset; use crate::source::{SplitId, SplitMetaData}; @@ -119,12 +122,15 @@ impl PostgresCdcSplit { } #[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)] -pub struct DebeziumCdcSplit { +pub struct DebeziumCdcSplit { pub mysql_split: Option, pub pg_split: Option, + + #[serde(skip)] + pub _phantom: PhantomData, } -impl SplitMetaData for DebeziumCdcSplit { +impl SplitMetaData for DebeziumCdcSplit { fn id(&self) -> SplitId { assert!(self.mysql_split.is_some() || self.pg_split.is_some()); if let Some(split) = &self.mysql_split { @@ -145,11 +151,12 @@ impl SplitMetaData for DebeziumCdcSplit { } } -impl DebeziumCdcSplit { +impl DebeziumCdcSplit { pub fn new(mysql_split: Option, pg_split: Option) -> Self { Self { mysql_split, pg_split, + _phantom: PhantomData, } } diff --git a/src/stream/src/executor/backfill/cdc_backfill.rs b/src/stream/src/executor/backfill/cdc_backfill.rs index 55742348bbf0..2f522ae8eeb0 100644 --- a/src/stream/src/executor/backfill/cdc_backfill.rs +++ b/src/stream/src/executor/backfill/cdc_backfill.rs @@ -492,7 +492,7 @@ impl CdcBackfillExecutor { .set(key.into(), JsonbVal::from(Value::Bool(true))) .await?; - if let Some(SplitImpl::MySqlCdc(split)) = cdc_split.as_mut() + if let Some(SplitImpl::MysqlCdc(split)) = cdc_split.as_mut() && let Some(s) = split.mysql_split.as_mut() { let start_offset = last_binlog_offset.as_ref().map(|cdc_offset| {