Skip to content

Commit

Permalink
refactor: replace Expr with datafusion::Expr
Browse files Browse the repository at this point in the history
  • Loading branch information
poltao committed May 21, 2024
1 parent 179c8c7 commit b2cd442
Show file tree
Hide file tree
Showing 30 changed files with 184 additions and 227 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

76 changes: 38 additions & 38 deletions src/catalog/src/information_schema/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

use arrow::array::StringArray;
use arrow::compute::kernels::comparison;
use common_query::logical_plan::DfExpr;
use datafusion::common::ScalarValue;
use datafusion::logical_expr::expr::Like;
use datafusion::logical_expr::Operator;
use datafusion::prelude::Expr;
use datatypes::value::Value;
use store_api::storage::ScanRequest;

Expand Down Expand Up @@ -118,23 +118,23 @@ impl Predicate {
}

/// Try to create a predicate from datafusion [`Expr`], return None if fails.
fn from_expr(expr: DfExpr) -> Option<Predicate> {
fn from_expr(expr: Expr) -> Option<Predicate> {
match expr {
// NOT expr
DfExpr::Not(expr) => Some(Predicate::Not(Box::new(Self::from_expr(*expr)?))),
Expr::Not(expr) => Some(Predicate::Not(Box::new(Self::from_expr(*expr)?))),
// expr LIKE pattern
DfExpr::Like(Like {
Expr::Like(Like {
negated,
expr,
pattern,
case_insensitive,
..
}) if is_column(&expr) && is_string_literal(&pattern) => {
// Safety: ensured by gurad
let DfExpr::Column(c) = *expr else {
let Expr::Column(c) = *expr else {
unreachable!();
};
let DfExpr::Literal(ScalarValue::Utf8(Some(pattern))) = *pattern else {
let Expr::Literal(ScalarValue::Utf8(Some(pattern))) = *pattern else {
unreachable!();
};

Expand All @@ -147,19 +147,19 @@ impl Predicate {
}
}
// left OP right
DfExpr::BinaryExpr(bin) => match (*bin.left, bin.op, *bin.right) {
Expr::BinaryExpr(bin) => match (*bin.left, bin.op, *bin.right) {
// left == right
(DfExpr::Literal(scalar), Operator::Eq, DfExpr::Column(c))
| (DfExpr::Column(c), Operator::Eq, DfExpr::Literal(scalar)) => {
(Expr::Literal(scalar), Operator::Eq, Expr::Column(c))
| (Expr::Column(c), Operator::Eq, Expr::Literal(scalar)) => {
let Ok(v) = Value::try_from(scalar) else {
return None;
};

Some(Predicate::Eq(c.name, v))
}
// left != right
(DfExpr::Literal(scalar), Operator::NotEq, DfExpr::Column(c))
| (DfExpr::Column(c), Operator::NotEq, DfExpr::Literal(scalar)) => {
(Expr::Literal(scalar), Operator::NotEq, Expr::Column(c))
| (Expr::Column(c), Operator::NotEq, Expr::Literal(scalar)) => {
let Ok(v) = Value::try_from(scalar) else {
return None;
};
Expand All @@ -183,14 +183,14 @@ impl Predicate {
_ => None,
},
// [NOT] IN (LIST)
DfExpr::InList(list) => {
Expr::InList(list) => {
match (*list.expr, list.list, list.negated) {
// column [NOT] IN (v1, v2, v3, ...)
(DfExpr::Column(c), list, negated) if is_all_scalars(&list) => {
(Expr::Column(c), list, negated) if is_all_scalars(&list) => {
let mut values = Vec::with_capacity(list.len());
for scalar in list {
// Safety: checked by `is_all_scalars`
let DfExpr::Literal(scalar) = scalar else {
let Expr::Literal(scalar) = scalar else {
unreachable!();
};

Expand Down Expand Up @@ -237,12 +237,12 @@ fn like_utf8(s: &str, pattern: &str, case_insensitive: &bool) -> Option<bool> {
Some(booleans.value(0))
}

fn is_string_literal(expr: &DfExpr) -> bool {
matches!(expr, DfExpr::Literal(ScalarValue::Utf8(Some(_))))
fn is_string_literal(expr: &Expr) -> bool {
matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(_))))
}

fn is_column(expr: &DfExpr) -> bool {
matches!(expr, DfExpr::Column(_))
fn is_column(expr: &Expr) -> bool {
matches!(expr, Expr::Column(_))
}

/// A list of predicate
Expand All @@ -257,7 +257,7 @@ impl Predicates {
let mut predicates = Vec::with_capacity(request.filters.len());

for filter in &request.filters {
if let Some(predicate) = Predicate::from_expr(filter.df_expr().clone()) {
if let Some(predicate) = Predicate::from_expr(filter.clone()) {
predicates.push(predicate);
}
}
Expand Down Expand Up @@ -286,8 +286,8 @@ impl Predicates {
}

/// Returns true when the values are all [`DfExpr::Literal`].
fn is_all_scalars(list: &[DfExpr]) -> bool {
list.iter().all(|v| matches!(v, DfExpr::Literal(_)))
fn is_all_scalars(list: &[Expr]) -> bool {
list.iter().all(|v| matches!(v, Expr::Literal(_)))
}

#[cfg(test)]
Expand Down Expand Up @@ -376,7 +376,7 @@ mod tests {
#[test]
fn test_predicate_like() {
// case insensitive
let expr = DfExpr::Like(Like {
let expr = Expr::Like(Like {
negated: false,
expr: Box::new(column("a")),
pattern: Box::new(string_literal("%abc")),
Expand All @@ -403,7 +403,7 @@ mod tests {
assert!(p.eval(&[]).is_none());

// case sensitive
let expr = DfExpr::Like(Like {
let expr = Expr::Like(Like {
negated: false,
expr: Box::new(column("a")),
pattern: Box::new(string_literal("%abc")),
Expand All @@ -423,7 +423,7 @@ mod tests {
assert!(p.eval(&[]).is_none());

// not like
let expr = DfExpr::Like(Like {
let expr = Expr::Like(Like {
negated: true,
expr: Box::new(column("a")),
pattern: Box::new(string_literal("%abc")),
Expand All @@ -437,15 +437,15 @@ mod tests {
assert!(p.eval(&[]).is_none());
}

fn column(name: &str) -> DfExpr {
DfExpr::Column(Column {
fn column(name: &str) -> Expr {
Expr::Column(Column {
relation: None,
name: name.to_string(),
})
}

fn string_literal(v: &str) -> DfExpr {
DfExpr::Literal(ScalarValue::Utf8(Some(v.to_string())))
fn string_literal(v: &str) -> Expr {
Expr::Literal(ScalarValue::Utf8(Some(v.to_string())))
}

fn match_string_value(v: &Value, expected: &str) -> bool {
Expand All @@ -463,14 +463,14 @@ mod tests {
result
}

fn mock_exprs() -> (DfExpr, DfExpr) {
let expr1 = DfExpr::BinaryExpr(BinaryExpr {
fn mock_exprs() -> (Expr, Expr) {
let expr1 = Expr::BinaryExpr(BinaryExpr {
left: Box::new(column("a")),
op: Operator::Eq,
right: Box::new(string_literal("a_value")),
});

let expr2 = DfExpr::BinaryExpr(BinaryExpr {
let expr2 = Expr::BinaryExpr(BinaryExpr {
left: Box::new(column("b")),
op: Operator::NotEq,
right: Box::new(string_literal("b_value")),
Expand All @@ -491,17 +491,17 @@ mod tests {
assert!(matches!(&p2, Predicate::NotEq(column, v) if column == "b"
&& match_string_value(v, "b_value")));

let and_expr = DfExpr::BinaryExpr(BinaryExpr {
let and_expr = Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr1.clone()),
op: Operator::And,
right: Box::new(expr2.clone()),
});
let or_expr = DfExpr::BinaryExpr(BinaryExpr {
let or_expr = Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr1.clone()),
op: Operator::Or,
right: Box::new(expr2.clone()),
});
let not_expr = DfExpr::Not(Box::new(expr1.clone()));
let not_expr = Expr::Not(Box::new(expr1.clone()));

let and_p = Predicate::from_expr(and_expr).unwrap();
assert!(matches!(and_p, Predicate::And(left, right) if *left == p1 && *right == p2));
Expand All @@ -510,7 +510,7 @@ mod tests {
let not_p = Predicate::from_expr(not_expr).unwrap();
assert!(matches!(not_p, Predicate::Not(p) if *p == p1));

let inlist_expr = DfExpr::InList(InList {
let inlist_expr = Expr::InList(InList {
expr: Box::new(column("a")),
list: vec![string_literal("a1"), string_literal("a2")],
negated: false,
Expand All @@ -520,7 +520,7 @@ mod tests {
assert!(matches!(&inlist_p, Predicate::InList(c, values) if c == "a"
&& match_string_values(values, &["a1", "a2"])));

let inlist_expr = DfExpr::InList(InList {
let inlist_expr = Expr::InList(InList {
expr: Box::new(column("a")),
list: vec![string_literal("a1"), string_literal("a2")],
negated: true,
Expand All @@ -540,7 +540,7 @@ mod tests {
let (expr1, expr2) = mock_exprs();

let request = ScanRequest {
filters: vec![expr1.into(), expr2.into()],
filters: vec![expr1, expr2],
..Default::default()
};
let predicates = Predicates::from_scan_request(&Some(request));
Expand Down Expand Up @@ -578,7 +578,7 @@ mod tests {

let (expr1, expr2) = mock_exprs();
let request = ScanRequest {
filters: vec![expr1.into(), expr2.into()],
filters: vec![expr1, expr2],
..Default::default()
};
let predicates = Predicates::from_scan_request(&Some(request));
Expand Down
1 change: 0 additions & 1 deletion src/common/query/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use datatypes::prelude::ConcreteDataType;
pub use expr::build_filter_from_timestamp;

pub use self::accumulator::{Accumulator, AggregateFunctionCreator, AggregateFunctionCreatorRef};
pub use self::expr::{DfExpr, Expr};
pub use self::udaf::AggregateFunction;
pub use self::udf::ScalarUdf;
use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation};
Expand Down
51 changes: 20 additions & 31 deletions src/common/query/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,22 @@ use common_time::range::TimestampRange;
use common_time::timestamp::TimeUnit;
use common_time::Timestamp;
use datafusion_common::{Column, ScalarValue};
pub use datafusion_expr::expr::Expr as DfExpr;
use datafusion_expr::expr::Expr;
use datafusion_expr::{and, binary_expr, Operator};

/// Central struct of query API.
/// Represent logical expressions such as `A + 1`, or `CAST(c1 AS int)`.
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Expr {
df_expr: DfExpr,
}

impl Expr {
pub fn df_expr(&self) -> &DfExpr {
&self.df_expr
}
}

impl From<DfExpr> for Expr {
fn from(df_expr: DfExpr) -> Self {
Self { df_expr }
}
}

/// Builds an `Expr` that filters timestamp column from given timestamp range.
/// Returns [None] if time range is [None] or full time range.
pub fn build_filter_from_timestamp(
ts_col_name: &str,
time_range: Option<&TimestampRange>,
) -> Option<Expr> {
let time_range = time_range?;
let ts_col_expr = DfExpr::Column(Column {
let ts_col_expr = Expr::Column(Column {
relation: None,
name: ts_col_name.to_string(),
});

let df_expr = match (time_range.start(), time_range.end()) {
match (time_range.start(), time_range.end()) {
(None, None) => None,
(Some(start), None) => Some(binary_expr(
ts_col_expr,
Expand All @@ -70,32 +51,40 @@ pub fn build_filter_from_timestamp(
),
binary_expr(ts_col_expr, Operator::Lt, timestamp_to_literal(end)),
)),
};

df_expr.map(Expr::from)
}
}

/// Converts a [Timestamp] to datafusion literal value.
fn timestamp_to_literal(timestamp: &Timestamp) -> DfExpr {
fn timestamp_to_literal(timestamp: &Timestamp) -> Expr {
let scalar_value = match timestamp.unit() {
TimeUnit::Second => ScalarValue::TimestampSecond(Some(timestamp.value()), None),
TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(Some(timestamp.value()), None),
TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(Some(timestamp.value()), None),
TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(Some(timestamp.value()), None),
};
DfExpr::Literal(scalar_value)
Expr::Literal(scalar_value)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_from_df_expr() {
let df_expr = DfExpr::Wildcard { qualifier: None };
fn test_timestamp_to_literal() {
let timestamp = Timestamp::new(123456789, TimeUnit::Second);
let expected = Expr::Literal(ScalarValue::TimestampSecond(Some(123456789), None));
assert_eq!(timestamp_to_literal(&timestamp), expected);

let timestamp = Timestamp::new(123456789, TimeUnit::Millisecond);
let expected = Expr::Literal(ScalarValue::TimestampMillisecond(Some(123456789), None));
assert_eq!(timestamp_to_literal(&timestamp), expected);

let expr: Expr = df_expr.into();
let timestamp = Timestamp::new(123456789, TimeUnit::Microsecond);
let expected = Expr::Literal(ScalarValue::TimestampMicrosecond(Some(123456789), None));
assert_eq!(timestamp_to_literal(&timestamp), expected);

assert_eq!(DfExpr::Wildcard { qualifier: None }, *expr.df_expr());
let timestamp = Timestamp::new(123456789, TimeUnit::Nanosecond);
let expected = Expr::Literal(ScalarValue::TimestampNanosecond(Some(123456789), None));
assert_eq!(timestamp_to_literal(&timestamp), expected);
}
}
2 changes: 1 addition & 1 deletion src/common/query/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub use datafusion_common::ScalarValue;

pub use crate::columnar_value::ColumnarValue;
pub use crate::function::*;
pub use crate::logical_plan::{create_udf, AggregateFunction, Expr, ScalarUdf};
pub use crate::logical_plan::{create_udf, AggregateFunction, ScalarUdf};
pub use crate::signature::{Signature, TypeSignature, Volatility};

/// Default timestamp column name for Prometheus metrics.
Expand Down
4 changes: 2 additions & 2 deletions src/file-engine/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use std::task::{Context, Poll};

use common_datasource::object_store::build_backend;
use common_error::ext::BoxedError;
use common_query::prelude::Expr;
use common_recordbatch::adapter::RecordBatchMetrics;
use common_recordbatch::error::{CastVectorSnafu, ExternalSnafu, Result as RecordBatchResult};
use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream};
use datafusion::logical_expr::utils as df_logical_expr_utils;
use datafusion_expr::expr::Expr;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
use datatypes::vectors::VectorRef;
Expand Down Expand Up @@ -113,7 +113,7 @@ impl FileRegion {

let mut aux_column_set = HashSet::new();
for scan_filter in scan_filters {
df_logical_expr_utils::expr_to_columns(scan_filter.df_expr(), &mut aux_column_set)
df_logical_expr_utils::expr_to_columns(scan_filter, &mut aux_column_set)
.context(ExtractColumnFromFilterSnafu)?;

let all_file_columns = aux_column_set
Expand Down
Loading

0 comments on commit b2cd442

Please sign in to comment.