diff --git a/ci/scripts/cron-e2e-test.sh b/ci/scripts/cron-e2e-test.sh index 434fc6dea2985..7a3b5484f8201 100755 --- a/ci/scripts/cron-e2e-test.sh +++ b/ci/scripts/cron-e2e-test.sh @@ -5,5 +5,4 @@ set -euo pipefail source ci/scripts/common.sh export RUN_COMPACTION=0; -export RUN_META_BACKUP=1; source ci/scripts/run-e2e-test.sh diff --git a/ci/scripts/pr.env.sh b/ci/scripts/pr.env.sh index 84fe372d5d574..df722539cbde4 100755 --- a/ci/scripts/pr.env.sh +++ b/ci/scripts/pr.env.sh @@ -4,6 +4,4 @@ set -euo pipefail # Don't run e2e compaction test in PR build -export RUN_COMPACTION=0; -# Don't run meta store backup/recovery test -export RUN_META_BACKUP=0; \ No newline at end of file +export RUN_COMPACTION=0; \ No newline at end of file diff --git a/ci/scripts/run-e2e-test.sh b/ci/scripts/run-e2e-test.sh index ef074af606531..cdf389cd9f65b 100755 --- a/ci/scripts/run-e2e-test.sh +++ b/ci/scripts/run-e2e-test.sh @@ -192,20 +192,6 @@ if [[ "$RUN_COMPACTION" -eq "1" ]]; then cluster_stop fi -if [[ "$RUN_META_BACKUP" -eq "1" ]]; then - echo "--- e2e, ci-meta-backup-test" - test_root="src/storage/backup/integration_tests" - BACKUP_TEST_MCLI=".risingwave/bin/mcli" \ - BACKUP_TEST_MCLI_CONFIG=".risingwave/config/mcli" \ - BACKUP_TEST_RW_ALL_IN_ONE="target/debug/risingwave" \ - RW_HUMMOCK_URL="hummock+minio://hummockadmin:hummockadmin@127.0.0.1:9301/hummock001" \ - RW_META_ADDR="http://127.0.0.1:5690" \ - RUST_LOG="info,risingwave_stream=info,risingwave_batch=info,risingwave_storage=info" \ - bash "${test_root}/run_all.sh" - echo "--- Kill cluster" - cargo make kill -fi - if [[ "$mode" == "standalone" ]]; then run_sql() { psql -h localhost -p 4566 -d dev -U root -c "$@" diff --git a/ci/scripts/run-meta-backup-test.sh b/ci/scripts/run-meta-backup-test.sh new file mode 100755 index 0000000000000..90db9c3f136a5 --- /dev/null +++ b/ci/scripts/run-meta-backup-test.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash + +# Exits as soon as any line fails. +set -euo pipefail + +source ci/scripts/common.sh + +while getopts 'p:m:' opt; do + case ${opt} in + p ) + profile=$OPTARG + ;; + m ) + mode=$OPTARG + ;; + \? ) + echo "Invalid Option: -$OPTARG" 1>&2 + exit 1 + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + ;; + esac +done +shift $((OPTIND -1)) + +if [[ $mode == "standalone" ]]; then + source ci/scripts/standalone-utils.sh +fi + +if [[ $mode == "single-node" ]]; then + source ci/scripts/single-node-utils.sh +fi + +cluster_start() { + if [[ $mode == "standalone" ]]; then + mkdir -p "$PREFIX_LOG" + cargo make clean-data + cargo make pre-start-dev + start_standalone "$PREFIX_LOG"/standalone.log & + cargo make dev standalone-minio-etcd + elif [[ $mode == "single-node" ]]; then + mkdir -p "$PREFIX_LOG" + cargo make clean-data + cargo make pre-start-dev + start_single_node "$PREFIX_LOG"/single-node.log & + # Give it a while to make sure the single-node is ready. + sleep 3 + else + cargo make ci-start "$mode" + fi +} + +cluster_stop() { + if [[ $mode == "standalone" ]] + then + stop_standalone + # Don't check standalone logs, they will exceed the limit. + cargo make kill + elif [[ $mode == "single-node" ]] + then + stop_single_node + else + cargo make ci-kill + fi +} + +download_and_prepare_rw "$profile" common + +echo "--- e2e, ci-meta-backup-test" +test_root="src/storage/backup/integration_tests" +BACKUP_TEST_MCLI=".risingwave/bin/mcli" \ +BACKUP_TEST_MCLI_CONFIG=".risingwave/config/mcli" \ +BACKUP_TEST_RW_ALL_IN_ONE="target/debug/risingwave" \ +RW_HUMMOCK_URL="hummock+minio://hummockadmin:hummockadmin@127.0.0.1:9301/hummock001" \ +RW_META_ADDR="http://127.0.0.1:5690" \ +RUST_LOG="info,risingwave_stream=info,risingwave_batch=info,risingwave_storage=info" \ +bash "${test_root}/run_all.sh" +echo "--- Kill cluster" +cargo make kill \ No newline at end of file diff --git a/ci/workflows/main-cron.yml b/ci/workflows/main-cron.yml index 7e4df5da09d41..6197df011a51f 100644 --- a/ci/workflows/main-cron.yml +++ b/ci/workflows/main-cron.yml @@ -87,7 +87,27 @@ steps: config: ci/docker-compose.yml mount-buildkite-agent: true - ./ci/plugins/upload-failure-logs - timeout_in_minutes: 65 + timeout_in_minutes: 15 + retry: *auto-retry + + - label: "meta backup test (release)" + key: "e2e-meta-backup-test-release" + command: "ci/scripts/run-meta-backup-test.sh -p ci-release -m ci-3streaming-2serving-3fe" + if: | + !(build.pull_request.labels includes "ci/main-cron/skip-ci") && build.env("CI_STEPS") == null + || build.pull_request.labels includes "ci/run-e2e-meta-backup-test" + || build.env("CI_STEPS") =~ /(^|,)e2e-tests?(,|$$)/ + depends_on: + - "build" + - "build-other" + - "docslt" + plugins: + - docker-compose#v5.1.0: + run: rw-build-env + config: ci/docker-compose.yml + mount-buildkite-agent: true + - ./ci/plugins/upload-failure-logs + timeout_in_minutes: 45 retry: *auto-retry - label: "end-to-end test (parallel) (release)" diff --git a/docker/docker-compose-distributed.yml b/docker/docker-compose-distributed.yml index c1ca626a824e6..9d0167ffc2438 100644 --- a/docker/docker-compose-distributed.yml +++ b/docker/docker-compose-distributed.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.7.1} services: compactor-0: <<: *image diff --git a/docker/docker-compose-with-azblob.yml b/docker/docker-compose-with-azblob.yml index e43d28a96ffe5..d93c8079706bb 100644 --- a/docker/docker-compose-with-azblob.yml +++ b/docker/docker-compose-with-azblob.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.7.1} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-gcs.yml b/docker/docker-compose-with-gcs.yml index 5300c6418581d..34768d87223d1 100644 --- a/docker/docker-compose-with-gcs.yml +++ b/docker/docker-compose-with-gcs.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.7.1} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-hdfs.yml b/docker/docker-compose-with-hdfs.yml index cf2b45078bac5..73a22eab4580a 100644 --- a/docker/docker-compose-with-hdfs.yml +++ b/docker/docker-compose-with-hdfs.yml @@ -42,7 +42,7 @@ services: reservations: memory: 1G compute-node-0: - image: "ghcr.io/risingwavelabs/risingwave:RisingWave_1.6.1_HDFS_2.7-x86_64" + image: "ghcr.io/risingwavelabs/risingwave:RisingWave_1.7.1_HDFS_2.7-x86_64" command: - compute-node - "--listen-addr" @@ -132,7 +132,7 @@ services: retries: 5 restart: always frontend-node-0: - image: "ghcr.io/risingwavelabs/risingwave:RisingWave_1.6.1_HDFS_2.7-x86_64" + image: "ghcr.io/risingwavelabs/risingwave:RisingWave_1.7.1_HDFS_2.7-x86_64" command: - frontend-node - "--listen-addr" @@ -195,7 +195,7 @@ services: retries: 5 restart: always meta-node-0: - image: "ghcr.io/risingwavelabs/risingwave:RisingWave_1.6.1_HDFS_2.7-x86_64" + image: "ghcr.io/risingwavelabs/risingwave:RisingWave_1.7.1_HDFS_2.7-x86_64" command: - meta-node - "--listen-addr" diff --git a/docker/docker-compose-with-obs.yml b/docker/docker-compose-with-obs.yml index 29d1c1a7452b9..f34460448b874 100644 --- a/docker/docker-compose-with-obs.yml +++ b/docker/docker-compose-with-obs.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.7.1} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-oss.yml b/docker/docker-compose-with-oss.yml index b759d16a93d24..5b1531caad100 100644 --- a/docker/docker-compose-with-oss.yml +++ b/docker/docker-compose-with-oss.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.7.1} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose-with-s3.yml b/docker/docker-compose-with-s3.yml index a3070dd8048d2..34ba2a29d0d67 100644 --- a/docker/docker-compose-with-s3.yml +++ b/docker/docker-compose-with-s3.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.7.1} services: risingwave-standalone: <<: *image diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index a9d4cc0f58f7b..e748a09a8c792 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,7 +1,7 @@ --- version: "3" x-image: &image - image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.6.1} + image: ${RW_IMAGE:-risingwavelabs/risingwave:v1.7.1} services: risingwave-standalone: <<: *image diff --git a/proto/stream_service.proto b/proto/stream_service.proto index e8c5d94a20ac3..5990fe1e2cbcf 100644 --- a/proto/stream_service.proto +++ b/proto/stream_service.proto @@ -44,16 +44,6 @@ message DropActorsResponse { common.Status status = 2; } -message ForceStopActorsRequest { - string request_id = 1; - uint64 prev_epoch = 2; -} - -message ForceStopActorsResponse { - string request_id = 1; - common.Status status = 2; -} - message InjectBarrierRequest { string request_id = 1; stream_plan.Barrier barrier = 2; @@ -61,16 +51,6 @@ message InjectBarrierRequest { repeated uint32 actor_ids_to_collect = 4; } -message InjectBarrierResponse { - string request_id = 1; - common.Status status = 2; -} - -message BarrierCompleteRequest { - string request_id = 1; - uint64 prev_epoch = 2; - map tracing_context = 3; -} message BarrierCompleteResponse { message CreateMviewProgress { uint32 backfill_actor_id = 1; @@ -104,15 +84,33 @@ message WaitEpochCommitResponse { common.Status status = 1; } +message StreamingControlStreamRequest { + message InitRequest { + uint64 prev_epoch = 2; + } + + oneof request { + InitRequest init = 1; + InjectBarrierRequest inject_barrier = 2; + } +} + +message StreamingControlStreamResponse { + message InitResponse {} + + oneof response { + InitResponse init = 1; + BarrierCompleteResponse complete_barrier = 2; + } +} + service StreamService { rpc UpdateActors(UpdateActorsRequest) returns (UpdateActorsResponse); rpc BuildActors(BuildActorsRequest) returns (BuildActorsResponse); rpc BroadcastActorInfoTable(BroadcastActorInfoTableRequest) returns (BroadcastActorInfoTableResponse); rpc DropActors(DropActorsRequest) returns (DropActorsResponse); - rpc ForceStopActors(ForceStopActorsRequest) returns (ForceStopActorsResponse); - rpc InjectBarrier(InjectBarrierRequest) returns (InjectBarrierResponse); - rpc BarrierComplete(BarrierCompleteRequest) returns (BarrierCompleteResponse); rpc WaitEpochCommit(WaitEpochCommitRequest) returns (WaitEpochCommitResponse); + rpc StreamingControlStream(stream StreamingControlStreamRequest) returns (stream StreamingControlStreamResponse); } // TODO: Lifecycle management for actors. diff --git a/src/batch/src/executor/row_seq_scan.rs b/src/batch/src/executor/row_seq_scan.rs index bf2fb9613b7eb..4c1261363a72b 100644 --- a/src/batch/src/executor/row_seq_scan.rs +++ b/src/batch/src/executor/row_seq_scan.rs @@ -33,7 +33,6 @@ use risingwave_storage::store::PrefetchOptions; use risingwave_storage::table::batch_table::storage_table::StorageTable; use risingwave_storage::table::{collect_data_chunk, TableDistribution}; use risingwave_storage::{dispatch_state_store, StateStore}; -use rw_futures_util::select_all; use crate::error::{BatchError, Result}; use crate::executor::{ @@ -319,28 +318,28 @@ impl RowSeqScanExecutor { } // Range Scan - let range_scans = select_all(range_scans.into_iter().map(|range_scan| { - let table = table.clone(); - let histogram = histogram.clone(); - Box::pin(Self::execute_range( - table, - range_scan, + // WARN: DO NOT use `select` to execute range scans concurrently + // it can consume too much memory if there're too many ranges. + for range in range_scans { + let stream = Self::execute_range( + table.clone(), + range, ordered, epoch.clone(), chunk_size, limit, - histogram, - )) - })); - #[for_await] - for chunk in range_scans { - let chunk = chunk?; - returned += chunk.cardinality() as u64; - yield chunk; - if let Some(limit) = &limit - && returned >= *limit - { - return Ok(()); + histogram.clone(), + ); + #[for_await] + for chunk in stream { + let chunk = chunk?; + returned += chunk.cardinality() as u64; + yield chunk; + if let Some(limit) = &limit + && returned >= *limit + { + return Ok(()); + } } } } diff --git a/src/compute/src/rpc/service/stream_service.rs b/src/compute/src/rpc/service/stream_service.rs index 6e96406743f29..18b77ff1804bc 100644 --- a/src/compute/src/rpc/service/stream_service.rs +++ b/src/compute/src/rpc/service/stream_service.rs @@ -13,18 +13,16 @@ // limitations under the License. use await_tree::InstrumentAwait; -use itertools::Itertools; -use risingwave_hummock_sdk::table_stats::to_prost_table_stats_map; -use risingwave_hummock_sdk::LocalSstableInfo; -use risingwave_pb::stream_service::barrier_complete_response::GroupedSstableInfo; +use futures::{Stream, StreamExt, TryStreamExt}; use risingwave_pb::stream_service::stream_service_server::StreamService; use risingwave_pb::stream_service::*; use risingwave_storage::dispatch_state_store; use risingwave_stream::error::StreamError; -use risingwave_stream::executor::Barrier; -use risingwave_stream::task::{BarrierCompleteResult, LocalStreamManager, StreamEnvironment}; +use risingwave_stream::task::{LocalStreamManager, StreamEnvironment}; use thiserror_ext::AsReport; -use tonic::{Request, Response, Status}; +use tokio::sync::mpsc::unbounded_channel; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::{Request, Response, Status, Streaming}; #[derive(Clone)] pub struct StreamServiceImpl { @@ -40,6 +38,9 @@ impl StreamServiceImpl { #[async_trait::async_trait] impl StreamService for StreamServiceImpl { + type StreamingControlStreamStream = + impl Stream>; + #[cfg_attr(coverage, coverage(off))] async fn update_actors( &self, @@ -110,86 +111,6 @@ impl StreamService for StreamServiceImpl { })) } - #[cfg_attr(coverage, coverage(off))] - async fn force_stop_actors( - &self, - request: Request, - ) -> std::result::Result, Status> { - let req = request.into_inner(); - self.mgr.reset(req.prev_epoch).await; - Ok(Response::new(ForceStopActorsResponse { - request_id: req.request_id, - status: None, - })) - } - - #[cfg_attr(coverage, coverage(off))] - async fn inject_barrier( - &self, - request: Request, - ) -> Result, Status> { - let req = request.into_inner(); - let barrier = - Barrier::from_protobuf(req.get_barrier().unwrap()).map_err(StreamError::from)?; - - self.mgr - .send_barrier(barrier, req.actor_ids_to_send, req.actor_ids_to_collect) - .await?; - - Ok(Response::new(InjectBarrierResponse { - request_id: req.request_id, - status: None, - })) - } - - #[cfg_attr(coverage, coverage(off))] - async fn barrier_complete( - &self, - request: Request, - ) -> Result, Status> { - let req = request.into_inner(); - let BarrierCompleteResult { - create_mview_progress, - sync_result, - } = self - .mgr - .collect_barrier(req.prev_epoch) - .instrument_await(format!("collect_barrier (epoch {})", req.prev_epoch)) - .await - .inspect_err( - |err| tracing::error!(error = %err.as_report(), "failed to collect barrier"), - )?; - - let (synced_sstables, table_watermarks) = sync_result - .map(|sync_result| (sync_result.uncommitted_ssts, sync_result.table_watermarks)) - .unwrap_or_default(); - - Ok(Response::new(BarrierCompleteResponse { - request_id: req.request_id, - status: None, - create_mview_progress, - synced_sstables: synced_sstables - .into_iter() - .map( - |LocalSstableInfo { - compaction_group_id, - sst_info, - table_stats, - }| GroupedSstableInfo { - compaction_group_id, - sst: Some(sst_info), - table_stats_map: to_prost_table_stats_map(table_stats), - }, - ) - .collect_vec(), - worker_id: self.env.worker_id(), - table_watermarks: table_watermarks - .into_iter() - .map(|(key, value)| (key.table_id, value.to_protobuf())) - .collect(), - })) - } - #[cfg_attr(coverage, coverage(off))] async fn wait_epoch_commit( &self, @@ -210,4 +131,24 @@ impl StreamService for StreamServiceImpl { Ok(Response::new(WaitEpochCommitResponse { status: None })) } + + async fn streaming_control_stream( + &self, + request: Request>, + ) -> Result, Status> { + let mut stream = request.into_inner().boxed(); + let first_request = stream.try_next().await?; + let Some(StreamingControlStreamRequest { + request: Some(streaming_control_stream_request::Request::Init(init_request)), + }) = first_request + else { + return Err(Status::invalid_argument(format!( + "unexpected first request: {:?}", + first_request + ))); + }; + let (tx, rx) = unbounded_channel(); + self.mgr.handle_new_control_stream(tx, stream, init_request); + Ok(Response::new(UnboundedReceiverStream::new(rx))) + } } diff --git a/src/connector/src/lib.rs b/src/connector/src/lib.rs index bbbf8fa5c3b69..3ebf2475f7866 100644 --- a/src/connector/src/lib.rs +++ b/src/connector/src/lib.rs @@ -57,6 +57,7 @@ pub mod mqtt_common; pub use paste::paste; mod with_options; +pub use with_options::WithPropertiesExt; #[cfg(test)] mod with_options_test; diff --git a/src/connector/src/source/base.rs b/src/connector/src/source/base.rs index c0ef0b65f8982..52724b1707660 100644 --- a/src/connector/src/source/base.rs +++ b/src/connector/src/source/base.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{BTreeMap, HashMap}; +use std::collections::HashMap; use std::sync::Arc; use anyhow::anyhow; @@ -331,17 +331,6 @@ impl Default for ConnectorProperties { } impl ConnectorProperties { - pub fn is_new_fs_connector_b_tree_map(with_properties: &BTreeMap) -> bool { - with_properties - .get(UPSTREAM_SOURCE_KEY) - .map(|s| { - s.eq_ignore_ascii_case(OPENDAL_S3_CONNECTOR) - || s.eq_ignore_ascii_case(POSIX_FS_CONNECTOR) - || s.eq_ignore_ascii_case(GCS_CONNECTOR) - }) - .unwrap_or(false) - } - pub fn is_new_fs_connector_hash_map(with_properties: &HashMap) -> bool { with_properties .get(UPSTREAM_SOURCE_KEY) diff --git a/src/connector/src/source/kafka/private_link.rs b/src/connector/src/source/kafka/private_link.rs index 3eebacca09f93..d2e3d6877d169 100644 --- a/src/connector/src/source/kafka/private_link.rs +++ b/src/connector/src/source/kafka/private_link.rs @@ -33,7 +33,6 @@ use crate::common::{ use crate::error::ConnectorResult; use crate::source::kafka::stats::RdKafkaStats; use crate::source::kafka::{KAFKA_PROPS_BROKER_KEY, KAFKA_PROPS_BROKER_KEY_ALIAS}; -use crate::source::KAFKA_CONNECTOR; pub const PRIVATELINK_ENDPOINT_KEY: &str = "privatelink.endpoint"; pub const CONNECTION_NAME_KEY: &str = "connection.name"; @@ -205,16 +204,6 @@ fn get_property_required( .map_err(Into::into) } -#[inline(always)] -fn is_kafka_connector(with_properties: &BTreeMap) -> bool { - const UPSTREAM_SOURCE_KEY: &str = "connector"; - with_properties - .get(UPSTREAM_SOURCE_KEY) - .unwrap_or(&"".to_string()) - .to_lowercase() - .eq_ignore_ascii_case(KAFKA_CONNECTOR) -} - pub fn insert_privatelink_broker_rewrite_map( with_options: &mut BTreeMap, svc: Option<&PrivateLinkService>, diff --git a/src/connector/src/with_options.rs b/src/connector/src/with_options.rs index a5c810834727a..941eaadd459e9 100644 --- a/src/connector/src/with_options.rs +++ b/src/connector/src/with_options.rs @@ -12,7 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; + +use crate::source::iceberg::ICEBERG_CONNECTOR; +use crate::source::{ + GCS_CONNECTOR, KAFKA_CONNECTOR, OPENDAL_S3_CONNECTOR, POSIX_FS_CONNECTOR, UPSTREAM_SOURCE_KEY, +}; /// Marker trait for `WITH` options. Only for `#[derive(WithOptions)]`, should not be used manually. /// @@ -56,3 +61,68 @@ impl WithOptions for crate::mqtt_common::QualityOfService {} impl WithOptions for crate::sink::kafka::CompressionCodec {} impl WithOptions for nexmark::config::RateShape {} impl WithOptions for nexmark::event::EventType {} + +pub trait Get { + fn get(&self, key: &str) -> Option<&String>; +} + +impl Get for HashMap { + fn get(&self, key: &str) -> Option<&String> { + self.get(key) + } +} + +impl Get for BTreeMap { + fn get(&self, key: &str) -> Option<&String> { + self.get(key) + } +} + +/// Utility methods for `WITH` properties (`HashMap` and `BTreeMap`). +pub trait WithPropertiesExt: Get { + #[inline(always)] + fn get_connector(&self) -> Option { + self.get(UPSTREAM_SOURCE_KEY).map(|s| s.to_lowercase()) + } + + #[inline(always)] + fn is_kafka_connector(&self) -> bool { + let Some(connector) = self.get_connector() else { + return false; + }; + connector == KAFKA_CONNECTOR + } + + #[inline(always)] + fn is_cdc_connector(&self) -> bool { + let Some(connector) = self.get_connector() else { + return false; + }; + connector.contains("-cdc") + } + + #[inline(always)] + fn is_iceberg_connector(&self) -> bool { + let Some(connector) = self.get_connector() else { + return false; + }; + connector == ICEBERG_CONNECTOR + } + + fn connector_need_pk(&self) -> bool { + // Currently only iceberg connector doesn't need primary key + !self.is_iceberg_connector() + } + + fn is_new_fs_connector(&self) -> bool { + self.get(UPSTREAM_SOURCE_KEY) + .map(|s| { + s.eq_ignore_ascii_case(OPENDAL_S3_CONNECTOR) + || s.eq_ignore_ascii_case(POSIX_FS_CONNECTOR) + || s.eq_ignore_ascii_case(GCS_CONNECTOR) + }) + .unwrap_or(false) + } +} + +impl WithPropertiesExt for T {} diff --git a/src/ctl/src/cmd_impl/hummock/list_kv.rs b/src/ctl/src/cmd_impl/hummock/list_kv.rs index 676c0b013163e..2eb54362b413c 100644 --- a/src/ctl/src/cmd_impl/hummock/list_kv.rs +++ b/src/ctl/src/cmd_impl/hummock/list_kv.rs @@ -14,11 +14,10 @@ use core::ops::Bound::Unbounded; -use futures::StreamExt; use risingwave_common::catalog::TableId; use risingwave_common::util::epoch::is_max_epoch; use risingwave_storage::hummock::CachePolicy; -use risingwave_storage::store::{PrefetchOptions, ReadOptions, StateStoreRead}; +use risingwave_storage::store::{PrefetchOptions, ReadOptions, StateStoreIter, StateStoreRead}; use crate::common::HummockServiceOpts; use crate::CtlContext; @@ -36,22 +35,20 @@ pub async fn list_kv( tracing::info!("using MAX EPOCH as epoch"); } let range = (Unbounded, Unbounded); - let mut scan_result = Box::pin( - hummock - .iter( - range, - epoch, - ReadOptions { - table_id: TableId { table_id }, - prefetch_options: PrefetchOptions::prefetch_for_large_range_scan(), - cache_policy: CachePolicy::NotFill, - ..Default::default() - }, - ) - .await?, - ); - while let Some(item) = scan_result.next().await { - let (k, v) = item?; + let mut scan_result = hummock + .iter( + range, + epoch, + ReadOptions { + table_id: TableId { table_id }, + prefetch_options: PrefetchOptions::prefetch_for_large_range_scan(), + cache_policy: CachePolicy::NotFill, + ..Default::default() + }, + ) + .await?; + while let Some(item) = scan_result.try_next().await? { + let (k, v) = item; let print_string = format!("[t{}]", k.user_key.table_id.table_id()); println!("{} {:?} => {:?}", print_string, k, v) } diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index d7631f11bf626..cc7d5df602e5d 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -1300,6 +1300,7 @@ impl Binder { ("pg_get_partkeydef", raw_literal(ExprImpl::literal_null(DataType::Varchar))), ("pg_encoding_to_char", raw_literal(ExprImpl::literal_varchar("UTF8".into()))), ("has_database_privilege", raw_literal(ExprImpl::literal_bool(true))), + ("pg_stat_get_numscans", raw_literal(ExprImpl::literal_bigint(0))), ("pg_backend_pid", raw(|binder, _inputs| { // FIXME: the session id is not global unique in multi-frontend env. Ok(ExprImpl::literal_int(binder.session_id.0)) diff --git a/src/frontend/src/catalog/connection_catalog.rs b/src/frontend/src/catalog/connection_catalog.rs index 58595dfbdfd62..54e1210979fe8 100644 --- a/src/frontend/src/catalog/connection_catalog.rs +++ b/src/frontend/src/catalog/connection_catalog.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use anyhow::anyhow; use risingwave_connector::source::kafka::private_link::insert_privatelink_broker_rewrite_map; -use risingwave_connector::source::KAFKA_CONNECTOR; +use risingwave_connector::WithPropertiesExt; use risingwave_pb::catalog::connection::private_link_service::PrivateLinkProvider; use risingwave_pb::catalog::connection::Info; use risingwave_pb::catalog::{connection, PbConnection}; @@ -65,23 +65,13 @@ impl OwnedByUserCatalog for ConnectionCatalog { } } -#[inline(always)] -fn is_kafka_connector(with_properties: &BTreeMap) -> bool { - const UPSTREAM_SOURCE_KEY: &str = "connector"; - with_properties - .get(UPSTREAM_SOURCE_KEY) - .unwrap_or(&"".to_string()) - .to_lowercase() - .eq_ignore_ascii_case(KAFKA_CONNECTOR) -} - pub(crate) fn resolve_private_link_connection( connection: &Arc, properties: &mut BTreeMap, ) -> Result<()> { #[allow(irrefutable_let_patterns)] if let connection::Info::PrivateLinkService(svc) = &connection.info { - if !is_kafka_connector(properties) { + if !properties.is_kafka_connector() { return Err(RwError::from(anyhow!( "Private link is only supported for Kafka connector" ))); diff --git a/src/frontend/src/handler/alter_source_with_sr.rs b/src/frontend/src/handler/alter_source_with_sr.rs index fc35552270a2e..1718432b70aa5 100644 --- a/src/frontend/src/handler/alter_source_with_sr.rs +++ b/src/frontend/src/handler/alter_source_with_sr.rs @@ -18,6 +18,7 @@ use itertools::Itertools; use pgwire::pg_response::StatementType; use risingwave_common::bail_not_implemented; use risingwave_common::catalog::ColumnCatalog; +use risingwave_connector::WithPropertiesExt; use risingwave_pb::catalog::StreamSourceInfo; use risingwave_pb::plan_common::{EncodeType, FormatType}; use risingwave_sqlparser::ast::{ @@ -28,7 +29,6 @@ use risingwave_sqlparser::parser::Parser; use super::alter_table_column::schema_has_schema_registry; use super::create_source::{bind_columns_from_source, validate_compatibility}; -use super::util::is_cdc_connector; use super::{HandlerArgs, RwPgResponse}; use crate::catalog::root_catalog::SchemaPath; use crate::catalog::source_catalog::SourceCatalog; @@ -152,7 +152,7 @@ pub async fn refresh_sr_and_get_columns_diff( .collect(); validate_compatibility(connector_schema, &mut with_properties)?; - if is_cdc_connector(&with_properties) { + if with_properties.is_cdc_connector() { bail_not_implemented!("altering a cdc source is not supported"); } diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index f0e6f075b261e..3585d58290c54 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -52,6 +52,7 @@ use risingwave_connector::source::{ KINESIS_CONNECTOR, MQTT_CONNECTOR, NATS_CONNECTOR, NEXMARK_CONNECTOR, OPENDAL_S3_CONNECTOR, POSIX_FS_CONNECTOR, PULSAR_CONNECTOR, S3_CONNECTOR, }; +use risingwave_connector::WithPropertiesExt; use risingwave_pb::catalog::{ PbSchemaRegistryNameStrategy, PbSource, StreamSourceInfo, WatermarkDesc, }; @@ -75,10 +76,7 @@ use crate::handler::create_table::{ bind_pk_on_relation, bind_sql_column_constraints, bind_sql_columns, bind_sql_pk_names, ensure_table_constraints_supported, ColumnIdGenerator, }; -use crate::handler::util::{ - connector_need_pk, get_connector, is_cdc_connector, is_iceberg_connector, is_kafka_connector, - SourceSchemaCompatExt, -}; +use crate::handler::util::SourceSchemaCompatExt; use crate::handler::HandlerArgs; use crate::optimizer::plan_node::generic::SourceNodeKind; use crate::optimizer::plan_node::{LogicalSource, ToStream, ToStreamContext}; @@ -298,7 +296,7 @@ pub(crate) async fn bind_columns_from_source( const KEY_MESSAGE_NAME_KEY: &str = "key.message"; const NAME_STRATEGY_KEY: &str = "schema.registry.name.strategy"; - let is_kafka: bool = is_kafka_connector(with_properties); + let is_kafka: bool = with_properties.is_kafka_connector(); let format_encode_options = WithOptions::try_from(source_schema.row_options())?.into_inner(); let mut format_encode_options_to_consume = format_encode_options.clone(); @@ -447,7 +445,7 @@ pub(crate) async fn bind_columns_from_source( .await? } (Format::None, Encode::None) => { - if is_iceberg_connector(with_properties) { + if with_properties.is_iceberg_connector() { Some( extract_iceberg_columns(with_properties) .await @@ -533,7 +531,7 @@ pub fn handle_addition_columns( mut additional_columns: IncludeOption, columns: &mut Vec, ) -> Result<()> { - let connector_name = get_connector(with_properties).unwrap(); // there must be a connector in source + let connector_name = with_properties.get_connector().unwrap(); // there must be a connector in source if COMPATIBLE_ADDITIONAL_COLUMNS .get(connector_name.as_str()) @@ -878,7 +876,7 @@ fn check_and_add_timestamp_column( with_properties: &HashMap, columns: &mut Vec, ) { - if is_kafka_connector(with_properties) { + if with_properties.is_kafka_connector() { if columns.iter().any(|col| { matches!( col.column_desc.additional_column.column_type, @@ -1027,7 +1025,8 @@ pub fn validate_compatibility( source_schema: &ConnectorSchema, props: &mut HashMap, ) -> Result<()> { - let connector = get_connector(props) + let connector = props + .get_connector() .ok_or_else(|| RwError::from(ProtocolError("missing field 'connector'".to_string())))?; let compatible_formats = CONNECTORS_COMPATIBLE_FORMATS @@ -1105,7 +1104,7 @@ pub(super) async fn check_source_schema( row_id_index: Option, columns: &[ColumnCatalog], ) -> Result<()> { - let Some(connector) = get_connector(props) else { + let Some(connector) = props.get_connector() else { return Ok(()); }; @@ -1309,7 +1308,7 @@ pub async fn handle_create_source( let sql_pk_names = bind_sql_pk_names(&stmt.columns, &stmt.constraints)?; // gated the feature with a session variable - let create_cdc_source_job = if is_cdc_connector(&with_properties) { + let create_cdc_source_job = if with_properties.is_cdc_connector() { CdcTableType::from_properties(&with_properties).can_backfill() } else { false @@ -1365,7 +1364,7 @@ pub async fn handle_create_source( .into()); } let (mut columns, pk_column_ids, row_id_index) = - bind_pk_on_relation(columns, pk_names, connector_need_pk(&with_properties))?; + bind_pk_on_relation(columns, pk_names, with_properties.connector_need_pk())?; debug_assert!(is_column_ids_dedup(&columns)); diff --git a/src/frontend/src/handler/create_table.rs b/src/frontend/src/handler/create_table.rs index 493cfe967f3d3..3eb6b5279293d 100644 --- a/src/frontend/src/handler/create_table.rs +++ b/src/frontend/src/handler/create_table.rs @@ -29,11 +29,11 @@ use risingwave_common::catalog::{ }; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::util::value_encoding::DatumToProtoExt; -use risingwave_connector::source; use risingwave_connector::source::cdc::external::{ DATABASE_NAME_KEY, SCHEMA_NAME_KEY, TABLE_NAME_KEY, }; use risingwave_connector::source::cdc::CDC_BACKFILL_ENABLE_KEY; +use risingwave_connector::{source, WithPropertiesExt}; use risingwave_pb::catalog::source::OptionalAssociatedTableId; use risingwave_pb::catalog::{PbSource, PbTable, StreamSourceInfo, Table, WatermarkDesc}; use risingwave_pb::ddl_service::TableJobType; @@ -61,7 +61,6 @@ use crate::handler::create_source::{ bind_all_columns, bind_columns_from_source, bind_source_pk, bind_source_watermark, check_source_schema, handle_addition_columns, validate_compatibility, UPSTREAM_SOURCE_KEY, }; -use crate::handler::util::is_iceberg_connector; use crate::handler::HandlerArgs; use crate::optimizer::plan_node::generic::SourceNodeKind; use crate::optimizer::plan_node::{LogicalCdcScan, LogicalSource}; @@ -514,7 +513,7 @@ pub(crate) async fn gen_create_table_plan_with_source( c.column_desc.column_id = col_id_gen.generate(c.name()) } - if is_iceberg_connector(&with_properties) { + if with_properties.is_iceberg_connector() { return Err( ErrorCode::BindError("can't create table with iceberg connector".to_string()).into(), ); diff --git a/src/frontend/src/handler/util.rs b/src/frontend/src/handler/util.rs index 6c9a9bb45f2ac..d3ccb55e6a6ab 100644 --- a/src/frontend/src/handler/util.rs +++ b/src/frontend/src/handler/util.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -31,12 +30,9 @@ use risingwave_common::catalog::Field; use risingwave_common::row::Row as _; use risingwave_common::types::{write_date_time_tz, DataType, ScalarRefImpl, Timestamptz}; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_connector::source::iceberg::ICEBERG_CONNECTOR; -use risingwave_connector::source::KAFKA_CONNECTOR; use risingwave_sqlparser::ast::{CompatibleSourceSchema, ConnectorSchema}; use crate::error::{ErrorCode, Result as RwResult}; -use crate::handler::create_source::UPSTREAM_SOURCE_KEY; use crate::session::{current, SessionImpl}; pin_project! { @@ -180,43 +176,6 @@ pub fn to_pg_field(f: &Field) -> PgFieldDescriptor { ) } -pub fn connector_need_pk(with_properties: &HashMap) -> bool { - // Currently only iceberg connector doesn't need primary key - !is_iceberg_connector(with_properties) -} - -#[inline(always)] -pub fn get_connector(with_properties: &HashMap) -> Option { - with_properties - .get(UPSTREAM_SOURCE_KEY) - .map(|s| s.to_lowercase()) -} - -#[inline(always)] -pub fn is_kafka_connector(with_properties: &HashMap) -> bool { - let Some(connector) = get_connector(with_properties) else { - return false; - }; - - connector == KAFKA_CONNECTOR -} - -#[inline(always)] -pub fn is_cdc_connector(with_properties: &HashMap) -> bool { - let Some(connector) = get_connector(with_properties) else { - return false; - }; - connector.contains("-cdc") -} - -#[inline(always)] -pub fn is_iceberg_connector(with_properties: &HashMap) -> bool { - let Some(connector) = get_connector(with_properties) else { - return false; - }; - connector == ICEBERG_CONNECTOR -} - #[easy_ext::ext(SourceSchemaCompatExt)] impl CompatibleSourceSchema { /// Convert `self` to [`ConnectorSchema`] and warn the user if the syntax is deprecated. diff --git a/src/frontend/src/optimizer/plan_node/generic/source.rs b/src/frontend/src/optimizer/plan_node/generic/source.rs index 04e63a246a7ae..406a3654def24 100644 --- a/src/frontend/src/optimizer/plan_node/generic/source.rs +++ b/src/frontend/src/optimizer/plan_node/generic/source.rs @@ -20,7 +20,7 @@ use educe::Educe; use risingwave_common::catalog::{ColumnCatalog, Field, Schema}; use risingwave_common::types::DataType; use risingwave_common::util::sort_util::OrderType; -use risingwave_connector::source::ConnectorProperties; +use risingwave_connector::WithPropertiesExt; use super::super::utils::TableCatalogBuilder; use super::GenericPlanNode; @@ -99,9 +99,9 @@ impl GenericPlanNode for Source { impl Source { pub fn is_new_fs_connector(&self) -> bool { - self.catalog.as_ref().is_some_and(|catalog| { - ConnectorProperties::is_new_fs_connector_b_tree_map(&catalog.with_properties) - }) + self.catalog + .as_ref() + .is_some_and(|catalog| catalog.with_properties.is_new_fs_connector()) } /// The columns in stream/batch source node indicate the actual columns it will produce, diff --git a/src/frontend/src/utils/with_options.rs b/src/frontend/src/utils/with_options.rs index 633bcf29354f5..574cab7d64979 100644 --- a/src/frontend/src/utils/with_options.rs +++ b/src/frontend/src/utils/with_options.rs @@ -19,7 +19,7 @@ use std::num::NonZeroU32; use risingwave_connector::source::kafka::{ insert_privatelink_broker_rewrite_map, CONNECTION_NAME_KEY, PRIVATELINK_ENDPOINT_KEY, }; -use risingwave_connector::source::KAFKA_CONNECTOR; +use risingwave_connector::WithPropertiesExt; use risingwave_sqlparser::ast::{ CreateConnectionStatement, CreateSinkStatement, CreateSourceStatement, SqlOption, Statement, Value, @@ -28,7 +28,6 @@ use risingwave_sqlparser::ast::{ use crate::catalog::connection_catalog::resolve_private_link_connection; use crate::catalog::ConnectionId; use crate::error::{ErrorCode, Result as RwResult, RwError}; -use crate::handler::create_source::UPSTREAM_SOURCE_KEY; use crate::session::SessionImpl; mod options { @@ -113,24 +112,12 @@ impl WithOptions { } } -#[inline(always)] -fn is_kafka_connector(with_options: &WithOptions) -> bool { - let Some(connector) = with_options - .inner() - .get(UPSTREAM_SOURCE_KEY) - .map(|s| s.to_lowercase()) - else { - return false; - }; - connector == KAFKA_CONNECTOR -} - pub(crate) fn resolve_privatelink_in_with_option( with_options: &mut WithOptions, schema_name: &Option, session: &SessionImpl, ) -> RwResult> { - let is_kafka = is_kafka_connector(with_options); + let is_kafka = with_options.is_kafka_connector(); let privatelink_endpoint = with_options.remove(PRIVATELINK_ENDPOINT_KEY); // if `privatelink.endpoint` is provided in WITH, use it to rewrite broker address directly diff --git a/src/jni_core/src/hummock_iterator.rs b/src/jni_core/src/hummock_iterator.rs index ee2084b6ecf81..69845ff0f459e 100644 --- a/src/jni_core/src/hummock_iterator.rs +++ b/src/jni_core/src/hummock_iterator.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use bytes::Bytes; -use futures::TryStreamExt; +use futures::{Stream, TryStreamExt}; use risingwave_common::catalog::ColumnDesc; use risingwave_common::config::{MetricLevel, ObjectStoreConfig}; use risingwave_common::hash::VirtualNode; @@ -37,20 +37,31 @@ use risingwave_storage::hummock::{ }; use risingwave_storage::monitor::{global_hummock_state_store_metrics, HummockStateStoreMetrics}; use risingwave_storage::row_serde::value_serde::ValueRowSerdeNew; -use risingwave_storage::store::{ReadOptions, StateStoreReadIterStream, StreamTypeOfIter}; +use risingwave_storage::store::{ReadOptions, StateStoreIterExt}; +use risingwave_storage::table::KeyedRow; use rw_futures_util::select_all; use tokio::sync::mpsc::unbounded_channel; -type SelectAllIterStream = impl StateStoreReadIterStream + Unpin; +type SelectAllIterStream = impl Stream>> + Unpin; +type SingleIterStream = impl Stream>>; -fn select_all_vnode_stream( - streams: Vec>, -) -> SelectAllIterStream { +fn select_all_vnode_stream(streams: Vec) -> SelectAllIterStream { select_all(streams.into_iter().map(Box::pin)) } -pub struct HummockJavaBindingIterator { +fn to_deserialized_stream( + iter: HummockStorageIterator, row_serde: EitherSerde, +) -> SingleIterStream { + iter.into_stream(move |(key, value)| { + Ok(KeyedRow::new( + key.user_key.table_key.copy_into(), + row_serde.deserialize(value).map(OwnedRow::new)?, + )) + }) +} + +pub struct HummockJavaBindingIterator { stream: SelectAllIterStream, } @@ -87,6 +98,28 @@ impl HummockJavaBindingIterator { 0, ); + let table = read_plan.table_catalog.unwrap(); + let versioned = table.version.is_some(); + let table_columns = table + .columns + .into_iter() + .map(|c| ColumnDesc::from(c.column_desc.unwrap())); + + // Decide which serializer to use based on whether the table is versioned or not. + let row_serde: EitherSerde = if versioned { + ColumnAwareSerde::new( + Arc::from_iter(0..table_columns.len()), + Arc::from_iter(table_columns), + ) + .into() + } else { + BasicSerde::new( + Arc::from_iter(0..table_columns.len()), + Arc::from_iter(table_columns), + ) + .into() + }; + let mut streams = Vec::with_capacity(read_plan.vnode_ids.len()); let key_range = read_plan.key_range.unwrap(); let pin_version = PinnedVersion::new( @@ -104,7 +137,7 @@ impl HummockJavaBindingIterator { key_range, read_plan.epoch, ); - let stream = reader + let iter = reader .iter( key_range, read_plan.epoch, @@ -116,45 +149,16 @@ impl HummockJavaBindingIterator { read_version_tuple, ) .await?; - streams.push(stream); + streams.push(to_deserialized_stream(iter, row_serde.clone())); } let stream = select_all_vnode_stream(streams); - let table = read_plan.table_catalog.unwrap(); - let versioned = table.version.is_some(); - let table_columns = table - .columns - .into_iter() - .map(|c| ColumnDesc::from(c.column_desc.unwrap())); - - // Decide which serializer to use based on whether the table is versioned or not. - let row_serde = if versioned { - ColumnAwareSerde::new( - Arc::from_iter(0..table_columns.len()), - Arc::from_iter(table_columns), - ) - .into() - } else { - BasicSerde::new( - Arc::from_iter(0..table_columns.len()), - Arc::from_iter(table_columns), - ) - .into() - }; - - Ok(Self { row_serde, stream }) + Ok(Self { stream }) } - pub async fn next(&mut self) -> StorageResult> { - let item = self.stream.try_next().await?; - Ok(match item { - Some((key, value)) => Some(( - key.user_key.table_key.0, - OwnedRow::new(self.row_serde.deserialize(&value)?), - )), - None => None, - }) + pub async fn next(&mut self) -> StorageResult>> { + self.stream.try_next().await } } diff --git a/src/jni_core/src/lib.rs b/src/jni_core/src/lib.rs index 3b877261dfe7b..ac5192700fae5 100644 --- a/src/jni_core/src/lib.rs +++ b/src/jni_core/src/lib.rs @@ -360,10 +360,11 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorNext<'a>( iter.cursor = None; Ok(JNI_FALSE) } - Some((key, row)) => { + Some(keyed_row) => { + let (key, row) = keyed_row.into_parts(); iter.cursor = Some(RowCursor { row, - extra: RowExtra::Key(key), + extra: RowExtra::Key(key.0), }); Ok(JNI_TRUE) } diff --git a/src/meta/src/barrier/command.rs b/src/meta/src/barrier/command.rs index 006de45d522e7..22311a2b43911 100644 --- a/src/meta/src/barrier/command.rs +++ b/src/meta/src/barrier/command.rs @@ -175,6 +175,8 @@ pub enum Command { RescheduleFragment { reschedules: HashMap, table_parallelism: HashMap, + // should contain the actor ids in upstream and downstream fragment of `reschedules` + fragment_actors: HashMap>, }, /// `ReplaceTable` command generates a `Update` barrier with the given `merge_updates`. This is @@ -351,7 +353,7 @@ impl CommandContext { impl CommandContext { /// Generate a mutation for the given command. - pub async fn to_mutation(&self) -> MetaResult> { + pub fn to_mutation(&self) -> Option { let mutation = match &self.command { Command::Plain(mutation) => mutation.clone(), @@ -479,21 +481,23 @@ impl CommandContext { init_split_assignment, ), - Command::RescheduleFragment { reschedules, .. } => { - let metadata_manager = &self.barrier_manager_context.metadata_manager; - + Command::RescheduleFragment { + reschedules, + fragment_actors, + .. + } => { let mut dispatcher_update = HashMap::new(); for reschedule in reschedules.values() { for &(upstream_fragment_id, dispatcher_id) in &reschedule.upstream_fragment_dispatcher_ids { // Find the actors of the upstream fragment. - let upstream_actor_ids = metadata_manager - .get_running_actors_of_fragment(upstream_fragment_id) - .await?; + let upstream_actor_ids = fragment_actors + .get(&upstream_fragment_id) + .expect("should contain"); // Record updates for all actors. - for actor_id in upstream_actor_ids { + for &actor_id in upstream_actor_ids { // Index with the dispatcher id to check duplicates. dispatcher_update .try_insert( @@ -526,9 +530,9 @@ impl CommandContext { for (&fragment_id, reschedule) in reschedules { for &downstream_fragment_id in &reschedule.downstream_fragment_ids { // Find the actors of the downstream fragment. - let downstream_actor_ids = metadata_manager - .get_running_actors_of_fragment(downstream_fragment_id) - .await?; + let downstream_actor_ids = fragment_actors + .get(&downstream_fragment_id) + .expect("should contain"); // Downstream removed actors should be skipped // Newly created actors of the current fragment will not dispatch Update @@ -545,7 +549,7 @@ impl CommandContext { .unwrap_or_default(); // Record updates for all actors. - for actor_id in downstream_actor_ids { + for &actor_id in downstream_actor_ids { if downstream_removed_actors.contains(&actor_id) { continue; } @@ -620,7 +624,7 @@ impl CommandContext { } }; - Ok(mutation) + mutation } fn generate_update_mutation_for_replace_table( @@ -962,6 +966,7 @@ impl CommandContext { Command::RescheduleFragment { reschedules, table_parallelism, + .. } => { let removed_actors = reschedules .values() diff --git a/src/meta/src/barrier/mod.rs b/src/meta/src/barrier/mod.rs index 4d867d266270b..652a4b51d9264 100644 --- a/src/meta/src/barrier/mod.rs +++ b/src/meta/src/barrier/mod.rs @@ -25,11 +25,11 @@ use arc_swap::ArcSwap; use fail::fail_point; use itertools::Itertools; use prometheus::HistogramTimer; +use risingwave_common::bail; use risingwave_common::catalog::TableId; use risingwave_common::system_param::reader::SystemParamsRead; use risingwave_common::system_param::PAUSE_ON_NEXT_BOOTSTRAP_KEY; use risingwave_common::util::epoch::{Epoch, INVALID_EPOCH}; -use risingwave_common::{bail, must_match}; use risingwave_hummock_sdk::table_watermark::{ merge_multiple_new_table_watermarks, TableWatermarks, }; @@ -41,7 +41,9 @@ use risingwave_pb::meta::subscribe_response::{Info, Operation}; use risingwave_pb::meta::PausedReason; use risingwave_pb::stream_plan::barrier::BarrierKind; use risingwave_pb::stream_service::barrier_complete_response::CreateMviewProgress; -use risingwave_pb::stream_service::BarrierCompleteResponse; +use risingwave_pb::stream_service::{ + streaming_control_stream_response, BarrierCompleteResponse, StreamingControlStreamResponse, +}; use thiserror_ext::AsReport; use tokio::sync::oneshot::{Receiver, Sender}; use tokio::sync::Mutex; @@ -54,12 +56,13 @@ use self::progress::TrackingCommand; use crate::barrier::info::InflightActorInfo; use crate::barrier::notifier::BarrierInfo; use crate::barrier::progress::CreateMviewProgressTracker; -use crate::barrier::rpc::BarrierRpcManager; +use crate::barrier::rpc::ControlStreamManager; use crate::barrier::state::BarrierManagerState; use crate::hummock::{CommitEpochInfo, HummockManagerRef}; use crate::manager::sink_coordination::SinkCoordinatorManager; use crate::manager::{ - ActiveStreamingWorkerNodes, LocalNotification, MetaSrvEnv, MetadataManager, WorkerId, + ActiveStreamingWorkerChange, ActiveStreamingWorkerNodes, LocalNotification, MetaSrvEnv, + MetadataManager, WorkerId, }; use crate::model::{ActorId, TableFragments}; use crate::rpc::metrics::MetaMetrics; @@ -188,9 +191,9 @@ pub struct GlobalBarrierManager { checkpoint_control: CheckpointControl, - rpc_manager: BarrierRpcManager, - active_streaming_nodes: ActiveStreamingWorkerNodes, + + control_stream_manager: ControlStreamManager, } /// Controls the concurrent execution of commands. @@ -228,7 +231,7 @@ impl CheckpointControl { self.context.metrics.in_flight_barrier_nums.set( self.command_ctx_queue .values() - .filter(|x| matches!(x.state, BarrierEpochState::InFlight)) + .filter(|x| x.state.is_inflight()) .count() as i64, ); self.context @@ -238,7 +241,12 @@ impl CheckpointControl { } /// Enqueue a barrier command, and init its state to `InFlight`. - fn enqueue_command(&mut self, command_ctx: Arc, notifiers: Vec) { + fn enqueue_command( + &mut self, + command_ctx: Arc, + notifiers: Vec, + node_to_collect: HashSet, + ) { let timer = self.context.metrics.barrier_latency.start_timer(); if let Some((_, node)) = self.command_ctx_queue.last_key_value() { @@ -251,7 +259,10 @@ impl CheckpointControl { command_ctx.prev_epoch.value().0, EpochNode { enqueue_time: timer, - state: BarrierEpochState::InFlight, + state: BarrierEpochState { + node_to_collect, + resps: vec![], + }, command_ctx, notifiers, }, @@ -260,14 +271,19 @@ impl CheckpointControl { /// Change the state of this `prev_epoch` to `Completed`. Return continuous nodes /// with `Completed` starting from first node [`Completed`..`InFlight`) and remove them. - fn barrier_collected(&mut self, prev_epoch: u64, result: Vec) { + fn barrier_collected( + &mut self, + worker_id: WorkerId, + prev_epoch: u64, + resp: BarrierCompleteResponse, + ) { if let Some(node) = self.command_ctx_queue.get_mut(&prev_epoch) { - assert!(matches!(node.state, BarrierEpochState::InFlight)); - node.state = BarrierEpochState::Collected(result); + assert!(node.state.node_to_collect.remove(&worker_id)); + node.state.resps.push(resp); } else { panic!( - "received barrier complete response for an unknown epoch: {}", - prev_epoch + "collect barrier on non-existing barrier: {}, {}", + prev_epoch, worker_id ); } } @@ -277,7 +293,7 @@ impl CheckpointControl { let in_flight_not_full = self .command_ctx_queue .values() - .filter(|x| matches!(x.state, BarrierEpochState::InFlight)) + .filter(|x| x.state.is_inflight()) .count() < in_flight_barrier_nums; @@ -340,13 +356,8 @@ impl CheckpointControl { }; if !is_err { // continue to finish the pending collected barrier. - while let Some(( - _, - EpochNode { - state: BarrierEpochState::Collected(_), - .. - }, - )) = self.command_ctx_queue.first_key_value() + while let Some((_, EpochNode { state, .. })) = self.command_ctx_queue.first_key_value() + && !state.is_inflight() { let (_, node) = self.command_ctx_queue.pop_first().expect("non-empty"); let command_ctx = node.command_ctx.clone(); @@ -390,12 +401,16 @@ pub struct EpochNode { } /// The state of barrier. -enum BarrierEpochState { - /// This barrier is current in-flight on the stream graph of compute nodes. - InFlight, +struct BarrierEpochState { + node_to_collect: HashSet, - /// This barrier is collected. - Collected(Vec), + resps: Vec, +} + +impl BarrierEpochState { + fn is_inflight(&self) -> bool { + !self.node_to_collect.is_empty() + } } enum CompletingCommand { @@ -411,13 +426,6 @@ enum CompletingCommand { Err(MetaError), } -/// The result of barrier collect. -#[derive(Debug)] -struct BarrierCollectResult { - prev_epoch: u64, - result: MetaResult>, -} - impl GlobalBarrierManager { /// Create a new [`crate::barrier::GlobalBarrierManager`]. #[allow(clippy::too_many_arguments)] @@ -458,10 +466,9 @@ impl GlobalBarrierManager { env: env.clone(), }; + let control_stream_manager = ControlStreamManager::new(context.clone()); let checkpoint_control = CheckpointControl::new(context.clone()); - let rpc_manager = BarrierRpcManager::new(context.clone()); - Self { enable_recovery, scheduled_barriers, @@ -470,8 +477,8 @@ impl GlobalBarrierManager { env, state: initial_invalid_state, checkpoint_control, - rpc_manager, active_streaming_nodes, + control_stream_manager, } } @@ -489,7 +496,7 @@ impl GlobalBarrierManager { } /// Check whether we should pause on bootstrap from the system parameter and reset it. - async fn take_pause_on_bootstrap(&self) -> MetaResult { + async fn take_pause_on_bootstrap(&mut self) -> MetaResult { let paused = self .env .system_params_reader() @@ -640,6 +647,9 @@ impl GlobalBarrierManager { self.state .resolve_worker_nodes(self.active_streaming_nodes.current().values().cloned()); + if let ActiveStreamingWorkerChange::Add(node) | ActiveStreamingWorkerChange::Update(node) = changed_worker { + self.control_stream_manager.add_worker(node).await; + } } // Checkpoint frequency changes. @@ -652,14 +662,19 @@ impl GlobalBarrierManager { .set_checkpoint_frequency(p.checkpoint_frequency() as usize) } } - // Barrier completes. - collect_result = self.rpc_manager.next_collected_barrier() => { - match collect_result.result { - Ok(resps) => { - self.checkpoint_control.barrier_collected(collect_result.prev_epoch, resps); - }, + resp_result = self.control_stream_manager.next_response() => { + match resp_result { + Ok((worker_id, prev_epoch, resp)) => { + let resp: StreamingControlStreamResponse = resp; + match resp.response { + Some(streaming_control_stream_response::Response::CompleteBarrier(resp)) => { + self.checkpoint_control.barrier_collected(worker_id, prev_epoch, resp); + }, + resp => unreachable!("invalid response: {:?}", resp), + } + + } Err(e) => { - fail_point!("inject_barrier_err_success"); self.failure_recovery(e).await; } } @@ -683,7 +698,9 @@ impl GlobalBarrierManager { if self .checkpoint_control .can_inject_barrier(self.in_flight_barrier_nums) => { - self.handle_new_barrier(scheduled); + if let Err(e) = self.handle_new_barrier(scheduled) { + self.failure_recovery(e).await; + } } } self.checkpoint_control.update_barrier_nums_metrics(); @@ -691,7 +708,7 @@ impl GlobalBarrierManager { } /// Handle the new barrier from the scheduled queue and inject it. - fn handle_new_barrier(&mut self, scheduled: Scheduled) { + fn handle_new_barrier(&mut self, scheduled: Scheduled) -> MetaResult<()> { let Scheduled { command, mut notifiers, @@ -728,7 +745,12 @@ impl GlobalBarrierManager { send_latency_timer.observe_duration(); - self.rpc_manager.inject_barrier(command_ctx.clone()); + let node_to_collect = self + .control_stream_manager + .inject_barrier(command_ctx.clone()) + .inspect_err(|_| { + fail_point!("inject_barrier_err_success"); + })?; // Notify about the injection. let prev_paused_reason = self.state.paused_reason(); @@ -746,12 +768,12 @@ impl GlobalBarrierManager { self.state.set_paused_reason(curr_paused_reason); // Record the in-flight barrier. self.checkpoint_control - .enqueue_command(command_ctx.clone(), notifiers); + .enqueue_command(command_ctx.clone(), notifiers, node_to_collect); + Ok(()) } async fn failure_recovery(&mut self, err: MetaError) { self.context.tracker.lock().await.abort_all(&err); - self.rpc_manager.clear(); self.checkpoint_control.clear_on_err(&err).await; if self.enable_recovery { @@ -787,7 +809,8 @@ impl GlobalBarrierManagerContext { state, .. } = node; - let resps = must_match!(state, BarrierEpochState::Collected(resps) => resps); + assert!(state.node_to_collect.is_empty()); + let resps = state.resps; let wait_commit_timer = self.metrics.barrier_wait_commit_latency.start_timer(); let (commit_info, create_mview_progress) = collect_commit_epoch_info(resps); if let Err(e) = self.update_snapshot(&command_ctx, commit_info).await { @@ -954,13 +977,8 @@ impl CheckpointControl { if matches!(&self.completing_command, CompletingCommand::None) { // If there is no completing barrier, try to start completing the earliest barrier if // it has been collected. - if let Some(( - _, - EpochNode { - state: BarrierEpochState::Collected(_), - .. - }, - )) = self.command_ctx_queue.first_key_value() + if let Some((_, EpochNode { state, .. })) = self.command_ctx_queue.first_key_value() + && !state.is_inflight() { let (_, node) = self.command_ctx_queue.pop_first().expect("non-empty"); let command_ctx = node.command_ctx.clone(); diff --git a/src/meta/src/barrier/recovery.rs b/src/meta/src/barrier/recovery.rs index cd71f9eea707e..a7ea3ae51665a 100644 --- a/src/meta/src/barrier/recovery.rs +++ b/src/meta/src/barrier/recovery.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::assert_matches::assert_matches; use std::collections::{BTreeSet, HashMap, HashSet}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -28,6 +29,9 @@ use risingwave_pb::meta::PausedReason; use risingwave_pb::stream_plan::barrier::BarrierKind; use risingwave_pb::stream_plan::barrier_mutation::Mutation; use risingwave_pb::stream_plan::AddMutation; +use risingwave_pb::stream_service::{ + streaming_control_stream_response, StreamingControlStreamResponse, +}; use thiserror_ext::AsReport; use tokio::sync::oneshot; use tokio_retry::strategy::{jitter, ExponentialBackoff}; @@ -38,6 +42,8 @@ use crate::barrier::command::CommandContext; use crate::barrier::info::InflightActorInfo; use crate::barrier::notifier::Notifier; use crate::barrier::progress::CreateMviewProgressTracker; +use crate::barrier::rpc::ControlStreamManager; +use crate::barrier::schedule::ScheduledBarriers; use crate::barrier::state::BarrierManagerState; use crate::barrier::{Command, GlobalBarrierManager, GlobalBarrierManagerContext}; use crate::controller::catalog::ReleaseContext; @@ -302,15 +308,16 @@ impl GlobalBarrierManagerContext { Ok(()) } -} -impl GlobalBarrierManager { /// Pre buffered drop and cancel command, return true if any. - async fn pre_apply_drop_cancel(&self) -> MetaResult { - let (dropped_actors, cancelled) = self.scheduled_barriers.pre_apply_drop_cancel_scheduled(); + async fn pre_apply_drop_cancel( + &self, + scheduled_barriers: &ScheduledBarriers, + ) -> MetaResult { + let (dropped_actors, cancelled) = scheduled_barriers.pre_apply_drop_cancel_scheduled(); let applied = !dropped_actors.is_empty() || !cancelled.is_empty(); if !cancelled.is_empty() { - match &self.context.metadata_manager { + match &self.metadata_manager { MetadataManager::V1(mgr) => { let unregister_table_ids = mgr .fragment_manager @@ -334,7 +341,9 @@ impl GlobalBarrierManager { } Ok(applied) } +} +impl GlobalBarrierManager { /// Recovery the whole cluster from the latest epoch. /// /// If `paused_reason` is `Some`, all data sources (including connectors and DMLs) will be @@ -375,11 +384,14 @@ impl GlobalBarrierManager { // get recovered. let recovery_timer = self.context.metrics.recovery_latency.start_timer(); - let (state, active_streaming_nodes) = tokio_retry::Retry::spawn(retry_strategy, || { + let new_state = tokio_retry::Retry::spawn(retry_strategy, || { async { let recovery_result: MetaResult<_> = try { // This is a quick path to accelerate the process of dropping and canceling streaming jobs. - let _ = self.pre_apply_drop_cancel().await?; + let _ = self + .context + .pre_apply_drop_cancel(&self.scheduled_barriers) + .await?; let active_streaming_nodes = ActiveStreamingWorkerNodes::new_snapshot( self.context.metadata_manager.clone(), @@ -427,14 +439,21 @@ impl GlobalBarrierManager { })? }; - // Reset all compute nodes, stop and drop existing actors. - self.reset_compute_nodes(&info, prev_epoch.value().0) + let mut control_stream_manager = + ControlStreamManager::new(self.context.clone()); + + control_stream_manager + .reset(prev_epoch.value().0, active_streaming_nodes.current()) .await .inspect_err(|err| { warn!(error = %err.as_report(), "reset compute nodes failed"); })?; - if self.pre_apply_drop_cancel().await? { + if self + .context + .pre_apply_drop_cancel(&self.scheduled_barriers) + .await? + { info = self .context .resolve_actor_info(all_nodes.clone()) @@ -445,10 +464,10 @@ impl GlobalBarrierManager { } // update and build all actors. - self.update_actors(&info).await.inspect_err(|err| { + self.context.update_actors(&info).await.inspect_err(|err| { warn!(error = %err.as_report(), "update actors failed"); })?; - self.build_actors(&info).await.inspect_err(|err| { + self.context.build_actors(&info).await.inspect_err(|err| { warn!(error = %err.as_report(), "build_actors failed"); })?; @@ -478,30 +497,25 @@ impl GlobalBarrierManager { tracing::Span::current(), // recovery span )); - let res = match self - .context - .inject_barrier(command_ctx.clone(), None, None) - .await - .result - { - Ok(response) => { - if let Err(err) = command_ctx.post_collect().await { - warn!(error = %err.as_report(), "post_collect failed"); - Err(err) - } else { - Ok((new_epoch.clone(), response)) + let mut node_to_collect = + control_stream_manager.inject_barrier(command_ctx.clone())?; + while !node_to_collect.is_empty() { + let (worker_id, _, resp) = control_stream_manager.next_response().await?; + assert_matches!( + resp, + StreamingControlStreamResponse { + response: Some( + streaming_control_stream_response::Response::CompleteBarrier(_) + ) } - } - Err(err) => { - warn!(error = %err.as_report(), "inject_barrier failed"); - Err(err) - } - }; - let (new_epoch, _) = res?; + ); + assert!(node_to_collect.remove(&worker_id)); + } ( BarrierManagerState::new(new_epoch, info, command_ctx.next_paused_reason()), active_streaming_nodes, + control_stream_manager, ) }; if recovery_result.is_err() { @@ -517,14 +531,17 @@ impl GlobalBarrierManager { recovery_timer.observe_duration(); self.scheduled_barriers.mark_ready(); + ( + self.state, + self.active_streaming_nodes, + self.control_stream_manager, + ) = new_state; + tracing::info!( - epoch = state.in_flight_prev_epoch().value().0, - paused = ?state.paused_reason(), + epoch = self.state.in_flight_prev_epoch().value().0, + paused = ?self.state.paused_reason(), "recovery success" ); - - self.state = state; - self.active_streaming_nodes = active_streaming_nodes; } } @@ -1013,9 +1030,7 @@ impl GlobalBarrierManagerContext { new_plan.insert(self.env.meta_store_checked()).await?; Ok(new_plan) } -} -impl GlobalBarrierManager { /// Update all actors in compute nodes. async fn update_actors(&self, info: &InflightActorInfo) -> MetaResult<()> { if info.actor_map.is_empty() { @@ -1041,7 +1056,7 @@ impl GlobalBarrierManager { .flatten_ok() .try_collect()?; - let mut all_node_actors = self.context.metadata_manager.all_node_actors(false).await?; + let mut all_node_actors = self.metadata_manager.all_node_actors(false).await?; // Check if any actors were dropped after info resolved. if all_node_actors.iter().any(|(node_id, node_actors)| { @@ -1055,8 +1070,7 @@ impl GlobalBarrierManager { return Err(anyhow!("actors dropped during update").into()); } - self.context - .stream_rpc_manager + self.stream_rpc_manager .broadcast_update_actor_info( &info.node_map, info.actor_map.keys().cloned(), @@ -1080,8 +1094,7 @@ impl GlobalBarrierManager { return Ok(()); } - self.context - .stream_rpc_manager + self.stream_rpc_manager .build_actors( &info.node_map, info.actor_map.iter().map(|(node_id, actors)| { @@ -1093,23 +1106,6 @@ impl GlobalBarrierManager { Ok(()) } - - /// Reset all compute nodes by calling `force_stop_actors`. - async fn reset_compute_nodes( - &self, - info: &InflightActorInfo, - prev_epoch: u64, - ) -> MetaResult<()> { - debug!(prev_epoch, worker = ?info.node_map.keys().collect_vec(), "force stop actors"); - self.context - .stream_rpc_manager - .force_stop_actors(info.node_map.values(), prev_epoch) - .await?; - - debug!(prev_epoch, "all compute nodes have been reset."); - - Ok(()) - } } #[cfg(test)] diff --git a/src/meta/src/barrier/rpc.rs b/src/meta/src/barrier/rpc.rs index dfe9ada44a47e..aa35d606c4bbf 100644 --- a/src/meta/src/barrier/rpc.rs +++ b/src/meta/src/barrier/rpc.rs @@ -12,14 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet, VecDeque}; use std::future::Future; use std::sync::Arc; use std::time::Duration; use anyhow::anyhow; use fail::fail_point; -use futures::stream::FuturesUnordered; +use futures::future::try_join_all; +use futures::stream::{BoxStream, FuturesUnordered}; use futures::{pin_mut, FutureExt, StreamExt}; use itertools::Itertools; use risingwave_common::bail; @@ -28,141 +29,193 @@ use risingwave_common::util::tracing::TracingContext; use risingwave_pb::common::{ActorInfo, WorkerNode}; use risingwave_pb::stream_plan::{Barrier, BarrierMutation, StreamActor}; use risingwave_pb::stream_service::{ - BarrierCompleteRequest, BroadcastActorInfoTableRequest, BuildActorsRequest, DropActorsRequest, - ForceStopActorsRequest, InjectBarrierRequest, UpdateActorsRequest, + streaming_control_stream_request, streaming_control_stream_response, + BroadcastActorInfoTableRequest, BuildActorsRequest, DropActorsRequest, InjectBarrierRequest, + StreamingControlStreamRequest, StreamingControlStreamResponse, UpdateActorsRequest, }; use risingwave_rpc_client::error::RpcError; use risingwave_rpc_client::StreamClient; use rw_futures_util::pending_on_none; -use tokio::sync::oneshot; +use thiserror_ext::AsReport; +use tokio::sync::mpsc::UnboundedSender; use tokio::time::timeout; -use tracing::Instrument; +use tracing::{error, info, warn}; use uuid::Uuid; use super::command::CommandContext; -use super::{BarrierCollectResult, GlobalBarrierManagerContext}; +use super::GlobalBarrierManagerContext; use crate::manager::{MetaSrvEnv, WorkerId}; use crate::{MetaError, MetaResult}; -pub(super) struct BarrierRpcManager { - context: GlobalBarrierManagerContext, +struct ControlStreamNode { + worker: WorkerNode, + sender: UnboundedSender, + // earlier epoch at the front + inflight_barriers: VecDeque>, +} + +fn into_future( + worker_id: WorkerId, + stream: BoxStream< + 'static, + risingwave_rpc_client::error::Result, + >, +) -> ResponseStreamFuture { + stream.into_future().map(move |(opt, stream)| { + ( + worker_id, + stream, + opt.ok_or_else(|| anyhow!("end of stream").into()) + .and_then(|result| result.map_err(|e| e.into())), + ) + }) +} - /// Futures that await on the completion of barrier. - injected_in_progress_barrier: FuturesUnordered, +type ResponseStreamFuture = impl Future< + Output = ( + WorkerId, + BoxStream< + 'static, + risingwave_rpc_client::error::Result, + >, + MetaResult, + ), + > + 'static; - prev_injecting_barrier: Option>, +pub(super) struct ControlStreamManager { + context: GlobalBarrierManagerContext, + nodes: HashMap, + response_streams: FuturesUnordered, } -impl BarrierRpcManager { +impl ControlStreamManager { pub(super) fn new(context: GlobalBarrierManagerContext) -> Self { Self { context, - injected_in_progress_barrier: FuturesUnordered::new(), - prev_injecting_barrier: None, + nodes: Default::default(), + response_streams: FuturesUnordered::new(), } } - pub(super) fn clear(&mut self) { - self.injected_in_progress_barrier = FuturesUnordered::new(); - self.prev_injecting_barrier = None; + pub(super) async fn add_worker(&mut self, node: WorkerNode) { + if self.nodes.contains_key(&node.id) { + warn!(id = node.id, host = ?node.host, "node already exists"); + return; + } + let prev_epoch = self + .context + .hummock_manager + .latest_snapshot() + .committed_epoch; + let node_id = node.id; + let node_host = node.host.clone().unwrap(); + match self.context.new_control_stream_node(node, prev_epoch).await { + Ok((stream_node, response_stream)) => { + let _ = self.nodes.insert(node_id, stream_node); + self.response_streams + .push(into_future(node_id, response_stream)); + info!(?node_host, "add control stream worker"); + } + Err(e) => { + error!(err = %e.as_report(), ?node_host, "fail to start control stream with worker node"); + } + } } - pub(super) fn inject_barrier(&mut self, command_context: Arc) { - // this is to notify that the barrier has been injected so that the next - // barrier can be injected to avoid out of order barrier injection. - // TODO: can be removed when bidi-stream control in implemented. - let (inject_tx, inject_rx) = oneshot::channel(); - let prev_inject_rx = self.prev_injecting_barrier.replace(inject_rx); - let await_complete_future = - self.context - .inject_barrier(command_context, Some(inject_tx), prev_inject_rx); - self.injected_in_progress_barrier - .push(await_complete_future); - } + pub(super) async fn reset( + &mut self, + prev_epoch: u64, + nodes: &HashMap, + ) -> MetaResult<()> { + let nodes = try_join_all(nodes.iter().map(|(worker_id, node)| async { + let node = self + .context + .new_control_stream_node(node.clone(), prev_epoch) + .await?; + Result::<_, MetaError>::Ok((*worker_id, node)) + })) + .await?; + self.nodes.clear(); + self.response_streams.clear(); + for (worker_id, (node, response_stream)) in nodes { + self.nodes.insert(worker_id, node); + self.response_streams + .push(into_future(worker_id, response_stream)); + } - pub(super) async fn next_collected_barrier(&mut self) -> BarrierCollectResult { - pending_on_none(self.injected_in_progress_barrier.next()).await + Ok(()) } -} - -pub(super) type BarrierCollectFuture = impl Future + Send + 'static; -impl GlobalBarrierManagerContext { - /// Inject a barrier to all CNs and spawn a task to collect it - pub(super) fn inject_barrier( - &self, - command_context: Arc, - inject_tx: Option>, - prev_inject_rx: Option>, - ) -> BarrierCollectFuture { - let (tx, rx) = oneshot::channel(); - let prev_epoch = command_context.prev_epoch.value().0; - let stream_rpc_manager = self.stream_rpc_manager.clone(); - // todo: the collect handler should be abort when recovery. - let _join_handle = tokio::spawn(async move { - let span = command_context.span.clone(); - if let Some(prev_inject_rx) = prev_inject_rx { - if prev_inject_rx.await.is_err() { - let _ = tx.send(BarrierCollectResult { - prev_epoch, - result: Err(anyhow!("prev barrier failed to be injected").into()), - }); - return; - } - } - let result = stream_rpc_manager - .inject_barrier(command_context.clone()) - .instrument(span.clone()) - .await; + pub(super) async fn next_response( + &mut self, + ) -> MetaResult<(WorkerId, u64, StreamingControlStreamResponse)> { + loop { + let (worker_id, response_stream, result) = + pending_on_none(self.response_streams.next()).await; match result { - Ok(node_need_collect) => { - if let Some(inject_tx) = inject_tx { - let _ = inject_tx.send(()); + Ok(resp) => match &resp.response { + Some(streaming_control_stream_response::Response::CompleteBarrier(_)) => { + self.response_streams + .push(into_future(worker_id, response_stream)); + let node = self + .nodes + .get_mut(&worker_id) + .expect("should exist when get collect resp"); + let command = node + .inflight_barriers + .pop_front() + .expect("should exist when get collect resp"); + break Ok((worker_id, command.prev_epoch.value().0, resp)); + } + resp => { + break Err(anyhow!("get unexpected resp: {:?}", resp).into()); + } + }, + Err(err) => { + let mut node = self + .nodes + .remove(&worker_id) + .expect("should exist when get collect resp"); + warn!(node = ?node.worker, err = ?err.as_report(), "get error from response stream"); + if let Some(command) = node.inflight_barriers.pop_front() { + self.context.report_collect_failure(&command, &err); + break Err(err); + } else { + // for node with no inflight barrier, simply ignore the error + continue; } - stream_rpc_manager - .collect_barrier(node_need_collect, command_context, tx) - .instrument(span.clone()) - .await; - } - Err(e) => { - let _ = tx.send(BarrierCollectResult { - prev_epoch, - result: Err(e), - }); } } - }); - rx.map(move |result| match result { - Ok(completion) => completion, - Err(_e) => BarrierCollectResult { - prev_epoch, - result: Err(anyhow!("failed to receive barrier completion result").into()), - }, - }) + } } } -impl StreamRpcManager { +impl ControlStreamManager { /// Send inject-barrier-rpc to stream service and wait for its response before returns. - async fn inject_barrier( - &self, + pub(super) fn inject_barrier( + &mut self, command_context: Arc, - ) -> MetaResult> { + ) -> MetaResult> { fail_point!("inject_barrier_err", |_| bail!("inject_barrier_err")); - let mutation = command_context.to_mutation().await?; + let mutation = command_context.to_mutation(); let info = command_context.info.clone(); - let mut node_need_collect = HashMap::new(); - self.make_request( - info.node_map.iter().filter_map(|(node_id, node)| { + let mut node_need_collect = HashSet::new(); + + info.node_map + .iter() + .map(|(node_id, worker_node)| { let actor_ids_to_send = info.actor_ids_to_send(node_id).collect_vec(); let actor_ids_to_collect = info.actor_ids_to_collect(node_id).collect_vec(); if actor_ids_to_collect.is_empty() { // No need to send or collect barrier for this node. assert!(actor_ids_to_send.is_empty()); - node_need_collect.insert(*node_id, false); - None + Ok(()) } else { - node_need_collect.insert(*node_id, true); + let Some(node) = self.nodes.get_mut(node_id) else { + return Err( + anyhow!("unconnected worker node: {:?}", worker_node.host).into() + ); + }; let mutation = mutation.clone(); let barrier = Barrier { epoch: Some(risingwave_pb::data::Epoch { @@ -177,104 +230,89 @@ impl StreamRpcManager { kind: command_context.kind as i32, passed_actors: vec![], }; - Some(( - node, - InjectBarrierRequest { - request_id: Self::new_request_id(), - barrier: Some(barrier), - actor_ids_to_send, - actor_ids_to_collect, - }, - )) - } - }), - |client, request| { - async move { - tracing::debug!( - target: "events::meta::barrier::inject_barrier", - ?request, "inject barrier request" - ); - // This RPC returns only if this worker node has injected this barrier. - client.inject_barrier(request).await - } - }, - ) - .await - .inspect_err(|e| { - // Record failure in event log. - use risingwave_pb::meta::event_log; - use thiserror_ext::AsReport; - let event = event_log::EventInjectBarrierFail { - prev_epoch: command_context.prev_epoch.value().0, - cur_epoch: command_context.curr_epoch.value().0, - error: e.to_report_string(), - }; - self.env - .event_log_manager_ref() - .add_event_logs(vec![event_log::Event::InjectBarrierFail(event)]); - })?; - Ok(node_need_collect) - } - - /// Send barrier-complete-rpc and wait for responses from all CNs - async fn collect_barrier( - &self, - node_need_collect: HashMap, - command_context: Arc, - barrier_collect_tx: oneshot::Sender, - ) { - let prev_epoch = command_context.prev_epoch.value().0; - let tracing_context = - TracingContext::from_span(command_context.prev_epoch.span()).to_protobuf(); - - let info = command_context.info.clone(); - let result = self - .broadcast( - info.node_map.iter().filter_map(|(node_id, node)| { - if !*node_need_collect.get(node_id).unwrap() { - // No need to send or collect barrier for this node. - None - } else { - Some(node) - } - }), - |client| { - let tracing_context = tracing_context.clone(); - async move { - let request = BarrierCompleteRequest { - request_id: Self::new_request_id(), - prev_epoch, - tracing_context, - }; - tracing::debug!( - target: "events::meta::barrier::barrier_complete", - ?request, "barrier complete" - ); + node.sender + .send(StreamingControlStreamRequest { + request: Some( + streaming_control_stream_request::Request::InjectBarrier( + InjectBarrierRequest { + request_id: StreamRpcManager::new_request_id(), + barrier: Some(barrier), + actor_ids_to_send, + actor_ids_to_collect, + }, + ), + ), + }) + .map_err(|_| { + MetaError::from(anyhow!( + "failed to send request to {} {:?}", + node.worker.id, + node.worker.host + )) + })?; - // This RPC returns only if this worker node has collected this barrier. - client.barrier_complete(request).await - } - }, - ) - .await + node.inflight_barriers.push_back(command_context.clone()); + node_need_collect.insert(*node_id); + Result::<_, MetaError>::Ok(()) + } + }) + .try_collect() .inspect_err(|e| { // Record failure in event log. use risingwave_pb::meta::event_log; - use thiserror_ext::AsReport; - let event = event_log::EventCollectBarrierFail { + let event = event_log::EventInjectBarrierFail { prev_epoch: command_context.prev_epoch.value().0, cur_epoch: command_context.curr_epoch.value().0, error: e.to_report_string(), }; - self.env + self.context + .env .event_log_manager_ref() - .add_event_logs(vec![event_log::Event::CollectBarrierFail(event)]); - }) - .map_err(Into::into); - let _ = barrier_collect_tx - .send(BarrierCollectResult { prev_epoch, result }) - .inspect_err(|_| tracing::warn!(prev_epoch, "failed to notify barrier completion")); + .add_event_logs(vec![event_log::Event::InjectBarrierFail(event)]); + })?; + Ok(node_need_collect) + } +} + +impl GlobalBarrierManagerContext { + async fn new_control_stream_node( + &self, + node: WorkerNode, + prev_epoch: u64, + ) -> MetaResult<( + ControlStreamNode, + BoxStream<'static, risingwave_rpc_client::error::Result>, + )> { + let handle = self + .env + .stream_client_pool() + .get(&node) + .await? + .start_streaming_control(prev_epoch) + .await?; + Ok(( + ControlStreamNode { + worker: node.clone(), + sender: handle.request_sender, + inflight_barriers: VecDeque::new(), + }, + handle.response_stream, + )) + } + + /// Send barrier-complete-rpc and wait for responses from all CNs + fn report_collect_failure(&self, command_context: &CommandContext, error: &MetaError) { + // Record failure in event log. + use risingwave_pb::meta::event_log; + let event = event_log::EventCollectBarrierFail { + prev_epoch: command_context.prev_epoch.value().0, + cur_epoch: command_context.curr_epoch.value().0, + error: error.to_report_string(), + }; + self.env + .event_log_manager_ref() + .add_event_logs(vec![event_log::Event::CollectBarrierFail(event)]); } } @@ -303,15 +341,6 @@ impl StreamRpcManager { result.map_err(|results_err| merge_node_rpc_errors("merged RPC Error", results_err)) } - async fn broadcast> + 'static>( - &self, - nodes: impl Iterator, - f: impl Fn(StreamClient) -> Fut, - ) -> MetaResult> { - self.make_request(nodes.map(|node| (node, ())), |client, ()| f(client)) - .await - } - fn new_request_id() -> String { Uuid::new_v4().to_string() } @@ -403,23 +432,6 @@ impl StreamRpcManager { .await?; Ok(()) } - - pub async fn force_stop_actors( - &self, - nodes: impl Iterator, - prev_epoch: u64, - ) -> MetaResult<()> { - self.broadcast(nodes, |client| async move { - client - .force_stop_actors(ForceStopActorsRequest { - request_id: Self::new_request_id(), - prev_epoch, - }) - .await - }) - .await?; - Ok(()) - } } /// This function is similar to `try_join_all`, but it attempts to collect as many error as possible within `error_timeout`. @@ -466,8 +478,6 @@ fn merge_node_rpc_errors( ) -> MetaError { use std::fmt::Write; - use thiserror_ext::AsReport; - let concat: String = errors .into_iter() .fold(format!("{message}:"), |mut s, (w, e)| { diff --git a/src/meta/src/stream/scale.rs b/src/meta/src/stream/scale.rs index 41ed041879f0c..99ae32d26bb92 100644 --- a/src/meta/src/stream/scale.rs +++ b/src/meta/src/stream/scale.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use std::time::Duration; use anyhow::{anyhow, Context}; -use futures::future::BoxFuture; +use futures::future::{try_join_all, BoxFuture}; use itertools::Itertools; use num_integer::Integer; use num_traits::abs; @@ -2651,9 +2651,33 @@ impl GlobalStreamManager { tracing::debug!("reschedule plan: {:?}", reschedule_fragment); + let up_down_stream_fragment: HashSet<_> = reschedule_fragment + .iter() + .flat_map(|(_, reschedule)| { + reschedule + .upstream_fragment_dispatcher_ids + .iter() + .map(|(fragment_id, _)| *fragment_id) + .chain(reschedule.downstream_fragment_ids.iter().cloned()) + }) + .collect(); + + let fragment_actors = + try_join_all(up_down_stream_fragment.iter().map(|fragment_id| async { + let actor_ids = self + .metadata_manager + .get_running_actors_of_fragment(*fragment_id) + .await?; + Result::<_, MetaError>::Ok((*fragment_id, actor_ids)) + })) + .await? + .into_iter() + .collect(); + let command = Command::RescheduleFragment { reschedules: reschedule_fragment, table_parallelism: table_parallelism.unwrap_or_default(), + fragment_actors, }; match &self.metadata_manager { diff --git a/src/meta/src/stream/stream_manager.rs b/src/meta/src/stream/stream_manager.rs index d6ef8944725e4..fa16b039236b6 100644 --- a/src/meta/src/stream/stream_manager.rs +++ b/src/meta/src/stream/stream_manager.rs @@ -756,6 +756,7 @@ mod tests { use std::sync::{Arc, Mutex}; use std::time::Duration; + use futures::{Stream, TryStreamExt}; use risingwave_common::catalog::TableId; use risingwave_common::hash::ParallelUnitMapping; use risingwave_common::system_param::reader::SystemParamsRead; @@ -768,16 +769,20 @@ mod tests { use risingwave_pb::stream_service::stream_service_server::{ StreamService, StreamServiceServer, }; + use risingwave_pb::stream_service::streaming_control_stream_response::InitResponse; use risingwave_pb::stream_service::{ BroadcastActorInfoTableResponse, BuildActorsResponse, DropActorsRequest, - DropActorsResponse, InjectBarrierRequest, InjectBarrierResponse, UpdateActorsResponse, *, + DropActorsResponse, UpdateActorsResponse, *, }; + use tokio::spawn; + use tokio::sync::mpsc::unbounded_channel; use tokio::sync::oneshot::Sender; #[cfg(feature = "failpoints")] use tokio::sync::Notify; use tokio::task::JoinHandle; use tokio::time::sleep; - use tonic::{Request, Response, Status}; + use tokio_stream::wrappers::UnboundedReceiverStream; + use tonic::{Request, Response, Status, Streaming}; use super::*; use crate::barrier::{GlobalBarrierManager, StreamRpcManager}; @@ -805,6 +810,9 @@ mod tests { #[async_trait::async_trait] impl StreamService for FakeStreamService { + type StreamingControlStreamStream = + impl Stream>; + async fn update_actors( &self, request: Request, @@ -856,29 +864,46 @@ mod tests { Ok(Response::new(DropActorsResponse::default())) } - async fn force_stop_actors( + async fn streaming_control_stream( &self, - _request: Request, - ) -> std::result::Result, Status> { - self.inner.actor_streams.lock().unwrap().clear(); - self.inner.actor_ids.lock().unwrap().clear(); - self.inner.actor_infos.lock().unwrap().clear(); - - Ok(Response::new(ForceStopActorsResponse::default())) - } - - async fn inject_barrier( - &self, - _request: Request, - ) -> std::result::Result, Status> { - Ok(Response::new(InjectBarrierResponse::default())) - } - - async fn barrier_complete( - &self, - _request: Request, - ) -> std::result::Result, Status> { - Ok(Response::new(BarrierCompleteResponse::default())) + request: Request>, + ) -> Result, Status> { + let (tx, rx) = unbounded_channel(); + let mut request_stream = request.into_inner(); + let inner = self.inner.clone(); + let _join_handle = spawn(async move { + while let Ok(Some(request)) = request_stream.try_next().await { + match request.request.unwrap() { + streaming_control_stream_request::Request::Init(_) => { + inner.actor_streams.lock().unwrap().clear(); + inner.actor_ids.lock().unwrap().clear(); + inner.actor_infos.lock().unwrap().clear(); + let _ = tx.send(Ok(StreamingControlStreamResponse { + response: Some(streaming_control_stream_response::Response::Init( + InitResponse {}, + )), + })); + } + streaming_control_stream_request::Request::InjectBarrier(_) => { + let _ = tx.send(Ok(StreamingControlStreamResponse { + response: Some( + streaming_control_stream_response::Response::CompleteBarrier( + BarrierCompleteResponse { + request_id: "".to_string(), + status: None, + create_mview_progress: vec![], + synced_sstables: vec![], + worker_id: 0, + table_watermarks: Default::default(), + }, + ), + ), + })); + } + } + } + }); + Ok(Response::new(UnboundedReceiverStream::new(rx))) } async fn wait_epoch_commit( diff --git a/src/rpc_client/src/lib.rs b/src/rpc_client/src/lib.rs index 0485465499f5a..fabd1dabeca01 100644 --- a/src/rpc_client/src/lib.rs +++ b/src/rpc_client/src/lib.rs @@ -43,7 +43,9 @@ use rand::prelude::SliceRandom; use risingwave_common::util::addr::HostAddr; use risingwave_pb::common::WorkerNode; use risingwave_pb::meta::heartbeat_request::extra_info; -use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::mpsc::{ + channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender, +}; pub mod error; use error::Result; @@ -63,7 +65,9 @@ pub use hummock_meta_client::{CompactionEventItem, HummockMetaClient}; pub use meta_client::{MetaClient, SinkCoordinationRpcClient}; use rw_futures_util::await_future_with_monitor_error_stream; pub use sink_coordinate_client::CoordinatorStreamHandle; -pub use stream_client::{StreamClient, StreamClientPool, StreamClientPoolRef}; +pub use stream_client::{ + StreamClient, StreamClientPool, StreamClientPoolRef, StreamingControlHandle, +}; #[async_trait] pub trait RpcClient: Send + Sync + 'static + Clone { @@ -274,3 +278,63 @@ impl BidiStreamHandle { } } } + +/// The handle of a bidi-stream started from the rpc client. It is similar to the `BidiStreamHandle` +/// except that its sender is unbounded. +pub struct UnboundedBidiStreamHandle { + pub request_sender: UnboundedSender, + pub response_stream: BoxStream<'static, Result>, +} + +impl Debug for UnboundedBidiStreamHandle { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(type_name::()) + } +} + +impl UnboundedBidiStreamHandle { + pub async fn initialize< + F: FnOnce(UnboundedReceiver) -> Fut, + St: Stream> + Send + Unpin + 'static, + Fut: Future> + Send, + R: Into, + >( + first_request: R, + init_stream_fn: F, + ) -> Result<(Self, RSP)> { + let (request_sender, request_receiver) = unbounded_channel(); + + // Send initial request in case of the blocking receive call from creating streaming request + request_sender + .send(first_request.into()) + .map_err(|_err| anyhow!("unable to send first request of {}", type_name::()))?; + + let mut response_stream = init_stream_fn(request_receiver).await?; + + let first_response = response_stream + .next() + .await + .context("get empty response from first request")??; + + Ok(( + Self { + request_sender, + response_stream: response_stream.boxed(), + }, + first_response, + )) + } + + pub async fn next_response(&mut self) -> Result { + self.response_stream + .next() + .await + .ok_or_else(|| anyhow!("end of response stream"))? + } + + pub fn send_request(&mut self, request: REQ) -> Result<()> { + self.request_sender + .send(request) + .map_err(|_| anyhow!("unable to send request {}", type_name::()).into()) + } +} diff --git a/src/rpc_client/src/stream_client.rs b/src/rpc_client/src/stream_client.rs index 3a271b5660bbd..ae5af65f28220 100644 --- a/src/rpc_client/src/stream_client.rs +++ b/src/rpc_client/src/stream_client.rs @@ -15,17 +15,22 @@ use std::sync::Arc; use std::time::Duration; +use anyhow::anyhow; use async_trait::async_trait; +use futures::TryStreamExt; use risingwave_common::config::MAX_CONNECTION_WINDOW_SIZE; use risingwave_common::monitor::connection::{EndpointExt, TcpConfig}; use risingwave_common::util::addr::HostAddr; use risingwave_pb::stream_service::stream_service_client::StreamServiceClient; +use risingwave_pb::stream_service::streaming_control_stream_request::InitRequest; +use risingwave_pb::stream_service::streaming_control_stream_response::InitResponse; use risingwave_pb::stream_service::*; +use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::transport::Endpoint; -use crate::error::Result; +use crate::error::{Result, RpcError}; use crate::tracing::{Channel, TracingInjectedChannelExt}; -use crate::{rpc_client_method_impl, RpcClient, RpcClientPool}; +use crate::{rpc_client_method_impl, RpcClient, RpcClientPool, UnboundedBidiStreamHandle}; #[derive(Clone)] pub struct StreamClient(StreamServiceClient); @@ -68,9 +73,6 @@ macro_rules! for_all_stream_rpc { ,{ 0, build_actors, BuildActorsRequest, BuildActorsResponse } ,{ 0, broadcast_actor_info_table, BroadcastActorInfoTableRequest, BroadcastActorInfoTableResponse } ,{ 0, drop_actors, DropActorsRequest, DropActorsResponse } - ,{ 0, force_stop_actors, ForceStopActorsRequest, ForceStopActorsResponse} - ,{ 0, inject_barrier, InjectBarrierRequest, InjectBarrierResponse } - ,{ 0, barrier_complete, BarrierCompleteRequest, BarrierCompleteResponse } ,{ 0, wait_epoch_commit, WaitEpochCommitRequest, WaitEpochCommitResponse } } }; @@ -79,3 +81,35 @@ macro_rules! for_all_stream_rpc { impl StreamClient { for_all_stream_rpc! { rpc_client_method_impl } } + +pub type StreamingControlHandle = + UnboundedBidiStreamHandle; + +impl StreamClient { + pub async fn start_streaming_control(&self, prev_epoch: u64) -> Result { + let first_request = StreamingControlStreamRequest { + request: Some(streaming_control_stream_request::Request::Init( + InitRequest { prev_epoch }, + )), + }; + let mut client = self.0.to_owned(); + let (handle, first_rsp) = + UnboundedBidiStreamHandle::initialize(first_request, |rx| async move { + client + .streaming_control_stream(UnboundedReceiverStream::new(rx)) + .await + .map(|response| response.into_inner().map_err(RpcError::from)) + .map_err(RpcError::from) + }) + .await?; + match first_rsp { + StreamingControlStreamResponse { + response: Some(streaming_control_stream_response::Response::Init(InitResponse {})), + } => {} + other => { + return Err(anyhow!("expect InitResponse but get {:?}", other).into()); + } + }; + Ok(handle) + } +} diff --git a/src/storage/hummock_sdk/src/key.rs b/src/storage/hummock_sdk/src/key.rs index c4a4761c58cb7..a12783a19b415 100644 --- a/src/storage/hummock_sdk/src/key.rs +++ b/src/storage/hummock_sdk/src/key.rs @@ -402,6 +402,23 @@ pub fn prefixed_range_with_vnode>( map_table_key_range((start, end)) } +pub trait SetSlice + ?Sized> { + fn set(&mut self, value: &S); +} + +impl + ?Sized> SetSlice for Vec { + fn set(&mut self, value: &S) { + self.clear(); + self.extend_from_slice(value.as_ref()); + } +} + +impl SetSlice for Bytes { + fn set(&mut self, value: &Bytes) { + *self = value.clone() + } +} + pub trait CopyFromSlice { fn copy_from_slice(slice: &[u8]) -> Self; } @@ -484,6 +501,12 @@ impl EstimateSize for TableKey { } } +impl<'a> TableKey<&'a [u8]> { + pub fn copy_into>(&self) -> TableKey { + TableKey(T::copy_from_slice(self.as_ref())) + } +} + #[inline] pub fn map_table_key_range(range: (Bound, Bound)) -> TableKeyRange { (range.0.map(TableKey), range.1.map(TableKey)) @@ -624,21 +647,22 @@ impl UserKey> { buf.advance(len); UserKey::new(TableId::new(table_id), TableKey(data)) } +} - pub fn extend_from_other(&mut self, other: &UserKey<&[u8]>) { - self.table_id = other.table_id; - self.table_key.0.clear(); - self.table_key.0.extend_from_slice(other.table_key.as_ref()); - } - +impl> UserKey { /// Use this method to override an old `UserKey>` with a `UserKey<&[u8]>` to own the /// table key without reallocating a new `UserKey` object. - pub fn set(&mut self, other: UserKey<&[u8]>) { + pub fn set(&mut self, other: UserKey) + where + T: SetSlice, + F: AsRef<[u8]>, + { self.table_id = other.table_id; - self.table_key.clear(); - self.table_key.extend_from_slice(other.table_key.as_ref()); + self.table_key.0.set(&other.table_key.0); } +} +impl UserKey> { pub fn into_bytes(self) -> UserKey { UserKey { table_id: self.table_id, @@ -811,10 +835,14 @@ impl> FullKey { } } -impl FullKey> { +impl> FullKey { /// Use this method to override an old `FullKey>` with a `FullKey<&[u8]>` to own the /// table key without reallocating a new `FullKey` object. - pub fn set(&mut self, other: FullKey<&[u8]>) { + pub fn set(&mut self, other: FullKey) + where + T: SetSlice, + F: AsRef<[u8]>, + { self.user_key.set(other.user_key); self.epoch_with_gap = other.epoch_with_gap; } @@ -835,15 +863,6 @@ impl + Ord + Eq> PartialOrd for FullKey { } } -impl<'a, T> From> for UserKey -where - T: AsRef<[u8]> + CopyFromSlice, -{ - fn from(value: UserKey<&'a [u8]>) -> Self { - value.copy_into() - } -} - #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct PointRange> { // When comparing `PointRange`, we first compare `left_user_key`, then @@ -977,20 +996,20 @@ impl + Ord + Eq, const SKIP_DEDUP: bool> FullKeyTracker(&mut self, key: FullKey) -> Option> + /// - If the provided `key` contains a new user key, return true. + /// - Otherwise: return false + pub fn observe(&mut self, key: FullKey) -> bool where - UserKey: Into>, + T: SetSlice, F: AsRef<[u8]>, { self.observe_multi_version(key.user_key, once(key.epoch_with_gap)) @@ -1001,9 +1020,9 @@ impl + Ord + Eq, const SKIP_DEDUP: bool> FullKeyTracker, mut epochs: impl Iterator, - ) -> Option> + ) -> bool where - UserKey: Into>, + T: SetSlice, F: AsRef<[u8]>, { let max_epoch_with_gap = epochs.next().expect("non-empty"); @@ -1033,16 +1052,11 @@ impl + Ord + Eq, const SKIP_DEDUP: bool> FullKeyTracker { if max_epoch_with_gap > self.last_observed_epoch_with_gap @@ -1055,7 +1069,7 @@ impl + Ord + Eq, const SKIP_DEDUP: bool> FullKeyTracker { // User key should be monotonically increasing diff --git a/src/storage/hummock_test/benches/bench_hummock_iter.rs b/src/storage/hummock_test/benches/bench_hummock_iter.rs index e6638e8bc95b4..059853433a2fc 100644 --- a/src/storage/hummock_test/benches/bench_hummock_iter.rs +++ b/src/storage/hummock_test/benches/bench_hummock_iter.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use bytes::Bytes; use criterion::{criterion_group, criterion_main, Criterion}; use foyer::memory::CacheContext; -use futures::{pin_mut, TryStreamExt}; +use futures::pin_mut; use risingwave_common::util::epoch::test_epoch; use risingwave_hummock_sdk::key::TableKey; use risingwave_hummock_sdk::HummockEpoch; diff --git a/src/storage/hummock_test/src/bin/replay/replay_impl.rs b/src/storage/hummock_test/src/bin/replay/replay_impl.rs index c3dedd6dbae46..43899fa7e256c 100644 --- a/src/storage/hummock_test/src/bin/replay/replay_impl.rs +++ b/src/storage/hummock_test/src/bin/replay/replay_impl.rs @@ -33,28 +33,28 @@ use risingwave_pb::meta::{SubscribeResponse, SubscribeType}; use risingwave_storage::hummock::store::LocalHummockStorage; use risingwave_storage::hummock::HummockStorage; use risingwave_storage::store::{ - LocalStateStore, StateStoreIterItemStream, StateStoreRead, SyncResult, + to_owned_item, LocalStateStore, StateStoreIterExt, StateStoreRead, SyncResult, }; -use risingwave_storage::{StateStore, StateStoreReadIterStream}; +use risingwave_storage::{StateStore, StateStoreIter, StateStoreReadIter}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; pub(crate) struct GlobalReplayIter where - S: StateStoreReadIterStream, + S: StateStoreReadIter, { inner: S, } impl GlobalReplayIter where - S: StateStoreReadIterStream, + S: StateStoreReadIter, { pub(crate) fn new(inner: S) -> Self { Self { inner } } pub(crate) fn into_stream(self) -> impl Stream> { - self.inner.map(|item_res| { + self.inner.into_stream(to_owned_item).map(|item_res| { item_res .map(|(key, value)| (key.user_key.table_key.0.into(), value.into())) .map_err(|_| TraceError::IterFailed("iter failed to retrieve item".to_string())) @@ -67,8 +67,9 @@ pub(crate) struct LocalReplayIter { } impl LocalReplayIter { - pub(crate) async fn new(stream: impl StateStoreIterItemStream) -> Self { - let inner = stream + pub(crate) async fn new(iter: impl StateStoreIter) -> Self { + let inner = iter + .into_stream(to_owned_item) .map_ok(|value| (value.0.user_key.table_key.0.into(), value.1.into())) .try_collect::>() .await @@ -115,7 +116,6 @@ impl ReplayRead for GlobalReplayImpl { .iter(key_range, epoch, read_options.into()) .await .unwrap(); - let iter = iter.boxed(); let stream = GlobalReplayIter::new(iter).into_stream().boxed(); Ok(stream) } @@ -241,7 +241,6 @@ impl LocalReplayRead for LocalReplayImpl { .await .unwrap(); - let iter = iter.boxed(); let stream = LocalReplayIter::new(iter).await.into_stream().boxed(); Ok(stream) } diff --git a/src/storage/hummock_test/src/compactor_tests.rs b/src/storage/hummock_test/src/compactor_tests.rs index e33f0250190c6..3718b06f00fe5 100644 --- a/src/storage/hummock_test/src/compactor_tests.rs +++ b/src/storage/hummock_test/src/compactor_tests.rs @@ -1457,7 +1457,8 @@ pub(crate) mod tests { fast_iter.key().user_key.table_id.table_id, ); assert_eq!(normal_iter.value(), fast_iter.value()); - let key_ref = fast_iter.key().user_key.as_ref(); + let key = fast_iter.key(); + let key_ref = key.user_key.as_ref(); assert!(normal_tables.iter().any(|table| { table.may_match_hash(&(Bound::Included(key_ref), Bound::Included(key_ref)), hash) })); diff --git a/src/storage/hummock_test/src/hummock_storage_tests.rs b/src/storage/hummock_test/src/hummock_storage_tests.rs index 31af1d0b8fd9a..e7b03fa61c504 100644 --- a/src/storage/hummock_test/src/hummock_storage_tests.rs +++ b/src/storage/hummock_test/src/hummock_storage_tests.rs @@ -258,7 +258,8 @@ async fn test_storage_basic() { }, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); futures::pin_mut!(iter); assert_eq!( Some(( @@ -335,7 +336,8 @@ async fn test_storage_basic() { }, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); futures::pin_mut!(iter); assert_eq!( Some(( @@ -388,7 +390,8 @@ async fn test_storage_basic() { }, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); futures::pin_mut!(iter); assert_eq!( Some(( @@ -627,7 +630,8 @@ async fn test_state_store_sync() { }, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); futures::pin_mut!(iter); let kv_map_batch_1 = [ @@ -680,7 +684,8 @@ async fn test_state_store_sync() { }, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); futures::pin_mut!(iter); @@ -1039,7 +1044,8 @@ async fn test_iter_with_min_epoch() { }, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); futures::pin_mut!(iter); @@ -1064,7 +1070,8 @@ async fn test_iter_with_min_epoch() { }, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); let result: Vec<_> = iter.try_collect().await.unwrap(); assert_eq!(20, result.len()); @@ -1088,7 +1095,8 @@ async fn test_iter_with_min_epoch() { }, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); futures::pin_mut!(iter); @@ -1131,7 +1139,8 @@ async fn test_iter_with_min_epoch() { }, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); futures::pin_mut!(iter); @@ -1156,7 +1165,8 @@ async fn test_iter_with_min_epoch() { }, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); futures::pin_mut!(iter); @@ -1182,7 +1192,8 @@ async fn test_iter_with_min_epoch() { }, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); futures::pin_mut!(iter); @@ -1294,7 +1305,8 @@ async fn test_hummock_version_reader() { read_snapshot, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); let result: Vec<_> = iter.try_collect().await.unwrap(); assert_eq!(10, result.len()); @@ -1328,7 +1340,8 @@ async fn test_hummock_version_reader() { read_snapshot, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); let result: Vec<_> = iter.try_collect().await.unwrap(); assert_eq!(20, result.len()); @@ -1363,7 +1376,8 @@ async fn test_hummock_version_reader() { read_snapshot, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); let result: Vec<_> = iter.try_collect().await.unwrap(); assert_eq!(10, result.len()); @@ -1422,7 +1436,8 @@ async fn test_hummock_version_reader() { read_snapshot, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); let result: Vec<_> = iter.try_collect().await.unwrap(); assert_eq!(10, result.len()); @@ -1465,7 +1480,8 @@ async fn test_hummock_version_reader() { read_snapshot, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); let result: Vec<_> = iter.try_collect().await.unwrap(); assert_eq!(20, result.len()); @@ -1500,7 +1516,8 @@ async fn test_hummock_version_reader() { read_snapshot, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); let result: Vec<_> = iter.try_collect().await.unwrap(); assert_eq!(10, result.len()); @@ -1534,7 +1551,8 @@ async fn test_hummock_version_reader() { read_snapshot, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); let result: Vec<_> = iter.try_collect().await.unwrap(); assert_eq!(30, result.len()); @@ -1571,7 +1589,8 @@ async fn test_hummock_version_reader() { read_snapshot, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); let result: Vec<_> = iter.try_collect().await.unwrap(); assert_eq!(8, result.len()); @@ -1602,7 +1621,8 @@ async fn test_hummock_version_reader() { read_snapshot, ) .await - .unwrap(); + .unwrap() + .into_stream(to_owned_item); let result: Vec<_> = iter.try_collect().await.unwrap(); assert_eq!(18, result.len()); @@ -1948,6 +1968,7 @@ async fn test_table_watermark() { ) .await .unwrap() + .into_stream(to_owned_item) .map_ok(|(full_key, value)| (full_key.user_key, value)) .try_collect::>() .await @@ -2017,6 +2038,7 @@ async fn test_table_watermark() { ) .await .unwrap() + .into_stream(to_owned_item) .map_ok(|(full_key, value)| (full_key.user_key, value)) .try_collect::>() .await @@ -2048,6 +2070,7 @@ async fn test_table_watermark() { ) .await .unwrap() + .into_stream(to_owned_item) .try_collect::>() .await .unwrap(); @@ -2116,6 +2139,7 @@ async fn test_table_watermark() { ) .await .unwrap() + .into_stream(to_owned_item) .map_ok(|(full_key, value)| (full_key.user_key, value)) .try_collect::>() .await @@ -2147,6 +2171,7 @@ async fn test_table_watermark() { ) .await .unwrap() + .into_stream(to_owned_item) .try_collect::>() .await .unwrap(); @@ -2218,6 +2243,7 @@ async fn test_table_watermark() { ) .await .unwrap() + .into_stream(to_owned_item) .map_ok(|(full_key, value)| (full_key.user_key, value)) .try_collect::>() .await @@ -2250,6 +2276,7 @@ async fn test_table_watermark() { ) .await .unwrap() + .into_stream(to_owned_item) .try_collect::>() .await .unwrap(); diff --git a/src/storage/hummock_test/src/snapshot_tests.rs b/src/storage/hummock_test/src/snapshot_tests.rs index 495ad7fc49cae..c019f12a0268b 100644 --- a/src/storage/hummock_test/src/snapshot_tests.rs +++ b/src/storage/hummock_test/src/snapshot_tests.rs @@ -16,7 +16,6 @@ use std::sync::Arc; use bytes::Bytes; use foyer::memory::CacheContext; -use futures::TryStreamExt; use risingwave_common::hash::VirtualNode; use risingwave_common::util::epoch::{test_epoch, EpochExt}; use risingwave_hummock_sdk::key::prefixed_range_with_vnode; @@ -38,6 +37,8 @@ use crate::test_utils::{ macro_rules! assert_count_range_scan { ($storage:expr, $vnode:expr, $range:expr, $expect_count:expr, $epoch:expr) => {{ use std::ops::RangeBounds; + + use risingwave_storage::StateStoreIter; let range = $range; let bounds: (Bound, Bound) = ( range.start_bound().map(|x: &Bytes| x.clone()), @@ -45,7 +46,7 @@ macro_rules! assert_count_range_scan { ); let vnode = $vnode; let table_key_range = prefixed_range_with_vnode(bounds, vnode); - let it = $storage + let mut it = $storage .iter( table_key_range, $epoch, @@ -57,7 +58,6 @@ macro_rules! assert_count_range_scan { ) .await .unwrap(); - futures::pin_mut!(it); let mut count = 0; loop { match it.try_next().await.unwrap() { diff --git a/src/storage/hummock_test/src/state_store_tests.rs b/src/storage/hummock_test/src/state_store_tests.rs index 22cc09233dfae..ff6385e35ab1e 100644 --- a/src/storage/hummock_test/src/state_store_tests.rs +++ b/src/storage/hummock_test/src/state_store_tests.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use bytes::Bytes; use expect_test::expect; use foyer::memory::CacheContext; -use futures::{pin_mut, StreamExt, TryStreamExt}; +use futures::{pin_mut, StreamExt}; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::{TableId, TableOption}; use risingwave_common::hash::VirtualNode; @@ -1442,6 +1442,7 @@ async fn test_replicated_local_hummock_storage() { ) .await .unwrap() + .into_stream(to_owned_item) .collect::>() .await; @@ -1509,6 +1510,7 @@ async fn test_replicated_local_hummock_storage() { ) .await .unwrap() + .into_stream(to_owned_item) .collect::>() .await; @@ -1544,6 +1546,7 @@ async fn test_replicated_local_hummock_storage() { ) .await .unwrap() + .into_stream(to_owned_item) .collect::>() .await; diff --git a/src/storage/src/hummock/compactor/compactor_runner.rs b/src/storage/src/hummock/compactor/compactor_runner.rs index 9566d57cc691e..38c3cea98136b 100644 --- a/src/storage/src/hummock/compactor/compactor_runner.rs +++ b/src/storage/src/hummock/compactor/compactor_runner.rs @@ -778,7 +778,7 @@ where let mut iter_key = iter.key(); compaction_statistics.iter_total_key_counts += 1; - let mut is_new_user_key = full_key_tracker.observe(iter.key()).is_some(); + let mut is_new_user_key = full_key_tracker.observe(iter.key()); let mut drop = false; // CRITICAL WARN: Because of memtable spill, there may be several versions of the same user-key share the same `pure_epoch`. Do not change this code unless necessary. diff --git a/src/storage/src/hummock/compactor/shared_buffer_compact.rs b/src/storage/src/hummock/compactor/shared_buffer_compact.rs index 075ee32ba9c46..dea4f017f4d48 100644 --- a/src/storage/src/hummock/compactor/shared_buffer_compact.rs +++ b/src/storage/src/hummock/compactor/shared_buffer_compact.rs @@ -336,16 +336,13 @@ pub async fn merge_imms_in_memory( table_id, table_key: key_entry.key.clone(), }; - if full_key_tracker - .observe_multi_version( - user_key, - key_entry - .new_values - .iter() - .map(|(epoch_with_gap, _)| *epoch_with_gap), - ) - .is_some() - { + if full_key_tracker.observe_multi_version( + user_key, + key_entry + .new_values + .iter() + .map(|(epoch_with_gap, _)| *epoch_with_gap), + ) { let last_entry = merged_entries.last_mut().expect("non-empty"); if last_entry.value_offset == values.len() { warn!(key = ?last_entry.key, "key has no value in imm compact. skipped"); @@ -423,7 +420,7 @@ fn generate_splits( if existing_table_ids.len() > 1 { if parallelism > 1 && compact_data_size > sstable_size { let mut last_buffer_size = 0; - let mut last_user_key = UserKey::default(); + let mut last_user_key: UserKey> = UserKey::default(); for (data_size, user_key) in size_and_start_user_keys { if last_buffer_size >= sub_compaction_data_size && last_user_key.as_ref() != user_key diff --git a/src/storage/src/hummock/iterator/forward_user.rs b/src/storage/src/hummock/iterator/forward_user.rs index 079ed59c5a8da..c3f94695d72c7 100644 --- a/src/storage/src/hummock/iterator/forward_user.rs +++ b/src/storage/src/hummock/iterator/forward_user.rs @@ -14,7 +14,7 @@ use std::ops::Bound::*; -use bytes::Bytes; +use risingwave_common::must_match; use risingwave_common::util::epoch::MAX_SPILL_TIMES; use risingwave_hummock_sdk::key::{FullKey, FullKeyTracker, UserKey, UserKeyRange}; use risingwave_hummock_sdk::{EpochWithGap, HummockEpoch}; @@ -32,10 +32,7 @@ pub struct UserIterator> { iterator: I, // Track the last seen full key - full_key_tracker: FullKeyTracker, - - /// Last user value - latest_val: Bytes, + full_key_tracker: FullKeyTracker, true>, /// Start and end bounds of user key. key_range: UserKeyRange, @@ -71,7 +68,6 @@ impl> UserIterator { Self { iterator, key_range, - latest_val: Bytes::new(), read_epoch, min_epoch, stats: StoreLocalStatistic::default(), @@ -119,17 +115,17 @@ impl> UserIterator { /// `rewind` or `seek` methods are called. /// /// Note: before call the function you need to ensure that the iterator is valid. - pub fn key(&self) -> &FullKey { + pub fn key(&self) -> FullKey<&[u8]> { assert!(self.is_valid()); - &self.full_key_tracker.latest_full_key + self.full_key_tracker.latest_full_key.to_ref() } /// The returned value is in the form of user value. /// /// Note: before call the function you need to ensure that the iterator is valid. - pub fn value(&self) -> &Bytes { + pub fn value(&self) -> &[u8] { assert!(self.is_valid()); - &self.latest_val + must_match!(self.iterator.value(), HummockValue::Put(val) => val) } /// Resets the iterating position to the beginning. @@ -245,7 +241,7 @@ impl> UserIterator { } // Skip older version entry for the same user key - if self.full_key_tracker.observe(full_key).is_none() { + if !self.full_key_tracker.observe(full_key) { self.stats.skip_multi_version_key_count += 1; self.iterator.next().await?; continue; @@ -265,12 +261,11 @@ impl> UserIterator { // Handle delete operation match self.iterator.value() { - HummockValue::Put(val) => { + HummockValue::Put(_val) => { self.delete_range_iter.next_until(full_key.user_key).await?; if self.delete_range_iter.current_epoch() >= epoch { self.stats.skip_delete_key_count += 1; } else { - self.latest_val = Bytes::copy_from_slice(val); self.stats.processed_key_count += 1; self.is_current_pos_valid = true; return Ok(()); @@ -325,6 +320,7 @@ mod tests { use std::ops::Bound::*; use std::sync::Arc; + use bytes::Bytes; use risingwave_common::util::epoch::test_epoch; use super::*; @@ -385,7 +381,7 @@ mod tests { while ui.is_valid() { let key = ui.key(); let val = ui.value(); - assert_eq!(key, &iterator_test_bytes_key_of(i)); + assert_eq!(key, iterator_test_bytes_key_of(i).to_ref()); assert_eq!(val, iterator_test_value_of(i).as_slice()); i += 1; ui.next().await.unwrap(); @@ -447,7 +443,7 @@ mod tests { let k = ui.key(); let v = ui.value(); assert_eq!(v, iterator_test_value_of(TEST_KEYS_COUNT + 5).as_slice()); - assert_eq!(k, &iterator_test_bytes_key_of(TEST_KEYS_COUNT + 5)); + assert_eq!(k, iterator_test_bytes_key_of(TEST_KEYS_COUNT + 5).to_ref()); ui.seek(iterator_test_bytes_user_key_of(2 * TEST_KEYS_COUNT + 5).as_ref()) .await .unwrap(); @@ -457,7 +453,10 @@ mod tests { v, iterator_test_value_of(2 * TEST_KEYS_COUNT + 5).as_slice() ); - assert_eq!(k, &iterator_test_bytes_key_of(2 * TEST_KEYS_COUNT + 5)); + assert_eq!( + k, + iterator_test_bytes_key_of(2 * TEST_KEYS_COUNT + 5).to_ref() + ); // left edge case ui.seek(iterator_test_bytes_user_key_of(0).as_ref()) @@ -466,7 +465,7 @@ mod tests { let k = ui.key(); let v = ui.value(); assert_eq!(v, iterator_test_value_of(0).as_slice()); - assert_eq!(k, &iterator_test_bytes_key_of(0)); + assert_eq!(k, iterator_test_bytes_key_of(0).to_ref()); } #[tokio::test] @@ -501,7 +500,7 @@ mod tests { // verify let k = ui.key(); let v = ui.value(); - assert_eq!(k, &iterator_test_bytes_key_of_epoch(2, 400)); + assert_eq!(k, iterator_test_bytes_key_of_epoch(2, 400).to_ref()); assert_eq!(v, &Bytes::from(iterator_test_value_of(2))); // only one valid kv pair @@ -559,11 +558,11 @@ mod tests { // ----- basic iterate ----- ui.rewind().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -571,11 +570,11 @@ mod tests { ui.seek(iterator_test_bytes_user_key_of(1).as_ref()) .await .unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -583,11 +582,11 @@ mod tests { ui.seek(iterator_test_bytes_user_key_of(2).as_ref()) .await .unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -637,11 +636,11 @@ mod tests { // ----- basic iterate ----- ui.rewind().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -649,11 +648,11 @@ mod tests { ui.seek(iterator_test_bytes_user_key_of(1).as_ref()) .await .unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -661,11 +660,11 @@ mod tests { ui.seek(iterator_test_bytes_user_key_of(2).as_ref()) .await .unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -698,13 +697,13 @@ mod tests { // ----- basic iterate ----- ui.rewind().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(1, 200)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(1, 200).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -712,13 +711,13 @@ mod tests { ui.seek(iterator_test_bytes_user_key_of(0).as_ref()) .await .unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(1, 200)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(1, 200).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -726,11 +725,11 @@ mod tests { ui.seek(iterator_test_bytes_user_key_of(2).as_ref()) .await .unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -762,13 +761,13 @@ mod tests { // ----- basic iterate ----- ui.rewind().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(8, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(8, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -776,13 +775,13 @@ mod tests { ui.seek(iterator_test_bytes_user_key_of(1).as_ref()) .await .unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(8, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(8, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -790,13 +789,13 @@ mod tests { ui.seek(iterator_test_bytes_user_key_of(2).as_ref()) .await .unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(2, 300)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(2, 300).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(3, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(3, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(6, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(6, 100).to_ref()); ui.next().await.unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(8, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(8, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -804,7 +803,7 @@ mod tests { ui.seek(iterator_test_bytes_user_key_of(8).as_ref()) .await .unwrap(); - assert_eq!(ui.key(), &iterator_test_bytes_key_of_epoch(8, 100)); + assert_eq!(ui.key(), iterator_test_bytes_key_of_epoch(8, 100).to_ref()); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -886,9 +885,15 @@ mod tests { // ----- basic iterate ----- ui.rewind().await.unwrap(); assert!(ui.is_valid()); - assert_eq!(ui.key().user_key, iterator_test_bytes_user_key_of(0)); + assert_eq!( + ui.key().user_key, + iterator_test_bytes_user_key_of(0).as_ref() + ); ui.next().await.unwrap(); - assert_eq!(ui.key().user_key, iterator_test_bytes_user_key_of(8)); + assert_eq!( + ui.key().user_key, + iterator_test_bytes_user_key_of(8).as_ref() + ); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -896,7 +901,10 @@ mod tests { ui.seek(iterator_test_bytes_user_key_of(1).as_ref()) .await .unwrap(); - assert_eq!(ui.key().user_key, iterator_test_bytes_user_key_of(8)); + assert_eq!( + ui.key().user_key, + iterator_test_bytes_user_key_of(8).as_ref() + ); ui.next().await.unwrap(); assert!(!ui.is_valid()); @@ -919,9 +927,15 @@ mod tests { ); ui.rewind().await.unwrap(); assert!(ui.is_valid()); - assert_eq!(ui.key().user_key, iterator_test_bytes_user_key_of(2)); + assert_eq!( + ui.key().user_key, + iterator_test_bytes_user_key_of(2).as_ref() + ); ui.next().await.unwrap(); - assert_eq!(ui.key().user_key, iterator_test_bytes_user_key_of(8)); + assert_eq!( + ui.key().user_key, + iterator_test_bytes_user_key_of(8).as_ref() + ); ui.next().await.unwrap(); assert!(!ui.is_valid()); } diff --git a/src/storage/src/hummock/shared_buffer/shared_buffer_batch.rs b/src/storage/src/hummock/shared_buffer/shared_buffer_batch.rs index 94d1bad371a68..5857c5d2f8bd2 100644 --- a/src/storage/src/hummock/shared_buffer/shared_buffer_batch.rs +++ b/src/storage/src/hummock/shared_buffer/shared_buffer_batch.rs @@ -23,7 +23,6 @@ use std::sync::atomic::Ordering::Relaxed; use std::sync::{Arc, LazyLock}; use bytes::Bytes; -use itertools::Itertools; use risingwave_common::catalog::TableId; use risingwave_common::hash::VirtualNode; use risingwave_hummock_sdk::key::{FullKey, PointRange, TableKey, TableKeyRange, UserKey}; @@ -674,6 +673,7 @@ impl SharedBufferDeleteRangeIterator { table_id: TableId, delete_ranges: Vec<(Bound, Bound)>, ) -> Self { + use itertools::Itertools; let point_range_pairs = delete_ranges .into_iter() .map(|(left_bound, right_bound)| { diff --git a/src/storage/src/hummock/store/hummock_storage.rs b/src/storage/src/hummock/store/hummock_storage.rs index 88c123ca5d5bc..46c6ba993e9be 100644 --- a/src/storage/src/hummock/store/hummock_storage.rs +++ b/src/storage/src/hummock/store/hummock_storage.rs @@ -260,7 +260,7 @@ impl HummockStorage { key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> StorageResult> { + ) -> StorageResult { let (key_range, read_version_tuple) = if read_options.read_version_from_backup { self.build_read_version_tuple_from_backup(epoch, read_options.table_id, key_range) .await? @@ -442,7 +442,7 @@ impl HummockStorage { } impl StateStoreRead for HummockStorage { - type IterStream = StreamTypeOfIter; + type Iter = HummockStorageIterator; fn get( &self, @@ -458,7 +458,7 @@ impl StateStoreRead for HummockStorage { key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> impl Future> + '_ { + ) -> impl Future> + '_ { let (l_vnode_inclusive, r_vnode_exclusive) = vnode_range(&key_range); assert_eq!( r_vnode_exclusive - l_vnode_inclusive, diff --git a/src/storage/src/hummock/store/local_hummock_storage.rs b/src/storage/src/hummock/store/local_hummock_storage.rs index 604a242c66754..38a653cfda375 100644 --- a/src/storage/src/hummock/store/local_hummock_storage.rs +++ b/src/storage/src/hummock/store/local_hummock_storage.rs @@ -135,7 +135,7 @@ impl LocalHummockStorage { table_key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> StorageResult> { + ) -> StorageResult { let (table_key_range, read_snapshot) = read_filter_for_version( epoch, read_options.table_id, @@ -163,7 +163,7 @@ impl LocalHummockStorage { table_key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> StorageResult>> { + ) -> StorageResult> { let (table_key_range, read_snapshot) = read_filter_for_version( epoch, read_options.table_id, @@ -205,7 +205,7 @@ impl LocalHummockStorage { } impl StateStoreRead for LocalHummockStorage { - type IterStream = StreamTypeOfIter; + type Iter = HummockStorageIterator; fn get( &self, @@ -222,7 +222,7 @@ impl StateStoreRead for LocalHummockStorage { key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> impl Future> + '_ { + ) -> impl Future> + '_ { assert!(epoch <= self.epoch()); self.iter_flushed(key_range, epoch, read_options) .instrument(tracing::trace_span!("hummock_iter")) @@ -230,7 +230,7 @@ impl StateStoreRead for LocalHummockStorage { } impl LocalStateStore for LocalHummockStorage { - type IterStream<'a> = StreamTypeOfIter>; + type Iter<'a> = LocalHummockStorageIterator<'a>; fn may_exist( &self, @@ -258,7 +258,7 @@ impl LocalStateStore for LocalHummockStorage { &self, key_range: TableKeyRange, read_options: ReadOptions, - ) -> StorageResult> { + ) -> StorageResult> { let (l_vnode_inclusive, r_vnode_exclusive) = vnode_range(&key_range); assert_eq!( r_vnode_exclusive - l_vnode_inclusive, @@ -593,19 +593,21 @@ pub type LocalHummockStorageIterator<'a> = HummockStorageIteratorInner<'a>; pub struct HummockStorageIteratorInner<'a> { inner: UserIterator>, + initial_read: bool, stats_guard: IterLocalMetricsGuard, } impl<'a> StateStoreIter for HummockStorageIteratorInner<'a> { - type Item = StateStoreIterItem; - - async fn next(&mut self) -> StorageResult> { + async fn try_next<'b>(&'b mut self) -> StorageResult>> { let iter = &mut self.inner; + if !self.initial_read { + self.initial_read = true; + } else { + iter.next().await?; + } if iter.is_valid() { - let kv = (iter.key().clone(), iter.value().clone()); - iter.next().await?; - Ok(Some(kv)) + Ok(Some((iter.key(), iter.value()))) } else { Ok(None) } @@ -621,6 +623,7 @@ impl<'a> HummockStorageIteratorInner<'a> { ) -> Self { Self { inner, + initial_read: false, stats_guard: IterLocalMetricsGuard::new(metrics, table_id, local_stats), } } diff --git a/src/storage/src/hummock/store/version.rs b/src/storage/src/hummock/store/version.rs index 80e0e94130d10..c32892613bbb9 100644 --- a/src/storage/src/hummock/store/version.rs +++ b/src/storage/src/hummock/store/version.rs @@ -60,7 +60,7 @@ use crate::mem_table::{ImmId, ImmutableMemtable, MemTableHummockIterator}; use crate::monitor::{ GetLocalMetricsGuard, HummockStateStoreMetrics, MayExistLocalMetricsGuard, StoreLocalStatistic, }; -use crate::store::{gen_min_epoch, ReadOptions, StateStoreIterExt, StreamTypeOfIter}; +use crate::store::{gen_min_epoch, ReadOptions}; pub type CommittedVersion = PinnedVersion; @@ -739,7 +739,7 @@ impl HummockVersionReader { epoch: u64, read_options: ReadOptions, read_version_tuple: ReadVersionTuple, - ) -> StorageResult> { + ) -> StorageResult { self.iter_inner( table_key_range, epoch, @@ -757,7 +757,7 @@ impl HummockVersionReader { read_options: ReadOptions, read_version_tuple: (Vec, Vec, CommittedVersion), memtable_iter: MemTableHummockIterator<'a>, - ) -> StorageResult>> { + ) -> StorageResult> { self.iter_inner( table_key_range, epoch, @@ -775,7 +775,7 @@ impl HummockVersionReader { read_options: ReadOptions, read_version_tuple: ReadVersionTuple, mem_table: Option>, - ) -> StorageResult>> { + ) -> StorageResult> { let (imms, uncommitted_ssts, committed) = read_version_tuple; let mut local_stats = StoreLocalStatistic::default(); @@ -1009,8 +1009,7 @@ impl HummockVersionReader { self.state_store_metrics.clone(), read_options.table_id, local_stats, - ) - .into_stream()) + )) } // Note: this method will not check the kv tomestones and delete range tomestones diff --git a/src/storage/src/hummock/test_utils.rs b/src/storage/src/hummock/test_utils.rs index 85c549671ea77..b6c658e0fa63f 100644 --- a/src/storage/src/hummock/test_utils.rs +++ b/src/storage/src/hummock/test_utils.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use bytes::Bytes; use foyer::memory::CacheContext; -use futures::{Stream, TryStreamExt}; use itertools::Itertools; use risingwave_common::catalog::TableId; use risingwave_common::hash::VirtualNode; @@ -32,7 +31,6 @@ use super::{ HummockResult, InMemWriter, MonotonicDeleteEvent, SstableMeta, SstableWriterOptions, DEFAULT_RESTART_INTERVAL, }; -use crate::error::StorageResult; use crate::filter_key_extractor::{FilterKeyExtractorImpl, FullKeyFilterKeyExtractor}; use crate::hummock::iterator::ForwardMergeRangeIterator; use crate::hummock::shared_buffer::shared_buffer_batch::SharedBufferBatch; @@ -45,6 +43,7 @@ use crate::hummock::{ use crate::monitor::StoreLocalStatistic; use crate::opts::StorageOpts; use crate::storage_value::StorageValue; +use crate::StateStoreIter; pub fn default_opts_for_test() -> StorageOpts { StorageOpts { @@ -376,10 +375,9 @@ pub async fn gen_default_test_sstable( .await } -pub async fn count_stream(s: impl Stream> + Send) -> usize { - futures::pin_mut!(s); +pub async fn count_stream(mut i: impl StateStoreIter + Send) -> usize { let mut c: usize = 0; - while s.try_next().await.unwrap().is_some() { + while i.try_next().await.unwrap().is_some() { c += 1 } c diff --git a/src/storage/src/lib.rs b/src/storage/src/lib.rs index 8e6efa63e3545..505eec276fbf4 100644 --- a/src/storage/src/lib.rs +++ b/src/storage/src/lib.rs @@ -62,10 +62,5 @@ pub mod mem_table; #[cfg(feature = "failpoints")] mod storage_failpoints; -pub use store::{StateStore, StateStoreIter, StateStoreReadIterStream}; +pub use store::{StateStore, StateStoreIter, StateStoreReadIter}; pub use store_impl::StateStoreImpl; - -pub enum TableScanOptions { - SequentialScan, - SparseIndexScan, -} diff --git a/src/storage/src/mem_table.rs b/src/storage/src/mem_table.rs index b03f24b901580..99f02758623e2 100644 --- a/src/storage/src/mem_table.rs +++ b/src/storage/src/mem_table.rs @@ -21,7 +21,7 @@ use std::ops::RangeBounds; use std::sync::Arc; use bytes::Bytes; -use futures::{pin_mut, StreamExt}; +use futures::{pin_mut, Stream, StreamExt}; use futures_async_stream::try_stream; use itertools::Itertools; use risingwave_common::buffer::Bitmap; @@ -350,7 +350,7 @@ impl KeyOp { #[try_stream(ok = StateStoreIterItem, error = StorageError)] pub(crate) async fn merge_stream<'a>( mem_table_iter: impl Iterator, &'a KeyOp)> + 'a, - inner_stream: impl StateStoreReadIterStream, + inner_stream: impl Stream> + 'static, table_id: TableId, epoch: u64, ) { @@ -458,7 +458,7 @@ impl MemtableLocalStateStore { } impl LocalStateStore for MemtableLocalStateStore { - type IterStream<'a> = impl StateStoreIterItemStream + 'a; + type Iter<'a> = impl StateStoreIter + 'a; #[allow(clippy::unused_async)] async fn may_exist( @@ -488,18 +488,18 @@ impl LocalStateStore for MemtableLocalState &self, key_range: TableKeyRange, read_options: ReadOptions, - ) -> impl Future>> + Send + '_ { + ) -> impl Future>> + Send + '_ { async move { - let stream = self + let iter = self .inner .iter(key_range.clone(), self.epoch(), read_options) .await?; - Ok(merge_stream( + Ok(FromStreamStateStoreIter::new(Box::pin(merge_stream( self.mem_table.iter(key_range), - stream, + iter.into_stream(to_owned_item), self.table_id, self.epoch(), - )) + )))) } } diff --git a/src/storage/src/memory.rs b/src/storage/src/memory.rs index 4984fbf5de0a2..096020ff569f6 100644 --- a/src/storage/src/memory.rs +++ b/src/storage/src/memory.rs @@ -537,7 +537,7 @@ impl RangeKvStateStore { } impl StateStoreRead for RangeKvStateStore { - type IterStream = StreamTypeOfIter>; + type Iter = RangeKvStateStoreIter; #[allow(clippy::unused_async)] async fn get( @@ -563,15 +563,14 @@ impl StateStoreRead for RangeKvStateStore { key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> StorageResult { + ) -> StorageResult { Ok(RangeKvStateStoreIter::new( batched_iter::Iter::new( self.inner.clone(), to_full_key_range(read_options.table_id, key_range), ), epoch, - ) - .into_stream()) + )) } } @@ -636,8 +635,10 @@ impl StateStore for RangeKvStateStore { fn seal_epoch(&self, _epoch: u64, _is_checkpoint: bool) {} #[allow(clippy::unused_async)] - async fn clear_shared_buffer(&self, _prev_epoch: u64) { - unimplemented!("recovery not supported") + async fn clear_shared_buffer(&self, prev_epoch: u64) { + for (key, _) in self.inner.range((Unbounded, Unbounded), None).unwrap() { + assert!(key.epoch_with_gap.pure_epoch() <= prev_epoch); + } } #[allow(clippy::unused_async)] @@ -657,8 +658,7 @@ pub struct RangeKvStateStoreIter { last_key: Option>, - /// For supporting semantic of `Fuse` - stopped: bool, + item_buffer: Option, } impl RangeKvStateStoreIter { @@ -667,29 +667,21 @@ impl RangeKvStateStoreIter { inner, epoch, last_key: None, - stopped: false, + item_buffer: None, } } } impl StateStoreIter for RangeKvStateStoreIter { - type Item = StateStoreIterItem; - #[allow(clippy::unused_async)] - async fn next(&mut self) -> StorageResult> { - if self.stopped { - Ok(None) - } else { - let ret = self.next_inner(); - match &ret { - Err(_) | Ok(None) => { - self.stopped = true; - } - _ => {} - } - - ret - } + async fn try_next(&mut self) -> StorageResult>> { + let ret = self.next_inner(); + let item = ret?; + self.item_buffer = item; + Ok(self + .item_buffer + .as_ref() + .map(|(key, value)| (key.to_ref(), value.as_ref()))) } } diff --git a/src/storage/src/monitor/monitored_store.rs b/src/storage/src/monitor/monitored_store.rs index 239f2ce9df7a0..ad8d831aa84ee 100644 --- a/src/storage/src/monitor/monitored_store.rs +++ b/src/storage/src/monitor/monitored_store.rs @@ -16,20 +16,19 @@ use std::sync::Arc; use await_tree::InstrumentAwait; use bytes::Bytes; -use futures::{Future, TryFutureExt, TryStreamExt}; -use futures_async_stream::try_stream; +use futures::{Future, TryFutureExt}; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::TableId; use risingwave_hummock_sdk::key::{TableKey, TableKeyRange}; use risingwave_hummock_sdk::HummockReadEpoch; use thiserror_ext::AsReport; use tokio::time::Instant; -use tracing::error; +use tracing::{error, Instrument}; #[cfg(all(not(madsim), feature = "hm-trace"))] use super::traced_store::TracedStateStore; use super::{MonitoredStateStoreGetStats, MonitoredStateStoreIterStats, MonitoredStorageMetrics}; -use crate::error::{StorageError, StorageResult}; +use crate::error::StorageResult; use crate::hummock::sstable_store::SstableStoreRef; use crate::hummock::{HummockStorage, SstableObjectIdManagerRef}; use crate::store::*; @@ -76,22 +75,20 @@ impl MonitoredStateStore { } /// A util function to break the type connection between two opaque return types defined by `impl`. -pub(crate) fn identity(input: impl StateStoreIterItemStream) -> impl StateStoreIterItemStream { +pub(crate) fn identity(input: impl StateStoreIter) -> impl StateStoreIter { input } -pub type MonitoredStateStoreIterStream = impl StateStoreIterItemStream; - -// Note: it is important to define the `MonitoredStateStoreIterStream` type alias, as it marks that +// Note: it is important to define the `MonitoredStateStoreIter` type alias, as it marks that // the return type of `monitored_iter` only captures the lifetime `'s` and has nothing to do with -// `'a`. If we simply use `impl StateStoreIterItemStream + 's`, the rust compiler will also capture +// `'a`. If we simply use `impl StateStoreIter + 's`, the rust compiler will also capture // the lifetime `'a` in the scope defined in the scope. impl MonitoredStateStore { - async fn monitored_iter<'a, St: StateStoreIterItemStream + 'a>( + async fn monitored_iter<'a, St: StateStoreIter + 'a>( &'a self, table_id: TableId, iter_stream_future: impl Future> + 'a, - ) -> StorageResult> { + ) -> StorageResult> { // start time takes iterator build time into account // wait for iterator creation (e.g. seek) let start_time = Instant::now(); @@ -109,7 +106,7 @@ impl MonitoredStateStore { self.storage_metrics.clone(), ), }; - Ok(monitored.into_stream()) + Ok(monitored) } pub fn inner(&self) -> &S { @@ -127,8 +124,6 @@ impl MonitoredStateStore { table_id: TableId, key_len: usize, ) -> StorageResult> { - use tracing::Instrument; - let mut stats = MonitoredStateStoreGetStats::new(table_id.table_id, self.storage_metrics.clone()); @@ -149,7 +144,7 @@ impl MonitoredStateStore { } impl StateStoreRead for MonitoredStateStore { - type IterStream = impl StateStoreReadIterStream; + type Iter = impl StateStoreReadIter; fn get( &self, @@ -167,7 +162,7 @@ impl StateStoreRead for MonitoredStateStore { key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> impl Future> + '_ { + ) -> impl Future> + '_ { self.monitored_iter( read_options.table_id, self.inner.iter(key_range, epoch, read_options), @@ -177,7 +172,7 @@ impl StateStoreRead for MonitoredStateStore { } impl LocalStateStore for MonitoredStateStore { - type IterStream<'a> = impl StateStoreIterItemStream + 'a; + type Iter<'a> = impl StateStoreIter + 'a; async fn may_exist( &self, @@ -214,7 +209,7 @@ impl LocalStateStore for MonitoredStateStore { &self, key_range: TableKeyRange, read_options: ReadOptions, - ) -> impl Future>> + Send + '_ { + ) -> impl Future>> + Send + '_ { let table_id = read_options.table_id; // TODO: may collect the metrics as local self.monitored_iter(table_id, self.inner.iter(key_range, read_options)) @@ -348,26 +343,20 @@ pub struct MonitoredStateStoreIter { stats: MonitoredStateStoreIterStats, } -impl MonitoredStateStoreIter { - #[try_stream(ok = StateStoreIterItem, error = StorageError)] - async fn into_stream_inner(self) { - let inner = self.inner; - - let mut stats = self.stats; - futures::pin_mut!(inner); - while let Some((key, value)) = inner +impl StateStoreIter for MonitoredStateStoreIter { + async fn try_next(&mut self) -> StorageResult>> { + if let Some((key, value)) = self + .inner .try_next() + .instrument(tracing::trace_span!("store_iter_try_next")) .await .inspect_err(|e| error!(error = %e.as_report(), "Failed in next"))? { - stats.total_items += 1; - stats.total_size += key.encoded_len() + value.len(); - yield (key, value); + self.stats.total_items += 1; + self.stats.total_size += key.encoded_len() + value.len(); + Ok(Some((key, value))) + } else { + Ok(None) } - drop(stats); - } - - fn into_stream(self) -> MonitoredStateStoreIterStream { - Self::into_stream_inner(self) } } diff --git a/src/storage/src/monitor/traced_store.rs b/src/storage/src/monitor/traced_store.rs index de55143dd7d73..8fe9a6705cc09 100644 --- a/src/storage/src/monitor/traced_store.rs +++ b/src/storage/src/monitor/traced_store.rs @@ -14,8 +14,7 @@ use std::sync::Arc; use bytes::Bytes; -use futures::{Future, TryFutureExt, TryStreamExt}; -use futures_async_stream::try_stream; +use futures::{Future, TryFutureExt}; use risingwave_common::buffer::Bitmap; use risingwave_hummock_sdk::key::{TableKey, TableKeyRange}; use risingwave_hummock_sdk::HummockReadEpoch; @@ -26,7 +25,7 @@ use risingwave_hummock_trace::{ use thiserror_ext::AsReport; use super::identity; -use crate::error::{StorageError, StorageResult}; +use crate::error::StorageResult; use crate::hummock::sstable_store::SstableStoreRef; use crate::hummock::{HummockStorage, SstableObjectIdManagerRef}; use crate::store::*; @@ -67,11 +66,11 @@ impl TracedStateStore { } } - async fn traced_iter<'a, St: StateStoreIterItemStream>( + async fn traced_iter<'a, St: StateStoreIter>( &'a self, iter_stream_future: impl Future> + 'a, span: MayTraceSpan, - ) -> StorageResult> { + ) -> StorageResult> { let res = iter_stream_future.await; if res.is_ok() { span.may_send_result(OperationResult::Iter(TraceResult::Ok(()))); @@ -79,7 +78,7 @@ impl TracedStateStore { span.may_send_result(OperationResult::Iter(TraceResult::Err)); } let traced = TracedStateStoreIter::new(res?, span); - Ok(traced.into_stream()) + Ok(traced) } async fn traced_get( @@ -106,10 +105,8 @@ impl TracedStateStore { } } -type TracedStateStoreIterStream = impl StateStoreIterItemStream; - impl LocalStateStore for TracedStateStore { - type IterStream<'a> = impl StateStoreIterItemStream + 'a; + type Iter<'a> = impl StateStoreIter + 'a; fn may_exist( &self, @@ -136,7 +133,7 @@ impl LocalStateStore for TracedStateStore { &self, key_range: TableKeyRange, read_options: ReadOptions, - ) -> impl Future>> + Send + '_ { + ) -> impl Future>> + Send + '_ { let (l, r) = key_range.clone(); let bytes_key_range = (l.map(|l| l.0), r.map(|r| r.0)); let span = TraceSpan::new_iter_span( @@ -277,7 +274,7 @@ impl StateStore for TracedStateStore { } impl StateStoreRead for TracedStateStore { - type IterStream = impl StateStoreReadIterStream; + type Iter = impl StateStoreReadIter; fn get( &self, @@ -298,7 +295,7 @@ impl StateStoreRead for TracedStateStore { key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> impl Future> + '_ { + ) -> impl Future> + '_ { let (l, r) = key_range.clone(); let bytes_key_range = (l.map(|l| l.0), r.map(|r| r.0)); let span = TraceSpan::new_iter_span( @@ -347,13 +344,10 @@ impl TracedStateStoreIter { } } -impl TracedStateStoreIter { - #[try_stream(ok = StateStoreIterItem, error = StorageError)] - async fn into_stream_inner(self) { - let inner = self.inner; - futures::pin_mut!(inner); - - while let Some((key, value)) = inner +impl StateStoreIter for TracedStateStoreIter { + async fn try_next(&mut self) -> StorageResult>> { + if let Some((key, value)) = self + .inner .try_next() .await .inspect_err(|e| tracing::error!(error = %e.as_report(), "Failed in next"))? @@ -362,15 +356,13 @@ impl TracedStateStoreIter { self.span .may_send_result(OperationResult::IterNext(TraceResult::Ok(Some(( TracedBytes::from(key.user_key.table_key.to_vec()), - TracedBytes::from(value.clone()), + TracedBytes::from(Bytes::copy_from_slice(value)), ))))); - yield (key, value); + Ok(Some((key, value))) + } else { + Ok(None) } } - - fn into_stream(self) -> TracedStateStoreIterStream { - Self::into_stream_inner(self) - } } pub fn get_concurrent_id() -> ConcurrentId { diff --git a/src/storage/src/panic_store.rs b/src/storage/src/panic_store.rs index 7e5985bb6aefe..6002acd9e1057 100644 --- a/src/storage/src/panic_store.rs +++ b/src/storage/src/panic_store.rs @@ -13,12 +13,9 @@ // limitations under the License. use std::ops::Bound; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; use bytes::Bytes; -use futures::Stream; use risingwave_common::buffer::Bitmap; use risingwave_hummock_sdk::key::{TableKey, TableKeyRange}; use risingwave_hummock_sdk::HummockReadEpoch; @@ -33,7 +30,7 @@ use crate::store::*; pub struct PanicStateStore; impl StateStoreRead for PanicStateStore { - type IterStream = PanicStateStoreStream; + type Iter = PanicStateStoreStream; #[allow(clippy::unused_async)] async fn get( @@ -51,7 +48,7 @@ impl StateStoreRead for PanicStateStore { _key_range: TableKeyRange, _epoch: u64, _read_options: ReadOptions, - ) -> StorageResult { + ) -> StorageResult { panic!("should not read from the state store!"); } } @@ -68,7 +65,7 @@ impl StateStoreWrite for PanicStateStore { } impl LocalStateStore for PanicStateStore { - type IterStream<'a> = PanicStateStoreStream; + type Iter<'a> = PanicStateStoreStream; #[allow(clippy::unused_async)] async fn may_exist( @@ -93,7 +90,7 @@ impl LocalStateStore for PanicStateStore { &self, _key_range: TableKeyRange, _read_options: ReadOptions, - ) -> StorageResult> { + ) -> StorageResult> { panic!("should not operate on the panic state store!"); } @@ -174,12 +171,10 @@ impl StateStore for PanicStateStore { } } -pub struct PanicStateStoreStream {} +pub struct PanicStateStoreStream; -impl Stream for PanicStateStoreStream { - type Item = StorageResult; - - fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { +impl StateStoreIter for PanicStateStoreStream { + async fn try_next(&mut self) -> StorageResult>> { panic!("should not call next on panic state store stream") } } diff --git a/src/storage/src/storage_failpoints/test_iterator.rs b/src/storage/src/storage_failpoints/test_iterator.rs index 7b1aa31c808cd..463c20ed469de 100644 --- a/src/storage/src/storage_failpoints/test_iterator.rs +++ b/src/storage/src/storage_failpoints/test_iterator.rs @@ -288,7 +288,7 @@ async fn test_failpoints_user_read_err() { while ui.is_valid() { let key = ui.key(); let val = ui.value(); - assert_eq!(key, &iterator_test_bytes_key_of(i)); + assert_eq!(key, iterator_test_bytes_key_of(i).to_ref()); assert_eq!(val, iterator_test_value_of(i).as_slice()); i += 1; let result = ui.next().await; diff --git a/src/storage/src/store.rs b/src/storage/src/store.rs index 96838e1ef25d1..2a70002c42af8 100644 --- a/src/storage/src/store.rs +++ b/src/storage/src/store.rs @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::cmp::min; use std::collections::HashMap; use std::default::Default; use std::fmt::{Debug, Formatter}; use std::future::Future; +use std::marker::PhantomData; use std::ops::Bound; use std::sync::{Arc, LazyLock}; use bytes::Bytes; -use futures::{Stream, StreamExt, TryStreamExt}; +use futures::{Stream, TryStreamExt}; use futures_async_stream::try_stream; use prost::Message; use risingwave_common::buffer::Bitmap; @@ -44,40 +46,126 @@ use crate::storage_value::StorageValue; pub trait StaticSendSync = Send + Sync + 'static; -pub trait StateStoreIter: Send + Sync { - type Item: Send; +pub trait IterItem: Send + 'static { + type ItemRef<'a>: Send + 'a; +} + +impl IterItem for StateStoreIterItem { + type ItemRef<'a> = StateStoreIterItemRef<'a>; +} + +pub trait StateStoreIter: Send { + fn try_next( + &mut self, + ) -> impl Future>>> + Send + '_; +} - fn next(&mut self) -> impl Future>> + Send + '_; +pub fn to_owned_item((key, value): StateStoreIterItemRef<'_>) -> StorageResult { + Ok((key.copy_into(), Bytes::copy_from_slice(value))) } -pub trait StateStoreIterExt: StateStoreIter { - type ItemStream: Stream::Item>> + Send; +pub trait StateStoreIterExt: StateStoreIter + Sized { + type ItemStream: Stream> + Send; - fn into_stream(self) -> Self::ItemStream; + fn into_stream Fn(T::ItemRef<'a>) -> StorageResult + Send>( + self, + f: F, + ) -> Self::ItemStream; + + fn fused(self) -> FusedStateStoreIter { + FusedStateStoreIter::new(self) + } } -#[try_stream(ok = I::Item, error = StorageError)] -async fn into_stream_inner(mut iter: I) { - while let Some(item) = iter.next().await? { - yield item; +#[try_stream(ok = O, error = StorageError)] +async fn into_stream_inner< + T: IterItem, + I: StateStoreIter, + O: Send, + F: for<'a> Fn(T::ItemRef<'a>) -> StorageResult + Send, +>( + iter: I, + f: F, +) { + let mut iter = iter.fused(); + while let Some(item) = iter.try_next().await? { + yield f(item)?; } } -pub type StreamTypeOfIter = ::ItemStream; -impl StateStoreIterExt for I { - type ItemStream = impl Stream::Item>>; +pub struct FromStreamStateStoreIter { + inner: S, + item_buffer: Option, +} + +impl FromStreamStateStoreIter { + pub fn new(inner: S) -> Self { + Self { + inner, + item_buffer: None, + } + } +} - fn into_stream(self) -> Self::ItemStream { - into_stream_inner(self) +impl> + Unpin + Send> StateStoreIter + for FromStreamStateStoreIter +{ + async fn try_next(&mut self) -> StorageResult>> { + self.item_buffer = self.inner.try_next().await?; + Ok(self + .item_buffer + .as_ref() + .map(|(key, value)| (key.to_ref(), value.as_ref()))) } } +pub struct FusedStateStoreIter { + inner: I, + finished: bool, + _phantom: PhantomData, +} + +impl FusedStateStoreIter { + fn new(inner: I) -> Self { + Self { + inner, + finished: false, + _phantom: PhantomData, + } + } +} + +impl> FusedStateStoreIter { + async fn try_next(&mut self) -> StorageResult>> { + assert!(!self.finished, "call try_next after finish"); + let result = self.inner.try_next().await; + match &result { + Ok(Some(_)) => {} + Ok(None) | Err(_) => { + self.finished = true; + } + } + result + } +} + +impl> StateStoreIterExt for I { + type ItemStream = impl Stream> + Send; + + fn into_stream Fn(T::ItemRef<'a>) -> StorageResult + Send>( + self, + f: F, + ) -> Self::ItemStream { + into_stream_inner(self, f) + } +} + +pub type StateStoreIterItemRef<'a> = (FullKey<&'a [u8]>, &'a [u8]); pub type StateStoreIterItem = (FullKey, Bytes); -pub trait StateStoreIterItemStream = Stream> + Send; -pub trait StateStoreReadIterStream = StateStoreIterItemStream + 'static; +pub trait StateStoreReadIter = StateStoreIter + 'static; pub trait StateStoreRead: StaticSendSync { - type IterStream: StateStoreReadIterStream; + type Iter: StateStoreReadIter; /// Point gets a value from the state store. /// The result is based on a snapshot corresponding to the given `epoch`. @@ -98,7 +186,7 @@ pub trait StateStoreRead: StaticSendSync { key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> impl Future> + Send + '_; + ) -> impl Future> + Send + '_; } pub trait StateStoreReadExt: StaticSendSync { @@ -129,12 +217,14 @@ impl StateStoreReadExt for S { if limit.is_some() { read_options.prefetch_options.prefetch = false; } + const MAX_INITIAL_CAP: usize = 1024; let limit = limit.unwrap_or(usize::MAX); - self.iter(key_range, epoch, read_options) - .await? - .take(limit) - .try_collect() - .await + let mut ret = Vec::with_capacity(min(limit, MAX_INITIAL_CAP)); + let mut iter = self.iter(key_range, epoch, read_options).await?; + while let Some((key, value)) = iter.try_next().await? { + ret.push((key.copy_into(), Bytes::copy_from_slice(value))) + } + Ok(ret) } } @@ -205,7 +295,7 @@ pub trait StateStore: StateStoreRead + StaticSendSync + Clone { /// written by itself. Each local state store is not `Clone`, and is owned by a streaming state /// table. pub trait LocalStateStore: StaticSendSync { - type IterStream<'a>: StateStoreIterItemStream + 'a; + type Iter<'a>: StateStoreIter + 'a; /// Point gets a value from the state store. /// The result is based on the latest written snapshot. @@ -224,7 +314,7 @@ pub trait LocalStateStore: StaticSendSync { &self, key_range: TableKeyRange, read_options: ReadOptions, - ) -> impl Future>> + Send + '_; + ) -> impl Future>> + Send + '_; /// Inserts a key-value entry associated with a given `epoch` into the state store. fn insert( diff --git a/src/storage/src/store_impl.rs b/src/storage/src/store_impl.rs index 50fe81d53ed54..5d57a8c4ba955 100644 --- a/src/storage/src/store_impl.rs +++ b/src/storage/src/store_impl.rs @@ -206,14 +206,12 @@ pub mod verify { use std::sync::Arc; use bytes::Bytes; - use futures::{pin_mut, TryStreamExt}; - use futures_async_stream::try_stream; use risingwave_common::buffer::Bitmap; use risingwave_hummock_sdk::key::{TableKey, TableKeyRange}; use risingwave_hummock_sdk::HummockReadEpoch; use tracing::log::warn; - use crate::error::{StorageError, StorageResult}; + use crate::error::StorageResult; use crate::hummock::HummockStorage; use crate::storage_value::StorageValue; use crate::store::*; @@ -251,7 +249,7 @@ pub mod verify { } impl StateStoreRead for VerifyStateStore { - type IterStream = impl StateStoreReadIterStream; + type Iter = impl StateStoreReadIter; async fn get( &self, @@ -278,7 +276,7 @@ pub mod verify { key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> impl Future> + '_ { + ) -> impl Future> + '_ { async move { let actual = self .actual @@ -290,34 +288,29 @@ pub mod verify { None }; - Ok(verify_stream(actual, expected)) + Ok(verify_iter(actual, expected)) } } } - #[try_stream(ok = StateStoreIterItem, error = StorageError)] - async fn verify_stream( - actual: impl StateStoreIterItemStream, - expected: Option, - ) { - pin_mut!(actual); - pin_mut!(expected); - let mut expected = expected.as_pin_mut(); - - loop { - let actual = actual.try_next().await?; - if let Some(expected) = expected.as_mut() { + impl StateStoreIter for VerifyStateStore { + async fn try_next(&mut self) -> StorageResult>> { + let actual = self.actual.try_next().await?; + if let Some(expected) = self.expected.as_mut() { let expected = expected.try_next().await?; assert_eq!(actual, expected); } - if let Some(actual) = actual { - yield actual; - } else { - break; - } + Ok(actual) } } + fn verify_iter( + actual: impl StateStoreIter, + expected: Option, + ) -> impl StateStoreIter { + VerifyStateStore { actual, expected } + } + impl StateStoreWrite for VerifyStateStore { fn ingest_batch( &self, @@ -348,7 +341,7 @@ pub mod verify { } impl LocalStateStore for VerifyStateStore { - type IterStream<'a> = impl StateStoreIterItemStream + 'a; + type Iter<'a> = impl StateStoreIter + 'a; // We don't verify `may_exist` across different state stores because // the return value of `may_exist` is implementation specific and may not @@ -379,7 +372,7 @@ pub mod verify { &self, key_range: TableKeyRange, read_options: ReadOptions, - ) -> impl Future>> + Send + '_ { + ) -> impl Future>> + Send + '_ { async move { let actual = self .actual @@ -391,7 +384,7 @@ pub mod verify { None }; - Ok(verify_stream(actual, expected)) + Ok(verify_iter(actual, expected)) } } @@ -697,8 +690,6 @@ pub mod boxed_state_store { use bytes::Bytes; use dyn_clone::{clone_trait_object, DynClone}; - use futures::stream::BoxStream; - use futures::StreamExt; use risingwave_common::buffer::Bitmap; use risingwave_hummock_sdk::key::{TableKey, TableKeyRange}; use risingwave_hummock_sdk::HummockReadEpoch; @@ -709,9 +700,31 @@ pub mod boxed_state_store { use crate::store_impl::AsHummock; use crate::StateStore; + #[async_trait::async_trait] + pub trait DynamicDispatchedStateStoreIter: Send { + async fn try_next(&mut self) -> StorageResult>>; + } + + #[async_trait::async_trait] + impl DynamicDispatchedStateStoreIter for I { + async fn try_next(&mut self) -> StorageResult>> { + self.try_next().await + } + } + + pub type BoxStateStoreIter<'a> = Box; + impl<'a> StateStoreIter for BoxStateStoreIter<'a> { + fn try_next( + &mut self, + ) -> impl Future>>> + Send + '_ + { + self.deref_mut().try_next() + } + } + // For StateStoreRead - pub type BoxStateStoreReadIterStream = BoxStream<'static, StorageResult>; + pub type BoxStateStoreReadIter = BoxStateStoreIter<'static>; #[async_trait::async_trait] pub trait DynamicDispatchedStateStoreRead: StaticSendSync { @@ -727,7 +740,7 @@ pub mod boxed_state_store { key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> StorageResult; + ) -> StorageResult; } #[async_trait::async_trait] @@ -746,13 +759,13 @@ pub mod boxed_state_store { key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> StorageResult { - Ok(self.iter(key_range, epoch, read_options).await?.boxed()) + ) -> StorageResult { + Ok(Box::new(self.iter(key_range, epoch, read_options).await?)) } } // For LocalStateStore - pub type BoxLocalStateStoreIterStream<'a> = BoxStream<'a, StorageResult>; + pub type BoxLocalStateStoreIterStream<'a> = BoxStateStoreIter<'a>; #[async_trait::async_trait] pub trait DynamicDispatchedLocalStateStore: StaticSendSync { async fn may_exist( @@ -820,7 +833,7 @@ pub mod boxed_state_store { key_range: TableKeyRange, read_options: ReadOptions, ) -> StorageResult> { - Ok(self.iter(key_range, read_options).await?.boxed()) + Ok(Box::new(self.iter(key_range, read_options).await?)) } fn insert( @@ -868,7 +881,7 @@ pub mod boxed_state_store { pub type BoxDynamicDispatchedLocalStateStore = Box; impl LocalStateStore for BoxDynamicDispatchedLocalStateStore { - type IterStream<'a> = BoxLocalStateStoreIterStream<'a>; + type Iter<'a> = BoxLocalStateStoreIterStream<'a>; fn may_exist( &self, @@ -890,7 +903,7 @@ pub mod boxed_state_store { &self, key_range: TableKeyRange, read_options: ReadOptions, - ) -> impl Future>> + Send + '_ { + ) -> impl Future>> + Send + '_ { self.deref().iter(key_range, read_options) } @@ -986,7 +999,7 @@ pub mod boxed_state_store { pub type BoxDynamicDispatchedStateStore = Box; impl StateStoreRead for BoxDynamicDispatchedStateStore { - type IterStream = BoxStateStoreReadIterStream; + type Iter = BoxStateStoreReadIter; fn get( &self, @@ -1002,7 +1015,7 @@ pub mod boxed_state_store { key_range: TableKeyRange, epoch: u64, read_options: ReadOptions, - ) -> impl Future> + '_ { + ) -> impl Future> + '_ { self.deref().iter(key_range, epoch, read_options) } } diff --git a/src/storage/src/table/batch_table/storage_table.rs b/src/storage/src/table/batch_table/storage_table.rs index b036180a4ab8b..cfed20ddb3ba2 100644 --- a/src/storage/src/table/batch_table/storage_table.rs +++ b/src/storage/src/table/batch_table/storage_table.rs @@ -45,7 +45,7 @@ use crate::hummock::CachePolicy; use crate::row_serde::row_serde_util::{serialize_pk, serialize_pk_with_vnode}; use crate::row_serde::value_serde::{ValueRowSerde, ValueRowSerdeNew}; use crate::row_serde::{find_columns_by_ids, ColumnMapping}; -use crate::store::{PrefetchOptions, ReadOptions}; +use crate::store::{PrefetchOptions, ReadOptions, StateStoreIter}; use crate::table::merge_sort::merge_sort; use crate::table::{KeyedRow, TableDistribution, TableIter}; use crate::StateStore; @@ -665,7 +665,7 @@ impl StorageTableInner { /// [`StorageTableInnerIterInner`] iterates on the storage table. struct StorageTableInnerIterInner { /// An iterator that returns raw bytes from storage. - iter: S::IterStream, + iter: S::Iter, mapping: Arc, @@ -725,18 +725,15 @@ impl StorageTableInnerIterInner { /// Yield a row with its primary key. #[try_stream(ok = KeyedRow, error = StorageError)] - async fn into_stream(self) { - use futures::TryStreamExt; - - // No need for table id and epoch. - let iter = self.iter.map_ok(|(k, v)| (k.user_key.table_key, v)); - futures::pin_mut!(iter); - while let Some((table_key, value)) = iter + async fn into_stream(mut self) { + while let Some((k, v)) = self + .iter .try_next() .verbose_instrument_await("storage_table_iter_next") .await? { - let full_row = self.row_deserializer.deserialize(&value)?; + let (table_key, value) = (k.user_key.table_key, v); + let full_row = self.row_deserializer.deserialize(value)?; let result_row_in_value = self .mapping .project(OwnedRow::new(full_row)) @@ -774,14 +771,15 @@ impl StorageTableInnerIterInner { } let row = OwnedRow::new(result_row_vec); + // TODO: may optimize the key clone yield KeyedRow { - vnode_prefixed_key: table_key, + vnode_prefixed_key: table_key.copy_into(), row, } } None => { yield KeyedRow { - vnode_prefixed_key: table_key, + vnode_prefixed_key: table_key.copy_into(), row: result_row_in_value, } } diff --git a/src/stream/src/common/log_store_impl/kv_log_store/reader.rs b/src/stream/src/common/log_store_impl/kv_log_store/reader.rs index 4b7c0bbec721e..3e9e9ad4833e8 100644 --- a/src/stream/src/common/log_store_impl/kv_log_store/reader.rs +++ b/src/stream/src/common/log_store_impl/kv_log_store/reader.rs @@ -20,7 +20,6 @@ use std::time::Duration; use anyhow::anyhow; use foyer::memory::CacheContext; use futures::future::{try_join_all, BoxFuture}; -use futures::stream::select_all; use futures::{FutureExt, TryFutureExt}; use risingwave_common::array::StreamChunk; use risingwave_common::buffer::Bitmap; @@ -113,7 +112,7 @@ pub struct KvLogStoreReader { first_write_epoch: Option, /// `Some` means consuming historical log data - state_store_stream: Option>>>, + state_store_stream: Option>>>, /// Store the future that attempts to read a flushed stream chunk. /// This is for cancellation safety. Since it is possible that the future of `next_item` @@ -184,7 +183,7 @@ impl KvLogStoreReader { fn read_persisted_log_store( &self, last_persisted_epoch: Option, - ) -> impl Future>>>> + Send + ) -> impl Future>>>> + Send { let range_start = if let Some(last_persisted_epoch) = last_persisted_epoch { // start from the next epoch of last_persisted_epoch @@ -341,7 +340,7 @@ impl LogReader for KvLogStoreReader { let table_id = self.table_id; let read_metrics = self.metrics.flushed_buffer_read_metrics.clone(); async move { - let streams = try_join_all(vnode_bitmap.iter_vnodes().map(|vnode| { + let iters = try_join_all(vnode_bitmap.iter_vnodes().map(|vnode| { let range_start = serde.serialize_log_store_pk(vnode, item_epoch, Some(start_seq_id)); let range_end = @@ -351,7 +350,7 @@ impl LogReader for KvLogStoreReader { // Use MAX EPOCH here because the epoch to consume may be below the safe // epoch async move { - Ok::<_, anyhow::Error>(Box::pin( + Ok::<_, anyhow::Error>( state_store .iter( (Included(range_start), Included(range_end)), @@ -365,15 +364,14 @@ impl LogReader for KvLogStoreReader { }, ) .await?, - )) + ) } })) .await?; - let combined_stream = select_all(streams); let chunk = serde .deserialize_stream_chunk( - combined_stream, + iters, start_seq_id, end_seq_id, item_epoch, diff --git a/src/stream/src/common/log_store_impl/kv_log_store/serde.rs b/src/stream/src/common/log_store_impl/kv_log_store/serde.rs index 38bb51c79b75c..67167f466a50b 100644 --- a/src/stream/src/common/log_store_impl/kv_log_store/serde.rs +++ b/src/stream/src/common/log_store_impl/kv_log_store/serde.rs @@ -25,12 +25,12 @@ use itertools::Itertools; use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::ColumnDesc; -use risingwave_common::estimate_size::EstimateSize; use risingwave_common::hash::VirtualNode; use risingwave_common::row::{OwnedRow, Row, RowExt}; use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_common::util::chunk_coalesce::DataChunkBuilder; use risingwave_common::util::row_serde::OrderedRowSerde; +use risingwave_common::util::value_encoding; use risingwave_common::util::value_encoding::{ BasicSerde, ValueRowDeserializer, ValueRowSerializer, }; @@ -38,11 +38,12 @@ use risingwave_connector::sink::log_store::LogStoreResult; use risingwave_hummock_sdk::key::{next_key, TableKey}; use risingwave_hummock_sdk::HummockEpoch; use risingwave_pb::catalog::Table; -use risingwave_storage::error::StorageError; +use risingwave_storage::error::StorageResult; use risingwave_storage::row_serde::row_serde_util::{serialize_pk, serialize_pk_with_vnode}; use risingwave_storage::row_serde::value_serde::ValueRowSerdeNew; -use risingwave_storage::store::StateStoreReadIterStream; +use risingwave_storage::store::{StateStoreIterExt, StateStoreReadIter}; use risingwave_storage::table::{compute_vnode, TableDistribution, SINGLETON_VNODE}; +use rw_futures_util::select_all; use crate::common::log_store_impl::kv_log_store::{ KvLogStorePkInfo, KvLogStoreReadMetrics, ReaderTruncationOffsetType, RowOpCodeType, SeqIdType, @@ -305,8 +306,8 @@ impl LogStoreRowSerde { } impl LogStoreRowSerde { - fn deserialize(&self, value_bytes: Bytes) -> LogStoreResult<(u64, LogStoreRowOp)> { - let row_data = self.row_serde.deserialize(&value_bytes)?; + fn deserialize(&self, value_bytes: &[u8]) -> value_encoding::Result<(u64, LogStoreRowOp)> { + let row_data = self.row_serde.deserialize(value_bytes)?; let payload_row = OwnedRow::new(row_data[self.pk_info.predefined_column_len()..].to_vec()); let epoch = Self::decode_epoch( @@ -354,24 +355,31 @@ impl LogStoreRowSerde { Ok((epoch, op)) } - pub(crate) async fn deserialize_stream_chunk( + pub(crate) async fn deserialize_stream_chunk( &self, - stream: impl StateStoreReadIterStream, + iters: impl IntoIterator, start_seq_id: SeqIdType, end_seq_id: SeqIdType, expected_epoch: u64, metrics: &KvLogStoreReadMetrics, ) -> LogStoreResult { - pin_mut!(stream); let size_bound = (end_seq_id - start_seq_id + 1) as usize; let mut data_chunk_builder = DataChunkBuilder::new(self.payload_schema.clone(), size_bound + 1); let mut ops = Vec::with_capacity(size_bound); let mut read_info = ReadInfo::new(); - while let Some((key, value)) = stream.try_next().await? { - read_info - .read_one_row(key.user_key.table_key.estimated_size() + value.estimated_size()); - match self.deserialize(value)? { + let stream = select_all(iters.into_iter().map(|iter| { + iter.into_stream(move |(key, value)| { + let row_size = key.user_key.table_key.len() + value.len(); + let output = self.deserialize(value)?; + Ok((row_size, output)) + }) + .boxed() + })); + pin_mut!(stream); + while let Some((row_size, output)) = stream.try_next().await? { + read_info.read_one_row(row_size); + match output { (epoch, LogStoreRowOp::Row { op, row }) => { if epoch != expected_epoch { return Err(anyhow!( @@ -435,7 +443,7 @@ pub(crate) enum KvLogStoreItem { type BoxPeekableLogStoreItemStream = Pin>>>; -struct LogStoreRowOpStream { +struct LogStoreRowOpStream { serde: LogStoreRowSerde, /// Streams that have not reached a barrier @@ -451,16 +459,16 @@ struct LogStoreRowOpStream { metrics: KvLogStoreReadMetrics, } -impl LogStoreRowOpStream { +impl LogStoreRowOpStream { pub(crate) fn new( - streams: Vec, + iters: Vec, serde: LogStoreRowSerde, metrics: KvLogStoreReadMetrics, ) -> Self { - assert!(!streams.is_empty()); + assert!(!iters.is_empty()); Self { serde: serde.clone(), - barrier_streams: streams + barrier_streams: iters .into_iter() .map(|s| Box::pin(deserialize_stream(s, serde.clone()).peekable())) .collect(), @@ -538,37 +546,35 @@ impl LogStoreRowOpStream { pub(crate) type LogStoreItemMergeStream = impl Stream>; -pub(crate) fn merge_log_store_item_stream( - streams: Vec, +pub(crate) fn merge_log_store_item_stream( + iters: Vec, serde: LogStoreRowSerde, chunk_size: usize, metrics: KvLogStoreReadMetrics, ) -> LogStoreItemMergeStream { - LogStoreRowOpStream::new(streams, serde, metrics).into_log_store_item_stream(chunk_size) + LogStoreRowOpStream::new(iters, serde, metrics).into_log_store_item_stream(chunk_size) } -type LogStoreItemStream = +type LogStoreItemStream = impl Stream> + Send; -fn deserialize_stream( - stream: S, +fn deserialize_stream( + iter: S, serde: LogStoreRowSerde, ) -> LogStoreItemStream { - stream.map( - move |result: Result<_, StorageError>| -> LogStoreResult<(u64, LogStoreRowOp, usize)> { - match result { - Ok((key, value)) => { - let read_size = - key.user_key.table_key.estimated_size() + value.estimated_size(); - let (epoch, op) = serde.deserialize(value)?; - Ok((epoch, op, read_size)) - } - Err(e) => Err(e.into()), - } + iter.into_stream( + move |(key, value)| -> StorageResult<(u64, LogStoreRowOp, usize)> { + let read_size = key.user_key.table_key.len() + value.len(); + let (epoch, op) = serde.deserialize(value)?; + Ok((epoch, op, read_size)) }, ) + .map_err(Into::into) + .boxed() + // The `boxed` call was unnecessary in usual build. But when doing cargo doc, + // rustc will panic in auto_trait.rs. May remove it when using future version of tool chain. } -impl LogStoreRowOpStream { +impl LogStoreRowOpStream { // Return Ok(false) means all streams have reach the end. async fn init(&mut self) -> LogStoreResult { match &self.stream_state { @@ -753,12 +759,13 @@ impl LogStoreRowOpStream { mod tests { use std::cmp::min; use std::future::poll_fn; + use std::iter::once; use std::sync::Arc; use std::task::Poll; use bytes::Bytes; use futures::stream::empty; - use futures::{pin_mut, stream, StreamExt, TryStreamExt}; + use futures::{pin_mut, stream, Stream, StreamExt, TryStreamExt}; use itertools::Itertools; use rand::prelude::SliceRandom; use rand::thread_rng; @@ -770,7 +777,10 @@ mod tests { use risingwave_common::util::chunk_coalesce::DataChunkBuilder; use risingwave_common::util::epoch::{test_epoch, EpochExt}; use risingwave_hummock_sdk::key::FullKey; - use risingwave_storage::store::StateStoreReadIterStream; + use risingwave_storage::error::StorageResult; + use risingwave_storage::store::{ + FromStreamStateStoreIter, StateStoreIterItem, StateStoreReadIter, + }; use risingwave_storage::table::DEFAULT_VNODE; use tokio::sync::oneshot; use tokio::sync::oneshot::Sender; @@ -834,7 +844,7 @@ mod tests { let key = remove_vnode_prefix(&key.0); assert!(key < delete_range_right1); serialized_keys.push(key); - let (decoded_epoch, row_op) = serde.deserialize(value).unwrap(); + let (decoded_epoch, row_op) = serde.deserialize(&value).unwrap(); assert_eq!(decoded_epoch, epoch); match row_op { LogStoreRowOp::Row { @@ -851,7 +861,7 @@ mod tests { let (key, encoded_barrier) = serde.serialize_barrier(epoch, DEFAULT_VNODE, false); let key = remove_vnode_prefix(&key.0); - match serde.deserialize(encoded_barrier).unwrap() { + match serde.deserialize(&encoded_barrier).unwrap() { (decoded_epoch, LogStoreRowOp::Barrier { is_checkpoint }) => { assert!(!is_checkpoint); assert_eq!(decoded_epoch, epoch); @@ -872,7 +882,7 @@ mod tests { assert!(key >= delete_range_right1); assert!(key < delete_range_right2); serialized_keys.push(key); - let (decoded_epoch, row_op) = serde.deserialize(value).unwrap(); + let (decoded_epoch, row_op) = serde.deserialize(&value).unwrap(); assert_eq!(decoded_epoch, epoch); match row_op { LogStoreRowOp::Row { @@ -889,7 +899,7 @@ mod tests { let (key, encoded_checkpoint_barrier) = serde.serialize_barrier(epoch, DEFAULT_VNODE, true); let key = remove_vnode_prefix(&key.0); - match serde.deserialize(encoded_checkpoint_barrier).unwrap() { + match serde.deserialize(&encoded_checkpoint_barrier).unwrap() { (decoded_epoch, LogStoreRowOp::Barrier { is_checkpoint }) => { assert_eq!(decoded_epoch, epoch); assert!(is_checkpoint); @@ -968,7 +978,7 @@ mod tests { tx.send(()).unwrap(); let chunk = serde .deserialize_stream_chunk( - stream, + once(FromStreamStateStoreIter::new(stream.boxed())), start_seq_id, end_seq_id, EPOCH1, @@ -988,7 +998,10 @@ mod tests { rows: Vec, epoch: u64, seq_id: &mut SeqIdType, - ) -> (impl StateStoreReadIterStream, Sender<()>) { + ) -> ( + impl Stream>, + Sender<()>, + ) { let (tx, rx) = oneshot::channel(); let row_data = ops .into_iter() @@ -1014,7 +1027,7 @@ mod tests { seq_id: &mut SeqIdType, base: i64, ) -> ( - impl StateStoreReadIterStream, + impl Stream>, oneshot::Sender<()>, oneshot::Sender<()>, Vec, @@ -1052,7 +1065,7 @@ mod tests { serde: LogStoreRowSerde, size: usize, ) -> ( - LogStoreRowOpStream, + LogStoreRowOpStream, Vec>>, Vec>>, Vec>, @@ -1067,6 +1080,7 @@ mod tests { for i in 0..size { let (s, t1, t2, op_list, row_list) = gen_single_test_stream(serde.clone(), &mut seq_id, (100 * i) as _); + let s = FromStreamStateStoreIter::new(s.boxed()); streams.push(s); tx1.push(Some(t1)); tx2.push(Some(t2)); @@ -1219,6 +1233,7 @@ mod tests { let mut seq_id = 1; let (stream, tx1, tx2, ops, rows) = gen_single_test_stream(serde.clone(), &mut seq_id, 0); + let stream = FromStreamStateStoreIter::new(stream.boxed()); const CHUNK_SIZE: usize = 3; @@ -1329,7 +1344,10 @@ mod tests { const CHUNK_SIZE: usize = 3; let stream = merge_log_store_item_stream( - vec![empty(), empty()], + vec![ + FromStreamStateStoreIter::new(empty()), + FromStreamStateStoreIter::new(empty()), + ], serde, CHUNK_SIZE, KvLogStoreReadMetrics::for_test(), diff --git a/src/stream/src/common/table/state_table.rs b/src/stream/src/common/table/state_table.rs index 0f1489adc8ff8..730e94c5345da 100644 --- a/src/stream/src/common/table/state_table.rs +++ b/src/stream/src/common/table/state_table.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use bytes::{BufMut, Bytes, BytesMut}; use either::Either; use foyer::memory::CacheContext; -use futures::{pin_mut, FutureExt, Stream, StreamExt}; +use futures::{pin_mut, FutureExt, Stream, StreamExt, TryStreamExt}; use futures_async_stream::for_await; use itertools::{izip, Itertools}; use risingwave_common::array::stream_record::Record; @@ -55,7 +55,7 @@ use risingwave_storage::row_serde::row_serde_util::{ use risingwave_storage::row_serde::value_serde::ValueRowSerde; use risingwave_storage::store::{ InitOptions, LocalStateStore, NewLocalOptions, OpConsistencyLevel, PrefetchOptions, - ReadOptions, SealCurrentEpochOptions, StateStoreIterItemStream, + ReadOptions, SealCurrentEpochOptions, StateStoreIter, StateStoreIterExt, }; use risingwave_storage::table::merge_sort::merge_sort; use risingwave_storage::table::{KeyedRow, TableDistribution}; @@ -1309,7 +1309,7 @@ where table_key_range: TableKeyRange, prefix_hint: Option, prefetch_options: PrefetchOptions, - ) -> StreamExecutorResult<::IterStream<'_>> { + ) -> StreamExecutorResult<::Iter<'_>> { let read_options = ReadOptions { prefix_hint, retention_seconds: self.table_option.retention_seconds, @@ -1394,7 +1394,7 @@ where // iterate over each vnode that the `StateTableInner` owns. vnode: VirtualNode, prefetch_options: PrefetchOptions, - ) -> StreamExecutorResult<::IterStream<'_>> { + ) -> StreamExecutorResult<::Iter<'_>> { let memcomparable_range = prefix_range_to_memcomparable(&self.pk_serde, pk_range); let memcomparable_range_with_vnode = prefixed_range_with_vnode(memcomparable_range, vnode); @@ -1459,19 +1459,17 @@ pub type KeyedRowStream<'a, S: StateStore, SD: ValueRowSerde + 'a> = impl Stream>> + 'a; fn deserialize_keyed_row_stream<'a>( - stream: impl StateStoreIterItemStream + 'a, + iter: impl StateStoreIter + 'a, deserializer: &'a impl ValueRowSerde, ) -> impl Stream>> + 'a { - stream.map(move |result| { - result - .map_err(StreamExecutorError::from) - .and_then(|(key, value)| { - Ok(KeyedRow::new( - key.user_key.table_key, - deserializer.deserialize(&value).map(OwnedRow::new)?, - )) - }) + iter.into_stream(move |(key, value)| { + Ok(KeyedRow::new( + // TODO: may avoid clone the key when key is not needed + key.user_key.table_key.copy_into(), + deserializer.deserialize(value).map(OwnedRow::new)?, + )) }) + .map_err(Into::into) } pub fn prefix_range_to_memcomparable( diff --git a/src/stream/src/executor/source/fs_source_executor.rs b/src/stream/src/executor/source/fs_source_executor.rs index 959c39f9724d5..94576d6a4c459 100644 --- a/src/stream/src/executor/source/fs_source_executor.rs +++ b/src/stream/src/executor/source/fs_source_executor.rs @@ -235,12 +235,12 @@ impl FsSourceExecutor { .collect_vec(); if !incompleted.is_empty() { - tracing::debug!(actor_id = self.actor_ctx.id, incompleted = ?incompleted, "take snapshot"); + tracing::debug!(incompleted = ?incompleted, "take snapshot"); core.split_state_store.set_states(incompleted).await? } if !completed.is_empty() { - tracing::debug!(actor_id = self.actor_ctx.id, completed = ?completed, "take snapshot"); + tracing::debug!(completed = ?completed, "take snapshot"); core.split_state_store.set_all_complete(completed).await? } // commit anyway, even if no message saved @@ -335,7 +335,7 @@ impl FsSourceExecutor { // init in-memory split states with persisted state if any self.stream_source_core.init_split_state(boot_state.clone()); let recover_state: ConnectorState = (!boot_state.is_empty()).then_some(boot_state); - tracing::debug!(actor_id = self.actor_ctx.id, state = ?recover_state, "start with state"); + tracing::debug!(state = ?recover_state, "start with state"); let source_chunk_reader = self .build_stream_source_reader(&source_desc, recover_state) diff --git a/src/stream/src/executor/source/source_executor.rs b/src/stream/src/executor/source/source_executor.rs index be9372efa70fc..36358bdcd372e 100644 --- a/src/stream/src/executor/source/source_executor.rs +++ b/src/stream/src/executor/source/source_executor.rs @@ -315,7 +315,7 @@ impl SourceExecutor { .collect_vec(); if !cache.is_empty() { - tracing::debug!(actor_id = self.actor_ctx.id, state = ?cache, "take snapshot"); + tracing::debug!(state = ?cache, "take snapshot"); core.split_state_store.set_states(cache).await?; } @@ -406,7 +406,7 @@ impl SourceExecutor { self.stream_source_core = Some(core); let recover_state: ConnectorState = (!boot_state.is_empty()).then_some(boot_state); - tracing::debug!(actor_id = self.actor_ctx.id, state = ?recover_state, "start with state"); + tracing::debug!(state = ?recover_state, "start with state"); let source_chunk_reader = self .build_stream_source_reader(&source_desc, recover_state) .instrument_await("source_build_reader") diff --git a/src/stream/src/executor/wrapper.rs b/src/stream/src/executor/wrapper.rs index dddd94da5ab73..74923928eaf6d 100644 --- a/src/stream/src/executor/wrapper.rs +++ b/src/stream/src/executor/wrapper.rs @@ -66,7 +66,7 @@ impl WrapperExecutor { // -- Shared wrappers -- // Await tree - let stream = trace::instrument_await_tree(info.clone(), actor_ctx.id, stream); + let stream = trace::instrument_await_tree(info.clone(), stream); // Schema check let stream = schema_check::schema_check(info.clone(), stream); diff --git a/src/stream/src/executor/wrapper/trace.rs b/src/stream/src/executor/wrapper/trace.rs index df594194966be..c95809f534728 100644 --- a/src/stream/src/executor/wrapper/trace.rs +++ b/src/stream/src/executor/wrapper/trace.rs @@ -21,7 +21,6 @@ use tracing::{Instrument, Span}; use crate::executor::error::StreamExecutorError; use crate::executor::{ActorContextRef, ExecutorInfo, Message, MessageStream}; -use crate::task::ActorId; /// Streams wrapped by `trace` will be traced with `tracing` spans and reported to `opentelemetry`. #[try_stream(ok = Message, error = StreamExecutorError)] @@ -34,12 +33,10 @@ pub async fn trace( let actor_id_str = actor_ctx.id.to_string(); let fragment_id_str = actor_ctx.fragment_id.to_string(); - let span_name = pretty_identity(&info.identity, actor_ctx.id); - let new_span = || { tracing::info_span!( "executor", - "otel.name" = span_name, + "otel.name" = info.identity, "message" = tracing::field::Empty, // record later "chunk_size" = tracing::field::Empty, // record later ) @@ -104,21 +101,13 @@ pub async fn trace( } } -fn pretty_identity(identity: &str, actor_id: ActorId) -> String { - format!("{} (actor {})", identity, actor_id) -} - /// Streams wrapped by `instrument_await_tree` will be able to print the spans of the /// executors in the stack trace through `await-tree`. #[try_stream(ok = Message, error = StreamExecutorError)] -pub async fn instrument_await_tree( - info: Arc, - actor_id: ActorId, - input: impl MessageStream, -) { +pub async fn instrument_await_tree(info: Arc, input: impl MessageStream) { pin_mut!(input); - let span: await_tree::Span = pretty_identity(&info.identity, actor_id).into(); + let span: await_tree::Span = info.identity.clone().into(); while let Some(message) = input .next() diff --git a/src/stream/src/from_proto/source/trad_source.rs b/src/stream/src/from_proto/source/trad_source.rs index 02fb68b1446c9..53661d87b20cc 100644 --- a/src/stream/src/from_proto/source/trad_source.rs +++ b/src/stream/src/from_proto/source/trad_source.rs @@ -19,8 +19,9 @@ use risingwave_common::catalog::{ }; use risingwave_connector::source::reader::desc::SourceDescBuilder; use risingwave_connector::source::{ - should_copy_to_format_encode_options, ConnectorProperties, SourceCtrlOpts, UPSTREAM_SOURCE_KEY, + should_copy_to_format_encode_options, SourceCtrlOpts, UPSTREAM_SOURCE_KEY, }; +use risingwave_connector::WithPropertiesExt; use risingwave_pb::catalog::PbStreamSourceInfo; use risingwave_pb::data::data_type::TypeName as PbTypeName; use risingwave_pb::plan_common::additional_column::ColumnType as AdditionalColumnType; @@ -208,8 +209,7 @@ impl ExecutorBuilder for SourceExecutorBuilder { .map(|c| c.to_ascii_lowercase()) .unwrap_or_default(); let is_fs_connector = FS_CONNECTORS.contains(&connector.as_str()); - let is_fs_v2_connector = - ConnectorProperties::is_new_fs_connector_hash_map(&source.with_properties); + let is_fs_v2_connector = source.with_properties.is_new_fs_connector(); if is_fs_connector { #[expect(deprecated)] diff --git a/src/stream/src/task/barrier_manager.rs b/src/stream/src/task/barrier_manager.rs index edbd660690049..6fef59b6740d1 100644 --- a/src/stream/src/task/barrier_manager.rs +++ b/src/stream/src/task/barrier_manager.rs @@ -13,20 +13,25 @@ // limitations under the License. use std::collections::{HashMap, HashSet}; +use std::future::pending; use std::sync::Arc; use std::time::Duration; use anyhow::{anyhow, Context}; -use futures::stream::FuturesUnordered; +use futures::stream::{BoxStream, FuturesUnordered}; use futures::StreamExt; +use itertools::Itertools; use parking_lot::Mutex; -use risingwave_pb::stream_service::barrier_complete_response::PbCreateMviewProgress; +use risingwave_pb::stream_service::barrier_complete_response::{ + GroupedSstableInfo, PbCreateMviewProgress, +}; use rw_futures_util::{pending_on_none, AttachedFuture}; use thiserror_ext::AsReport; use tokio::select; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::oneshot; use tokio::task::JoinHandle; +use tonic::Status; use self::managed_state::ManagedBarrierState; use crate::error::{IntoUnexpectedExit, StreamError, StreamResult}; @@ -41,9 +46,17 @@ mod tests; pub use progress::CreateMviewProgress; use risingwave_common::util::runtime::BackgroundShutdownRuntime; +use risingwave_hummock_sdk::table_stats::to_prost_table_stats_map; +use risingwave_hummock_sdk::LocalSstableInfo; use risingwave_pb::common::ActorInfo; use risingwave_pb::stream_plan; use risingwave_pb::stream_plan::barrier::BarrierKind; +use risingwave_pb::stream_service::streaming_control_stream_request::{InitRequest, Request}; +use risingwave_pb::stream_service::streaming_control_stream_response::InitResponse; +use risingwave_pb::stream_service::{ + streaming_control_stream_response, BarrierCompleteResponse, StreamingControlStreamRequest, + StreamingControlStreamResponse, +}; use risingwave_storage::store::SyncResult; use crate::executor::exchange::permit::Receiver; @@ -65,6 +78,71 @@ pub struct BarrierCompleteResult { pub create_mview_progress: Vec, } +pub(super) struct ControlStreamHandle { + #[expect(clippy::type_complexity)] + pair: Option<( + UnboundedSender>, + BoxStream<'static, Result>, + )>, +} + +impl ControlStreamHandle { + fn empty() -> Self { + Self { pair: None } + } + + pub(super) fn new( + sender: UnboundedSender>, + request_stream: BoxStream<'static, Result>, + ) -> Self { + Self { + pair: Some((sender, request_stream)), + } + } + + fn reset_stream_with_err(&mut self, err: Status) { + if let Some((sender, _)) = self.pair.take() { + warn!("control stream reset with: {:?}", err.as_report()); + if sender.send(Err(err)).is_err() { + warn!("failed to notify finish of control stream"); + } + } + } + + fn inspect_result(&mut self, result: StreamResult<()>) { + if let Err(e) = result { + self.reset_stream_with_err(Status::internal(format!("get error: {:?}", e.as_report()))); + } + } + + fn send_response(&mut self, response: StreamingControlStreamResponse) { + if let Some((sender, _)) = self.pair.as_ref() { + if sender.send(Ok(response)).is_err() { + self.pair = None; + warn!("fail to send response. control stream reset"); + } + } else { + debug!(?response, "control stream has been reset. ignore response"); + } + } + + async fn next_request(&mut self) -> StreamingControlStreamRequest { + if let Some((_, stream)) = &mut self.pair { + match stream.next().await { + Some(Ok(request)) => { + return request; + } + Some(Err(e)) => self.reset_stream_with_err(Status::internal(format!( + "failed to get request: {:?}", + e.as_report() + ))), + None => self.reset_stream_with_err(Status::internal("end of stream")), + } + } + pending().await + } +} + pub(super) enum LocalBarrierEvent { RegisterSender { actor_id: ActorId, @@ -84,19 +162,9 @@ pub(super) enum LocalBarrierEvent { } pub(super) enum LocalActorOperation { - InjectBarrier { - barrier: Barrier, - actor_ids_to_send: HashSet, - actor_ids_to_collect: HashSet, - result_sender: oneshot::Sender>, - }, - Reset { - prev_epoch: u64, - result_sender: oneshot::Sender<()>, - }, - AwaitEpochCompleted { - epoch: u64, - result_sender: oneshot::Sender>, + NewControlStream { + handle: ControlStreamHandle, + init_request: InitRequest, }, DropActors { actors: Vec, @@ -194,7 +262,7 @@ pub(super) struct LocalBarrierWorker { /// Record all unexpected exited actors. failure_actors: HashMap, - epoch_result_sender: HashMap>>, + control_stream_handle: ControlStreamHandle, pub(super) actor_manager: Arc, @@ -228,7 +296,7 @@ impl LocalBarrierWorker { actor_manager.env.state_store(), actor_manager.streaming_metrics.clone(), ), - epoch_result_sender: HashMap::default(), + control_stream_handle: ControlStreamHandle::empty(), actor_manager, actor_manager_state: StreamActorManagerState::new(), current_shared_context: shared_context, @@ -246,7 +314,8 @@ impl LocalBarrierWorker { self.handle_actor_created(sender, create_actors_result); } completed_epoch = self.state.next_completed_epoch() => { - self.on_epoch_completed(completed_epoch); + let result = self.on_epoch_completed(completed_epoch); + self.control_stream_handle.inspect_result(result); }, // Note: it's important to select in a biased way to ensure that // barrier event is handled before actor_op, because we must ensure @@ -261,10 +330,13 @@ impl LocalBarrierWorker { actor_op = actor_op_rx.recv() => { if let Some(actor_op) = actor_op { match actor_op { - LocalActorOperation::Reset { - result_sender, prev_epoch} => { - self.reset(prev_epoch).await; - let _ = result_sender.send(()); + LocalActorOperation::NewControlStream { handle, init_request } => { + self.control_stream_handle.reset_stream_with_err(Status::internal("control stream has been reset to a new one")); + self.reset(init_request.prev_epoch).await; + self.control_stream_handle = handle; + self.control_stream_handle.send_response(StreamingControlStreamResponse { + response: Some(streaming_control_stream_response::Response::Init(InitResponse {})) + }); } actor_op => { self.handle_actor_op(actor_op); @@ -274,7 +346,11 @@ impl LocalBarrierWorker { else { break; } - } + }, + request = self.control_stream_handle.next_request() => { + let result = self.handle_streaming_control_request(request); + self.control_stream_handle.inspect_result(result); + }, } } } @@ -291,6 +367,26 @@ impl LocalBarrierWorker { let _ = sender.send(result); } + fn handle_streaming_control_request( + &mut self, + request: StreamingControlStreamRequest, + ) -> StreamResult<()> { + match request.request.expect("should not be empty") { + Request::InjectBarrier(req) => { + let barrier = Barrier::from_protobuf(req.get_barrier().unwrap())?; + self.send_barrier( + &barrier, + req.actor_ids_to_send.into_iter().collect(), + req.actor_ids_to_collect.into_iter().collect(), + )?; + Ok(()) + } + Request::Init(_) => { + unreachable!() + } + } + } + fn handle_barrier_event(&mut self, event: LocalBarrierEvent) { match event { LocalBarrierEvent::RegisterSender { actor_id, sender } => { @@ -313,26 +409,8 @@ impl LocalBarrierWorker { fn handle_actor_op(&mut self, actor_op: LocalActorOperation) { match actor_op { - LocalActorOperation::InjectBarrier { - barrier, - actor_ids_to_send, - actor_ids_to_collect, - result_sender, - } => { - let result = self.send_barrier(&barrier, actor_ids_to_send, actor_ids_to_collect); - let _ = result_sender.send(result).inspect_err(|e| { - warn!(err=?e, "fail to send inject barrier result"); - }); - } - LocalActorOperation::Reset { .. } => { - unreachable!("Reset event should be handled separately in async context") - } - - LocalActorOperation::AwaitEpochCompleted { - epoch, - result_sender, - } => { - self.await_epoch_completed(epoch, result_sender); + LocalActorOperation::NewControlStream { .. } => { + unreachable!("NewControlStream event should be handled separately in async context") } LocalActorOperation::DropActors { actors, @@ -371,17 +449,55 @@ impl LocalBarrierWorker { // event handler impl LocalBarrierWorker { - fn on_epoch_completed(&mut self, epoch: u64) { - if let Some(sender) = self.epoch_result_sender.remove(&epoch) { - let result = self - .state - .pop_completed_epoch(epoch) - .expect("should exist") - .expect("should have completed"); - if sender.send(result).is_err() { - warn!(epoch, "fail to send epoch complete result"); - } - } + fn on_epoch_completed(&mut self, epoch: u64) -> StreamResult<()> { + let result = self + .state + .pop_completed_epoch(epoch) + .expect("should exist") + .expect("should have completed")?; + + let BarrierCompleteResult { + create_mview_progress, + sync_result, + } = result; + + let (synced_sstables, table_watermarks) = sync_result + .map(|sync_result| (sync_result.uncommitted_ssts, sync_result.table_watermarks)) + .unwrap_or_default(); + + let result = StreamingControlStreamResponse { + response: Some( + streaming_control_stream_response::Response::CompleteBarrier( + BarrierCompleteResponse { + request_id: "todo".to_string(), + status: None, + create_mview_progress, + synced_sstables: synced_sstables + .into_iter() + .map( + |LocalSstableInfo { + compaction_group_id, + sst_info, + table_stats, + }| GroupedSstableInfo { + compaction_group_id, + sst: Some(sst_info), + table_stats_map: to_prost_table_stats_map(table_stats), + }, + ) + .collect_vec(), + worker_id: self.actor_manager.env.worker_id(), + table_watermarks: table_watermarks + .into_iter() + .map(|(key, value)| (key.table_id, value.to_protobuf())) + .collect(), + }, + ), + ), + }; + + self.control_stream_handle.send_response(result); + Ok(()) } /// Register sender for source actors, used to send barriers. @@ -407,7 +523,6 @@ impl LocalBarrierWorker { ) -> StreamResult<()> { #[cfg(not(test))] { - use itertools::Itertools; // The barrier might be outdated and been injected after recovery in some certain extreme // scenarios. So some newly creating actors in the barrier are possibly not rebuilt during // recovery. Check it here and return an error here if some actors are not found to @@ -492,36 +607,6 @@ impl LocalBarrierWorker { Ok(()) } - /// Use `prev_epoch` to remove collect rx and return rx. - fn await_epoch_completed( - &mut self, - prev_epoch: u64, - result_sender: oneshot::Sender>, - ) { - match self.state.pop_completed_epoch(prev_epoch) { - Err(e) => { - let _ = result_sender.send(Err(e)); - } - Ok(Some(result)) => { - if result_sender.send(result).is_err() { - warn!(prev_epoch, "failed to send completed epoch result"); - } - } - Ok(None) => { - if let Some(prev_sender) = - self.epoch_result_sender.insert(prev_epoch, result_sender) - { - warn!(?prev_epoch, "duplicate await_collect_barrier on epoch"); - let _ = prev_sender.send(Err(anyhow!( - "duplicate await_collect_barrier on epoch {}", - prev_epoch - ) - .into())); - } - } - } - } - /// Reset all internal states. pub(super) fn reset_state(&mut self) { *self = Self::new(self.actor_manager.clone()); @@ -538,12 +623,14 @@ impl LocalBarrierWorker { async fn notify_failure(&mut self, actor_id: ActorId, err: StreamError) { self.add_failure(actor_id, err.clone()); let root_err = self.try_find_root_failure(err).await; - for fail_epoch in self.state.epochs_await_on_actor(actor_id) { - if let Some(result_sender) = self.epoch_result_sender.remove(&fail_epoch) { - if result_sender.send(Err(root_err.clone())).is_err() { - warn!(fail_epoch, actor_id, err = %root_err.as_report(), "fail to notify actor failure"); - } - } + let failed_epochs = self.state.epochs_await_on_actor(actor_id).collect_vec(); + if !failed_epochs.is_empty() { + self.control_stream_handle + .reset_stream_with_err(Status::internal(format!( + "failed to collect barrier. epoch: {:?}, err: {:?}", + failed_epochs, + root_err.as_report() + ))); } } @@ -648,40 +735,7 @@ impl LocalBarrierManager { pub fn register_sender(&self, actor_id: ActorId, sender: UnboundedSender) { self.send_event(LocalBarrierEvent::RegisterSender { actor_id, sender }); } -} -impl EventSender { - /// Broadcast a barrier to all senders. Save a receiver which will get notified when this - /// barrier is finished, in managed mode. - pub(super) async fn send_barrier( - &self, - barrier: Barrier, - actor_ids_to_send: impl IntoIterator, - actor_ids_to_collect: impl IntoIterator, - ) -> StreamResult<()> { - self.send_and_await(move |result_sender| LocalActorOperation::InjectBarrier { - barrier, - actor_ids_to_send: actor_ids_to_send.into_iter().collect(), - actor_ids_to_collect: actor_ids_to_collect.into_iter().collect(), - result_sender, - }) - .await? - } - - /// Use `prev_epoch` to remove collect rx and return rx. - pub(super) async fn await_epoch_completed( - &self, - prev_epoch: u64, - ) -> StreamResult { - self.send_and_await(|result_sender| LocalActorOperation::AwaitEpochCompleted { - epoch: prev_epoch, - result_sender, - }) - .await? - } -} - -impl LocalBarrierManager { /// When a [`crate::executor::StreamConsumer`] (typically [`crate::executor::DispatchExecutor`]) get a barrier, it should report /// and collect this barrier with its own `actor_id` using this function. pub fn collect(&self, actor_id: ActorId, barrier: &Barrier) { @@ -727,7 +781,7 @@ pub fn try_find_root_actor_failure<'a>( #[cfg(test)] impl LocalBarrierManager { - pub(super) async fn spawn_for_test() -> (EventSender, Self) { + pub(super) fn spawn_for_test() -> EventSender { use std::sync::atomic::AtomicU64; let (tx, rx) = unbounded_channel(); let _join_handle = LocalBarrierWorker::spawn( @@ -737,13 +791,7 @@ impl LocalBarrierManager { Arc::new(AtomicU64::new(0)), rx, ); - let sender = EventSender(tx); - let context = sender - .send_and_await(LocalActorOperation::GetCurrentSharedContext) - .await - .unwrap(); - - (sender, context.local_barrier_manager.clone()) + EventSender(tx) } pub fn for_test() -> Self { diff --git a/src/stream/src/task/barrier_manager/tests.rs b/src/stream/src/task/barrier_manager/tests.rs index ae248facc0e5f..2ec421661a14a 100644 --- a/src/stream/src/task/barrier_manager/tests.rs +++ b/src/stream/src/task/barrier_manager/tests.rs @@ -17,15 +17,43 @@ use std::iter::once; use std::pin::pin; use std::task::Poll; +use assert_matches::assert_matches; +use futures::future::join_all; +use futures::FutureExt; use itertools::Itertools; use risingwave_common::util::epoch::test_epoch; +use risingwave_pb::stream_service::{streaming_control_stream_request, InjectBarrierRequest}; use tokio::sync::mpsc::unbounded_channel; +use tokio_stream::wrappers::UnboundedReceiverStream; use super::*; #[tokio::test] async fn test_managed_barrier_collection() -> StreamResult<()> { - let (actor_op_tx, manager) = LocalBarrierManager::spawn_for_test().await; + let actor_op_tx = LocalBarrierManager::spawn_for_test(); + + let (request_tx, request_rx) = unbounded_channel(); + let (response_tx, mut response_rx) = unbounded_channel(); + + actor_op_tx.send_event(LocalActorOperation::NewControlStream { + handle: ControlStreamHandle::new( + response_tx, + UnboundedReceiverStream::new(request_rx).boxed(), + ), + init_request: InitRequest { prev_epoch: 0 }, + }); + + assert_matches!( + response_rx.recv().await.unwrap().unwrap().response.unwrap(), + streaming_control_stream_response::Response::Init(_) + ); + + let context = actor_op_tx + .send_and_await(LocalActorOperation::GetCurrentSharedContext) + .await + .unwrap(); + + let manager = &context.local_barrier_manager; let register_sender = |actor_id: u32| { let (barrier_tx, barrier_rx) = unbounded_channel(); @@ -47,21 +75,35 @@ async fn test_managed_barrier_collection() -> StreamResult<()> { let barrier = Barrier::new_test_barrier(curr_epoch); let epoch = barrier.epoch.prev; - actor_op_tx - .send_barrier(barrier.clone(), actor_ids.clone(), actor_ids) - .await + request_tx + .send(Ok(StreamingControlStreamRequest { + request: Some(streaming_control_stream_request::Request::InjectBarrier( + InjectBarrierRequest { + request_id: "".to_string(), + barrier: Some(barrier.to_protobuf()), + actor_ids_to_send: actor_ids.clone(), + actor_ids_to_collect: actor_ids, + }, + )), + })) .unwrap(); - // Collect barriers from actors - let collected_barriers = rxs - .iter_mut() - .map(|(actor_id, rx)| { - let barrier = rx.try_recv().unwrap(); - assert_eq!(barrier.epoch.prev, epoch); - (*actor_id, barrier) - }) - .collect_vec(); - let mut await_epoch_future = pin!(actor_op_tx.await_epoch_completed(epoch)); + // Collect barriers from actors + let collected_barriers = join_all(rxs.iter_mut().map(|(actor_id, rx)| async move { + let barrier = rx.recv().await.unwrap(); + assert_eq!(barrier.epoch.prev, epoch); + (*actor_id, barrier) + })) + .await; + + let mut await_epoch_future = pin!(response_rx.recv().map(|result| { + let resp: StreamingControlStreamResponse = result.unwrap().unwrap(); + let resp = resp.response.unwrap(); + match resp { + streaming_control_stream_response::Response::CompleteBarrier(_complete_barrier) => {} + _ => unreachable!(), + } + })); // Report to local barrier manager for (i, (actor_id, barrier)) in collected_barriers.into_iter().enumerate() { @@ -77,7 +119,30 @@ async fn test_managed_barrier_collection() -> StreamResult<()> { #[tokio::test] async fn test_managed_barrier_collection_before_send_request() -> StreamResult<()> { - let (actor_op_tx, manager) = LocalBarrierManager::spawn_for_test().await; + let actor_op_tx = LocalBarrierManager::spawn_for_test(); + + let (request_tx, request_rx) = unbounded_channel(); + let (response_tx, mut response_rx) = unbounded_channel(); + + actor_op_tx.send_event(LocalActorOperation::NewControlStream { + handle: ControlStreamHandle::new( + response_tx, + UnboundedReceiverStream::new(request_rx).boxed(), + ), + init_request: InitRequest { prev_epoch: 0 }, + }); + + assert_matches!( + response_rx.recv().await.unwrap().unwrap().response.unwrap(), + streaming_control_stream_response::Response::Init(_) + ); + + let context = actor_op_tx + .send_and_await(LocalActorOperation::GetCurrentSharedContext) + .await + .unwrap(); + + let manager = &context.local_barrier_manager; let register_sender = |actor_id: u32| { let (barrier_tx, barrier_rx) = unbounded_channel(); @@ -109,23 +174,35 @@ async fn test_managed_barrier_collection_before_send_request() -> StreamResult<( // Collect a barrier before sending manager.collect(extra_actor_id, &barrier); - // Send the barrier to all actors - actor_op_tx - .send_barrier(barrier.clone(), actor_ids_to_send, actor_ids_to_collect) - .await + request_tx + .send(Ok(StreamingControlStreamRequest { + request: Some(streaming_control_stream_request::Request::InjectBarrier( + InjectBarrierRequest { + request_id: "".to_string(), + barrier: Some(barrier.to_protobuf()), + actor_ids_to_send, + actor_ids_to_collect, + }, + )), + })) .unwrap(); // Collect barriers from actors - let collected_barriers = rxs - .iter_mut() - .map(|(actor_id, rx)| { - let barrier = rx.try_recv().unwrap(); - assert_eq!(barrier.epoch.prev, epoch); - (*actor_id, barrier) - }) - .collect_vec(); - - let mut await_epoch_future = pin!(actor_op_tx.await_epoch_completed(epoch)); + let collected_barriers = join_all(rxs.iter_mut().map(|(actor_id, rx)| async move { + let barrier = rx.recv().await.unwrap(); + assert_eq!(barrier.epoch.prev, epoch); + (*actor_id, barrier) + })) + .await; + + let mut await_epoch_future = pin!(response_rx.recv().map(|result| { + let resp: StreamingControlStreamResponse = result.unwrap().unwrap(); + let resp = resp.response.unwrap(); + match resp { + streaming_control_stream_response::Response::CompleteBarrier(_complete_barrier) => {} + _ => unreachable!(), + } + })); // Report to local barrier manager for (i, (actor_id, barrier)) in collected_barriers.into_iter().enumerate() { diff --git a/src/stream/src/task/stream_manager.rs b/src/stream/src/task/stream_manager.rs index 602a6f49e2385..9045adc1263ce 100644 --- a/src/stream/src/task/stream_manager.rs +++ b/src/stream/src/task/stream_manager.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use anyhow::anyhow; use async_recursion::async_recursion; +use futures::stream::BoxStream; use futures::FutureExt; use itertools::Itertools; use parking_lot::Mutex; @@ -33,25 +34,32 @@ use risingwave_pb::common::ActorInfo; use risingwave_pb::stream_plan; use risingwave_pb::stream_plan::stream_node::NodeBody; use risingwave_pb::stream_plan::{StreamActor, StreamNode}; +use risingwave_pb::stream_service::streaming_control_stream_request::InitRequest; +use risingwave_pb::stream_service::{ + StreamingControlStreamRequest, StreamingControlStreamResponse, +}; use risingwave_storage::monitor::HummockTraceFutureExt; use risingwave_storage::{dispatch_state_store, StateStore}; use rw_futures_util::AttachedFuture; use thiserror_ext::AsReport; -use tokio::sync::mpsc::unbounded_channel; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::oneshot; use tokio::task::JoinHandle; +use tonic::Status; -use super::{unique_executor_id, unique_operator_id, BarrierCompleteResult}; +use super::{unique_executor_id, unique_operator_id}; use crate::error::StreamResult; use crate::executor::exchange::permit::Receiver; use crate::executor::monitor::StreamingMetrics; use crate::executor::subtask::SubtaskHandle; use crate::executor::{ - Actor, ActorContext, ActorContextRef, Barrier, DispatchExecutor, DispatcherImpl, Executor, - ExecutorInfo, WrapperExecutor, + Actor, ActorContext, ActorContextRef, DispatchExecutor, DispatcherImpl, Executor, ExecutorInfo, + WrapperExecutor, }; use crate::from_proto::create_executor; -use crate::task::barrier_manager::{EventSender, LocalActorOperation, LocalBarrierWorker}; +use crate::task::barrier_manager::{ + ControlStreamHandle, EventSender, LocalActorOperation, LocalBarrierWorker, +}; use crate::task::{ ActorId, FragmentId, LocalBarrierManager, SharedContext, StreamActorManager, StreamActorManagerState, StreamEnvironment, UpDownActorIds, @@ -199,22 +207,19 @@ impl LocalStreamManager { } } - /// Broadcast a barrier to all senders. Save a receiver in barrier manager - pub async fn send_barrier( + /// Receive a new control stream request from meta. Notify the barrier worker to reset the CN and use the new control stream + /// to receive control message from meta + pub fn handle_new_control_stream( &self, - barrier: Barrier, - actor_ids_to_send: impl IntoIterator, - actor_ids_to_collect: impl IntoIterator, - ) -> StreamResult<()> { + sender: UnboundedSender>, + request_stream: BoxStream<'static, Result>, + init_request: InitRequest, + ) { self.actor_op_tx - .send_barrier(barrier, actor_ids_to_send, actor_ids_to_collect) - .await - } - - /// Use `epoch` to find collect rx. And wait for all actor to be collected before - /// returning. - pub async fn collect_barrier(&self, prev_epoch: u64) -> StreamResult { - self.actor_op_tx.await_epoch_completed(prev_epoch).await + .send_event(LocalActorOperation::NewControlStream { + handle: ControlStreamHandle::new(sender, request_stream), + init_request, + }) } /// Drop the resources of the given actors. @@ -227,17 +232,6 @@ impl LocalStreamManager { .await } - /// Force stop all actors on this worker, and then drop their resources. - pub async fn reset(&self, prev_epoch: u64) { - self.actor_op_tx - .send_and_await(|result_sender| LocalActorOperation::Reset { - result_sender, - prev_epoch, - }) - .await - .expect("should receive reset") - } - pub async fn update_actors(&self, actors: Vec) -> StreamResult<()> { self.actor_op_tx .send_and_await(|result_sender| LocalActorOperation::UpdateActors { diff --git a/src/tests/compaction_test/src/compaction_test_runner.rs b/src/tests/compaction_test/src/compaction_test_runner.rs index 71d3f5e5a80af..9132c0e0735ab 100644 --- a/src/tests/compaction_test/src/compaction_test_runner.rs +++ b/src/tests/compaction_test/src/compaction_test_runner.rs @@ -24,7 +24,6 @@ use anyhow::anyhow; use bytes::{BufMut, Bytes, BytesMut}; use clap::Parser; use foyer::memory::CacheContext; -use futures::TryStreamExt; use risingwave_common::catalog::TableId; use risingwave_common::config::{ extract_storage_memory_config, load_config, MetaConfig, NoOverride, @@ -44,7 +43,7 @@ use risingwave_storage::monitor::{ }; use risingwave_storage::opts::StorageOpts; use risingwave_storage::store::{ReadOptions, StateStoreRead}; -use risingwave_storage::{StateStore, StateStoreImpl}; +use risingwave_storage::{StateStore, StateStoreImpl, StateStoreIter}; const SST_ID_SHIFT_COUNT: u32 = 1000000; const CHECKPOINT_FREQ_FOR_REPLAY: u64 = 99999999; @@ -603,8 +602,7 @@ async fn poll_compaction_tasks_status( (compaction_ok, cur_version) } -type StateStoreIterType = - Pin as StateStoreRead>::IterStream>>; +type StateStoreIterType = Pin as StateStoreRead>::Iter>>; async fn open_hummock_iters( hummock: &MonitoredStateStore, @@ -661,8 +659,6 @@ pub async fn check_compaction_results( let mut expect_cnt = 0; let mut actual_cnt = 0; - futures::pin_mut!(expect_iter); - futures::pin_mut!(actual_iter); while let Some(kv_expect) = expect_iter.try_next().await? { expect_cnt += 1; let ret = actual_iter.try_next().await?; diff --git a/src/tests/compaction_test/src/delete_range_runner.rs b/src/tests/compaction_test/src/delete_range_runner.rs index 6986300699437..46052015ac688 100644 --- a/src/tests/compaction_test/src/delete_range_runner.rs +++ b/src/tests/compaction_test/src/delete_range_runner.rs @@ -21,7 +21,6 @@ use std::time::{Duration, SystemTime}; use bytes::Bytes; use foyer::memory::CacheContext; -use futures::StreamExt; use rand::rngs::StdRng; use rand::{RngCore, SeedableRng}; use risingwave_common::catalog::TableId; @@ -61,7 +60,7 @@ use risingwave_storage::opts::StorageOpts; use risingwave_storage::store::{ LocalStateStore, NewLocalOptions, PrefetchOptions, ReadOptions, SealCurrentEpochOptions, }; -use risingwave_storage::StateStore; +use risingwave_storage::{StateStore, StateStoreIter}; use crate::CompactionTestOpts; pub fn start_delete_range(opts: CompactionTestOpts) -> Pin + Send>> { @@ -473,10 +472,10 @@ impl NormalState { .await .unwrap(),); let mut ret = vec![]; - while let Some(item) = iter.next().await { - let (full_key, val) = item.unwrap(); - let tkey = full_key.user_key.table_key.0.clone(); - ret.push((tkey, val)); + while let Some(item) = iter.try_next().await.unwrap() { + let (full_key, val) = item; + let tkey = Bytes::copy_from_slice(full_key.user_key.table_key.0); + ret.push((tkey, Bytes::copy_from_slice(val))); } ret } @@ -485,29 +484,31 @@ impl NormalState { #[async_trait::async_trait] impl CheckState for NormalState { async fn delete_range(&mut self, left: &[u8], right: &[u8]) { - let mut iter = Box::pin( - self.storage - .iter( - ( - Bound::Included(Bytes::copy_from_slice(left)).map(TableKey), - Bound::Excluded(Bytes::copy_from_slice(right)).map(TableKey), - ), - ReadOptions { - ignore_range_tombstone: true, - table_id: self.table_id, - read_version_from_backup: false, - prefetch_options: PrefetchOptions::default(), - cache_policy: CachePolicy::Fill(CacheContext::Default), - ..Default::default() - }, - ) - .await - .unwrap(), - ); + let mut iter = self + .storage + .iter( + ( + Bound::Included(Bytes::copy_from_slice(left)).map(TableKey), + Bound::Excluded(Bytes::copy_from_slice(right)).map(TableKey), + ), + ReadOptions { + ignore_range_tombstone: true, + table_id: self.table_id, + read_version_from_backup: false, + prefetch_options: PrefetchOptions::default(), + cache_policy: CachePolicy::Fill(CacheContext::Default), + ..Default::default() + }, + ) + .await + .unwrap(); let mut delete_item = Vec::new(); - while let Some(item) = iter.next().await { - let (full_key, value) = item.unwrap(); - delete_item.push((full_key.user_key.table_key, value)); + while let Some(item) = iter.try_next().await.unwrap() { + let (full_key, value) = item; + delete_item.push(( + full_key.user_key.table_key.copy_into(), + Bytes::copy_from_slice(value), + )); } drop(iter); for (key, value) in delete_item { diff --git a/src/tests/simulation/src/slt.rs b/src/tests/simulation/src/slt.rs index d44d69219bbaf..847c7c60c7cd2 100644 --- a/src/tests/simulation/src/slt.rs +++ b/src/tests/simulation/src/slt.rs @@ -26,6 +26,9 @@ use crate::client::RisingWave; use crate::cluster::{Cluster, KillOpts}; use crate::utils::TimedExt; +// retry a maximum times until it succeed +const MAX_RETRY: usize = 5; + fn is_create_table_as(sql: &str) -> bool { let parts: Vec = sql.split_whitespace().map(|s| s.to_lowercase()).collect(); @@ -271,10 +274,13 @@ pub async fn run_slt_task( }) .await { + let err_string = err.to_string(); // cluster could be still under recovering if killed before, retry if // meets `no reader for dml in table with id {}`. - let should_retry = - err.to_string().contains("no reader for dml in table") && i < 5; + let should_retry = (err_string.contains("no reader for dml in table") + || err_string + .contains("error reading a body from connection: broken pipe")) + || err_string.contains("failed to inject barrier") && i < MAX_RETRY; if !should_retry { panic!("{}", err); } @@ -302,8 +308,6 @@ pub async fn run_slt_task( None }; - // retry up to 5 times until it succeed - let max_retry = 5; for i in 0usize.. { tracing::debug!(iteration = i, "retry count"); let delay = Duration::from_secs(1 << i); @@ -348,7 +352,7 @@ pub async fn run_slt_task( ?err, "failed to wait for background mv to finish creating" ); - if i >= max_retry { + if i >= MAX_RETRY { panic!("failed to run test after retry {i} times, error={err:#?}"); } continue; @@ -379,8 +383,8 @@ pub async fn run_slt_task( break } - // Keep i >= max_retry for other errors. Since these errors indicate that the MV might not yet be created. - _ if i >= max_retry => { + // Keep i >= MAX_RETRY for other errors. Since these errors indicate that the MV might not yet be created. + _ if i >= MAX_RETRY => { panic!("failed to run test after retry {i} times: {e}") } SqlCmd::CreateMaterializedView { ref name } @@ -404,7 +408,7 @@ pub async fn run_slt_task( ?err, "failed to wait for background mv to finish creating" ); - if i >= max_retry { + if i >= MAX_RETRY { panic!("failed to run test after retry {i} times, error={err:#?}"); } continue;