From d6ef7a75debde9a12f248bbd7fb87880d7aef2ac Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 5 Dec 2023 11:25:29 +0800 Subject: [PATCH] fix: type conversion rule reverses operands (#2871) Signed-off-by: Ruihang Xia --- src/query/src/optimizer/type_conversion.rs | 73 ++++++++++++++++------ 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/src/query/src/optimizer/type_conversion.rs b/src/query/src/optimizer/type_conversion.rs index cfc8a5e33f3a..07cba75d7f44 100644 --- a/src/query/src/optimizer/type_conversion.rs +++ b/src/query/src/optimizer/type_conversion.rs @@ -150,41 +150,39 @@ impl TypeConverter { } } - fn convert_type<'b>(&self, mut left: &'b Expr, mut right: &'b Expr) -> Result<(Expr, Expr)> { + fn convert_type<'b>(&self, left: &'b Expr, right: &'b Expr) -> Result<(Expr, Expr)> { let left_type = self.column_type(left); let right_type = self.column_type(right); - let mut reverse = false; - let left_type = match (&left_type, &right_type) { + let target_type = match (&left_type, &right_type) { (Some(v), None) => v, - (None, Some(v)) => { - reverse = true; - std::mem::swap(&mut left, &mut right); - v - } + (None, Some(v)) => v, _ => return Ok((left.clone(), right.clone())), }; // only try to convert timestamp or boolean types - if !matches!(left_type, DataType::Timestamp(_, _)) - && !matches!(left_type, DataType::Boolean) - { + if !matches!(target_type, DataType::Timestamp(_, _) | DataType::Boolean) { return Ok((left.clone(), right.clone())); } match (left, right) { (Expr::Column(col), Expr::Literal(value)) => { - let casted_right = Self::cast_scalar_value(value, left_type)?; + let casted_right = Self::cast_scalar_value(value, target_type)?; if casted_right.is_null() { return Err(DataFusionError::Plan(format!( - "column:{col:?}. Casting value:{value:?} to {left_type:?} is invalid", + "column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid", ))); } - if reverse { - Ok((Expr::Literal(casted_right), left.clone())) - } else { - Ok((left.clone(), Expr::Literal(casted_right))) + Ok((left.clone(), Expr::Literal(casted_right))) + } + (Expr::Literal(value), Expr::Column(col)) => { + let casted_left = Self::cast_scalar_value(value, target_type)?; + if casted_left.is_null() { + return Err(DataFusionError::Plan(format!( + "column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid", + ))); } + Ok((Expr::Literal(casted_left), right.clone())) } _ => Ok((left.clone(), right.clone())), } @@ -250,7 +248,6 @@ impl TreeNodeRewriter for TypeConverter { ScalarValue::TimestampMillisecond(Some(i), _) => { timestamp_to_timestamp_ms_expr(i, TimeUnit::Millisecond) } - ScalarValue::TimestampMicrosecond(Some(i), _) => { timestamp_to_timestamp_ms_expr(i, TimeUnit::Microsecond) } @@ -425,6 +422,13 @@ mod tests { ScalarValue::Utf8(Some("1970-01-01 00:00:00+08:00".to_string())), ))) .unwrap() + .filter( + Expr::Literal(ScalarValue::Utf8(Some( + "1970-01-01 00:00:00+08:00".to_string(), + ))) + .lt_eq(Expr::Column(Column::from_name("column3"))), + ) + .unwrap() .aggregate( Vec::::new(), vec![Expr::AggregateFunction(AggrExpr { @@ -444,8 +448,37 @@ mod tests { .unwrap(); let expected = String::from( "Aggregate: groupBy=[[]], aggr=[[COUNT(column1)]]\ - \n Filter: column3 > TimestampSecond(-28800, None)\ - \n Values: (Int64(1), Float64(1), TimestampMillisecond(1, None))", + \n Filter: TimestampSecond(-28800, None) <= column3\ + \n Filter: column3 > TimestampSecond(-28800, None)\ + \n Values: (Int64(1), Float64(1), TimestampMillisecond(1, None))", + ); + assert_eq!(format!("{}", transformed_plan.display_indent()), expected); + } + + #[test] + fn test_reverse_non_ts_type() { + let plan = + LogicalPlanBuilder::values(vec![vec![Expr::Literal(ScalarValue::Float64(Some(1.0)))]]) + .unwrap() + .filter( + Expr::Column(Column::from_name("column1")) + .gt_eq(Expr::Literal(ScalarValue::Utf8(Some("1.2345".to_string())))), + ) + .unwrap() + .filter( + Expr::Literal(ScalarValue::Utf8(Some("1.2345".to_string()))) + .lt(Expr::Column(Column::from_name("column1"))), + ) + .unwrap() + .build() + .unwrap(); + let transformed_plan = TypeConversionRule + .analyze(plan, &ConfigOptions::default()) + .unwrap(); + let expected = String::from( + "Filter: Utf8(\"1.2345\") < column1\ + \n Filter: column1 >= Utf8(\"1.2345\")\ + \n Values: (Float64(1))", ); assert_eq!(format!("{}", transformed_plan.display_indent()), expected); }