diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 90d5e4d1..dca6e527 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -128,6 +128,17 @@ impl<'a, T: Transaction> Binder<'a, T> { results, else_result, } => { + let fn_check_ty = |ty: &mut LogicalType, result_ty| { + if result_ty != LogicalType::SqlNull { + if ty == &LogicalType::SqlNull { + *ty = result_ty; + } else if ty != &result_ty { + return Err(DatabaseError::Incomparable(*ty, result_ty)); + } + } + + Ok(()) + }; let mut operand_expr = None; let mut ty = LogicalType::SqlNull; if let Some(expr) = operand { @@ -138,19 +149,17 @@ impl<'a, T: Transaction> Binder<'a, T> { let result = self.bind_expr(&results[i])?; let result_ty = result.return_type(); - if result_ty != LogicalType::SqlNull { - if ty == LogicalType::SqlNull { - ty = result_ty; - } else if ty != result_ty { - return Err(DatabaseError::Incomparable(ty, result_ty)); - } - } + fn_check_ty(&mut ty, result_ty)?; expr_pairs.push((self.bind_expr(&conditions[i])?, result)) } let mut else_expr = None; if let Some(expr) = else_result { - else_expr = Some(Box::new(self.bind_expr(expr)?)); + let temp_expr = Box::new(self.bind_expr(expr)?); + let else_ty = temp_expr.return_type(); + + fn_check_ty(&mut ty, else_ty)?; + else_expr = Some(temp_expr); } Ok(ScalarExpression::CaseWhen { diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index f722c963..e60a699a 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -28,6 +28,13 @@ macro_rules! eval_to_num { impl ScalarExpression { pub fn eval(&self, tuple: &Tuple) -> Result { + let check_cast = |value: ValueRef, return_type: &LogicalType| { + if value.logical_type() != *return_type { + return Ok(Arc::new(DataValue::clone(&value).cast(return_type)?)); + } + Ok(value) + }; + match self { ScalarExpression::Constant(val) => Ok(val.clone()), ScalarExpression::ColumnRef(col) => { @@ -186,39 +193,39 @@ impl ScalarExpression { condition, left_expr, right_expr, - .. + ty, } => { if condition.eval(tuple)?.is_true()? { - left_expr.eval(tuple) + check_cast(left_expr.eval(tuple)?, ty) } else { - right_expr.eval(tuple) + check_cast(right_expr.eval(tuple)?, ty) } } ScalarExpression::IfNull { left_expr, right_expr, - .. + ty, } => { - let value = left_expr.eval(tuple)?; + let mut value = left_expr.eval(tuple)?; if value.is_null() { - return right_expr.eval(tuple); + value = right_expr.eval(tuple)?; } - Ok(value) + check_cast(value, ty) } ScalarExpression::NullIf { left_expr, right_expr, - .. + ty, } => { - let value = left_expr.eval(tuple)?; + let mut value = left_expr.eval(tuple)?; if right_expr.eval(tuple)? == value { - return Ok(NULL_VALUE.clone()); + value = NULL_VALUE.clone(); } - Ok(value) + check_cast(value, ty) } - ScalarExpression::Coalesce { exprs, .. } => { + ScalarExpression::Coalesce { exprs, ty } => { let mut value = None; for expr in exprs { @@ -229,13 +236,13 @@ impl ScalarExpression { break; } } - Ok(value.unwrap_or_else(|| NULL_VALUE.clone())) + check_cast(value.unwrap_or_else(|| NULL_VALUE.clone()), ty) } ScalarExpression::CaseWhen { operand_expr, expr_pairs, else_expr, - .. + ty, } => { let mut operand_value = None; let mut result = None; @@ -262,7 +269,7 @@ impl ScalarExpression { result = Some(expr.eval(tuple)?); } } - Ok(result.unwrap_or_else(|| NULL_VALUE.clone())) + check_cast(result.unwrap_or_else(|| NULL_VALUE.clone()), ty) } } }