From 417561d3989374bf795c83687d38203fd717f113 Mon Sep 17 00:00:00 2001 From: William Wen <44139337+wenym1@users.noreply.github.com> Date: Tue, 26 Sep 2023 14:12:47 +0800 Subject: [PATCH] feat(test): support customized source logic in deterministic test (#12456) --- Cargo.lock | 1 + src/connector/src/macros.rs | 5 +- src/connector/src/source/mod.rs | 2 + src/connector/src/source/test_source.rs | 239 ++++++++++++++++++ src/frontend/src/handler/create_source.rs | 4 + .../src/executor/wrapper/schema_check.rs | 2 +- src/tests/simulation/Cargo.toml | 1 + .../tests/integration_tests/sink/basic.rs | 114 +++++++-- 8 files changed, 345 insertions(+), 23 deletions(-) create mode 100644 src/connector/src/source/test_source.rs diff --git a/Cargo.lock b/Cargo.lock index 42a08e60f43d5..919c5a234dc00 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7473,6 +7473,7 @@ dependencies = [ "tempfile", "tikv-jemallocator", "tokio-postgres", + "tokio-stream", "tracing", "tracing-subscriber", ] diff --git a/src/connector/src/macros.rs b/src/connector/src/macros.rs index 792e2066abcca..62a3cfdcd9682 100644 --- a/src/connector/src/macros.rs +++ b/src/connector/src/macros.rs @@ -31,7 +31,8 @@ macro_rules! for_all_classified_sources { { Datagen, $crate::source::datagen::DatagenProperties, $crate::source::datagen::DatagenSplit }, { GooglePubsub, $crate::source::google_pubsub::PubsubProperties, $crate::source::google_pubsub::PubsubSplit }, { Nats, $crate::source::nats::NatsProperties, $crate::source::nats::split::NatsSplit }, - { S3, $crate::source::filesystem::S3Properties, $crate::source::filesystem::FsSplit } + { S3, $crate::source::filesystem::S3Properties, $crate::source::filesystem::FsSplit }, + { Test, $crate::source::test_source::TestSourceProperties, $crate::source::test_source::TestSourceSplit} } $( ,$extra_args @@ -152,7 +153,7 @@ macro_rules! dispatch_split_impl { macro_rules! impl_split { ({$({ $variant_name:ident, $prop_name:ty, $split:ty}),*}) => { - #[derive(Debug, Clone, EnumAsInner, PartialEq, Hash)] + #[derive(Debug, Clone, EnumAsInner, PartialEq)] pub enum SplitImpl { $( $variant_name($split), diff --git a/src/connector/src/source/mod.rs b/src/connector/src/source/mod.rs index 20a9f706e60b5..762af05cd0c96 100644 --- a/src/connector/src/source/mod.rs +++ b/src/connector/src/source/mod.rs @@ -34,6 +34,8 @@ mod common; pub mod external; mod manager; mod mock_external_table; +pub mod test_source; + pub use manager::{SourceColumnDesc, SourceColumnType}; pub use mock_external_table::MockExternalTableReader; diff --git a/src/connector/src/source/test_source.rs b/src/connector/src/source/test_source.rs new file mode 100644 index 0000000000000..743ae3b179427 --- /dev/null +++ b/src/connector/src/source/test_source.rs @@ -0,0 +1,239 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::{Arc, OnceLock}; + +use anyhow::anyhow; +use async_trait::async_trait; +use parking_lot::Mutex; +use risingwave_common::types::JsonbVal; +use serde_derive::{Deserialize, Serialize}; + +use crate::parser::ParserConfig; +use crate::source::{ + BoxSourceWithStateStream, Column, SourceContextRef, SourceEnumeratorContextRef, + SourceProperties, SplitEnumerator, SplitId, SplitMetaData, SplitReader, TryFromHashmap, +}; + +pub type BoxListSplits = Box< + dyn FnMut( + TestSourceProperties, + SourceEnumeratorContextRef, + ) -> anyhow::Result> + + Send + + 'static, +>; + +pub type BoxIntoSourceStream = Box< + dyn FnMut( + TestSourceProperties, + Vec, + ParserConfig, + SourceContextRef, + Option>, + ) -> BoxSourceWithStateStream + + Send + + 'static, +>; + +pub struct BoxSource { + list_split: BoxListSplits, + into_source_stream: BoxIntoSourceStream, +} + +impl BoxSource { + pub fn new( + list_splits: impl FnMut( + TestSourceProperties, + SourceEnumeratorContextRef, + ) -> anyhow::Result> + + Send + + 'static, + into_source_stream: impl FnMut( + TestSourceProperties, + Vec, + ParserConfig, + SourceContextRef, + Option>, + ) -> BoxSourceWithStateStream + + Send + + 'static, + ) -> BoxSource { + BoxSource { + list_split: Box::new(list_splits), + into_source_stream: Box::new(into_source_stream), + } + } +} + +struct TestSourceRegistry { + box_source: Arc>>, +} + +impl TestSourceRegistry { + fn new() -> Self { + TestSourceRegistry { + box_source: Arc::new(Mutex::new(None)), + } + } +} + +fn get_registry() -> &'static TestSourceRegistry { + static GLOBAL_REGISTRY: OnceLock = OnceLock::new(); + GLOBAL_REGISTRY.get_or_init(TestSourceRegistry::new) +} + +pub struct TestSourceRegistryGuard; + +impl Drop for TestSourceRegistryGuard { + fn drop(&mut self) { + assert!(get_registry().box_source.lock().take().is_some()); + } +} + +pub fn registry_test_source(box_source: BoxSource) -> TestSourceRegistryGuard { + assert!(get_registry() + .box_source + .lock() + .replace(box_source) + .is_none()); + TestSourceRegistryGuard +} + +pub const TEST_CONNECTOR: &str = "test"; + +#[derive(Clone, Debug)] +pub struct TestSourceProperties { + properties: HashMap, +} + +impl TryFromHashmap for TestSourceProperties { + fn try_from_hashmap(props: HashMap) -> anyhow::Result { + if cfg!(any(madsim, test)) { + Ok(TestSourceProperties { properties: props }) + } else { + Err(anyhow!("test source only available at test")) + } + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct TestSourceSplit { + pub id: SplitId, + pub properties: HashMap, + pub offset: String, +} + +impl SplitMetaData for TestSourceSplit { + fn id(&self) -> SplitId { + self.id.clone() + } + + fn encode_to_json(&self) -> JsonbVal { + serde_json::to_value(self.clone()).unwrap().into() + } + + fn restore_from_json(value: JsonbVal) -> anyhow::Result { + serde_json::from_value(value.take()).map_err(|e| anyhow!(e)) + } + + fn update_with_offset(&mut self, start_offset: String) -> anyhow::Result<()> { + self.offset = start_offset; + Ok(()) + } +} + +pub struct TestSourceSplitEnumerator { + properties: TestSourceProperties, + context: SourceEnumeratorContextRef, +} + +#[async_trait] +impl SplitEnumerator for TestSourceSplitEnumerator { + type Properties = TestSourceProperties; + type Split = TestSourceSplit; + + async fn new( + properties: Self::Properties, + context: SourceEnumeratorContextRef, + ) -> anyhow::Result { + Ok(Self { + properties, + context, + }) + } + + async fn list_splits(&mut self) -> anyhow::Result> { + (get_registry() + .box_source + .lock() + .as_mut() + .expect("should have init") + .list_split)(self.properties.clone(), self.context.clone()) + } +} + +pub struct TestSourceSplitReader { + properties: TestSourceProperties, + state: Vec, + parser_config: ParserConfig, + source_ctx: SourceContextRef, + columns: Option>, +} + +#[async_trait] +impl SplitReader for TestSourceSplitReader { + type Properties = TestSourceProperties; + type Split = TestSourceSplit; + + async fn new( + properties: Self::Properties, + state: Vec, + parser_config: ParserConfig, + source_ctx: SourceContextRef, + columns: Option>, + ) -> anyhow::Result { + Ok(Self { + properties, + state, + parser_config, + source_ctx, + columns, + }) + } + + fn into_stream(self) -> BoxSourceWithStateStream { + (get_registry() + .box_source + .lock() + .as_mut() + .expect("should have init") + .into_source_stream)( + self.properties, + self.state, + self.parser_config, + self.source_ctx, + self.columns, + ) + } +} + +impl SourceProperties for TestSourceProperties { + type Split = TestSourceSplit; + type SplitEnumerator = TestSourceSplitEnumerator; + type SplitReader = TestSourceSplitReader; + + const SOURCE_NAME: &'static str = TEST_CONNECTOR; +} diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index 448d2f49923f7..7479348c4b80f 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -36,6 +36,7 @@ use risingwave_connector::source::cdc::{ use risingwave_connector::source::datagen::DATAGEN_CONNECTOR; use risingwave_connector::source::filesystem::S3_CONNECTOR; use risingwave_connector::source::nexmark::source::{get_event_data_types_with_names, EventType}; +use risingwave_connector::source::test_source::TEST_CONNECTOR; use risingwave_connector::source::{ SourceEncode, SourceFormat, SourceStruct, GOOGLE_PUBSUB_CONNECTOR, KAFKA_CONNECTOR, KINESIS_CONNECTOR, NATS_CONNECTOR, NEXMARK_CONNECTOR, PULSAR_CONNECTOR, @@ -907,6 +908,9 @@ static CONNECTORS_COMPATIBLE_FORMATS: LazyLock hashmap!( Format::Plain => vec![Encode::Json], ), + TEST_CONNECTOR => hashmap!( + Format::Plain => vec![Encode::Json], + ) )) }); diff --git a/src/stream/src/executor/wrapper/schema_check.rs b/src/stream/src/executor/wrapper/schema_check.rs index d23eca2b455c6..3e8738db8327a 100644 --- a/src/stream/src/executor/wrapper/schema_check.rs +++ b/src/stream/src/executor/wrapper/schema_check.rs @@ -45,7 +45,7 @@ pub async fn schema_check(info: Arc, input: impl MessageStream) { } Message::Barrier(_) => Ok(()), } - .unwrap_or_else(|e| panic!("schema check failed on {}: {}", info.identity, e)); + .unwrap_or_else(|e| panic!("schema check failed on {:?}: {}", info, e)); yield message; } diff --git a/src/tests/simulation/Cargo.toml b/src/tests/simulation/Cargo.toml index 26fce12ce37b4..1268b670471d5 100644 --- a/src/tests/simulation/Cargo.toml +++ b/src/tests/simulation/Cargo.toml @@ -48,6 +48,7 @@ sqllogictest = "0.17.0" tempfile = "3" tokio = { version = "0.2.23", package = "madsim-tokio" } tokio-postgres = "0.7" +tokio-stream = "0.1" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/tests/simulation/tests/integration_tests/sink/basic.rs b/src/tests/simulation/tests/integration_tests/sink/basic.rs index c0f9f7253f373..a5715a8471c44 100644 --- a/src/tests/simulation/tests/integration_tests/sink/basic.rs +++ b/src/tests/simulation/tests/integration_tests/sink/basic.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::io::Write; +use std::iter::once; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; @@ -20,15 +21,23 @@ use std::time::Duration; use anyhow::Result; use async_trait::async_trait; +use futures::stream::select_all; +use futures::StreamExt; use itertools::Itertools; use rand::prelude::SliceRandom; -use risingwave_common::array::StreamChunk; +use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::buffer::Bitmap; +use risingwave_common::types::{DataType, ScalarImpl}; +use risingwave_common::util::chunk_coalesce::DataChunkBuilder; use risingwave_connector::sink::boxed::{BoxCoordinator, BoxWriter}; use risingwave_connector::sink::test_sink::registry_build_sink; use risingwave_connector::sink::{Sink, SinkWriter, SinkWriterParam}; +use risingwave_connector::source::test_source::{registry_test_source, BoxSource, TestSourceSplit}; +use risingwave_connector::source::StreamChunkWithState; use risingwave_simulation::cluster::{Cluster, ConfigPath, Configuration}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; use tokio::time::sleep; +use tokio_stream::wrappers::UnboundedReceiverStream; struct TestWriter { row_counter: Arc, @@ -91,6 +100,21 @@ impl Sink for TestSink { } } +fn build_stream_chunk(row_iter: impl Iterator) -> StreamChunk { + let mut builder = DataChunkBuilder::new(vec![DataType::Int32, DataType::Varchar], 100000); + for (id, name) in row_iter { + assert!(builder + .append_one_row([ + Some(ScalarImpl::Int32(id)), + Some(ScalarImpl::Utf8(name.into())), + ]) + .is_none()); + } + let chunk = builder.consume_all().unwrap(); + let ops = (0..chunk.cardinality()).map(|_| Op::Insert).collect_vec(); + StreamChunk::from_parts(ops, chunk) +} + #[tokio::test] async fn test_sink_basic() -> Result<()> { let config_path = { @@ -126,30 +150,55 @@ async fn test_sink_basic() -> Result<()> { } }); + let source_parallelism = 12; + let mut txs = Vec::new(); + let mut rxs = Vec::new(); + for _ in 0..source_parallelism { + let (tx, rx): (_, UnboundedReceiver) = unbounded_channel(); + txs.push(tx); + rxs.push(Some(rx)); + } + + let _source_guard = registry_test_source(BoxSource::new( + move |_, _| { + Ok((0..source_parallelism) + .map(|i: usize| TestSourceSplit { + id: format!("{}", i).as_str().into(), + properties: Default::default(), + offset: "".to_string(), + }) + .collect_vec()) + }, + move |_, splits, _, _, _| { + select_all(splits.into_iter().map(|split| { + let id: usize = split.id.parse().unwrap(); + let rx = rxs[id].take().unwrap(); + UnboundedReceiverStream::new(rx).map(|chunk| Ok(StreamChunkWithState::from(chunk))) + })) + .boxed() + }, + )); + let mut session = cluster.start_session(); session.run("set streaming_parallelism = 6").await?; session.run("set sink_decouple = false").await?; session - .run("create table test_table (id int, name varchar)") + .run("create table test_table (id int primary key, name varchar) with (connector = 'test') FORMAT PLAIN ENCODE JSON") .await?; session .run("create sink test_sink from test_table with (connector = 'test')") .await?; let mut count = 0; - let mut id_list = (0..100000).collect_vec(); + let mut id_list: Vec = (0..100000).collect_vec(); id_list.shuffle(&mut rand::thread_rng()); let flush_freq = 50; - for id in &id_list[0..1000] { - session - .run(format!( - "insert into test_table values ({}, 'name-{}')", - id, id - )) - .await?; + for id in &id_list[0..10000] { + let chunk = build_stream_chunk(once((*id as i32, format!("name-{}", id)))); + txs[id % source_parallelism].send(chunk).unwrap(); count += 1; if count % flush_freq == 0 { - session.run("flush").await?; + sleep(Duration::from_millis(10)).await; } } @@ -198,12 +247,41 @@ async fn test_sink_decouple_basic() -> Result<()> { } }); + let source_parallelism = 12; + let mut txs = Vec::new(); + let mut rxs = Vec::new(); + for _ in 0..source_parallelism { + let (tx, rx): (_, UnboundedReceiver) = unbounded_channel(); + txs.push(tx); + rxs.push(Some(rx)); + } + + let _source_guard = registry_test_source(BoxSource::new( + move |_, _| { + Ok((0..source_parallelism) + .map(|i: usize| TestSourceSplit { + id: format!("{}", i).as_str().into(), + properties: Default::default(), + offset: "".to_string(), + }) + .collect_vec()) + }, + move |_, splits, _, _, _| { + select_all(splits.into_iter().map(|split| { + let id: usize = split.id.parse().unwrap(); + let rx = rxs[id].take().unwrap(); + UnboundedReceiverStream::new(rx).map(|chunk| Ok(StreamChunkWithState::from(chunk))) + })) + .boxed() + }, + )); + let mut session = cluster.start_session(); session.run("set streaming_parallelism = 6").await?; session.run("set sink_decouple = true").await?; session - .run("create table test_table (id int, name varchar)") + .run("create table test_table (id int primary key, name varchar) with (connector = 'test') FORMAT PLAIN ENCODE JSON") .await?; session .run("create sink test_sink from test_table with (connector = 'test')") @@ -214,16 +292,12 @@ async fn test_sink_decouple_basic() -> Result<()> { let mut id_list = (0..100000).collect_vec(); id_list.shuffle(&mut rand::thread_rng()); let flush_freq = 50; - for id in &id_list[0..1000] { - session - .run(format!( - "insert into test_table values ({}, 'name-{}')", - id, id - )) - .await?; + for id in &id_list[0..10000] { + let chunk = build_stream_chunk(once((*id as i32, format!("name-{}", id)))); + txs[id % source_parallelism].send(chunk).unwrap(); count += 1; if count % flush_freq == 0 { - session.run("flush").await?; + sleep(Duration::from_millis(10)).await; } }