diff --git a/src/query/src/tests/time_range_filter_test.rs b/src/query/src/tests/time_range_filter_test.rs index dfbcc3a13a73..1f5a926d79c5 100644 --- a/src/query/src/tests/time_range_filter_test.rs +++ b/src/query/src/tests/time_range_filter_test.rs @@ -149,7 +149,7 @@ async fn test_range_filter() { tester .check( "select * from m where ts > 1000;", - TimestampRange::from_start(Timestamp::new(1000, TimeUnit::Millisecond)), + TimestampRange::from_start(Timestamp::new(1001, TimeUnit::Millisecond)), ) .await; @@ -163,7 +163,7 @@ async fn test_range_filter() { tester .check( "select * from m where ts > 1000 and ts < 2000;", - TimestampRange::with_unit(1000, 2000, TimeUnit::Millisecond).unwrap(), + TimestampRange::with_unit(1001, 2000, TimeUnit::Millisecond).unwrap(), ) .await; diff --git a/src/table/src/predicate.rs b/src/table/src/predicate.rs index e11117f8066b..b828313778fa 100644 --- a/src/table/src/predicate.rs +++ b/src/table/src/predicate.rs @@ -181,24 +181,62 @@ impl<'a> TimeRangePredicateBuilder<'a> { match op { Operator::Eq => self .get_timestamp_filter(left, right) - .and_then(|ts| ts.convert_to(self.ts_col_unit)) + .and_then(|(ts, _)| ts.convert_to(self.ts_col_unit)) .map(TimestampRange::single), - Operator::Lt => self - .get_timestamp_filter(left, right) - .and_then(|ts| ts.convert_to_ceil(self.ts_col_unit)) - .map(|ts| TimestampRange::until_end(ts, false)), - Operator::LtEq => self - .get_timestamp_filter(left, right) - .and_then(|ts| ts.convert_to_ceil(self.ts_col_unit)) - .map(|ts| TimestampRange::until_end(ts, true)), - Operator::Gt => self - .get_timestamp_filter(left, right) - .and_then(|ts| ts.convert_to(self.ts_col_unit)) - .map(TimestampRange::from_start), - Operator::GtEq => self - .get_timestamp_filter(left, right) - .and_then(|ts| ts.convert_to(self.ts_col_unit)) - .map(TimestampRange::from_start), + Operator::Lt => { + let (ts, reverse) = self.get_timestamp_filter(left, right)?; + if reverse { + // [lit] < ts_col + let ts_val = ts.convert_to(self.ts_col_unit)?.value(); + Some(TimestampRange::from_start(Timestamp::new( + ts_val + 1, + self.ts_col_unit, + ))) + } else { + // ts_col < [lit] + ts.convert_to_ceil(self.ts_col_unit) + .map(|ts| TimestampRange::until_end(ts, false)) + } + } + Operator::LtEq => { + let (ts, reverse) = self.get_timestamp_filter(left, right)?; + if reverse { + // [lit] <= ts_col + ts.convert_to_ceil(self.ts_col_unit) + .map(TimestampRange::from_start) + } else { + // ts_col <= [lit] + ts.convert_to(self.ts_col_unit) + .map(|ts| TimestampRange::until_end(ts, true)) + } + } + Operator::Gt => { + let (ts, reverse) = self.get_timestamp_filter(left, right)?; + if reverse { + // [lit] > ts_col + ts.convert_to_ceil(self.ts_col_unit) + .map(|t| TimestampRange::until_end(t, false)) + } else { + // ts_col > [lit] + let ts_val = ts.convert_to(self.ts_col_unit)?.value(); + Some(TimestampRange::from_start(Timestamp::new( + ts_val + 1, + self.ts_col_unit, + ))) + } + } + Operator::GtEq => { + let (ts, reverse) = self.get_timestamp_filter(left, right)?; + if reverse { + // [lit] >= ts_col + ts.convert_to(self.ts_col_unit) + .map(|t| TimestampRange::until_end(t, true)) + } else { + // ts_col >= [lit] + ts.convert_to_ceil(self.ts_col_unit) + .map(TimestampRange::from_start) + } + } Operator::And => { // instead of return none when failed to extract time range from left/right, we unwrap the none into // `TimestampRange::min_to_max`. @@ -236,10 +274,10 @@ impl<'a> TimeRangePredicateBuilder<'a> { } } - fn get_timestamp_filter(&self, left: &DfExpr, right: &DfExpr) -> Option { - let (col, lit) = match (left, right) { - (DfExpr::Column(column), DfExpr::Literal(scalar)) => (column, scalar), - (DfExpr::Literal(scalar), DfExpr::Column(column)) => (column, scalar), + fn get_timestamp_filter(&self, left: &DfExpr, right: &DfExpr) -> Option<(Timestamp, bool)> { + let (col, lit, reverse) = match (left, right) { + (DfExpr::Column(column), DfExpr::Literal(scalar)) => (column, scalar, false), + (DfExpr::Literal(scalar), DfExpr::Column(column)) => (column, scalar, true), _ => { return None; } @@ -247,7 +285,7 @@ impl<'a> TimeRangePredicateBuilder<'a> { if col.name != self.ts_col_name { return None; } - scalar_value_to_timestamp(lit) + scalar_value_to_timestamp(lit).map(|t| (t, reverse)) } fn extract_from_between_expr( @@ -324,7 +362,7 @@ mod tests { use datafusion::parquet::arrow::ArrowWriter; pub use datafusion::parquet::schema::types::BasicTypeInfo; use datafusion_common::{Column, ScalarValue}; - use datafusion_expr::{BinaryExpr, Expr, Literal, Operator}; + use datafusion_expr::{col, lit, BinaryExpr, Literal, Operator}; use datatypes::arrow::array::Int32Array; use datatypes::arrow::datatypes::{DataType, Field, Schema}; use datatypes::arrow::record_batch::RecordBatch; @@ -334,6 +372,169 @@ mod tests { use super::*; + fn check_build_predicate(expr: DfExpr, expect: TimestampRange) { + assert_eq!( + expect, + TimeRangePredicateBuilder::new("ts", TimeUnit::Millisecond, &[Expr::from(expr)]) + .build() + ); + } + + #[test] + fn test_gt() { + // ts > 1ms + check_build_predicate( + col("ts").gt(lit(ScalarValue::TimestampMillisecond(Some(1), None))), + TimestampRange::from_start(Timestamp::new_millisecond(2)), + ); + + // 1ms > ts + check_build_predicate( + lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt(col("ts")), + TimestampRange::until_end(Timestamp::new_millisecond(1), false), + ); + + // 1001us > ts + check_build_predicate( + lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt(col("ts")), + TimestampRange::until_end(Timestamp::new_millisecond(1), true), + ); + + // ts > 1001us + check_build_predicate( + col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))), + TimestampRange::from_start(Timestamp::new_millisecond(2)), + ); + + // 1s > ts + check_build_predicate( + lit(ScalarValue::TimestampSecond(Some(1), None)).gt(col("ts")), + TimestampRange::until_end(Timestamp::new_millisecond(1000), false), + ); + + // ts > 1s + check_build_predicate( + col("ts").gt(lit(ScalarValue::TimestampSecond(Some(1), None))), + TimestampRange::from_start(Timestamp::new_millisecond(1001)), + ); + } + + #[test] + fn test_gt_eq() { + // ts >= 1ms + check_build_predicate( + col("ts").gt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))), + TimestampRange::from_start(Timestamp::new_millisecond(1)), + ); + + // 1ms >= ts + check_build_predicate( + lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt_eq(col("ts")), + TimestampRange::until_end(Timestamp::new_millisecond(1), true), + ); + + // 1001us >= ts + check_build_predicate( + lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt_eq(col("ts")), + TimestampRange::until_end(Timestamp::new_millisecond(1), true), + ); + + // ts >= 1001us + check_build_predicate( + col("ts").gt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))), + TimestampRange::from_start(Timestamp::new_millisecond(2)), + ); + + // 1s >= ts + check_build_predicate( + lit(ScalarValue::TimestampSecond(Some(1), None)).gt_eq(col("ts")), + TimestampRange::until_end(Timestamp::new_millisecond(1000), true), + ); + + // ts >= 1s + check_build_predicate( + col("ts").gt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))), + TimestampRange::from_start(Timestamp::new_millisecond(1000)), + ); + } + + #[test] + fn test_lt() { + // ts < 1ms + check_build_predicate( + col("ts").lt(lit(ScalarValue::TimestampMillisecond(Some(1), None))), + TimestampRange::until_end(Timestamp::new_millisecond(1), false), + ); + + // 1ms < ts + check_build_predicate( + lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt(col("ts")), + TimestampRange::from_start(Timestamp::new_millisecond(2)), + ); + + // 1001us < ts + check_build_predicate( + lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt(col("ts")), + TimestampRange::from_start(Timestamp::new_millisecond(2)), + ); + + // ts < 1001us + check_build_predicate( + col("ts").lt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))), + TimestampRange::until_end(Timestamp::new_millisecond(1), true), + ); + + // 1s < ts + check_build_predicate( + lit(ScalarValue::TimestampSecond(Some(1), None)).lt(col("ts")), + TimestampRange::from_start(Timestamp::new_millisecond(1001)), + ); + + // ts < 1s + check_build_predicate( + col("ts").lt(lit(ScalarValue::TimestampSecond(Some(1), None))), + TimestampRange::until_end(Timestamp::new_millisecond(1000), false), + ); + } + #[test] + fn test_lt_eq() { + // ts <= 1ms + check_build_predicate( + col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))), + TimestampRange::until_end(Timestamp::new_millisecond(1), true), + ); + + // 1ms <= ts + check_build_predicate( + lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt_eq(col("ts")), + TimestampRange::from_start(Timestamp::new_millisecond(1)), + ); + + // 1001us <= ts + check_build_predicate( + lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt_eq(col("ts")), + TimestampRange::from_start(Timestamp::new_millisecond(2)), + ); + + // ts <= 1001us + check_build_predicate( + col("ts").lt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))), + TimestampRange::until_end(Timestamp::new_millisecond(1), true), + ); + + // 1s <= ts + check_build_predicate( + lit(ScalarValue::TimestampSecond(Some(1), None)).lt_eq(col("ts")), + TimestampRange::from_start(Timestamp::new_millisecond(1000)), + ); + + // ts <= 1s + check_build_predicate( + col("ts").lt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))), + TimestampRange::until_end(Timestamp::new_millisecond(1000), true), + ); + } + async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc) { let path = dir .path() @@ -397,13 +598,15 @@ mod tests { } fn gen_predicate(max_val: i32, op: Operator) -> Vec { - vec![common_query::logical_plan::Expr::from(Expr::BinaryExpr( - BinaryExpr { - left: Box::new(Expr::Column(Column::from_name("cnt"))), + vec![common_query::logical_plan::Expr::from( + datafusion_expr::Expr::BinaryExpr(BinaryExpr { + left: Box::new(datafusion_expr::Expr::Column(Column::from_name("cnt"))), op, - right: Box::new(Expr::Literal(ScalarValue::Int32(Some(max_val)))), - }, - ))] + right: Box::new(datafusion_expr::Expr::Literal(ScalarValue::Int32(Some( + max_val, + )))), + }), + )] } #[tokio::test] @@ -469,9 +672,9 @@ mod tests { #[tokio::test] async fn test_or() { // cnt > 30 or cnt < 20 - let e = Expr::Column(Column::from_name("cnt")) + let e = datafusion_expr::Expr::Column(Column::from_name("cnt")) .gt(30.lit()) - .or(Expr::Column(Column::from_name("cnt")).lt(20.lit())); + .or(datafusion_expr::Expr::Column(Column::from_name("cnt")).lt(20.lit())); assert_prune(40, vec![e.into()], vec![true, true, false, true]).await; } }