diff --git a/Cargo.lock b/Cargo.lock index fb7a1d3799621..44e04444fa03d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10117,6 +10117,7 @@ dependencies = [ "assert_matches", "async-recursion", "async-trait", + "bytes", "criterion", "either", "foyer", @@ -10130,9 +10131,11 @@ dependencies = [ "madsim-tokio", "madsim-tonic", "memcomparable", + "opendal", "parking_lot 0.12.1", "paste", "prometheus", + "prost 0.12.1", "rand", "risingwave_common", "risingwave_common_estimate_size", @@ -10155,6 +10158,8 @@ dependencies = [ "tokio-stream", "tokio-util", "tracing", + "twox-hash", + "uuid", "workspace-hack", ] diff --git a/ci/scripts/run-backfill-tests.sh b/ci/scripts/run-backfill-tests.sh index b641065853337..ac552cfcdcdd0 100755 --- a/ci/scripts/run-backfill-tests.sh +++ b/ci/scripts/run-backfill-tests.sh @@ -23,7 +23,7 @@ TEST_DIR=$PWD/e2e_test BACKGROUND_DDL_DIR=$TEST_DIR/background_ddl COMMON_DIR=$BACKGROUND_DDL_DIR/common -CLUSTER_PROFILE='ci-1cn-1fe-kafka-with-recovery' +CLUSTER_PROFILE='ci-1cn-1fe-user-kafka-with-recovery' echo "--- Configuring cluster profiles" if [[ -n "${BUILDKITE:-}" ]]; then echo "Running in buildkite" @@ -187,14 +187,14 @@ test_sink_backfill_recovery() { # Restart restart_cluster - sleep 3 + sleep 5 # Sink back into rw run_sql "CREATE TABLE table_kafka (v1 int primary key) WITH ( connector = 'kafka', topic = 's_kafka', - properties.bootstrap.server = 'localhost:29092', + properties.bootstrap.server = 'message_queue:29092', ) FORMAT DEBEZIUM ENCODE JSON;" sleep 10 diff --git a/ci/workflows/main-cron.yml b/ci/workflows/main-cron.yml index e097d2d587994..14f3a23161c80 100644 --- a/ci/workflows/main-cron.yml +++ b/ci/workflows/main-cron.yml @@ -704,7 +704,7 @@ steps: - "build" plugins: - docker-compose#v5.1.0: - run: rw-build-env + run: source-test-env config: ci/docker-compose.yml mount-buildkite-agent: true - ./ci/plugins/upload-failure-logs diff --git a/ci/workflows/pull-request.yml b/ci/workflows/pull-request.yml index 9ced38891a75b..3a4e87307231d 100644 --- a/ci/workflows/pull-request.yml +++ b/ci/workflows/pull-request.yml @@ -700,7 +700,7 @@ steps: - "build" plugins: - docker-compose#v5.1.0: - run: rw-build-env + run: source-test-env config: ci/docker-compose.yml mount-buildkite-agent: true - ./ci/plugins/upload-failure-logs diff --git a/e2e_test/backfill/sink/create_sink.slt b/e2e_test/backfill/sink/create_sink.slt index bc9fba04da5c8..017eb8e693de2 100644 --- a/e2e_test/backfill/sink/create_sink.slt +++ b/e2e_test/backfill/sink/create_sink.slt @@ -20,7 +20,7 @@ from t x join t y on x.v1 = y.v1 with ( connector='kafka', - properties.bootstrap.server='localhost:29092', + properties.bootstrap.server='message_queue:29092', topic='s_kafka', primary_key='v1', allow.auto.create.topics=true, diff --git a/e2e_test/error_ui/extended/main.slt b/e2e_test/error_ui/extended/main.slt index 67db6ccf3393f..eb6669ce89d71 100644 --- a/e2e_test/error_ui/extended/main.slt +++ b/e2e_test/error_ui/extended/main.slt @@ -4,8 +4,9 @@ selet 1; db error: ERROR: Failed to prepare the statement Caused by: - sql parser error: Expected an SQL statement, found: selet at line:1, column:6 -Near "selet" + sql parser error: expected an SQL statement, found: selet at line 1, column 1 +LINE 1: selet 1; + ^ query error diff --git a/e2e_test/error_ui/simple/main.slt b/e2e_test/error_ui/simple/main.slt index f77aa3dd9dd6d..c569560af631a 100644 --- a/e2e_test/error_ui/simple/main.slt +++ b/e2e_test/error_ui/simple/main.slt @@ -4,8 +4,9 @@ selet 1; db error: ERROR: Failed to run the query Caused by: - sql parser error: Expected an SQL statement, found: selet at line:1, column:6 -Near "selet" + sql parser error: expected an SQL statement, found: selet at line 1, column 1 +LINE 1: selet 1; + ^ statement error diff --git a/e2e_test/source/basic/datagen.slt b/e2e_test/source/basic/datagen.slt index 91c51f624a1ea..d89850538a3fd 100644 --- a/e2e_test/source/basic/datagen.slt +++ b/e2e_test/source/basic/datagen.slt @@ -186,9 +186,9 @@ statement ok drop table s1; # Do NOT allow With clause to contain a comma only. -statement error Expected identifier.* +statement error expected identifier.* create table s1 (v1 int) with (,) FORMAT PLAIN ENCODE JSON; # Do NOT allow an empty With clause. -statement error Expected identifier.* +statement error expected identifier.* create table s1 (v1 int) with () FORMAT PLAIN ENCODE JSON; diff --git a/e2e_test/source/basic/ddl.slt b/e2e_test/source/basic/ddl.slt index 402cf129b86ba..33b79dfda9b67 100644 --- a/e2e_test/source/basic/ddl.slt +++ b/e2e_test/source/basic/ddl.slt @@ -4,8 +4,9 @@ create source s; db error: ERROR: Failed to run the query Caused by: - sql parser error: Expected description of the format, found: ; at line:1, column:17 -Near "create source s" + sql parser error: expected description of the format, found: ; at line 1, column 16 +LINE 1: create source s; + ^ statement error missing WITH clause diff --git a/e2e_test/source/basic/old_row_format_syntax/datagen.slt b/e2e_test/source/basic/old_row_format_syntax/datagen.slt index 267ae8eff4c66..6467b624d0dbb 100644 --- a/e2e_test/source/basic/old_row_format_syntax/datagen.slt +++ b/e2e_test/source/basic/old_row_format_syntax/datagen.slt @@ -182,9 +182,9 @@ statement ok drop table s1; # Do NOT allow With clause to contain a comma only. -statement error Expected identifier.* +statement error expected identifier.* create table s1 (v1 int) with (,) ROW FORMAT JSON; # Do NOT allow an empty With clause. -statement error Expected identifier.* +statement error expected identifier.* create table s1 (v1 int) with () ROW FORMAT JSON; diff --git a/e2e_test/source/cdc/cdc.share_stream.slt b/e2e_test/source/cdc/cdc.share_stream.slt index e07a0c1d773ef..d30d9c53dc6fe 100644 --- a/e2e_test/source/cdc/cdc.share_stream.slt +++ b/e2e_test/source/cdc/cdc.share_stream.slt @@ -27,10 +27,9 @@ statement error Should not create MATERIALIZED VIEW or SELECT directly on shared create materialized view mv as select * from mysql_mytest; statement error The upstream table name must contain database name prefix* -create table products_test ( id INT, +create table products_test ( id INT PRIMARY KEY, name STRING, - description STRING, - PRIMARY KEY (id) + description STRING ) from mysql_mytest table 'products'; statement ok @@ -233,12 +232,11 @@ CREATE TABLE IF NOT EXISTS postgres_all_types( statement error The upstream table name must contain schema name prefix* CREATE TABLE person_new ( - id int, + id int PRIMARY KEY, name varchar, email_address varchar, credit_card varchar, - city varchar, - PRIMARY KEY (id) + city varchar ) FROM pg_source TABLE 'person'; statement ok diff --git a/e2e_test/udf/sql_udf.slt b/e2e_test/udf/sql_udf.slt index 866f27abd52ce..5f32bb6f7e024 100644 --- a/e2e_test/udf/sql_udf.slt +++ b/e2e_test/udf/sql_udf.slt @@ -360,7 +360,7 @@ In SQL UDF definition: `select a + b + c + not_be_displayed(c)` ^ -statement error Expected end of statement, found: 💩 +statement error expected end of statement, found: 💩 create function call_regexp_replace() returns varchar language sql as 'select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')'; # Recursive definition can NOT be accepted at present due to semantic check @@ -401,7 +401,7 @@ statement error return type mismatch detected create function type_mismatch(INT) returns varchar language sql as 'select $1 + 114514 + $1'; # Invalid function body syntax -statement error Expected an expression:, found: EOF at the end +statement error expected an expression:, found: EOF at the end create function add_error(INT, INT) returns int language sql as $$select $1 + $2 +$$; ###################################################################### diff --git a/risedev.yml b/risedev.yml index 65f84882c682c..df69da7cb2457 100644 --- a/risedev.yml +++ b/risedev.yml @@ -950,7 +950,7 @@ profile: - use: frontend - use: compactor - ci-1cn-1fe-kafka-with-recovery: + ci-1cn-1fe-user-kafka-with-recovery: config-path: src/config/ci-recovery.toml steps: - use: minio @@ -962,7 +962,9 @@ profile: - use: frontend - use: compactor - use: kafka - persist-data: true + user-managed: true + address: message_queue + port: 29092 ci-meta-backup-test-etcd: config-path: src/config/ci-meta-backup-test.toml diff --git a/src/batch/Cargo.toml b/src/batch/Cargo.toml index 019c33253466b..2ca8ed1be4e77 100644 --- a/src/batch/Cargo.toml +++ b/src/batch/Cargo.toml @@ -20,6 +20,7 @@ arrow-schema = { workspace = true } assert_matches = "1" async-recursion = "1" async-trait = "0.1" +bytes = "1" either = "1" foyer = { workspace = true } futures = { version = "0.3", default-features = false, features = ["alloc"] } @@ -30,9 +31,11 @@ hytra = "0.1.2" icelake = { workspace = true } itertools = { workspace = true } memcomparable = "0.2" +opendal = "0.45.1" parking_lot = { workspace = true } paste = "1" prometheus = { version = "0.13", features = ["process"] } +prost = "0.12" rand = { workspace = true } risingwave_common = { workspace = true } risingwave_common_estimate_size = { workspace = true } @@ -62,6 +65,8 @@ tokio-stream = "0.1" tokio-util = { workspace = true } tonic = { workspace = true } tracing = "0.1" +twox-hash = "1" +uuid = { version = "1", features = ["v4"] } [target.'cfg(not(madsim))'.dependencies] workspace-hack = { path = "../workspace-hack" } diff --git a/src/batch/benches/hash_agg.rs b/src/batch/benches/hash_agg.rs index f37261f7563e4..e91564692dc95 100644 --- a/src/batch/benches/hash_agg.rs +++ b/src/batch/benches/hash_agg.rs @@ -13,6 +13,8 @@ // limitations under the License. pub mod utils; +use std::sync::Arc; + use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; use itertools::Itertools; use risingwave_batch::executor::aggregation::build as build_agg; @@ -96,7 +98,7 @@ fn create_hash_agg_executor( let schema = Schema { fields }; Box::new(HashAggExecutor::::new( - agg_init_states, + Arc::new(agg_init_states), group_key_columns, group_key_types, schema, @@ -104,6 +106,7 @@ fn create_hash_agg_executor( "HashAggExecutor".to_string(), CHUNK_SIZE, MemoryContext::none(), + false, ShutdownToken::empty(), )) } diff --git a/src/batch/src/error.rs b/src/batch/src/error.rs index 8033ddfb3479b..27f355aed48b3 100644 --- a/src/batch/src/error.rs +++ b/src/batch/src/error.rs @@ -139,6 +139,13 @@ pub enum BatchError { #[error("Not enough memory to run this query, batch memory limit is {0} bytes")] OutOfMemory(u64), + + #[error("Failed to spill out to disk")] + Spill( + #[from] + #[backtrace] + opendal::Error, + ), } // Serialize/deserialize error. diff --git a/src/batch/src/executor/hash_agg.rs b/src/batch/src/executor/hash_agg.rs index a0e06c958fc59..cb4adcecdc8c7 100644 --- a/src/batch/src/executor/hash_agg.rs +++ b/src/batch/src/executor/hash_agg.rs @@ -12,27 +12,40 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::hash::BuildHasher; use std::marker::PhantomData; +use std::sync::Arc; +use anyhow::anyhow; +use bytes::Bytes; use futures_async_stream::try_stream; +use futures_util::AsyncReadExt; use hashbrown::hash_map::Entry; use itertools::Itertools; +use prost::Message; use risingwave_common::array::{DataChunk, StreamChunk}; +use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::hash::{HashKey, HashKeyDispatcher, PrecomputedBuildHasher}; use risingwave_common::memory::MemoryContext; -use risingwave_common::types::DataType; +use risingwave_common::row::{OwnedRow, Row, RowExt}; +use risingwave_common::types::{DataType, ToOwnedDatum}; +use risingwave_common::util::chunk_coalesce::DataChunkBuilder; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common_estimate_size::EstimateSize; use risingwave_expr::aggregate::{AggCall, AggregateState, BoxedAggregateFunction}; use risingwave_pb::batch_plan::plan_node::NodeBody; use risingwave_pb::batch_plan::HashAggNode; +use risingwave_pb::data::DataChunk as PbDataChunk; +use twox_hash::XxHash64; use crate::error::{BatchError, Result}; use crate::executor::aggregation::build as build_agg; use crate::executor::{ BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder, + WrapStreamExecutor, }; +use crate::spill::spill_op::{SpillOp, DEFAULT_SPILL_PARTITION_NUM}; use crate::task::{BatchTaskContext, ShutdownToken, TaskId}; type AggHashMap = hashbrown::HashMap, PrecomputedBuildHasher, A>; @@ -43,7 +56,7 @@ impl HashKeyDispatcher for HashAggExecutorBuilder { fn dispatch_impl(self) -> Self::Output { Box::new(HashAggExecutor::::new( - self.aggs, + Arc::new(self.aggs), self.group_key_columns, self.group_key_types, self.schema, @@ -51,6 +64,7 @@ impl HashKeyDispatcher for HashAggExecutorBuilder { self.identity, self.chunk_size, self.mem_context, + self.enable_spill, self.shutdown_rx, )) } @@ -70,6 +84,7 @@ pub struct HashAggExecutorBuilder { identity: String, chunk_size: usize, mem_context: MemoryContext, + enable_spill: bool, shutdown_rx: ShutdownToken, } @@ -81,6 +96,7 @@ impl HashAggExecutorBuilder { identity: String, chunk_size: usize, mem_context: MemoryContext, + enable_spill: bool, shutdown_rx: ShutdownToken, ) -> Result { let aggs: Vec<_> = hash_agg_node @@ -119,6 +135,7 @@ impl HashAggExecutorBuilder { identity, chunk_size, mem_context, + enable_spill, shutdown_rx, }; @@ -148,6 +165,7 @@ impl BoxedExecutorBuilder for HashAggExecutorBuilder { identity.clone(), source.context.get_config().developer.chunk_size, source.context.create_executor_mem_context(identity), + source.context.get_config().enable_spill, source.shutdown_rx.clone(), ) } @@ -156,7 +174,7 @@ impl BoxedExecutorBuilder for HashAggExecutorBuilder { /// `HashAggExecutor` implements the hash aggregate algorithm. pub struct HashAggExecutor { /// Aggregate functions. - aggs: Vec, + aggs: Arc>, /// Column indexes that specify a group group_key_columns: Vec, /// Data types of group key columns @@ -164,16 +182,19 @@ pub struct HashAggExecutor { /// Output schema schema: Schema, child: BoxedExecutor, + /// Used to initialize the state of the aggregation from the spilled files. + init_agg_state_executor: Option, identity: String, chunk_size: usize, mem_context: MemoryContext, + enable_spill: bool, shutdown_rx: ShutdownToken, _phantom: PhantomData, } impl HashAggExecutor { pub fn new( - aggs: Vec, + aggs: Arc>, group_key_columns: Vec, group_key_types: Vec, schema: Schema, @@ -181,6 +202,36 @@ impl HashAggExecutor { identity: String, chunk_size: usize, mem_context: MemoryContext, + enable_spill: bool, + shutdown_rx: ShutdownToken, + ) -> Self { + Self::new_with_init_agg_state( + aggs, + group_key_columns, + group_key_types, + schema, + child, + None, + identity, + chunk_size, + mem_context, + enable_spill, + shutdown_rx, + ) + } + + #[allow(clippy::too_many_arguments)] + fn new_with_init_agg_state( + aggs: Arc>, + group_key_columns: Vec, + group_key_types: Vec, + schema: Schema, + child: BoxedExecutor, + init_agg_state_executor: Option, + identity: String, + chunk_size: usize, + mem_context: MemoryContext, + enable_spill: bool, shutdown_rx: ShutdownToken, ) -> Self { HashAggExecutor { @@ -189,9 +240,11 @@ impl HashAggExecutor { group_key_types, schema, child, + init_agg_state_executor, identity, chunk_size, mem_context, + enable_spill, shutdown_rx, _phantom: PhantomData, } @@ -212,18 +265,259 @@ impl Executor for HashAggExecutor { } } +#[derive(Default, Clone, Copy)] +pub struct SpillBuildHasher(u64); + +impl BuildHasher for SpillBuildHasher { + type Hasher = XxHash64; + + fn build_hasher(&self) -> Self::Hasher { + XxHash64::with_seed(self.0) + } +} + +/// `AggSpillManager` is used to manage how to write spill data file and read them back. +/// The spill data first need to be partitioned. Each partition contains 2 files: `agg_state_file` and `input_chunks_file`. +/// The spill file consume a data chunk and serialize the chunk into a protobuf bytes. +/// Finally, spill file content will look like the below. +/// The file write pattern is append-only and the read pattern is sequential scan. +/// This can maximize the disk IO performance. +/// +/// ```text +/// [proto_len] +/// [proto_bytes] +/// ... +/// [proto_len] +/// [proto_bytes] +/// ``` +pub struct AggSpillManager { + op: SpillOp, + partition_num: usize, + agg_state_writers: Vec, + agg_state_readers: Vec, + agg_state_chunk_builder: Vec, + input_writers: Vec, + input_readers: Vec, + input_chunk_builders: Vec, + spill_build_hasher: SpillBuildHasher, + group_key_types: Vec, + child_data_types: Vec, + agg_data_types: Vec, + spill_chunk_size: usize, +} + +impl AggSpillManager { + fn new( + agg_identity: &String, + partition_num: usize, + group_key_types: Vec, + agg_data_types: Vec, + child_data_types: Vec, + spill_chunk_size: usize, + ) -> Result { + let suffix_uuid = uuid::Uuid::new_v4(); + let dir = format!("/{}-{}/", agg_identity, suffix_uuid); + let op = SpillOp::create(dir)?; + let agg_state_writers = Vec::with_capacity(partition_num); + let agg_state_readers = Vec::with_capacity(partition_num); + let agg_state_chunk_builder = Vec::with_capacity(partition_num); + let input_writers = Vec::with_capacity(partition_num); + let input_readers = Vec::with_capacity(partition_num); + let input_chunk_builders = Vec::with_capacity(partition_num); + // Use uuid to generate an unique hasher so that when recursive spilling happens they would use a different hasher to avoid data skew. + let spill_build_hasher = SpillBuildHasher(suffix_uuid.as_u64_pair().1); + Ok(Self { + op, + partition_num, + agg_state_writers, + agg_state_readers, + agg_state_chunk_builder, + input_writers, + input_readers, + input_chunk_builders, + spill_build_hasher, + group_key_types, + child_data_types, + agg_data_types, + spill_chunk_size, + }) + } + + async fn init_writers(&mut self) -> Result<()> { + for i in 0..self.partition_num { + let agg_state_partition_file_name = format!("agg-state-p{}", i); + let w = self.op.writer_with(&agg_state_partition_file_name).await?; + self.agg_state_writers.push(w); + + let partition_file_name = format!("input-chunks-p{}", i); + let w = self.op.writer_with(&partition_file_name).await?; + self.input_writers.push(w); + self.input_chunk_builders.push(DataChunkBuilder::new( + self.child_data_types.clone(), + self.spill_chunk_size, + )); + self.agg_state_chunk_builder.push(DataChunkBuilder::new( + self.group_key_types + .iter() + .cloned() + .chain(self.agg_data_types.iter().cloned()) + .collect(), + self.spill_chunk_size, + )); + } + Ok(()) + } + + async fn write_agg_state_row(&mut self, row: impl Row, hash_code: u64) -> Result<()> { + let partition = hash_code as usize % self.partition_num; + if let Some(output_chunk) = self.agg_state_chunk_builder[partition].append_one_row(row) { + let chunk_pb: PbDataChunk = output_chunk.to_protobuf(); + let buf = Message::encode_to_vec(&chunk_pb); + let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes()); + self.agg_state_writers[partition].write(len_bytes).await?; + self.agg_state_writers[partition].write(buf).await?; + } + Ok(()) + } + + async fn write_input_chunk(&mut self, chunk: DataChunk, hash_codes: Vec) -> Result<()> { + let (columns, vis) = chunk.into_parts_v2(); + for partition in 0..self.partition_num { + let new_vis = vis.clone() + & Bitmap::from_iter( + hash_codes + .iter() + .map(|hash_code| (*hash_code as usize % self.partition_num) == partition), + ); + let new_chunk = DataChunk::from_parts(columns.clone(), new_vis); + for output_chunk in self.input_chunk_builders[partition].append_chunk(new_chunk) { + let chunk_pb: PbDataChunk = output_chunk.to_protobuf(); + let buf = Message::encode_to_vec(&chunk_pb); + let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes()); + self.input_writers[partition].write(len_bytes).await?; + self.input_writers[partition].write(buf).await?; + } + } + Ok(()) + } + + async fn close_writers(&mut self) -> Result<()> { + for partition in 0..self.partition_num { + if let Some(output_chunk) = self.agg_state_chunk_builder[partition].consume_all() { + let chunk_pb: PbDataChunk = output_chunk.to_protobuf(); + let buf = Message::encode_to_vec(&chunk_pb); + let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes()); + self.agg_state_writers[partition].write(len_bytes).await?; + self.agg_state_writers[partition].write(buf).await?; + } + + if let Some(output_chunk) = self.input_chunk_builders[partition].consume_all() { + let chunk_pb: PbDataChunk = output_chunk.to_protobuf(); + let buf = Message::encode_to_vec(&chunk_pb); + let len_bytes = Bytes::copy_from_slice(&(buf.len() as u32).to_le_bytes()); + self.input_writers[partition].write(len_bytes).await?; + self.input_writers[partition].write(buf).await?; + } + } + + for mut w in self.agg_state_writers.drain(..) { + w.close().await?; + } + for mut w in self.input_writers.drain(..) { + w.close().await?; + } + Ok(()) + } + + #[try_stream(boxed, ok = DataChunk, error = BatchError)] + async fn read_stream(mut reader: opendal::Reader) { + let mut buf = [0u8; 4]; + loop { + if let Err(err) = reader.read_exact(&mut buf).await { + if err.kind() == std::io::ErrorKind::UnexpectedEof { + break; + } else { + return Err(anyhow!(err).into()); + } + } + let len = u32::from_le_bytes(buf) as usize; + let mut buf = vec![0u8; len]; + reader.read_exact(&mut buf).await.map_err(|e| anyhow!(e))?; + let chunk_pb: PbDataChunk = Message::decode(buf.as_slice()).map_err(|e| anyhow!(e))?; + let chunk = DataChunk::from_protobuf(&chunk_pb)?; + yield chunk; + } + } + + async fn read_agg_state_partition(&mut self, partition: usize) -> Result { + let agg_state_partition_file_name = format!("agg-state-p{}", partition); + let r = self.op.reader_with(&agg_state_partition_file_name).await?; + Ok(Self::read_stream(r)) + } + + async fn read_input_partition(&mut self, partition: usize) -> Result { + let input_partition_file_name = format!("input-chunks-p{}", partition); + let r = self.op.reader_with(&input_partition_file_name).await?; + Ok(Self::read_stream(r)) + } + + async fn clear_partition(&mut self, partition: usize) -> Result<()> { + let agg_state_partition_file_name = format!("agg-state-p{}", partition); + self.op.delete(&agg_state_partition_file_name).await?; + let input_partition_file_name = format!("input-chunks-p{}", partition); + self.op.delete(&input_partition_file_name).await?; + Ok(()) + } +} + impl HashAggExecutor { #[try_stream(boxed, ok = DataChunk, error = BatchError)] async fn do_execute(self: Box) { + let child_schema = self.child.schema().clone(); + let mut need_to_spill = false; + // hash map for each agg groups let mut groups = AggHashMap::::with_hasher_in( PrecomputedBuildHasher, self.mem_context.global_allocator(), ); + if let Some(init_agg_state_executor) = self.init_agg_state_executor { + // `init_agg_state_executor` exists which means this is a sub `HashAggExecutor` used to consume spilling data. + // The spilled agg states by its parent executor need to be recovered first. + let mut init_agg_state_stream = init_agg_state_executor.execute(); + #[for_await] + for chunk in &mut init_agg_state_stream { + let chunk = chunk?; + let group_key_indices = (0..self.group_key_columns.len()).collect_vec(); + let keys = K::build_many(&group_key_indices, &chunk); + let mut memory_usage_diff = 0; + for (row_id, key) in keys.into_iter().enumerate() { + let mut agg_states = vec![]; + for i in 0..self.aggs.len() { + let agg = &self.aggs[i]; + let datum = chunk + .row_at(row_id) + .0 + .datum_at(self.group_key_columns.len() + i) + .to_owned_datum(); + let agg_state = agg.decode_state(datum)?; + memory_usage_diff += agg_state.estimated_size() as i64; + agg_states.push(agg_state); + } + groups.try_insert(key, agg_states).unwrap(); + } + + if !self.mem_context.add(memory_usage_diff) { + warn!("not enough memory to load one partition agg state after spill which is not a normal case, so keep going"); + } + } + } + + let mut input_stream = self.child.execute(); // consume all chunks to compute the agg result #[for_await] - for chunk in self.child.execute() { + for chunk in &mut input_stream { let chunk = StreamChunk::from(chunk?); let keys = K::build_many(self.group_key_columns.as_slice(), &chunk); let mut memory_usage_diff = 0; @@ -260,53 +554,158 @@ impl HashAggExecutor { } // update memory usage if !self.mem_context.add(memory_usage_diff) { - Err(BatchError::OutOfMemory(self.mem_context.mem_limit()))?; + if self.enable_spill { + need_to_spill = true; + break; + } else { + Err(BatchError::OutOfMemory(self.mem_context.mem_limit()))?; + } } } - // Don't use `into_iter` here, it may cause memory leak. - let mut result = groups.iter_mut(); - let cardinality = self.chunk_size; - loop { - let mut group_builders: Vec<_> = self - .group_key_types - .iter() - .map(|datatype| datatype.create_array_builder(cardinality)) - .collect(); - - let mut agg_builders: Vec<_> = self - .aggs - .iter() - .map(|agg| agg.return_type().create_array_builder(cardinality)) - .collect(); - - let mut has_next = false; - let mut array_len = 0; - for (key, states) in result.by_ref().take(cardinality) { - self.shutdown_rx.check()?; - has_next = true; - array_len += 1; - key.deserialize_to_builders(&mut group_builders[..], &self.group_key_types)?; - for ((agg, state), builder) in (self.aggs.iter()) - .zip_eq_fast(states) - .zip_eq_fast(&mut agg_builders) - { - let result = agg.get_result(state).await?; - builder.append(result); + if need_to_spill { + // A spilling version of aggregation based on the RFC: Spill Hash Aggregation https://github.com/risingwavelabs/rfcs/pull/89 + // When HashAggExecutor told memory is insufficient, AggSpillManager will start to partition the hash table and spill to disk. + // After spilling the hash table, AggSpillManager will consume all chunks from the input executor, + // partition and spill to disk with the same hash function as the hash table spilling. + // Finally, we would get e.g. 20 partitions. Each partition should contain a portion of the original hash table and input data. + // A sub HashAggExecutor would be used to consume each partition one by one. + // If memory is still not enough in the sub HashAggExecutor, it will partition its hash table and input recursively. + let mut agg_spill_manager = AggSpillManager::new( + &self.identity, + DEFAULT_SPILL_PARTITION_NUM, + self.group_key_types.clone(), + self.aggs.iter().map(|agg| agg.return_type()).collect(), + child_schema.data_types(), + self.chunk_size, + )?; + agg_spill_manager.init_writers().await?; + + let mut memory_usage_diff = 0; + // Spill agg states. + for (key, states) in groups { + let key_row = key.deserialize(&self.group_key_types)?; + let mut agg_datums = vec![]; + for (agg, state) in self.aggs.iter().zip_eq_fast(states) { + let encode_state = agg.encode_state(&state)?; + memory_usage_diff -= state.estimated_size() as i64; + agg_datums.push(encode_state); } + let agg_state_row = OwnedRow::from_iter(agg_datums.into_iter()); + let hash_code = agg_spill_manager.spill_build_hasher.hash_one(key); + agg_spill_manager + .write_agg_state_row(key_row.chain(agg_state_row), hash_code) + .await?; } - if !has_next { - break; // exit loop + + // Release memory occupied by agg hash map. + self.mem_context.add(memory_usage_diff); + + // Spill input chunks. + #[for_await] + for chunk in input_stream { + let chunk: DataChunk = chunk?; + let hash_codes = chunk.get_hash_values( + self.group_key_columns.as_slice(), + agg_spill_manager.spill_build_hasher, + ); + agg_spill_manager + .write_input_chunk( + chunk, + hash_codes + .into_iter() + .map(|hash_code| hash_code.value()) + .collect(), + ) + .await?; } - let columns = group_builders - .into_iter() - .chain(agg_builders) - .map(|b| b.finish().into()) - .collect::>(); + agg_spill_manager.close_writers().await?; + + // Process each partition one by one. + for i in 0..agg_spill_manager.partition_num { + let agg_state_stream = agg_spill_manager.read_agg_state_partition(i).await?; + let input_stream = agg_spill_manager.read_input_partition(i).await?; + + let sub_hash_agg_executor: HashAggExecutor = + HashAggExecutor::new_with_init_agg_state( + self.aggs.clone(), + self.group_key_columns.clone(), + self.group_key_types.clone(), + self.schema.clone(), + Box::new(WrapStreamExecutor::new(child_schema.clone(), input_stream)), + Some(Box::new(WrapStreamExecutor::new( + self.schema.clone(), + agg_state_stream, + ))), + format!("{}-sub{}", self.identity.clone(), i), + self.chunk_size, + self.mem_context.clone(), + self.enable_spill, + self.shutdown_rx.clone(), + ); + + debug!( + "create sub_hash_agg {} for hash_agg {} to spill", + sub_hash_agg_executor.identity, self.identity + ); + + let sub_hash_agg_stream = Box::new(sub_hash_agg_executor).execute(); + + #[for_await] + for chunk in sub_hash_agg_stream { + let chunk = chunk?; + yield chunk; + } - let output = DataChunk::new(columns, array_len); - yield output; + // Clear files of the current partition. + agg_spill_manager.clear_partition(i).await?; + } + } else { + // Don't use `into_iter` here, it may cause memory leak. + let mut result = groups.iter_mut(); + let cardinality = self.chunk_size; + loop { + let mut group_builders: Vec<_> = self + .group_key_types + .iter() + .map(|datatype| datatype.create_array_builder(cardinality)) + .collect(); + + let mut agg_builders: Vec<_> = self + .aggs + .iter() + .map(|agg| agg.return_type().create_array_builder(cardinality)) + .collect(); + + let mut has_next = false; + let mut array_len = 0; + for (key, states) in result.by_ref().take(cardinality) { + self.shutdown_rx.check()?; + has_next = true; + array_len += 1; + key.deserialize_to_builders(&mut group_builders[..], &self.group_key_types)?; + for ((agg, state), builder) in (self.aggs.iter()) + .zip_eq_fast(states) + .zip_eq_fast(&mut agg_builders) + { + let result = agg.get_result(state).await?; + builder.append(result); + } + } + if !has_next { + break; // exit loop + } + + let columns = group_builders + .into_iter() + .chain(agg_builders) + .map(|b| b.finish().into()) + .collect::>(); + + let output = DataChunk::new(columns, array_len); + yield output; + } } } } @@ -316,7 +715,6 @@ mod tests { use std::alloc::{AllocError, Allocator, Global, Layout}; use std::ptr::NonNull; use std::sync::atomic::{AtomicBool, Ordering}; - use std::sync::Arc; use futures_async_stream::for_await; use risingwave_common::metrics::LabelGuardedIntGauge; @@ -390,6 +788,7 @@ mod tests { "HashAggExecutor".to_string(), CHUNK_SIZE, mem_context.clone(), + false, ShutdownToken::empty(), ) .unwrap(); @@ -462,6 +861,7 @@ mod tests { "HashAggExecutor".to_string(), CHUNK_SIZE, MemoryContext::none(), + false, ShutdownToken::empty(), ) .unwrap(); @@ -577,6 +977,7 @@ mod tests { "HashAggExecutor".to_string(), CHUNK_SIZE, MemoryContext::none(), + false, shutdown_rx, ) .unwrap(); diff --git a/src/batch/src/executor/utils.rs b/src/batch/src/executor/utils.rs index 9c6f162f02268..4f724ec5416c8 100644 --- a/src/batch/src/executor/utils.rs +++ b/src/batch/src/executor/utils.rs @@ -99,3 +99,28 @@ impl DummyExecutor { #[try_stream(boxed, ok = DataChunk, error = BatchError)] async fn do_nothing() {} } + +pub struct WrapStreamExecutor { + schema: Schema, + stream: BoxedDataChunkStream, +} + +impl WrapStreamExecutor { + pub fn new(schema: Schema, stream: BoxedDataChunkStream) -> Self { + Self { schema, stream } + } +} + +impl Executor for WrapStreamExecutor { + fn schema(&self) -> &Schema { + &self.schema + } + + fn identity(&self) -> &str { + "WrapStreamExecutor" + } + + fn execute(self: Box) -> BoxedDataChunkStream { + self.stream + } +} diff --git a/src/batch/src/lib.rs b/src/batch/src/lib.rs index b8e6df1ac9538..2c072319e1c8a 100644 --- a/src/batch/src/lib.rs +++ b/src/batch/src/lib.rs @@ -39,6 +39,7 @@ pub mod execution; pub mod executor; pub mod monitor; pub mod rpc; +mod spill; pub mod task; pub mod worker_manager; diff --git a/src/batch/src/spill/mod.rs b/src/batch/src/spill/mod.rs new file mode 100644 index 0000000000000..6af1eae7429c6 --- /dev/null +++ b/src/batch/src/spill/mod.rs @@ -0,0 +1,15 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod spill_op; diff --git a/src/batch/src/spill/spill_op.rs b/src/batch/src/spill/spill_op.rs new file mode 100644 index 0000000000000..115a0c2d430e1 --- /dev/null +++ b/src/batch/src/spill/spill_op.rs @@ -0,0 +1,98 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::{Deref, DerefMut}; + +use opendal::layers::RetryLayer; +use opendal::services::Fs; +use opendal::Operator; +use thiserror_ext::AsReport; + +use crate::error::Result; + +const RW_BATCH_SPILL_DIR_ENV: &str = "RW_BATCH_SPILL_DIR"; +pub const DEFAULT_SPILL_PARTITION_NUM: usize = 20; +const DEFAULT_SPILL_DIR: &str = "/tmp/"; +const RW_MANAGED_SPILL_DIR: &str = "/rw_batch_spill/"; +const DEFAULT_IO_BUFFER_SIZE: usize = 256 * 1024; +const DEFAULT_IO_CONCURRENT_TASK: usize = 8; + +/// `SpillOp` is used to manage the spill directory of the spilling executor and it will drop the directory with a RAII style. +pub struct SpillOp { + pub op: Operator, +} + +impl SpillOp { + pub fn create(path: String) -> Result { + assert!(path.ends_with('/')); + + let spill_dir = + std::env::var(RW_BATCH_SPILL_DIR_ENV).unwrap_or_else(|_| DEFAULT_SPILL_DIR.to_string()); + let root = format!("/{}/{}/{}/", spill_dir, RW_MANAGED_SPILL_DIR, path); + + let mut builder = Fs::default(); + builder.root(&root); + + let op: Operator = Operator::new(builder)? + .layer(RetryLayer::default()) + .finish(); + Ok(SpillOp { op }) + } + + pub async fn writer_with(&self, name: &str) -> Result { + Ok(self + .op + .writer_with(name) + .buffer(DEFAULT_IO_BUFFER_SIZE) + .concurrent(DEFAULT_IO_CONCURRENT_TASK) + .await?) + } + + pub async fn reader_with(&self, name: &str) -> Result { + Ok(self + .op + .reader_with(name) + .buffer(DEFAULT_IO_BUFFER_SIZE) + .await?) + } +} + +impl Drop for SpillOp { + fn drop(&mut self) { + let op = self.op.clone(); + tokio::task::spawn(async move { + let result = op.remove_all("/").await; + if let Err(error) = result { + error!( + error = %error.as_report(), + "Failed to remove spill directory" + ); + } + }); + } +} + +impl DerefMut for SpillOp { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.op + } +} + +impl Deref for SpillOp { + type Target = Operator; + + fn deref(&self) -> &Self::Target { + &self.op + } +} diff --git a/src/common/src/config.rs b/src/common/src/config.rs index 8ae14702d3261..c8da0f6dce5e9 100644 --- a/src/common/src/config.rs +++ b/src/common/src/config.rs @@ -535,6 +535,10 @@ pub struct BatchConfig { /// A SQL option with a name containing any of these keywords will be redacted. #[serde(default = "default::batch::redact_sql_option_keywords")] pub redact_sql_option_keywords: Vec, + + /// Enable the spill out to disk feature for batch queries. + #[serde(default = "default::batch::enable_spill")] + pub enable_spill: bool, } /// The section `[streaming]` in `risingwave.toml`. @@ -1759,6 +1763,10 @@ pub mod default { false } + pub fn enable_spill() -> bool { + true + } + pub fn statement_timeout_in_sec() -> u32 { // 1 hour 60 * 60 diff --git a/src/common/src/hash/key.rs b/src/common/src/hash/key.rs index 4911e041370a9..c7e57173a3e74 100644 --- a/src/common/src/hash/key.rs +++ b/src/common/src/hash/key.rs @@ -236,7 +236,7 @@ impl From for HashCode { } impl HashCode { - pub fn value(self) -> u64 { + pub fn value(&self) -> u64 { self.value } } diff --git a/src/config/docs.md b/src/config/docs.md index 2f8c4ce2812b1..018c9dd41087c 100644 --- a/src/config/docs.md +++ b/src/config/docs.md @@ -8,6 +8,7 @@ This page is automatically generated by `./risedev generate-example-config` |--------|-------------|---------| | distributed_query_limit | This is the max number of queries per sql session. | | | enable_barrier_read | | false | +| enable_spill | Enable the spill out to disk feature for batch queries. | true | | frontend_compute_runtime_worker_threads | frontend compute runtime worker threads | 4 | | mask_worker_temporary_secs | This is the secs used to mask a worker unavailable temporarily. | 30 | | max_batch_queries_per_frontend_node | This is the max number of batch queries per frontend node. | | diff --git a/src/config/example.toml b/src/config/example.toml index fb2243535d6a4..00b1ef759e5f9 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -87,6 +87,7 @@ statement_timeout_in_sec = 3600 frontend_compute_runtime_worker_threads = 4 mask_worker_temporary_secs = 30 redact_sql_option_keywords = ["credential", "key", "password", "private", "secret", "token"] +enable_spill = true [batch.developer] batch_connector_message_buffer_size = 16 diff --git a/src/expr/core/src/aggregate/mod.rs b/src/expr/core/src/aggregate/mod.rs index 8ccd71b0e3e37..2a1119d6fe301 100644 --- a/src/expr/core/src/aggregate/mod.rs +++ b/src/expr/core/src/aggregate/mod.rs @@ -15,6 +15,7 @@ use std::fmt::Debug; use std::ops::Range; +use anyhow::anyhow; use downcast_rs::{impl_downcast, Downcast}; use itertools::Itertools; use risingwave_common::array::StreamChunk; @@ -60,7 +61,7 @@ pub trait AggregateFunction: Send + Sync + 'static { fn encode_state(&self, state: &AggregateState) -> Result { match state { AggregateState::Datum(d) => Ok(d.clone()), - _ => panic!("cannot encode state"), + AggregateState::Any(_) => Err(ExprError::Internal(anyhow!("cannot encode state"))), } } diff --git a/src/frontend/src/handler/create_table.rs b/src/frontend/src/handler/create_table.rs index 4f3b81a20e630..38829a16be11e 100644 --- a/src/frontend/src/handler/create_table.rs +++ b/src/frontend/src/handler/create_table.rs @@ -743,6 +743,7 @@ pub(crate) fn gen_create_table_plan_for_cdc_source( with_version_column: Option, include_column_options: IncludeOption, ) -> Result<(PlanRef, PbTable)> { + // cdc table must have primary key constraint or primary key column if !constraints.iter().any(|c| { matches!( c, @@ -751,6 +752,10 @@ pub(crate) fn gen_create_table_plan_for_cdc_source( .. } ) + }) && !column_defs.iter().any(|col| { + col.options + .iter() + .any(|opt| matches!(opt.option, ColumnOption::Unique { is_primary: true })) }) { return Err(ErrorCode::NotSupported( "CDC table without primary key constraint is not supported".to_owned(), diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 243d9e695ff3a..71a2099e0e6fd 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -226,22 +226,54 @@ impl Parser { let mut tokenizer = Tokenizer::new(sql); let tokens = tokenizer.tokenize_with_location()?; let mut parser = Parser::new(tokens); + let ast = parser.parse_statements().map_err(|e| { + // append SQL context to the error message, e.g.: + // LINE 1: SELECT 1::int(2); + // ^ + // XXX: the cursor location is not accurate + // it may be offset one token forward because the error token has been consumed + let loc = match parser.tokens.get(parser.index) { + Some(token) => token.location.clone(), + None => { + // get location of EOF + Location { + line: sql.lines().count() as u64, + column: sql.lines().last().map_or(0, |l| l.len() as u64) + 1, + } + } + }; + let prefix = format!("LINE {}: ", loc.line); + let sql_line = sql.split('\n').nth(loc.line as usize - 1).unwrap(); + let cursor = " ".repeat(prefix.len() + loc.column as usize - 1); + ParserError::ParserError(format!( + "{}\n{}{}\n{}^", + e.inner_msg(), + prefix, + sql_line, + cursor + )) + })?; + Ok(ast) + } + + /// Parse a list of semicolon-separated SQL statements. + pub fn parse_statements(&mut self) -> Result, ParserError> { let mut stmts = Vec::new(); let mut expecting_statement_delimiter = false; loop { // ignore empty statements (between successive statement delimiters) - while parser.consume_token(&Token::SemiColon) { + while self.consume_token(&Token::SemiColon) { expecting_statement_delimiter = false; } - if parser.peek_token() == Token::EOF { + if self.peek_token() == Token::EOF { break; } if expecting_statement_delimiter { - return parser.expected("end of statement", parser.peek_token()); + return self.expected("end of statement", self.peek_token()); } - let statement = parser.parse_statement()?; + let statement = self.parse_statement()?; stmts.push(statement); expecting_statement_delimiter = true; } @@ -1958,24 +1990,7 @@ impl Parser { /// Report unexpected token pub fn expected(&self, expected: &str, found: TokenWithLocation) -> Result { - let start_off = self.index.saturating_sub(10); - let end_off = self.index.min(self.tokens.len()); - let near_tokens = &self.tokens[start_off..end_off]; - struct TokensDisplay<'a>(&'a [TokenWithLocation]); - impl<'a> fmt::Display for TokensDisplay<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - for token in self.0 { - write!(f, "{}", token.token)?; - } - Ok(()) - } - } - parser_err!(format!( - "Expected {}, found: {}\nNear \"{}\"", - expected, - found, - TokensDisplay(near_tokens), - )) + parser_err!(format!("expected {}, found: {}", expected, found)) } /// Look for an expected keyword and consume it if it exists diff --git a/src/sqlparser/src/tokenizer.rs b/src/sqlparser/src/tokenizer.rs index 03b794b06c62b..9ca4948b282ea 100644 --- a/src/sqlparser/src/tokenizer.rs +++ b/src/sqlparser/src/tokenizer.rs @@ -392,7 +392,7 @@ impl fmt::Display for TokenWithLocation { } else { write!( f, - "{} at line:{}, column:{}", + "{} at line {}, column {}", self.token, self.location.line, self.location.column ) } @@ -405,14 +405,15 @@ pub struct TokenizerError { pub message: String, pub line: u64, pub col: u64, + pub context: String, } impl fmt::Display for TokenizerError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "{} at Line: {}, Column {}", - self.message, self.line, self.col + "{} at line {}, column {}\n{}", + self.message, self.line, self.col, self.context ) } } @@ -422,7 +423,8 @@ impl std::error::Error for TokenizerError {} /// SQL Tokenizer pub struct Tokenizer<'a> { - query: &'a str, + sql: &'a str, + chars: Peekable>, line: u64, col: u64, } @@ -431,36 +433,39 @@ impl<'a> Tokenizer<'a> { /// Create a new SQL tokenizer for the specified SQL statement pub fn new(query: &'a str) -> Self { Self { - query, + sql: query, + chars: query.chars().peekable(), line: 1, col: 1, } } - /// Tokenize the statement and produce a vector of tokens with locations. - pub fn tokenize_with_location(&mut self) -> Result, TokenizerError> { - let mut peekable = self.query.chars().peekable(); - - let mut tokens: Vec = vec![]; - - while let Some(token) = self.next_token(&mut peekable)? { - match &token { - Token::Whitespace(Whitespace::Newline) => { + /// Consume the next character. + fn next(&mut self) -> Option { + let ch = self.chars.next(); + if let Some(ch) = ch { + match ch { + '\n' => { self.line += 1; self.col = 1; } - - Token::Whitespace(Whitespace::Tab) => self.col += 4, - Token::Word(w) if w.quote_style.is_none() => self.col += w.value.len() as u64, - Token::Word(w) if w.quote_style.is_some() => self.col += w.value.len() as u64 + 2, - Token::Number(s) => self.col += s.len() as u64, - Token::SingleQuotedString(s) => self.col += s.len() as u64, + '\t' => self.col += 4, _ => self.col += 1, } + } + ch + } - let token_with_location = TokenWithLocation::new(token, self.line, self.col); + /// Return the next character without consuming it. + fn peek(&mut self) -> Option { + self.chars.peek().cloned() + } - tokens.push(token_with_location); + /// Tokenize the statement and produce a vector of tokens with locations. + pub fn tokenize_with_location(&mut self) -> Result, TokenizerError> { + let mut tokens = Vec::new(); + while let Some(token) = self.next_token_with_location()? { + tokens.push(token); } Ok(tokens) } @@ -468,51 +473,61 @@ impl<'a> Tokenizer<'a> { /// Tokenize the statement and produce a vector of tokens without locations. #[allow(dead_code)] fn tokenize(&mut self) -> Result, TokenizerError> { - self.tokenize_with_location() - .map(|v| v.into_iter().map(|t| t.token).collect()) + let tokens = self.tokenize_with_location()?; + Ok(tokens.into_iter().map(|t| t.token).collect()) } /// Get the next token or return None - fn next_token(&self, chars: &mut Peekable>) -> Result, TokenizerError> { - match chars.peek() { - Some(&ch) => match ch { - ' ' => self.consume_and_return(chars, Token::Whitespace(Whitespace::Space)), - '\t' => self.consume_and_return(chars, Token::Whitespace(Whitespace::Tab)), - '\n' => self.consume_and_return(chars, Token::Whitespace(Whitespace::Newline)), + fn next_token_with_location(&mut self) -> Result, TokenizerError> { + let loc = Location { + line: self.line, + column: self.col, + }; + self.next_token() + .map(|t| t.map(|token| token.with_location(loc))) + } + + /// Get the next token or return None + fn next_token(&mut self) -> Result, TokenizerError> { + match self.peek() { + Some(ch) => match ch { + ' ' => self.consume_and_return(Token::Whitespace(Whitespace::Space)), + '\t' => self.consume_and_return(Token::Whitespace(Whitespace::Tab)), + '\n' => self.consume_and_return(Token::Whitespace(Whitespace::Newline)), '\r' => { // Emit a single Whitespace::Newline token for \r and \r\n - chars.next(); - if let Some('\n') = chars.peek() { - chars.next(); + self.next(); + if let Some('\n') = self.peek() { + self.next(); } Ok(Some(Token::Whitespace(Whitespace::Newline))) } 'N' => { - chars.next(); // consume, to check the next char - match chars.peek() { + self.next(); // consume, to check the next char + match self.peek() { Some('\'') => { // N'...' - a - let s = self.tokenize_single_quoted_string(chars)?; + let s = self.tokenize_single_quoted_string()?; Ok(Some(Token::NationalStringLiteral(s))) } _ => { // regular identifier starting with an "N" - let s = self.tokenize_word('N', chars); + let s = self.tokenize_word('N'); Ok(Some(Token::make_word(&s, None))) } } } x @ 'e' | x @ 'E' => { - chars.next(); // consume, to check the next char - match chars.peek() { + self.next(); // consume, to check the next char + match self.peek() { Some('\'') => { // E'...' - a - let s = self.tokenize_single_quoted_string_with_escape(chars)?; + let s = self.tokenize_single_quoted_string_with_escape()?; Ok(Some(Token::CstyleEscapesString(s))) } _ => { // regular identifier starting with an "E" - let s = self.tokenize_word(x, chars); + let s = self.tokenize_word(x); Ok(Some(Token::make_word(&s, None))) } } @@ -520,42 +535,42 @@ impl<'a> Tokenizer<'a> { // The spec only allows an uppercase 'X' to introduce a hex // string, but PostgreSQL, at least, allows a lowercase 'x' too. x @ 'x' | x @ 'X' => { - chars.next(); // consume, to check the next char - match chars.peek() { + self.next(); // consume, to check the next char + match self.peek() { Some('\'') => { // X'...' - a - let s = self.tokenize_single_quoted_string(chars)?; + let s = self.tokenize_single_quoted_string()?; Ok(Some(Token::HexStringLiteral(s))) } _ => { // regular identifier starting with an "X" - let s = self.tokenize_word(x, chars); + let s = self.tokenize_word(x); Ok(Some(Token::make_word(&s, None))) } } } // identifier or keyword ch if is_identifier_start(ch) => { - chars.next(); // consume the first char - let s = self.tokenize_word(ch, chars); + self.next(); // consume the first char + let s = self.tokenize_word(ch); Ok(Some(Token::make_word(&s, None))) } // string '\'' => { - let s = self.tokenize_single_quoted_string(chars)?; + let s = self.tokenize_single_quoted_string()?; Ok(Some(Token::SingleQuotedString(s))) } // delimited (quoted) identifier quote_start if is_delimited_identifier_start(quote_start) => { - chars.next(); // consume the opening quote + self.next(); // consume the opening quote let quote_end = Word::matching_end_quote(quote_start); - let s = peeking_take_while(chars, |ch| ch != quote_end); - if chars.next() == Some(quote_end) { + let s = self.peeking_take_while(|ch| ch != quote_end); + if self.next() == Some(quote_end) { Ok(Some(Token::make_word(&s, Some(quote_start)))) } else { - self.tokenizer_error(format!( + self.error(format!( "Expected close delimiter '{}' before EOF.", quote_end )) @@ -563,14 +578,14 @@ impl<'a> Tokenizer<'a> { } // numbers and period '0'..='9' | '.' => { - let mut s = peeking_take_while(chars, |ch| ch.is_ascii_digit()); + let mut s = self.peeking_take_while(|ch| ch.is_ascii_digit()); // match binary literal that starts with 0x if s == "0" - && let Some(&radix) = chars.peek() + && let Some(radix) = self.peek() && "xob".contains(radix.to_ascii_lowercase()) { - chars.next(); + self.next(); let radix = radix.to_ascii_lowercase(); let base = match radix { 'x' => 16, @@ -578,67 +593,67 @@ impl<'a> Tokenizer<'a> { 'b' => 2, _ => unreachable!(), }; - let s2 = peeking_take_while(chars, |ch| ch.is_digit(base)); + let s2 = self.peeking_take_while(|ch| ch.is_digit(base)); if s2.is_empty() { - return self.tokenizer_error("incomplete integer literal"); + return self.error("incomplete integer literal"); } - self.reject_number_junk(chars)?; + self.reject_number_junk()?; return Ok(Some(Token::Number(format!("0{radix}{s2}")))); } // match one period - if let Some('.') = chars.peek() { + if let Some('.') = self.peek() { s.push('.'); - chars.next(); + self.next(); } - s += &peeking_take_while(chars, |ch| ch.is_ascii_digit()); + s += &self.peeking_take_while(|ch| ch.is_ascii_digit()); // No number -> Token::Period if s == "." { return Ok(Some(Token::Period)); } - match chars.peek() { + match self.peek() { // Number is a scientific number (1e6) Some('e') | Some('E') => { s.push('e'); - chars.next(); + self.next(); - if let Some('-') = chars.peek() { + if let Some('-') = self.peek() { s.push('-'); - chars.next(); + self.next(); } - s += &peeking_take_while(chars, |ch| ch.is_ascii_digit()); - self.reject_number_junk(chars)?; + s += &self.peeking_take_while(|ch| ch.is_ascii_digit()); + self.reject_number_junk()?; return Ok(Some(Token::Number(s))); } // Not a scientific number _ => {} }; - self.reject_number_junk(chars)?; + self.reject_number_junk()?; Ok(Some(Token::Number(s))) } // punctuation - '(' => self.consume_and_return(chars, Token::LParen), - ')' => self.consume_and_return(chars, Token::RParen), - ',' => self.consume_and_return(chars, Token::Comma), + '(' => self.consume_and_return(Token::LParen), + ')' => self.consume_and_return(Token::RParen), + ',' => self.consume_and_return(Token::Comma), // operators '-' => { - chars.next(); // consume the '-' - match chars.peek() { + self.next(); // consume the '-' + match self.peek() { Some('-') => { - chars.next(); // consume the second '-', starting a single-line comment - let comment = self.tokenize_single_line_comment(chars); + self.next(); // consume the second '-', starting a single-line comment + let comment = self.tokenize_single_line_comment(); Ok(Some(Token::Whitespace(Whitespace::SingleLineComment { prefix: "--".to_owned(), comment, }))) } Some('>') => { - chars.next(); // consume first '>' - match chars.peek() { + self.next(); // consume first '>' + match self.peek() { Some('>') => { - chars.next(); // consume second '>' + self.next(); // consume second '>' Ok(Some(Token::LongArrow)) } _ => Ok(Some(Token::Arrow)), @@ -649,27 +664,27 @@ impl<'a> Tokenizer<'a> { } } '/' => { - chars.next(); // consume the '/' - match chars.peek() { + self.next(); // consume the '/' + match self.peek() { Some('*') => { - chars.next(); // consume the '*', starting a multi-line comment - self.tokenize_multiline_comment(chars) + self.next(); // consume the '*', starting a multi-line comment + self.tokenize_multiline_comment() } // a regular '/' operator _ => Ok(Some(Token::Div)), } } - '+' => self.consume_and_return(chars, Token::Plus), - '*' => self.consume_and_return(chars, Token::Mul), - '%' => self.consume_and_return(chars, Token::Mod), + '+' => self.consume_and_return(Token::Plus), + '*' => self.consume_and_return(Token::Mul), + '%' => self.consume_and_return(Token::Mod), '|' => { - chars.next(); // consume the '|' - match chars.peek() { - Some('/') => self.consume_and_return(chars, Token::PGSquareRoot), + self.next(); // consume the '|' + match self.peek() { + Some('/') => self.consume_and_return(Token::PGSquareRoot), Some('|') => { - chars.next(); // consume the second '|' - match chars.peek() { - Some('/') => self.consume_and_return(chars, Token::PGCubeRoot), + self.next(); // consume the second '|' + match self.peek() { + Some('/') => self.consume_and_return(Token::PGCubeRoot), _ => Ok(Some(Token::Concat)), } } @@ -678,32 +693,32 @@ impl<'a> Tokenizer<'a> { } } '=' => { - chars.next(); // consume - match chars.peek() { - Some('>') => self.consume_and_return(chars, Token::RArrow), + self.next(); // consume + match self.peek() { + Some('>') => self.consume_and_return(Token::RArrow), _ => Ok(Some(Token::Eq)), } } '!' => { - chars.next(); // consume - match chars.peek() { - Some('=') => self.consume_and_return(chars, Token::Neq), - Some('!') => self.consume_and_return(chars, Token::DoubleExclamationMark), + self.next(); // consume + match self.peek() { + Some('=') => self.consume_and_return(Token::Neq), + Some('!') => self.consume_and_return(Token::DoubleExclamationMark), Some('~') => { - chars.next(); - match chars.peek() { + self.next(); + match self.peek() { Some('~') => { - chars.next(); - match chars.peek() { + self.next(); + match self.peek() { Some('*') => self.consume_and_return( - chars, Token::ExclamationMarkDoubleTildeAsterisk, ), _ => Ok(Some(Token::ExclamationMarkDoubleTilde)), } } - Some('*') => self - .consume_and_return(chars, Token::ExclamationMarkTildeAsterisk), + Some('*') => { + self.consume_and_return(Token::ExclamationMarkTildeAsterisk) + } _ => Ok(Some(Token::ExclamationMarkTilde)), } } @@ -711,76 +726,74 @@ impl<'a> Tokenizer<'a> { } } '<' => { - chars.next(); // consume - match chars.peek() { + self.next(); // consume + match self.peek() { Some('=') => { - chars.next(); - match chars.peek() { - Some('>') => self.consume_and_return(chars, Token::Spaceship), + self.next(); + match self.peek() { + Some('>') => self.consume_and_return(Token::Spaceship), _ => Ok(Some(Token::LtEq)), } } - Some('>') => self.consume_and_return(chars, Token::Neq), - Some('<') => self.consume_and_return(chars, Token::ShiftLeft), - Some('@') => self.consume_and_return(chars, Token::ArrowAt), + Some('>') => self.consume_and_return(Token::Neq), + Some('<') => self.consume_and_return(Token::ShiftLeft), + Some('@') => self.consume_and_return(Token::ArrowAt), _ => Ok(Some(Token::Lt)), } } '>' => { - chars.next(); // consume - match chars.peek() { - Some('=') => self.consume_and_return(chars, Token::GtEq), - Some('>') => self.consume_and_return(chars, Token::ShiftRight), + self.next(); // consume + match self.peek() { + Some('=') => self.consume_and_return(Token::GtEq), + Some('>') => self.consume_and_return(Token::ShiftRight), _ => Ok(Some(Token::Gt)), } } ':' => { - chars.next(); - match chars.peek() { - Some(':') => self.consume_and_return(chars, Token::DoubleColon), + self.next(); + match self.peek() { + Some(':') => self.consume_and_return(Token::DoubleColon), _ => Ok(Some(Token::Colon)), } } - '$' => Ok(Some(self.tokenize_dollar_preceded_value(chars)?)), - ';' => self.consume_and_return(chars, Token::SemiColon), - '\\' => self.consume_and_return(chars, Token::Backslash), - '[' => self.consume_and_return(chars, Token::LBracket), - ']' => self.consume_and_return(chars, Token::RBracket), - '&' => self.consume_and_return(chars, Token::Ampersand), + '$' => Ok(Some(self.tokenize_dollar_preceded_value()?)), + ';' => self.consume_and_return(Token::SemiColon), + '\\' => self.consume_and_return(Token::Backslash), + '[' => self.consume_and_return(Token::LBracket), + ']' => self.consume_and_return(Token::RBracket), + '&' => self.consume_and_return(Token::Ampersand), '^' => { - chars.next(); - match chars.peek() { - Some('@') => self.consume_and_return(chars, Token::Prefix), + self.next(); + match self.peek() { + Some('@') => self.consume_and_return(Token::Prefix), _ => Ok(Some(Token::Caret)), } } - '{' => self.consume_and_return(chars, Token::LBrace), - '}' => self.consume_and_return(chars, Token::RBrace), + '{' => self.consume_and_return(Token::LBrace), + '}' => self.consume_and_return(Token::RBrace), '~' => { - chars.next(); // consume - match chars.peek() { + self.next(); // consume + match self.peek() { Some('~') => { - chars.next(); - match chars.peek() { - Some('*') => { - self.consume_and_return(chars, Token::DoubleTildeAsterisk) - } + self.next(); + match self.peek() { + Some('*') => self.consume_and_return(Token::DoubleTildeAsterisk), _ => Ok(Some(Token::DoubleTilde)), } } - Some('*') => self.consume_and_return(chars, Token::TildeAsterisk), + Some('*') => self.consume_and_return(Token::TildeAsterisk), _ => Ok(Some(Token::Tilde)), } } '#' => { - chars.next(); // consume the '#' - match chars.peek() { - Some('-') => self.consume_and_return(chars, Token::HashMinus), + self.next(); // consume the '#' + match self.peek() { + Some('-') => self.consume_and_return(Token::HashMinus), Some('>') => { - chars.next(); // consume first '>' - match chars.peek() { + self.next(); // consume first '>' + match self.peek() { Some('>') => { - chars.next(); // consume second '>' + self.next(); // consume second '>' Ok(Some(Token::HashLongArrow)) } _ => Ok(Some(Token::HashArrow)), @@ -791,50 +804,47 @@ impl<'a> Tokenizer<'a> { } } '@' => { - chars.next(); // consume the '@' - match chars.peek() { - Some('>') => self.consume_and_return(chars, Token::AtArrow), - Some('?') => self.consume_and_return(chars, Token::AtQuestionMark), - Some('@') => self.consume_and_return(chars, Token::AtAt), + self.next(); // consume the '@' + match self.peek() { + Some('>') => self.consume_and_return(Token::AtArrow), + Some('?') => self.consume_and_return(Token::AtQuestionMark), + Some('@') => self.consume_and_return(Token::AtAt), // a regular '@' operator _ => Ok(Some(Token::AtSign)), } } '?' => { - chars.next(); // consume the '?' - match chars.peek() { - Some('|') => self.consume_and_return(chars, Token::QuestionMarkPipe), - Some('&') => self.consume_and_return(chars, Token::QuestionMarkAmpersand), + self.next(); // consume the '?' + match self.peek() { + Some('|') => self.consume_and_return(Token::QuestionMarkPipe), + Some('&') => self.consume_and_return(Token::QuestionMarkAmpersand), // a regular '?' operator _ => Ok(Some(Token::QuestionMark)), } } - other => self.consume_and_return(chars, Token::Char(other)), + other => self.consume_and_return(Token::Char(other)), }, None => Ok(None), } } /// Tokenize dollar preceded value (i.e: a string/placeholder) - fn tokenize_dollar_preceded_value( - &self, - chars: &mut Peekable>, - ) -> Result { + fn tokenize_dollar_preceded_value(&mut self) -> Result { let mut s = String::new(); let mut value = String::new(); - chars.next(); + self.next(); - if let Some('$') = chars.peek() { - chars.next(); + if let Some('$') = self.peek() { + self.next(); let mut is_terminated = false; let mut prev: Option = None; - while let Some(&ch) = chars.peek() { + while let Some(ch) = self.peek() { if prev == Some('$') { if ch == '$' { - chars.next(); + self.next(); is_terminated = true; break; } else { @@ -846,11 +856,11 @@ impl<'a> Tokenizer<'a> { } prev = Some(ch); - chars.next(); + self.next(); } - return if chars.peek().is_none() && !is_terminated { - self.tokenizer_error("Unterminated dollar-quoted string") + return if self.peek().is_none() && !is_terminated { + self.error("Unterminated dollar-quoted string") } else { Ok(Token::DollarQuotedString(DollarQuotedString { value: s, @@ -858,36 +868,33 @@ impl<'a> Tokenizer<'a> { })) }; } else { - value.push_str(&peeking_take_while(chars, |ch| { - ch.is_alphanumeric() || ch == '_' - })); + value.push_str(&self.peeking_take_while(|ch| ch.is_alphanumeric() || ch == '_')); - if let Some('$') = chars.peek() { - chars.next(); - s.push_str(&peeking_take_while(chars, |ch| ch != '$')); + if let Some('$') = self.peek() { + self.next(); + s.push_str(&self.peeking_take_while(|ch| ch != '$')); - match chars.peek() { + match self.peek() { Some('$') => { - chars.next(); + self.next(); for c in value.chars() { - let next_char = chars.next(); + let next_char = self.next(); if Some(c) != next_char { - return self.tokenizer_error(format!( + return self.error(format!( "Unterminated dollar-quoted string at or near \"{}\"", value )); } } - if let Some('$') = chars.peek() { - chars.next(); + if let Some('$') = self.peek() { + self.next(); } else { - return self - .tokenizer_error("Unterminated dollar-quoted string, expected $"); + return self.error("Unterminated dollar-quoted string, expected $"); } } _ => { - return self.tokenizer_error("Unterminated dollar-quoted, expected $"); + return self.error("Unterminated dollar-quoted, expected $"); } } } else { @@ -901,27 +908,32 @@ impl<'a> Tokenizer<'a> { })) } - fn tokenizer_error(&self, message: impl Into) -> Result { + fn error(&self, message: impl Into) -> Result { + let prefix = format!("LINE {}: ", self.line); + let sql_line = self.sql.split('\n').nth(self.line as usize - 1).unwrap(); + let cursor = " ".repeat(prefix.len() + self.col as usize - 1); + let context = format!("{}{}\n{}^", prefix, sql_line, cursor); Err(TokenizerError { message: message.into(), col: self.col, line: self.line, + context, }) } - fn reject_number_junk(&self, chars: &mut Peekable>) -> Result<(), TokenizerError> { - if let Some(ch) = chars.peek() - && is_identifier_start(*ch) + fn reject_number_junk(&mut self) -> Result<(), TokenizerError> { + if let Some(ch) = self.peek() + && is_identifier_start(ch) { - return self.tokenizer_error("trailing junk after numeric literal"); + return self.error("trailing junk after numeric literal"); } Ok(()) } // Consume characters until newline - fn tokenize_single_line_comment(&self, chars: &mut Peekable>) -> String { - let mut comment = peeking_take_while(chars, |ch| ch != '\n'); - if let Some(ch) = chars.next() { + fn tokenize_single_line_comment(&mut self) -> String { + let mut comment = self.peeking_take_while(|ch| ch != '\n'); + if let Some(ch) = self.next() { assert_eq!(ch, '\n'); comment.push(ch); } @@ -929,66 +941,62 @@ impl<'a> Tokenizer<'a> { } /// Tokenize an identifier or keyword, after the first char is already consumed. - fn tokenize_word(&self, first_char: char, chars: &mut Peekable>) -> String { + fn tokenize_word(&mut self, first_char: char) -> String { let mut s = first_char.to_string(); - s.push_str(&peeking_take_while(chars, is_identifier_part)); + s.push_str(&self.peeking_take_while(is_identifier_part)); s } /// Read a single quoted string, starting with the opening quote. - fn tokenize_single_quoted_string( - &self, - chars: &mut Peekable>, - ) -> Result { + fn tokenize_single_quoted_string(&mut self) -> Result { let mut s = String::new(); - chars.next(); // consume the opening quote + self.next(); // consume the opening quote // slash escaping is specific to MySQL dialect let mut is_escaped = false; - while let Some(&ch) = chars.peek() { + while let Some(ch) = self.peek() { match ch { '\'' => { - chars.next(); // consume + self.next(); // consume if is_escaped { s.push(ch); is_escaped = false; - } else if chars.peek().map(|c| *c == '\'').unwrap_or(false) { + } else if self.peek().map(|c| c == '\'').unwrap_or(false) { s.push(ch); - chars.next(); + self.next(); } else { return Ok(s); } } '\\' => { s.push(ch); - chars.next(); + self.next(); } _ => { - chars.next(); // consume + self.next(); // consume s.push(ch); } } } - self.tokenizer_error("Unterminated string literal") + self.error("Unterminated string literal") } /// Read a single qutoed string with escape fn tokenize_single_quoted_string_with_escape( - &self, - chars: &mut Peekable>, + &mut self, ) -> Result { let mut terminated = false; let mut s = String::new(); - chars.next(); // consume the opening quote + self.next(); // consume the opening quote - while let Some(&ch) = chars.peek() { + while let Some(ch) = self.peek() { match ch { '\'' => { - chars.next(); // consume - if chars.peek().map(|c| *c == '\'').unwrap_or(false) { + self.next(); // consume + if self.peek().map(|c| c == '\'').unwrap_or(false) { s.push('\\'); s.push(ch); - chars.next(); + self.next(); } else { terminated = true; break; @@ -996,29 +1004,25 @@ impl<'a> Tokenizer<'a> { } '\\' => { s.push(ch); - chars.next(); - if chars - .peek() - .map(|c| *c == '\'' || *c == '\\') - .unwrap_or(false) - { - s.push(chars.next().unwrap()); + self.next(); + if self.peek().map(|c| c == '\'' || c == '\\').unwrap_or(false) { + s.push(self.next().unwrap()); } } _ => { - chars.next(); // consume + self.next(); // consume s.push(ch); } } } if !terminated { - return self.tokenizer_error("Unterminated string literal"); + return self.error("Unterminated string literal"); } let unescaped = match Self::unescape_c_style(&s) { Ok(unescaped) => unescaped, - Err(e) => return self.tokenizer_error(e), + Err(e) => return self.error(e), }; Ok(CstyleEscapedString { @@ -1139,17 +1143,14 @@ impl<'a> Tokenizer<'a> { Ok(res) } - fn tokenize_multiline_comment( - &self, - chars: &mut Peekable>, - ) -> Result, TokenizerError> { + fn tokenize_multiline_comment(&mut self) -> Result, TokenizerError> { let mut s = String::new(); let mut nested = 1; let mut last_ch = ' '; loop { - match chars.next() { + match self.next() { Some(ch) => { if last_ch == '/' && ch == '*' { nested += 1; @@ -1163,39 +1164,32 @@ impl<'a> Tokenizer<'a> { s.push(ch); last_ch = ch; } - None => break self.tokenizer_error("Unexpected EOF while in a multi-line comment"), + None => break self.error("Unexpected EOF while in a multi-line comment"), } } } #[allow(clippy::unnecessary_wraps)] - fn consume_and_return( - &self, - chars: &mut Peekable>, - t: Token, - ) -> Result, TokenizerError> { - chars.next(); + fn consume_and_return(&mut self, t: Token) -> Result, TokenizerError> { + self.next(); Ok(Some(t)) } -} -/// Read from `chars` until `predicate` returns `false` or EOF is hit. -/// Return the characters read as String, and keep the first non-matching -/// char available as `chars.next()`. -fn peeking_take_while( - chars: &mut Peekable>, - mut predicate: impl FnMut(char) -> bool, -) -> String { - let mut s = String::new(); - while let Some(&ch) = chars.peek() { - if predicate(ch) { - chars.next(); // consume - s.push(ch); - } else { - break; + /// Read from `self` until `predicate` returns `false` or EOF is hit. + /// Return the characters read as String, and keep the first non-matching + /// char available as `self.next()`. + fn peeking_take_while(&mut self, mut predicate: impl FnMut(char) -> bool) -> String { + let mut s = String::new(); + while let Some(ch) = self.peek() { + if predicate(ch) { + self.next(); // consume + s.push(ch); + } else { + break; + } } + s } - s } /// Determine if a character starts a quoted identifier. The default @@ -1230,13 +1224,14 @@ mod tests { message: "test".into(), line: 1, col: 1, + context: "LINE 1:".to_string(), }; #[cfg(feature = "std")] { use std::error::Error; assert!(err.source().is_none()); } - assert_eq!(err.to_string(), "test at Line: 1, Column 1"); + assert_eq!(err.to_string(), "test at line 1, column 1\nLINE 1:"); } #[test] @@ -1522,7 +1517,8 @@ mod tests { Err(TokenizerError { message: "Unterminated string literal".to_string(), line: 1, - col: 8, + col: 12, + context: "LINE 1: select 'foo\n ^".to_string(), }) ); } @@ -1667,7 +1663,8 @@ mod tests { Err(TokenizerError { message: "Expected close delimiter '\"' before EOF.".to_string(), line: 1, - col: 1, + col: 5, + context: "LINE 1: \"foo\n ^".to_string(), }) ); } diff --git a/src/sqlparser/tests/sqlparser_common.rs b/src/sqlparser/tests/sqlparser_common.rs index ac6a1d310944a..c694aba3d1308 100644 --- a/src/sqlparser/tests/sqlparser_common.rs +++ b/src/sqlparser/tests/sqlparser_common.rs @@ -130,12 +130,12 @@ fn parse_update() { let sql = "UPDATE t WHERE 1"; let res = parse_sql_statements(sql); - assert!(format!("{}", res.unwrap_err()).contains("Expected SET, found: WHERE")); + assert!(format!("{}", res.unwrap_err()).contains("expected SET, found: WHERE")); let sql = "UPDATE t SET a = 1 extrabadstuff"; let res = parse_sql_statements(sql); assert!( - format!("{}", res.unwrap_err()).contains("Expected end of statement, found: extrabadstuff") + format!("{}", res.unwrap_err()).contains("expected end of statement, found: extrabadstuff") ); } @@ -252,8 +252,10 @@ fn parse_select_all() { #[test] fn parse_select_all_distinct() { let result = parse_sql_statements("SELECT ALL DISTINCT name FROM customer"); - assert!(format!("{}", result.unwrap_err()) - .contains("syntax error at or near DISTINCT at line:1, column:20")); + assert!(result + .unwrap_err() + .to_string() + .contains("syntax error at or near DISTINCT at line 1, column 12")); } #[test] @@ -284,7 +286,7 @@ fn parse_select_wildcard() { let sql = "SELECT * + * FROM foo;"; let result = parse_sql_statements(sql); - assert!(format!("{}", result.unwrap_err()).contains("Expected end of statement, found: +")); + assert!(format!("{}", result.unwrap_err()).contains("expected end of statement, found: +")); } #[test] @@ -321,7 +323,7 @@ fn parse_column_aliases() { assert_eq!(&Expr::Value(number("1")), right.as_ref()); assert_eq!(&Ident::new_unchecked("newname"), alias); } else { - panic!("Expected ExprWithAlias") + panic!("expected ExprWithAlias") } // alias without AS is parsed correctly: @@ -331,10 +333,10 @@ fn parse_column_aliases() { #[test] fn test_eof_after_as() { let res = parse_sql_statements("SELECT foo AS"); - assert!(format!("{}", res.unwrap_err()).contains("Expected an identifier after AS, found: EOF")); + assert!(format!("{}", res.unwrap_err()).contains("expected an identifier after AS, found: EOF")); let res = parse_sql_statements("SELECT 1 FROM foo AS"); - assert!(format!("{}", res.unwrap_err()).contains("Expected an identifier after AS, found: EOF")); + assert!(format!("{}", res.unwrap_err()).contains("expected an identifier after AS, found: EOF")); } #[test] @@ -390,7 +392,7 @@ fn parse_select_count_distinct() { #[test] fn parse_invalid_infix_not() { let res = parse_sql_statements("SELECT c FROM t WHERE c NOT ("); - assert!(format!("{}", res.unwrap_err(),).contains("Expected end of statement, found: NOT")); + assert!(format!("{}", res.unwrap_err(),).contains("expected end of statement, found: NOT")); } #[test] @@ -1268,7 +1270,7 @@ fn parse_extract() { verified_stmt("SELECT EXTRACT(SECOND FROM d)"); let res = parse_sql_statements("SELECT EXTRACT(0 FROM d)"); - assert!(format!("{}", res.unwrap_err()).contains("Expected date/time field, found: 0")); + assert!(format!("{}", res.unwrap_err()).contains("expected date/time field, found: 0")); } #[test] @@ -1434,13 +1436,13 @@ fn parse_create_table() { assert!(res .unwrap_err() .to_string() - .contains("Expected \',\' or \')\' after column definition, found: GARBAGE")); + .contains("expected \',\' or \')\' after column definition, found: GARBAGE")); let res = parse_sql_statements("CREATE TABLE t (a int NOT NULL CONSTRAINT foo)"); assert!(res .unwrap_err() .to_string() - .contains("Expected constraint details after CONSTRAINT ")); + .contains("expected constraint details after CONSTRAINT ")); } #[test] @@ -1781,12 +1783,12 @@ fn parse_alter_table_alter_column_type() { #[test] fn parse_bad_constraint() { let res = parse_sql_statements("ALTER TABLE tab ADD"); - assert!(format!("{}", res.unwrap_err()).contains("Expected identifier, found: EOF")); + assert!(format!("{}", res.unwrap_err()).contains("expected identifier, found: EOF")); let res = parse_sql_statements("CREATE TABLE tab (foo int,"); assert!(format!("{}", res.unwrap_err()) - .contains("Expected column name or constraint definition, found: EOF")); + .contains("expected column name or constraint definition, found: EOF")); } fn run_explain_analyze(query: &str, expected_analyze: bool, expected_options: ExplainOptions) { @@ -1885,7 +1887,7 @@ fn parse_explain_with_invalid_options() { assert!(res.is_err()); let res = parse_sql_statements("EXPLAIN (VERBOSE TRACE) SELECT sqrt(id) FROM foo"); - assert!(format!("{}", res.unwrap_err()).contains("Expected ), found: TRACE")); + assert!(format!("{}", res.unwrap_err()).contains("expected ), found: TRACE")); let res = parse_sql_statements("EXPLAIN () SELECT sqrt(id) FROM foo"); assert!(res.is_err()); @@ -1893,7 +1895,7 @@ fn parse_explain_with_invalid_options() { let res = parse_sql_statements("EXPLAIN (VERBOSE, ) SELECT sqrt(id) FROM foo"); let err_msg = - "Expected one of VERBOSE or TRACE or TYPE or LOGICAL or PHYSICAL or DISTSQL, found: )"; + "expected one of VERBOSE or TRACE or TYPE or LOGICAL or PHYSICAL or DISTSQL, found: )"; assert!(format!("{}", res.unwrap_err()).contains(err_msg)); } @@ -2207,10 +2209,10 @@ fn parse_literal_interval() { ); let result = parse_sql_statements("SELECT INTERVAL '1' SECOND TO SECOND"); - assert!(format!("{}", result.unwrap_err()).contains("Expected end of statement, found: SECOND")); + assert!(format!("{}", result.unwrap_err()).contains("expected end of statement, found: SECOND")); let result = parse_sql_statements("SELECT INTERVAL '10' HOUR (1) TO HOUR (2)"); - assert!(format!("{}", result.unwrap_err()).contains("Expected end of statement, found: (")); + assert!(format!("{}", result.unwrap_err()).contains("expected end of statement, found: (")); verified_only_select("SELECT INTERVAL '1' YEAR"); verified_only_select("SELECT INTERVAL '1' MONTH"); @@ -2293,7 +2295,7 @@ fn parse_delimited_identifiers() { ); assert_eq!(&Ident::with_quote_unchecked('"', "column alias"), alias); } - _ => panic!("Expected ExprWithAlias"), + _ => panic!("expected ExprWithAlias"), } verified_stmt(r#"CREATE TABLE "foo" ("bar" "int")"#); @@ -2612,7 +2614,7 @@ fn parse_natural_join() { let sql = "SELECT * FROM t1 natural"; assert!(format!("{}", parse_sql_statements(sql).unwrap_err(),) - .contains("Expected a join type after NATURAL, found: EOF")); + .contains("expected a join type after NATURAL, found: EOF")); } #[test] @@ -2676,7 +2678,7 @@ fn parse_join_syntax_variants() { ); let res = parse_sql_statements("SELECT * FROM a OUTER JOIN b ON 1"); - assert!(format!("{}", res.unwrap_err()).contains("Expected LEFT, RIGHT, or FULL, found: OUTER")); + assert!(format!("{}", res.unwrap_err()).contains("expected LEFT, RIGHT, or FULL, found: OUTER")); } #[test] @@ -2712,7 +2714,7 @@ fn parse_ctes() { Expr::Subquery(ref subquery) => { assert_ctes_in_select(&cte_sqls, subquery.as_ref()); } - _ => panic!("Expected subquery"), + _ => panic!("expected subquery"), } // CTE in a derived table let sql = &format!("SELECT * FROM ({})", with); @@ -2721,13 +2723,13 @@ fn parse_ctes() { TableFactor::Derived { subquery, .. } => { assert_ctes_in_select(&cte_sqls, subquery.as_ref()) } - _ => panic!("Expected derived table"), + _ => panic!("expected derived table"), } // CTE in a view let sql = &format!("CREATE VIEW v AS {}", with); match verified_stmt(sql) { Statement::CreateView { query, .. } => assert_ctes_in_select(&cte_sqls, &query), - _ => panic!("Expected CREATE VIEW"), + _ => panic!("expected CREATE VIEW"), } // CTE in a CTE... let sql = &format!("WITH outer_cte AS ({}) SELECT * FROM outer_cte", with); @@ -2866,7 +2868,7 @@ fn parse_multiple_statements() { one_statement_parses_to(&(sql1.to_owned() + ";"), sql1); // Check that forgetting the semicolon results in an error: let res = parse_sql_statements(&(sql1.to_owned() + " " + sql2_kw + sql2_rest)); - let err_msg = "Expected end of statement, found: "; + let err_msg = "expected end of statement, found: "; assert!(format!("{}", res.unwrap_err()).contains(err_msg)); } test_with("SELECT foo", "SELECT", " bar"); @@ -2930,18 +2932,18 @@ fn parse_overlay() { ); for (sql, err_msg) in [ - ("SELECT OVERLAY('abc', 'xyz')", "Expected PLACING, found: ,"), + ("SELECT OVERLAY('abc', 'xyz')", "expected PLACING, found: ,"), ( "SELECT OVERLAY('abc' PLACING 'xyz')", - "Expected FROM, found: )", + "expected FROM, found: )", ), ( "SELECT OVERLAY('abc' PLACING 'xyz' FOR 2)", - "Expected FROM, found: FOR", + "expected FROM, found: FOR", ), ( "SELECT OVERLAY('abc' PLACING 'xyz' FOR 2 FROM 1)", - "Expected FROM, found: FOR", + "expected FROM, found: FOR", ), ] { let res = parse_sql_statements(sql); @@ -2970,7 +2972,7 @@ fn parse_trim() { let res = parse_sql_statements("SELECT TRIM(FOO 'xyz' FROM 'xyzfooxyz')"); - let err_msg = "Expected ), found: 'xyz'"; + let err_msg = "expected ), found: 'xyz'"; assert!(format!("{}", res.unwrap_err()).contains(err_msg)); } @@ -2998,12 +3000,12 @@ fn parse_exists_subquery() { verified_stmt("SELECT EXISTS (SELECT 1)"); let res = parse_sql_statements("SELECT EXISTS ("); - let err_msg = "Expected SELECT, VALUES, or a subquery in the query body, found: EOF"; + let err_msg = "expected SELECT, VALUES, or a subquery in the query body, found: EOF"; assert!(format!("{}", res.unwrap_err()).contains(err_msg)); let res = parse_sql_statements("SELECT EXISTS (NULL)"); - let err_msg = "Expected SELECT, VALUES, or a subquery in the query body, found: NULL"; + let err_msg = "expected SELECT, VALUES, or a subquery in the query body, found: NULL"; assert!(format!("{}", res.unwrap_err()).contains(err_msg)); } @@ -3337,11 +3339,11 @@ fn parse_drop_table() { let sql = "DROP TABLE"; assert!(format!("{}", parse_sql_statements(sql).unwrap_err(),) - .contains("Expected identifier, found: EOF")); + .contains("expected identifier, found: EOF")); let sql = "DROP TABLE IF EXISTS foo CASCADE RESTRICT"; assert!(format!("{}", parse_sql_statements(sql).unwrap_err(),) - .contains("Expected end of statement, found: RESTRICT")); + .contains("expected end of statement, found: RESTRICT")); } #[test] @@ -3402,7 +3404,7 @@ fn parse_create_user() { #[test] fn parse_invalid_subquery_without_parens() { let res = parse_sql_statements("SELECT SELECT 1 FROM bar WHERE 1=1 FROM baz"); - assert!(format!("{}", res.unwrap_err()).contains("Expected end of statement, found: 1")); + assert!(format!("{}", res.unwrap_err()).contains("expected end of statement, found: 1")); } #[test] @@ -3564,12 +3566,12 @@ fn lateral_derived() { let sql = "SELECT * FROM customer LEFT JOIN LATERAL generate_series(1, customer.id)"; let res = parse_sql_statements(sql); assert!(format!("{}", res.unwrap_err()) - .contains("Expected subquery after LATERAL, found: generate_series")); + .contains("expected subquery after LATERAL, found: generate_series")); let sql = "SELECT * FROM a LEFT JOIN LATERAL (b CROSS JOIN c)"; let res = parse_sql_statements(sql); assert!(format!("{}", res.unwrap_err()) - .contains("Expected SELECT, VALUES, or a subquery in the query body, found: b")); + .contains("expected SELECT, VALUES, or a subquery in the query body, found: b")); } #[test] @@ -3622,13 +3624,13 @@ fn parse_start_transaction() { ); let res = parse_sql_statements("START TRANSACTION ISOLATION LEVEL BAD"); - assert!(format!("{}", res.unwrap_err()).contains("Expected isolation level, found: BAD")); + assert!(format!("{}", res.unwrap_err()).contains("expected isolation level, found: BAD")); let res = parse_sql_statements("START TRANSACTION BAD"); - assert!(format!("{}", res.unwrap_err()).contains("Expected end of statement, found: BAD")); + assert!(format!("{}", res.unwrap_err()).contains("expected end of statement, found: BAD")); let res = parse_sql_statements("START TRANSACTION READ ONLY,"); - assert!(format!("{}", res.unwrap_err()).contains("Expected transaction mode, found: EOF")); + assert!(format!("{}", res.unwrap_err()).contains("expected transaction mode, found: EOF")); } #[test] diff --git a/src/sqlparser/tests/sqlparser_postgres.rs b/src/sqlparser/tests/sqlparser_postgres.rs index 59d6c9d6d82a1..e205addd7102f 100644 --- a/src/sqlparser/tests/sqlparser_postgres.rs +++ b/src/sqlparser/tests/sqlparser_postgres.rs @@ -16,7 +16,6 @@ #[macro_use] mod test_utils; use risingwave_sqlparser::ast::*; -use risingwave_sqlparser::parser::ParserError; use test_utils::*; #[test] @@ -314,19 +313,19 @@ fn parse_bad_if_not_exists() { for (sql, err_msg) in [ ( "CREATE TABLE NOT EXISTS uk_cities ()", - "Expected end of statement, found: EXISTS", + "expected end of statement, found: EXISTS", ), ( "CREATE TABLE IF EXISTS uk_cities ()", - "Expected end of statement, found: EXISTS", + "expected end of statement, found: EXISTS", ), ( "CREATE TABLE IF uk_cities ()", - "Expected end of statement, found: uk_cities", + "expected end of statement, found: uk_cities", ), ( "CREATE TABLE IF NOT uk_cities ()", - "Expected end of statement, found: NOT", + "expected end of statement, found: NOT", ), ] { let res = parse_sql_statements(sql); @@ -440,9 +439,9 @@ fn parse_set() { one_statement_parses_to("SET a TO b", "SET a = b"); one_statement_parses_to("SET SESSION a = b", "SET a = b"); for (sql, err_msg) in [ - ("SET", "Expected identifier, found: EOF"), - ("SET a b", "Expected equals sign or TO, found: b"), - ("SET a =", "Expected parameter value, found: EOF"), + ("SET", "expected identifier, found: EOF"), + ("SET a b", "expected equals sign or TO, found: b"), + ("SET a =", "expected parameter value, found: EOF"), ] { let res = parse_sql_statements(sql); assert!(format!("{}", res.unwrap_err()).contains(err_msg)); @@ -1287,10 +1286,8 @@ fn parse_variadic_argument() { _ = verified_stmt(sql); let sql = "SELECT foo(VARIADIC a, b, VARIADIC c)"; - assert_eq!( - parse_sql_statements(sql), - Err(ParserError::ParserError( - "VARIADIC argument must be last".to_string() - )) - ); + assert!(parse_sql_statements(sql) + .unwrap_err() + .to_string() + .contains("VARIADIC argument must be last"),); } diff --git a/src/sqlparser/tests/testdata/array.yaml b/src/sqlparser/tests/testdata/array.yaml index 855e93fb24dd0..565431c0f4ee6 100644 --- a/src/sqlparser/tests/testdata/array.yaml +++ b/src/sqlparser/tests/testdata/array.yaml @@ -6,13 +6,20 @@ - input: CREATE TABLE t(a int[][][]); formatted_sql: CREATE TABLE t (a INT[][][]) - input: CREATE TABLE t(a int[); - error_msg: 'sql parser error: Unexpected ) at line:1, column:23' + error_msg: |- + sql parser error: Unexpected ) at line 1, column 22 + LINE 1: CREATE TABLE t(a int[); + ^ - input: CREATE TABLE t(a int[[]); - error_msg: 'sql parser error: Unexpected [ at line:1, column:23' + error_msg: |- + sql parser error: Unexpected [ at line 1, column 22 + LINE 1: CREATE TABLE t(a int[[]); + ^ - input: CREATE TABLE t(a int]); error_msg: |- - sql parser error: Expected ',' or ')' after column definition, found: ] at line:1, column:22 - Near "CREATE TABLE t(a int" + sql parser error: expected ',' or ')' after column definition, found: ] at line 1, column 21 + LINE 1: CREATE TABLE t(a int]); + ^ - input: SELECT foo[0] FROM foos formatted_sql: SELECT foo[0] FROM foos - input: SELECT foo[0][0] FROM foos @@ -29,33 +36,41 @@ formatted_sql: SELECT ARRAY[[], []] - input: SELECT ARRAY[ARRAY[],[]] error_msg: |- - sql parser error: Expected an expression:, found: [ at line:1, column:23 - Near "SELECT ARRAY[ARRAY[],[" + sql parser error: expected an expression:, found: [ at line 1, column 22 + LINE 1: SELECT ARRAY[ARRAY[],[]] + ^ - input: SELECT ARRAY[[],ARRAY[]] error_msg: |- - sql parser error: Expected [, found: ARRAY at line:1, column:22 - Near "SELECT ARRAY[[]," + sql parser error: expected [, found: ARRAY at line 1, column 17 + LINE 1: SELECT ARRAY[[],ARRAY[]] + ^ - input: SELECT ARRAY[[1,2],3] error_msg: |- - sql parser error: Expected [, found: 3 at line:1, column:21 - Near "SELECT ARRAY[[1,2]," + sql parser error: expected [, found: 3 at line 1, column 20 + LINE 1: SELECT ARRAY[[1,2],3] + ^ - input: SELECT ARRAY[1,[2,3]] error_msg: |- - sql parser error: Expected an expression:, found: [ at line:1, column:17 - Near "SELECT ARRAY[1,[" + sql parser error: expected an expression:, found: [ at line 1, column 16 + LINE 1: SELECT ARRAY[1,[2,3]] + ^ - input: SELECT ARRAY[ARRAY[1,2],[3,4]] error_msg: |- - sql parser error: Expected an expression:, found: [ at line:1, column:26 - Near "ARRAY[ARRAY[1,2],[" + sql parser error: expected an expression:, found: [ at line 1, column 25 + LINE 1: SELECT ARRAY[ARRAY[1,2],[3,4]] + ^ - input: SELECT ARRAY[[1,2],ARRAY[3,4]] error_msg: |- - sql parser error: Expected [, found: ARRAY at line:1, column:25 - Near "SELECT ARRAY[[1,2]," + sql parser error: expected [, found: ARRAY at line 1, column 20 + LINE 1: SELECT ARRAY[[1,2],ARRAY[3,4]] + ^ - input: SELECT ARRAY[[1,2],[3] || [4]] error_msg: |- - sql parser error: Expected ], found: || at line:1, column:25 - Near "[[1,2],[3]" + sql parser error: expected ], found: || at line 1, column 24 + LINE 1: SELECT ARRAY[[1,2],[3] || [4]] + ^ - input: SELECT [1,2] error_msg: |- - sql parser error: Expected an expression:, found: [ at line:1, column:9 - Near "SELECT [" + sql parser error: expected an expression:, found: [ at line 1, column 8 + LINE 1: SELECT [1,2] + ^ diff --git a/src/sqlparser/tests/testdata/create.yaml b/src/sqlparser/tests/testdata/create.yaml index 831886b9bdb36..ed4c16f77e312 100644 --- a/src/sqlparser/tests/testdata/create.yaml +++ b/src/sqlparser/tests/testdata/create.yaml @@ -17,12 +17,14 @@ formatted_sql: CREATE TABLE t (a INT, b INT) AS SELECT 1 AS b, 2 AS a - input: CREATE SOURCE src error_msg: |- - sql parser error: Expected description of the format, found: EOF at the end - Near "CREATE SOURCE src" + sql parser error: expected description of the format, found: EOF at the end + LINE 1: CREATE SOURCE src + ^ - input: CREATE SOURCE src-a FORMAT PLAIN ENCODE JSON error_msg: |- - sql parser error: Expected description of the format, found: - at line:1, column:19 - Near "CREATE SOURCE src" + sql parser error: expected description of the format, found: - at line 1, column 18 + LINE 1: CREATE SOURCE src-a FORMAT PLAIN ENCODE JSON + ^ - input: CREATE SOURCE src FORMAT PLAIN ENCODE JSON formatted_sql: CREATE SOURCE src FORMAT PLAIN ENCODE JSON - input: CREATE SOURCE mysql_src with ( connector = 'mysql-cdc', hostname = 'localhost', port = '3306', database.name = 'mytest', server.id = '5601' ) @@ -31,8 +33,9 @@ formatted_sql: CREATE TABLE sbtest10 (id INT PRIMARY KEY, k INT, c CHARACTER VARYING, pad CHARACTER VARYING) FROM sbtest TABLE 'mydb.sbtest10' - input: CREATE TABLE sbtest10 (id INT PRIMARY KEY, k INT, c CHARACTER VARYING, pad CHARACTER VARYING) FROM sbtest error_msg: |- - sql parser error: Expected TABLE, found: EOF at the end - Near "pad CHARACTER VARYING) FROM sbtest" + sql parser error: expected TABLE, found: EOF at the end + LINE 1: CREATE TABLE sbtest10 (id INT PRIMARY KEY, k INT, c CHARACTER VARYING, pad CHARACTER VARYING) FROM sbtest + ^ - input: CREATE SOURCE IF NOT EXISTS src WITH (kafka.topic = 'abc', kafka.servers = 'localhost:1001') FORMAT PLAIN ENCODE PROTOBUF (message = 'Foo', schema.location = 'file://') formatted_sql: CREATE SOURCE IF NOT EXISTS src WITH (kafka.topic = 'abc', kafka.servers = 'localhost:1001') FORMAT PLAIN ENCODE PROTOBUF (message = 'Foo', schema.location = 'file://') formatted_ast: 'CreateSource { stmt: CreateSourceStatement { if_not_exists: true, columns: [], wildcard_idx: None, constraints: [], source_name: ObjectName([Ident { value: "src", quote_style: None }]), with_properties: WithProperties([SqlOption { name: ObjectName([Ident { value: "kafka", quote_style: None }, Ident { value: "topic", quote_style: None }]), value: SingleQuotedString("abc") }, SqlOption { name: ObjectName([Ident { value: "kafka", quote_style: None }, Ident { value: "servers", quote_style: None }]), value: SingleQuotedString("localhost:1001") }]), source_schema: V2(ConnectorSchema { format: Plain, row_encode: Protobuf, row_options: [SqlOption { name: ObjectName([Ident { value: "message", quote_style: None }]), value: SingleQuotedString("Foo") }, SqlOption { name: ObjectName([Ident { value: "schema", quote_style: None }, Ident { value: "location", quote_style: None }]), value: SingleQuotedString("file://") }], key_encode: None }), source_watermarks: [], include_column_options: [] } }' @@ -49,7 +52,10 @@ - input: CREATE TABLE T (a STRUCT) formatted_sql: CREATE TABLE T (a STRUCT) - input: CREATE TABLE T (FULL INT) - error_msg: 'sql parser error: syntax error at or near FULL at line:1, column:21' + error_msg: |- + sql parser error: syntax error at or near FULL at line 1, column 17 + LINE 1: CREATE TABLE T (FULL INT) + ^ - input: CREATE TABLE T ("FULL" INT) formatted_sql: CREATE TABLE T ("FULL" INT) - input: CREATE TABLE T ("FULL" INT) ON CONFLICT OVERWRITE WITH VERSION COLUMN("FULL") @@ -58,22 +64,28 @@ formatted_sql: CREATE TABLE T ("FULL" INT) ON CONFLICT IGNORE - input: CREATE TABLE T ("FULL" INT) ON CONFLICT OVERWRITE WITH VERSION COLUMN error_msg: |- - sql parser error: Expected (, found: EOF at the end - Near " CONFLICT OVERWRITE WITH VERSION COLUMN" + sql parser error: expected (, found: EOF at the end + LINE 1: CREATE TABLE T ("FULL" INT) ON CONFLICT OVERWRITE WITH VERSION COLUMN + ^ - input: CREATE TABLE T ("FULL" INT) ON CONFLICT OVERWRITE WITH VERSION COLUMN(FULL - error_msg: 'sql parser error: syntax error at or near FULL at line:1, column:75' + error_msg: |- + sql parser error: syntax error at or near FULL at line 1, column 71 + LINE 1: CREATE TABLE T ("FULL" INT) ON CONFLICT OVERWRITE WITH VERSION COLUMN(FULL + ^ - input: CREATE TABLE T ("FULL" INT) ON CONFLICT OVERWRITE WITH VERSION COLUMNFULL) error_msg: |- - sql parser error: Expected (, found: VERSION at line:1, column:63 - Near "INT) ON CONFLICT OVERWRITE WITH" + sql parser error: expected (, found: VERSION at line 1, column 56 + LINE 1: CREATE TABLE T ("FULL" INT) ON CONFLICT OVERWRITE WITH VERSION COLUMNFULL) + ^ - input: CREATE TABLE T ("FULL" INT) ON CONFLICT DO UPDATE IF NOT NULL formatted_sql: CREATE TABLE T ("FULL" INT) ON CONFLICT DO UPDATE IF NOT NULL - input: CREATE USER user WITH SUPERUSER CREATEDB PASSWORD 'password' formatted_sql: CREATE USER user WITH SUPERUSER CREATEDB PASSWORD 'password' - input: CREATE SINK snk error_msg: |- - sql parser error: Expected FROM or AS after CREATE SINK sink_name, found: EOF at the end - Near "CREATE SINK snk" + sql parser error: expected FROM or AS after CREATE SINK sink_name, found: EOF at the end + LINE 1: CREATE SINK snk + ^ - input: CREATE SINK IF NOT EXISTS snk FROM mv WITH (connector = 'mysql', mysql.endpoint = '127.0.0.1:3306', mysql.table = '', mysql.database = '', mysql.user = '', mysql.password = '') formatted_sql: CREATE SINK IF NOT EXISTS snk FROM mv WITH (connector = 'mysql', mysql.endpoint = '127.0.0.1:3306', mysql.table = '', mysql.database = '', mysql.user = '', mysql.password = '') - input: CREATE SINK IF NOT EXISTS snk AS SELECT count(*) AS cnt FROM mv WITH (connector = 'mysql', mysql.endpoint = '127.0.0.1:3306', mysql.table = '', mysql.database = '', mysql.user = '', mysql.password = '') @@ -90,33 +102,53 @@ formatted_sql: CREATE SINK snk INTO t AS SELECT * FROM t - input: CREATE SINK snk FROM mv WITH (connector = 'kafka', properties.bootstrap.server = '127.0.0.1:9092', topic = 'test_topic') format; error_msg: |- - sql parser error: Expected identifier, found: ; at line:1, column:123 - Near " topic = 'test_topic') format;" + sql parser error: expected identifier, found: ; at line 1, column 128 + LINE 1: CREATE SINK snk FROM mv WITH (connector = 'kafka', properties.bootstrap.server = '127.0.0.1:9092', topic = 'test_topic') format; + ^ - input: create sink sk1 from tt where v1 % 10 = 0 with (connector='blackhole') error_msg: |- - sql parser error: Expected WITH, found: where at line:1, column:30 - Near "create sink sk1 from tt" + sql parser error: expected WITH, found: where at line 1, column 25 + LINE 1: create sink sk1 from tt where v1 % 10 = 0 with (connector='blackhole') + ^ - input: CREATE SINK snk FROM mv WITH (connector = 'kafka', properties.bootstrap.server = '127.0.0.1:9092', topic = 'test_topic') format debezium; error_msg: |- - sql parser error: Expected ENCODE, found: ; at line:1, column:132 - Near "topic = 'test_topic') format debezium" + sql parser error: expected ENCODE, found: ; at line 1, column 137 + LINE 1: CREATE SINK snk FROM mv WITH (connector = 'kafka', properties.bootstrap.server = '127.0.0.1:9092', topic = 'test_topic') format debezium; + ^ - input: CREATE SINK snk FROM mv WITH (connector = 'kafka', properties.bootstrap.server = '127.0.0.1:9092', topic = 'test_topic') format debezium encode; error_msg: |- - sql parser error: Expected identifier, found: ; at line:1, column:139 - Near " 'test_topic') format debezium encode;" + sql parser error: expected identifier, found: ; at line 1, column 144 + LINE 1: CREATE SINK snk FROM mv WITH (connector = 'kafka', properties.bootstrap.server = '127.0.0.1:9092', topic = 'test_topic') format debezium encode; + ^ - input: create user tmp createdb nocreatedb - error_msg: 'sql parser error: conflicting or redundant options' + error_msg: |- + sql parser error: conflicting or redundant options + LINE 1: create user tmp createdb nocreatedb + ^ - input: create user tmp createdb createdb - error_msg: 'sql parser error: conflicting or redundant options' + error_msg: |- + sql parser error: conflicting or redundant options + LINE 1: create user tmp createdb createdb + ^ - input: create user tmp with password '123' password null - error_msg: 'sql parser error: conflicting or redundant options' + error_msg: |- + sql parser error: conflicting or redundant options + LINE 1: create user tmp with password '123' password null + ^ - input: create user tmp with encrypted password '' password null - error_msg: 'sql parser error: conflicting or redundant options' + error_msg: |- + sql parser error: conflicting or redundant options + LINE 1: create user tmp with encrypted password '' password null + ^ - input: create user tmp with encrypted password null error_msg: |- - sql parser error: Expected literal string, found: null at line:1, column:45 - Near " tmp with encrypted password null" + sql parser error: expected literal string, found: null at line 1, column 41 + LINE 1: create user tmp with encrypted password null + ^ - input: CREATE SECRET secret1 WITH (backend = 'meta') AS 'demo-secret' formatted_sql: CREATE SECRET secret1 WITH (backend = 'meta') AS 'demo-secret' - input: CREATE SECRET IF NOT EXISTS secret2 WITH (backend = 'meta') AS 'demo-secret - error_msg: 'sql parser error: Unterminated string literal at Line: 1, Column 62' + error_msg: |- + sql parser error: Unterminated string literal at line 1, column 76 + LINE 1: CREATE SECRET IF NOT EXISTS secret2 WITH (backend = 'meta') AS 'demo-secret + ^ diff --git a/src/sqlparser/tests/testdata/insert.yaml b/src/sqlparser/tests/testdata/insert.yaml index 3b3661aa9f94c..f90924b3f9a02 100644 --- a/src/sqlparser/tests/testdata/insert.yaml +++ b/src/sqlparser/tests/testdata/insert.yaml @@ -1,7 +1,8 @@ # This file is automatically generated by `src/sqlparser/tests/parser_test.rs`. - input: INSERT public.customer (id, name, active) VALUES (1, 2, 3) error_msg: |- - sql parser error: Expected INTO, found: public at line:1, column:14 - Near "INSERT" + sql parser error: expected INTO, found: public at line 1, column 8 + LINE 1: INSERT public.customer (id, name, active) VALUES (1, 2, 3) + ^ - input: INSERT INTO t VALUES(1,3), (2,4) RETURNING *, a, a as aaa formatted_sql: INSERT INTO t VALUES (1, 3), (2, 4) RETURNING (*, a, a AS aaa) diff --git a/src/sqlparser/tests/testdata/select.yaml b/src/sqlparser/tests/testdata/select.yaml index 9b16f3fa9667a..d982596d05b67 100644 --- a/src/sqlparser/tests/testdata/select.yaml +++ b/src/sqlparser/tests/testdata/select.yaml @@ -48,37 +48,57 @@ formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [Wildcard(Some([Identifier(Ident { value: "v1", quote_style: None })])), QualifiedWildcard(ObjectName([Ident { value: "bar", quote_style: None }]), Some([CompoundIdentifier([Ident { value: "foo", quote_style: None }, Ident { value: "v2", quote_style: None }])]))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "foo", quote_style: None }]), alias: None, as_of: None }, joins: [] }, TableWithJoins { relation: Table { name: ObjectName([Ident { value: "bar", quote_style: None }]), alias: None, as_of: None }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: SELECT v3 EXCEPT (v1, v2) FROM foo error_msg: |- - sql parser error: Expected SELECT, VALUES, or a subquery in the query body, found: v1 at line:1, column:21 - Near "SELECT v3 EXCEPT (" + sql parser error: expected SELECT, VALUES, or a subquery in the query body, found: v1 at line 1, column 19 + LINE 1: SELECT v3 EXCEPT (v1, v2) FROM foo + ^ - input: SELECT * EXCEPT (V1, ) FROM foo error_msg: |- - sql parser error: Expected an expression:, found: ) at line:1, column:23 - Near " * EXCEPT (V1, )" + sql parser error: expected an expression:, found: ) at line 1, column 22 + LINE 1: SELECT * EXCEPT (V1, ) FROM foo + ^ - input: SELECT * EXCEPT (v1 FROM foo error_msg: |- - sql parser error: Expected ( should be followed by ) after column names, found: FROM at line:1, column:25 - Near "SELECT * EXCEPT (v1" + sql parser error: expected ( should be followed by ) after column names, found: FROM at line 1, column 21 + LINE 1: SELECT * EXCEPT (v1 FROM foo + ^ - input: SELECT * FROM t LIMIT 1 FETCH FIRST ROWS ONLY - error_msg: 'sql parser error: Cannot specify both LIMIT and FETCH' + error_msg: |- + sql parser error: Cannot specify both LIMIT and FETCH + LINE 1: SELECT * FROM t LIMIT 1 FETCH FIRST ROWS ONLY + ^ - input: SELECT * FROM t FETCH FIRST ROWS WITH TIES - error_msg: 'sql parser error: WITH TIES cannot be specified without ORDER BY clause' + error_msg: |- + sql parser error: WITH TIES cannot be specified without ORDER BY clause + LINE 1: SELECT * FROM t FETCH FIRST ROWS WITH TIES + ^ - input: select * from (select 1 from 1); error_msg: |- - sql parser error: Expected identifier, found: 1 at line:1, column:31 - Near "from (select 1 from 1" + sql parser error: expected identifier, found: 1 at line 1, column 30 + LINE 1: select * from (select 1 from 1); + ^ - input: select * from (select * from tumble(t, x, interval '10' minutes)) error_msg: |- - sql parser error: Expected ), found: minutes at line:1, column:62 - Near "(t, x, interval '10'" + sql parser error: expected ), found: minutes at line 1, column 57 + LINE 1: select * from (select * from tumble(t, x, interval '10' minutes)) + ^ - input: SELECT 1, FROM t - error_msg: 'sql parser error: syntax error at or near FROM at line:1, column:15' + error_msg: |- + sql parser error: syntax error at or near FROM at line 1, column 11 + LINE 1: SELECT 1, FROM t + ^ - input: SELECT 1, WHERE true - error_msg: 'sql parser error: syntax error at or near WHERE at line:1, column:16' + error_msg: |- + sql parser error: syntax error at or near WHERE at line 1, column 11 + LINE 1: SELECT 1, WHERE true + ^ - input: SELECT timestamp with time zone '2022-10-01 12:00:00Z' AT TIME ZONE 'US/Pacific' formatted_sql: SELECT TIMESTAMP WITH TIME ZONE '2022-10-01 12:00:00Z' AT TIME ZONE 'US/Pacific' formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(AtTimeZone { timestamp: TypedString { data_type: Timestamp(true), value: "2022-10-01 12:00:00Z" }, time_zone: "US/Pacific" })], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: SELECT 0c6 - error_msg: 'sql parser error: trailing junk after numeric literal at Line: 1, Column 8' + error_msg: |- + sql parser error: trailing junk after numeric literal at line 1, column 9 + LINE 1: SELECT 0c6 + ^ - input: SELECT 1e6 formatted_sql: SELECT 1e6 formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Value(Number("1e6")))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' @@ -111,22 +131,34 @@ formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Value(Number("-0o755")))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: SELECT 0o129 error_msg: |- - sql parser error: Expected end of statement, found: 9 at line:1, column:13 - Near "SELECT 0o12" + sql parser error: expected end of statement, found: 9 at line 1, column 12 + LINE 1: SELECT 0o129 + ^ - input: SELECT 0o3.5 error_msg: |- - sql parser error: Expected end of statement, found: .5 at line:1, column:13 - Near "SELECT 0o3" + sql parser error: expected end of statement, found: .5 at line 1, column 11 + LINE 1: SELECT 0o3.5 + ^ - input: SELECT 0x - error_msg: 'sql parser error: incomplete integer literal at Line: 1, Column 8' + error_msg: |- + sql parser error: incomplete integer literal at line 1, column 10 + LINE 1: SELECT 0x + ^ - input: SELECT 1::float(0) - error_msg: 'sql parser error: Unexpected 0 at line:1, column:17: Precision must be in range 1..54' + error_msg: |- + sql parser error: Unexpected 0 at line 1, column 17: Precision must be in range 1..54 + LINE 1: SELECT 1::float(0) + ^ - input: SELECT 1::float(54) - error_msg: 'sql parser error: Unexpected 54 at line:1, column:18: Precision must be in range 1..54' + error_msg: |- + sql parser error: Unexpected 54 at line 1, column 17: Precision must be in range 1..54 + LINE 1: SELECT 1::float(54) + ^ - input: SELECT 1::int(2) error_msg: |- - sql parser error: Expected end of statement, found: ( at line:1, column:14 - Near "SELECT 1::int" + sql parser error: expected end of statement, found: ( at line 1, column 14 + LINE 1: SELECT 1::int(2) + ^ - input: select id1, a1, id2, a2 from stream as S join version FOR SYSTEM_TIME AS OF PROCTIME() AS V on id1= id2 formatted_sql: SELECT id1, a1, id2, a2 FROM stream AS S JOIN version FOR SYSTEM_TIME AS OF PROCTIME() AS V ON id1 = id2 formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Identifier(Ident { value: "id1", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "a1", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "id2", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "a2", quote_style: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "stream", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "S", quote_style: None }, columns: [] }), as_of: None }, joins: [Join { relation: Table { name: ObjectName([Ident { value: "version", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "V", quote_style: None }, columns: [] }), as_of: Some(ProcessTime) }, join_operator: Inner(On(BinaryOp { left: Identifier(Ident { value: "id1", quote_style: None }), op: Eq, right: Identifier(Ident { value: "id2", quote_style: None }) })) }] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' @@ -134,7 +166,10 @@ formatted_sql: SELECT percentile_cont(0.3) FROM unnest(ARRAY[1, 2, 4, 5, 10]) AS x formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "percentile_cont", quote_style: None }]), args: [Unnamed(Expr(Value(Number("0.3"))))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: Some(OrderByExpr { expr: Identifier(Ident { value: "x", quote_style: None }), asc: Some(false), nulls_first: None }) }))], from: [TableWithJoins { relation: TableFunction { name: ObjectName([Ident { value: "unnest", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "x", quote_style: None }, columns: [] }), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("4")), Value(Number("5")), Value(Number("10"))], named: true })))], with_ordinality: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select percentile_cont(0.3) within group (order by x, y desc) from t - error_msg: 'sql parser error: only one arg in order by is expected here' + error_msg: |- + sql parser error: only one arg in order by is expected here + LINE 1: select percentile_cont(0.3) within group (order by x, y desc) from t + ^ - input: select 'apple' ~~ 'app%' formatted_sql: SELECT 'apple' ~~ 'app%' formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(BinaryOp { left: Value(SingleQuotedString("apple")), op: PGLikeMatch, right: Value(SingleQuotedString("app%")) })], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' diff --git a/src/sqlparser/tests/testdata/set.yaml b/src/sqlparser/tests/testdata/set.yaml index fa059ccc1515e..e84f446d0f518 100644 --- a/src/sqlparser/tests/testdata/set.yaml +++ b/src/sqlparser/tests/testdata/set.yaml @@ -17,9 +17,11 @@ formatted_sql: SET search_path = 'default', 'my_path' - input: set search_path to default, 'my_path'; error_msg: |- - sql parser error: Expected end of statement, found: , at line:1, column:28 - Near "set search_path to default" + sql parser error: expected end of statement, found: , at line 1, column 27 + LINE 1: set search_path to default, 'my_path'; + ^ - input: set search_path to 'my_path', default; error_msg: |- - sql parser error: Expected parameter list value, found: default at line:1, column:36 - Near "set search_path to 'my_path', default" + sql parser error: expected parameter list value, found: default at line 1, column 31 + LINE 1: set search_path to 'my_path', default; + ^ diff --git a/src/sqlparser/tests/testdata/struct.yaml b/src/sqlparser/tests/testdata/struct.yaml index 49313402252e6..6f53b00aaa588 100644 --- a/src/sqlparser/tests/testdata/struct.yaml +++ b/src/sqlparser/tests/testdata/struct.yaml @@ -8,4 +8,7 @@ - input: create table st (v1 int, v2 struct>, v3 struct>) formatted_sql: CREATE TABLE st (v1 INT, v2 STRUCT>, v3 STRUCT>) - input: SELECT NULL::STRUCT> - error_msg: 'sql parser error: Unexpected EOF: Unconsumed `>>`' + error_msg: |- + sql parser error: Unexpected EOF: Unconsumed `>>` + LINE 1: SELECT NULL::STRUCT> + ^ diff --git a/src/sqlparser/tests/testdata/subquery.yaml b/src/sqlparser/tests/testdata/subquery.yaml index f5b454fc883bf..64c9f7bf26b25 100644 --- a/src/sqlparser/tests/testdata/subquery.yaml +++ b/src/sqlparser/tests/testdata/subquery.yaml @@ -7,12 +7,14 @@ formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Identifier(Ident { value: "a1", quote_style: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "a", quote_style: None }]), alias: None, as_of: None }, joins: [] }], lateral_views: [], selection: Some(InSubquery { expr: Identifier(Ident { value: "a1", quote_style: None }), subquery: Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Identifier(Ident { value: "b1", quote_style: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "b", quote_style: None }]), alias: None, as_of: None }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None }, negated: true }), group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select a1 from a where a1 < ALL (select b1 from b); error_msg: |- - sql parser error: Expected ), found: b1 at line:1, column:43 - Near "where a1 < ALL (select" + sql parser error: expected ), found: b1 at line 1, column 41 + LINE 1: select a1 from a where a1 < ALL (select b1 from b); + ^ - input: select a1 from a where a1 <> SOME (select b1 from b); error_msg: |- - sql parser error: Expected ), found: b1 at line:1, column:44 - Near "where a1 <> SOME (select" + sql parser error: expected ), found: b1 at line 1, column 43 + LINE 1: select a1 from a where a1 <> SOME (select b1 from b); + ^ - input: select 1 + (select 2); formatted_sql: SELECT 1 + (SELECT 2) formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(BinaryOp { left: Value(Number("1")), op: Plus, right: Subquery(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Value(Number("2")))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None }) })], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' diff --git a/src/stream/src/common/log_store_impl/kv_log_store/buffer.rs b/src/stream/src/common/log_store_impl/kv_log_store/buffer.rs index cfe7404786ece..85926a82373da 100644 --- a/src/stream/src/common/log_store_impl/kv_log_store/buffer.rs +++ b/src/stream/src/common/log_store_impl/kv_log_store/buffer.rs @@ -199,6 +199,16 @@ impl LogStoreBufferInner { self.update_unconsumed_buffer_metrics(); } + fn add_truncate_offset(&mut self, (epoch, seq_id): ReaderTruncationOffsetType) { + if let Some((prev_epoch, ref mut prev_seq_id)) = self.truncation_list.back_mut() + && *prev_epoch == epoch + { + *prev_seq_id = seq_id; + } else { + self.truncation_list.push_back((epoch, seq_id)); + } + } + fn rewind(&mut self) { while let Some((epoch, item)) = self.consumed_queue.pop_front() { self.unconsumed_queue.push_back((epoch, item)); @@ -371,7 +381,7 @@ impl LogStoreBufferReceiver { } } - pub(crate) fn truncate(&mut self, offset: TruncateOffset) { + pub(crate) fn truncate_buffer(&mut self, offset: TruncateOffset) { let mut inner = self.buffer.inner(); let mut latest_offset: Option = None; while let Some((epoch, item)) = inner.consumed_queue.back() { @@ -431,17 +441,16 @@ impl LogStoreBufferReceiver { } } } - if let Some((epoch, seq_id)) = latest_offset { - if let Some((prev_epoch, ref mut prev_seq_id)) = inner.truncation_list.back_mut() - && *prev_epoch == epoch - { - *prev_seq_id = seq_id; - } else { - inner.truncation_list.push_back((epoch, seq_id)); - } + if let Some(offset) = latest_offset { + inner.add_truncate_offset(offset); } } + pub(crate) fn truncate_historical(&mut self, epoch: u64) { + let mut inner = self.buffer.inner(); + inner.add_truncate_offset((epoch, None)); + } + pub(crate) fn rewind(&self) { self.buffer.inner().rewind() } diff --git a/src/stream/src/common/log_store_impl/kv_log_store/mod.rs b/src/stream/src/common/log_store_impl/kv_log_store/mod.rs index d1574a71debc0..3db9723e2ba14 100644 --- a/src/stream/src/common/log_store_impl/kv_log_store/mod.rs +++ b/src/stream/src/common/log_store_impl/kv_log_store/mod.rs @@ -1555,4 +1555,229 @@ mod tests { let chunk_ids = check_reader_last_unsealed(&mut reader, empty()).await; assert!(chunk_ids.is_empty()); } + + async fn validate_reader( + reader: &mut impl LogReader, + expected: impl IntoIterator, + ) { + for (expected_epoch, expected_item) in expected { + let (epoch, item) = reader.next_item().await.unwrap(); + assert_eq!(expected_epoch, epoch); + match (expected_item, item) { + ( + LogStoreReadItem::StreamChunk { + chunk: expected_chunk, + .. + }, + LogStoreReadItem::StreamChunk { chunk, .. }, + ) => { + check_stream_chunk_eq(&expected_chunk, &chunk); + } + ( + LogStoreReadItem::Barrier { + is_checkpoint: expected_is_checkpoint, + }, + LogStoreReadItem::Barrier { is_checkpoint }, + ) => { + assert_eq!(expected_is_checkpoint, is_checkpoint); + } + _ => unreachable!(), + } + } + } + + #[tokio::test] + async fn test_truncate_historical() { + #[expect(deprecated)] + test_truncate_historical_inner( + 10, + &crate::common::log_store_impl::kv_log_store::v1::KV_LOG_STORE_V1_INFO, + ) + .await; + test_truncate_historical_inner(10, &KV_LOG_STORE_V2_INFO).await; + } + + async fn test_truncate_historical_inner( + max_row_count: usize, + pk_info: &'static KvLogStorePkInfo, + ) { + let gen_stream_chunk = |base| gen_stream_chunk_with_info(base, pk_info); + let test_env = prepare_hummock_test_env().await; + + let table = gen_test_log_store_table(pk_info); + + test_env.register_table(table.clone()).await; + + let stream_chunk1 = gen_stream_chunk(0); + let stream_chunk2 = gen_stream_chunk(10); + let bitmap = calculate_vnode_bitmap(stream_chunk1.rows().chain(stream_chunk2.rows())); + let bitmap = Arc::new(bitmap); + + let factory = KvLogStoreFactory::new( + test_env.storage.clone(), + table.clone(), + Some(bitmap.clone()), + max_row_count, + KvLogStoreMetrics::for_test(), + "test", + pk_info, + ); + let (mut reader, mut writer) = factory.build().await; + + let epoch1 = test_env + .storage + .get_pinned_version() + .version() + .max_committed_epoch + .next_epoch(); + writer + .init(EpochPair::new_test_epoch(epoch1), false) + .await + .unwrap(); + writer.write_chunk(stream_chunk1.clone()).await.unwrap(); + let epoch2 = epoch1.next_epoch(); + writer.flush_current_epoch(epoch2, false).await.unwrap(); + writer.write_chunk(stream_chunk2.clone()).await.unwrap(); + let epoch3 = epoch2.next_epoch(); + writer.flush_current_epoch(epoch3, true).await.unwrap(); + + test_env.storage.seal_epoch(epoch1, false); + test_env.commit_epoch(epoch2).await; + + reader.init().await.unwrap(); + validate_reader( + &mut reader, + [ + ( + epoch1, + LogStoreReadItem::StreamChunk { + chunk: stream_chunk1.clone(), + chunk_id: 0, + }, + ), + ( + epoch1, + LogStoreReadItem::Barrier { + is_checkpoint: false, + }, + ), + ( + epoch2, + LogStoreReadItem::StreamChunk { + chunk: stream_chunk2.clone(), + chunk_id: 0, + }, + ), + ( + epoch2, + LogStoreReadItem::Barrier { + is_checkpoint: true, + }, + ), + ], + ) + .await; + + drop(writer); + + // Recovery + test_env.storage.clear_shared_buffer(epoch2).await; + + // Rebuild log reader and writer in recovery + let factory = KvLogStoreFactory::new( + test_env.storage.clone(), + table.clone(), + Some(bitmap.clone()), + max_row_count, + KvLogStoreMetrics::for_test(), + "test", + pk_info, + ); + let (mut reader, mut writer) = factory.build().await; + writer + .init(EpochPair::new_test_epoch(epoch3), false) + .await + .unwrap(); + reader.init().await.unwrap(); + validate_reader( + &mut reader, + [ + ( + epoch1, + LogStoreReadItem::StreamChunk { + chunk: stream_chunk1.clone(), + chunk_id: 0, + }, + ), + ( + epoch1, + LogStoreReadItem::Barrier { + is_checkpoint: false, + }, + ), + ( + epoch2, + LogStoreReadItem::StreamChunk { + chunk: stream_chunk2.clone(), + chunk_id: 0, + }, + ), + ( + epoch2, + LogStoreReadItem::Barrier { + is_checkpoint: true, + }, + ), + ], + ) + .await; + // The truncate should take effect + reader + .truncate(TruncateOffset::Barrier { epoch: epoch1 }) + .unwrap(); + let epoch4 = epoch3.next_epoch(); + writer.flush_current_epoch(epoch4, true).await.unwrap(); + test_env.commit_epoch(epoch3).await; + + drop(writer); + + // Recovery + test_env.storage.clear_shared_buffer(epoch3).await; + + // Rebuild log reader and writer in recovery + let factory = KvLogStoreFactory::new( + test_env.storage.clone(), + table.clone(), + Some(bitmap), + max_row_count, + KvLogStoreMetrics::for_test(), + "test", + pk_info, + ); + let (mut reader, mut writer) = factory.build().await; + writer + .init(EpochPair::new_test_epoch(epoch4), false) + .await + .unwrap(); + reader.init().await.unwrap(); + validate_reader( + &mut reader, + [ + ( + epoch2, + LogStoreReadItem::StreamChunk { + chunk: stream_chunk2.clone(), + chunk_id: 0, + }, + ), + ( + epoch2, + LogStoreReadItem::Barrier { + is_checkpoint: true, + }, + ), + ], + ) + .await; + } } diff --git a/src/stream/src/common/log_store_impl/kv_log_store/reader.rs b/src/stream/src/common/log_store_impl/kv_log_store/reader.rs index 7cad46b5b5fc0..21ee99ec91d08 100644 --- a/src/stream/src/common/log_store_impl/kv_log_store/reader.rs +++ b/src/stream/src/common/log_store_impl/kv_log_store/reader.rs @@ -465,11 +465,11 @@ impl LogReader for KvLogStoreReader { } } if offset.epoch() >= self.first_write_epoch.expect("should have init") { - self.rx.truncate(offset); + self.rx.truncate_buffer(offset); } else { // For historical data, no need to truncate at seq id level. Only truncate at barrier. - if let TruncateOffset::Barrier { .. } = &offset { - self.rx.truncate(offset); + if let TruncateOffset::Barrier { epoch } = &offset { + self.rx.truncate_historical(*epoch); } } self.truncate_offset = Some(offset);