Skip to content

Commit

Permalink
fix: timestamp range filter
Browse files Browse the repository at this point in the history
  • Loading branch information
v0y4g3r committed Oct 7, 2023
1 parent fe783c7 commit 1829f13
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 33 deletions.
4 changes: 2 additions & 2 deletions src/query/src/tests/time_range_filter_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async fn test_range_filter() {
tester
.check(
"select * from m where ts > 1000;",
TimestampRange::from_start(Timestamp::new(1000, TimeUnit::Millisecond)),
TimestampRange::from_start(Timestamp::new(1001, TimeUnit::Millisecond)),
)
.await;

Expand All @@ -163,7 +163,7 @@ async fn test_range_filter() {
tester
.check(
"select * from m where ts > 1000 and ts < 2000;",
TimestampRange::with_unit(1000, 2000, TimeUnit::Millisecond).unwrap(),
TimestampRange::with_unit(1001, 2000, TimeUnit::Millisecond).unwrap(),
)
.await;

Expand Down
273 changes: 242 additions & 31 deletions src/table/src/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,24 +181,74 @@ impl<'a> TimeRangePredicateBuilder<'a> {
match op {
Operator::Eq => self
.get_timestamp_filter(left, right)
.and_then(|ts| ts.convert_to(self.ts_col_unit))
.and_then(|(ts, _)| ts.convert_to(self.ts_col_unit))
.map(TimestampRange::single),
Operator::Lt => self
.get_timestamp_filter(left, right)
.and_then(|ts| ts.convert_to_ceil(self.ts_col_unit))
.map(|ts| TimestampRange::until_end(ts, false)),
Operator::LtEq => self
.get_timestamp_filter(left, right)
.and_then(|ts| ts.convert_to_ceil(self.ts_col_unit))
.map(|ts| TimestampRange::until_end(ts, true)),
Operator::Gt => self
.get_timestamp_filter(left, right)
.and_then(|ts| ts.convert_to(self.ts_col_unit))
.map(TimestampRange::from_start),
Operator::GtEq => self
.get_timestamp_filter(left, right)
.and_then(|ts| ts.convert_to(self.ts_col_unit))
.map(TimestampRange::from_start),
Operator::Lt => {
let Some((ts, reverse)) = self
.get_timestamp_filter(left, right)else {
return None;
};
if reverse {
// [lit] < ts_col
let ts_val = ts.convert_to(self.ts_col_unit)?.value();
Some(TimestampRange::from_start(Timestamp::new(
ts_val + 1,
self.ts_col_unit,
)))
} else {
// ts_col < [lit]
ts.convert_to_ceil(self.ts_col_unit)
.map(|ts| TimestampRange::until_end(ts, false))
}
}
Operator::LtEq => {
let Some((ts, reverse)) = self
.get_timestamp_filter(left, right)else {
return None;
};
if reverse {
// [lit] <= ts_col
ts.convert_to_ceil(self.ts_col_unit)
.map(TimestampRange::from_start)
} else {
// ts_col <= [lit]
ts.convert_to(self.ts_col_unit)
.map(|ts| TimestampRange::until_end(ts, true))
}
}
Operator::Gt => {
let Some((ts, reverse)) = self
.get_timestamp_filter(left, right)else {
return None;
};
if reverse {
// [lit] > ts_col
ts.convert_to_ceil(self.ts_col_unit)
.map(|t| TimestampRange::until_end(t, false))
} else {
// ts_col > [lit]
let ts_val = ts.convert_to(self.ts_col_unit)?.value();
Some(TimestampRange::from_start(Timestamp::new(
ts_val + 1,
self.ts_col_unit,
)))
}
}
Operator::GtEq => {
let Some((ts, reverse)) = self
.get_timestamp_filter(left, right)else {
return None;
};
if reverse {
// [lit] >= ts_col
ts.convert_to(self.ts_col_unit)
.map(|t| TimestampRange::until_end(t, true))
} else {
// ts_col >= [lit]
ts.convert_to_ceil(self.ts_col_unit)
.map(TimestampRange::from_start)
}
}
Operator::And => {
// instead of return none when failed to extract time range from left/right, we unwrap the none into
// `TimestampRange::min_to_max`.
Expand Down Expand Up @@ -236,18 +286,18 @@ impl<'a> TimeRangePredicateBuilder<'a> {
}
}

fn get_timestamp_filter(&self, left: &DfExpr, right: &DfExpr) -> Option<Timestamp> {
let (col, lit) = match (left, right) {
(DfExpr::Column(column), DfExpr::Literal(scalar)) => (column, scalar),
(DfExpr::Literal(scalar), DfExpr::Column(column)) => (column, scalar),
fn get_timestamp_filter(&self, left: &DfExpr, right: &DfExpr) -> Option<(Timestamp, bool)> {
let (col, lit, reverse) = match (left, right) {
(DfExpr::Column(column), DfExpr::Literal(scalar)) => (column, scalar, false),
(DfExpr::Literal(scalar), DfExpr::Column(column)) => (column, scalar, true),
_ => {
return None;
}
};
if col.name != self.ts_col_name {
return None;
}
scalar_value_to_timestamp(lit)
scalar_value_to_timestamp(lit).map(|t| (t, reverse))
}

fn extract_from_between_expr(
Expand Down Expand Up @@ -324,7 +374,7 @@ mod tests {
use datafusion::parquet::arrow::ArrowWriter;
pub use datafusion::parquet::schema::types::BasicTypeInfo;
use datafusion_common::{Column, ScalarValue};
use datafusion_expr::{BinaryExpr, Expr, Literal, Operator};
use datafusion_expr::{col, lit, BinaryExpr, Literal, Operator};
use datatypes::arrow::array::Int32Array;
use datatypes::arrow::datatypes::{DataType, Field, Schema};
use datatypes::arrow::record_batch::RecordBatch;
Expand All @@ -334,6 +384,165 @@ mod tests {

use super::*;

fn check_build_predicate(expr: DfExpr, expect: TimestampRange) {
assert_eq!(
expect,
TimeRangePredicateBuilder::new("ts", TimeUnit::Millisecond, &[Expr::from(expr)])
.build()
);
}

#[test]
fn test_build_predicate() {
// operator GT
// ts > 1ms
check_build_predicate(
col("ts").gt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
TimestampRange::from_start(Timestamp::new_millisecond(2)),
);

// 1ms > ts
check_build_predicate(
lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt(col("ts")),
TimestampRange::until_end(Timestamp::new_millisecond(1), false),
);

// 1001us > ts
check_build_predicate(
lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt(col("ts")),
TimestampRange::until_end(Timestamp::new_millisecond(1), true),
);

// ts > 1001us
check_build_predicate(
col("ts").gt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
TimestampRange::from_start(Timestamp::new_millisecond(2)),
);

// 1s > ts
check_build_predicate(
lit(ScalarValue::TimestampSecond(Some(1), None)).gt(col("ts")),
TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
);

// ts > 1s
check_build_predicate(
col("ts").gt(lit(ScalarValue::TimestampSecond(Some(1), None))),
TimestampRange::from_start(Timestamp::new_millisecond(1001)),
);

// operator GT_EQ
// ts >= 1ms
check_build_predicate(
col("ts").gt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
TimestampRange::from_start(Timestamp::new_millisecond(1)),
);

// 1ms >= ts
check_build_predicate(
lit(ScalarValue::TimestampMillisecond(Some(1), None)).gt_eq(col("ts")),
TimestampRange::until_end(Timestamp::new_millisecond(1), true),
);

// 1001us >= ts
check_build_predicate(
lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).gt_eq(col("ts")),
TimestampRange::until_end(Timestamp::new_millisecond(1), true),
);

// ts >= 1001us
check_build_predicate(
col("ts").gt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
TimestampRange::from_start(Timestamp::new_millisecond(2)),
);

// 1s >= ts
check_build_predicate(
lit(ScalarValue::TimestampSecond(Some(1), None)).gt_eq(col("ts")),
TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
);

// ts >= 1s
check_build_predicate(
col("ts").gt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
TimestampRange::from_start(Timestamp::new_millisecond(1000)),
);

// operator LT
// ts < 1ms
check_build_predicate(
col("ts").lt(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
TimestampRange::until_end(Timestamp::new_millisecond(1), false),
);

// 1ms < ts
check_build_predicate(
lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt(col("ts")),
TimestampRange::from_start(Timestamp::new_millisecond(2)),
);

// 1001us < ts
check_build_predicate(
lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt(col("ts")),
TimestampRange::from_start(Timestamp::new_millisecond(2)),
);

// ts < 1001us
check_build_predicate(
col("ts").lt(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
TimestampRange::until_end(Timestamp::new_millisecond(1), true),
);

// 1s < ts
check_build_predicate(
lit(ScalarValue::TimestampSecond(Some(1), None)).lt(col("ts")),
TimestampRange::from_start(Timestamp::new_millisecond(1001)),
);

// ts < 1s
check_build_predicate(
col("ts").lt(lit(ScalarValue::TimestampSecond(Some(1), None))),
TimestampRange::until_end(Timestamp::new_millisecond(1000), false),
);

// operator LT_EQ
// ts <= 1ms
check_build_predicate(
col("ts").lt_eq(lit(ScalarValue::TimestampMillisecond(Some(1), None))),
TimestampRange::until_end(Timestamp::new_millisecond(1), true),
);

// 1ms <= ts
check_build_predicate(
lit(ScalarValue::TimestampMillisecond(Some(1), None)).lt_eq(col("ts")),
TimestampRange::from_start(Timestamp::new_millisecond(1)),
);

// 1001us <= ts
check_build_predicate(
lit(ScalarValue::TimestampMicrosecond(Some(1001), None)).lt_eq(col("ts")),
TimestampRange::from_start(Timestamp::new_millisecond(2)),
);

// ts <= 1001us
check_build_predicate(
col("ts").lt_eq(lit(ScalarValue::TimestampMicrosecond(Some(1001), None))),
TimestampRange::until_end(Timestamp::new_millisecond(1), true),
);

// 1s <= ts
check_build_predicate(
lit(ScalarValue::TimestampSecond(Some(1), None)).lt_eq(col("ts")),
TimestampRange::from_start(Timestamp::new_millisecond(1000)),
);

// ts <= 1s
check_build_predicate(
col("ts").lt_eq(lit(ScalarValue::TimestampSecond(Some(1), None))),
TimestampRange::until_end(Timestamp::new_millisecond(1000), true),
);
}

async fn gen_test_parquet_file(dir: &TempDir, cnt: usize) -> (String, Arc<Schema>) {
let path = dir
.path()
Expand Down Expand Up @@ -397,13 +606,15 @@ mod tests {
}

fn gen_predicate(max_val: i32, op: Operator) -> Vec<common_query::logical_plan::Expr> {
vec![common_query::logical_plan::Expr::from(Expr::BinaryExpr(
BinaryExpr {
left: Box::new(Expr::Column(Column::from_name("cnt"))),
vec![common_query::logical_plan::Expr::from(
datafusion_expr::Expr::BinaryExpr(BinaryExpr {
left: Box::new(datafusion_expr::Expr::Column(Column::from_name("cnt"))),
op,
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(max_val)))),
},
))]
right: Box::new(datafusion_expr::Expr::Literal(ScalarValue::Int32(Some(
max_val,
)))),
}),
)]
}

#[tokio::test]
Expand Down Expand Up @@ -469,9 +680,9 @@ mod tests {
#[tokio::test]
async fn test_or() {
// cnt > 30 or cnt < 20
let e = Expr::Column(Column::from_name("cnt"))
let e = datafusion_expr::Expr::Column(Column::from_name("cnt"))
.gt(30.lit())
.or(Expr::Column(Column::from_name("cnt")).lt(20.lit()));
.or(datafusion_expr::Expr::Column(Column::from_name("cnt")).lt(20.lit()));
assert_prune(40, vec![e.into()], vec![true, true, false, true]).await;
}
}

0 comments on commit 1829f13

Please sign in to comment.