Skip to content

Commit

Permalink
fix: type conversion rule reverses operands (#2871)
Browse files Browse the repository at this point in the history
Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia authored Dec 5, 2023
1 parent 6344b1e commit d6ef7a7
Showing 1 changed file with 53 additions and 20 deletions.
73 changes: 53 additions & 20 deletions src/query/src/optimizer/type_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,41 +150,39 @@ impl TypeConverter {
}
}

fn convert_type<'b>(&self, mut left: &'b Expr, mut right: &'b Expr) -> Result<(Expr, Expr)> {
fn convert_type<'b>(&self, left: &'b Expr, right: &'b Expr) -> Result<(Expr, Expr)> {
let left_type = self.column_type(left);
let right_type = self.column_type(right);

let mut reverse = false;
let left_type = match (&left_type, &right_type) {
let target_type = match (&left_type, &right_type) {
(Some(v), None) => v,
(None, Some(v)) => {
reverse = true;
std::mem::swap(&mut left, &mut right);
v
}
(None, Some(v)) => v,
_ => return Ok((left.clone(), right.clone())),
};

// only try to convert timestamp or boolean types
if !matches!(left_type, DataType::Timestamp(_, _))
&& !matches!(left_type, DataType::Boolean)
{
if !matches!(target_type, DataType::Timestamp(_, _) | DataType::Boolean) {
return Ok((left.clone(), right.clone()));
}

match (left, right) {
(Expr::Column(col), Expr::Literal(value)) => {
let casted_right = Self::cast_scalar_value(value, left_type)?;
let casted_right = Self::cast_scalar_value(value, target_type)?;
if casted_right.is_null() {
return Err(DataFusionError::Plan(format!(
"column:{col:?}. Casting value:{value:?} to {left_type:?} is invalid",
"column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid",
)));
}
if reverse {
Ok((Expr::Literal(casted_right), left.clone()))
} else {
Ok((left.clone(), Expr::Literal(casted_right)))
Ok((left.clone(), Expr::Literal(casted_right)))
}
(Expr::Literal(value), Expr::Column(col)) => {
let casted_left = Self::cast_scalar_value(value, target_type)?;
if casted_left.is_null() {
return Err(DataFusionError::Plan(format!(
"column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid",
)));
}
Ok((Expr::Literal(casted_left), right.clone()))
}
_ => Ok((left.clone(), right.clone())),
}
Expand Down Expand Up @@ -250,7 +248,6 @@ impl TreeNodeRewriter for TypeConverter {
ScalarValue::TimestampMillisecond(Some(i), _) => {
timestamp_to_timestamp_ms_expr(i, TimeUnit::Millisecond)
}

ScalarValue::TimestampMicrosecond(Some(i), _) => {
timestamp_to_timestamp_ms_expr(i, TimeUnit::Microsecond)
}
Expand Down Expand Up @@ -425,6 +422,13 @@ mod tests {
ScalarValue::Utf8(Some("1970-01-01 00:00:00+08:00".to_string())),
)))
.unwrap()
.filter(
Expr::Literal(ScalarValue::Utf8(Some(
"1970-01-01 00:00:00+08:00".to_string(),
)))
.lt_eq(Expr::Column(Column::from_name("column3"))),
)
.unwrap()
.aggregate(
Vec::<Expr>::new(),
vec![Expr::AggregateFunction(AggrExpr {
Expand All @@ -444,8 +448,37 @@ mod tests {
.unwrap();
let expected = String::from(
"Aggregate: groupBy=[[]], aggr=[[COUNT(column1)]]\
\n Filter: column3 > TimestampSecond(-28800, None)\
\n Values: (Int64(1), Float64(1), TimestampMillisecond(1, None))",
\n Filter: TimestampSecond(-28800, None) <= column3\
\n Filter: column3 > TimestampSecond(-28800, None)\
\n Values: (Int64(1), Float64(1), TimestampMillisecond(1, None))",
);
assert_eq!(format!("{}", transformed_plan.display_indent()), expected);
}

#[test]
fn test_reverse_non_ts_type() {
let plan =
LogicalPlanBuilder::values(vec![vec![Expr::Literal(ScalarValue::Float64(Some(1.0)))]])
.unwrap()
.filter(
Expr::Column(Column::from_name("column1"))
.gt_eq(Expr::Literal(ScalarValue::Utf8(Some("1.2345".to_string())))),
)
.unwrap()
.filter(
Expr::Literal(ScalarValue::Utf8(Some("1.2345".to_string())))
.lt(Expr::Column(Column::from_name("column1"))),
)
.unwrap()
.build()
.unwrap();
let transformed_plan = TypeConversionRule
.analyze(plan, &ConfigOptions::default())
.unwrap();
let expected = String::from(
"Filter: Utf8(\"1.2345\") < column1\
\n Filter: column1 >= Utf8(\"1.2345\")\
\n Values: (Float64(1))",
);
assert_eq!(format!("{}", transformed_plan.display_indent()), expected);
}
Expand Down

0 comments on commit d6ef7a7

Please sign in to comment.