From 0fab1ba963eafa4061c896ac295f330c45f92143 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Mon, 19 Feb 2024 18:02:39 +0800 Subject: [PATCH 1/2] fix: when an expression such as if returns a result, return the result as the expected type --- src/binder/expr.rs | 25 +++++++++++++++++-------- src/expression/evaluator.rs | 37 ++++++++++++++++++++++--------------- 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 90d5e4d1..dca6e527 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -128,6 +128,17 @@ impl<'a, T: Transaction> Binder<'a, T> { results, else_result, } => { + let fn_check_ty = |ty: &mut LogicalType, result_ty| { + if result_ty != LogicalType::SqlNull { + if ty == &LogicalType::SqlNull { + *ty = result_ty; + } else if ty != &result_ty { + return Err(DatabaseError::Incomparable(*ty, result_ty)); + } + } + + Ok(()) + }; let mut operand_expr = None; let mut ty = LogicalType::SqlNull; if let Some(expr) = operand { @@ -138,19 +149,17 @@ impl<'a, T: Transaction> Binder<'a, T> { let result = self.bind_expr(&results[i])?; let result_ty = result.return_type(); - if result_ty != LogicalType::SqlNull { - if ty == LogicalType::SqlNull { - ty = result_ty; - } else if ty != result_ty { - return Err(DatabaseError::Incomparable(ty, result_ty)); - } - } + fn_check_ty(&mut ty, result_ty)?; expr_pairs.push((self.bind_expr(&conditions[i])?, result)) } let mut else_expr = None; if let Some(expr) = else_result { - else_expr = Some(Box::new(self.bind_expr(expr)?)); + let temp_expr = Box::new(self.bind_expr(expr)?); + let else_ty = temp_expr.return_type(); + + fn_check_ty(&mut ty, else_ty)?; + else_expr = Some(temp_expr); } Ok(ScalarExpression::CaseWhen { diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index f722c963..40bee4f0 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -28,6 +28,13 @@ macro_rules! eval_to_num { impl ScalarExpression { pub fn eval(&self, tuple: &Tuple) -> Result { + let check_cast = |value: ValueRef, return_type: &LogicalType| { + if value.logical_type() != *return_type { + return Ok(Arc::new(DataValue::clone(&value).cast(return_type)?)); + } + Ok(value) + }; + match self { ScalarExpression::Constant(val) => Ok(val.clone()), ScalarExpression::ColumnRef(col) => { @@ -186,39 +193,39 @@ impl ScalarExpression { condition, left_expr, right_expr, - .. + ty } => { if condition.eval(tuple)?.is_true()? { - left_expr.eval(tuple) + check_cast(left_expr.eval(tuple)?, ty) } else { - right_expr.eval(tuple) + check_cast(right_expr.eval(tuple)?, ty) } } ScalarExpression::IfNull { left_expr, right_expr, - .. + ty } => { - let value = left_expr.eval(tuple)?; + let mut value = left_expr.eval(tuple)?; if value.is_null() { - return right_expr.eval(tuple); + value = right_expr.eval(tuple)?; } - Ok(value) + check_cast(value, ty) } ScalarExpression::NullIf { left_expr, right_expr, - .. + ty } => { - let value = left_expr.eval(tuple)?; + let mut value = left_expr.eval(tuple)?; if right_expr.eval(tuple)? == value { - return Ok(NULL_VALUE.clone()); + value = NULL_VALUE.clone(); } - Ok(value) + check_cast(value, ty) } - ScalarExpression::Coalesce { exprs, .. } => { + ScalarExpression::Coalesce { exprs, ty } => { let mut value = None; for expr in exprs { @@ -229,13 +236,13 @@ impl ScalarExpression { break; } } - Ok(value.unwrap_or_else(|| NULL_VALUE.clone())) + check_cast(value.unwrap_or_else(|| NULL_VALUE.clone()), ty) } ScalarExpression::CaseWhen { operand_expr, expr_pairs, else_expr, - .. + ty } => { let mut operand_value = None; let mut result = None; @@ -262,7 +269,7 @@ impl ScalarExpression { result = Some(expr.eval(tuple)?); } } - Ok(result.unwrap_or_else(|| NULL_VALUE.clone())) + check_cast(result.unwrap_or_else(|| NULL_VALUE.clone()), ty) } } } From b62506f9c11bc572f38b9073fe1a90d8aacaddc2 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Mon, 19 Feb 2024 18:09:12 +0800 Subject: [PATCH 2/2] code fmt --- src/expression/evaluator.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 40bee4f0..e60a699a 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -193,7 +193,7 @@ impl ScalarExpression { condition, left_expr, right_expr, - ty + ty, } => { if condition.eval(tuple)?.is_true()? { check_cast(left_expr.eval(tuple)?, ty) @@ -204,19 +204,19 @@ impl ScalarExpression { ScalarExpression::IfNull { left_expr, right_expr, - ty + ty, } => { let mut value = left_expr.eval(tuple)?; if value.is_null() { - value = right_expr.eval(tuple)?; + value = right_expr.eval(tuple)?; } check_cast(value, ty) } ScalarExpression::NullIf { left_expr, right_expr, - ty + ty, } => { let mut value = left_expr.eval(tuple)?; @@ -242,7 +242,7 @@ impl ScalarExpression { operand_expr, expr_pairs, else_expr, - ty + ty, } => { let mut operand_value = None; let mut result = None;