diff --git a/src/frontend/src/binder/expr/column.rs b/src/frontend/src/binder/expr/column.rs index 9fb17c6e43520..c2fd536c1ccc4 100644 --- a/src/frontend/src/binder/expr/column.rs +++ b/src/frontend/src/binder/expr/column.rs @@ -45,7 +45,7 @@ impl Binder { // to the name of the defined sql udf parameters stored in `udf_context`. // If so, we will treat this bind as an special bind, the actual expression // stored in `udf_context` will then be bound instead of binding the non-existing column. - if self.udf_binding_flag { + if self.udf_context.global_count() != 0 { if let Some(expr) = self.udf_context.get_expr(&column_name) { return Ok(expr.clone()); } else { diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 520bf9737ca35..efd7e3f0e423d 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -239,23 +239,13 @@ impl Binder { } if let Ok(expr) = UdfContext::extract_udf_expression(ast) { - self.set_udf_binding_flag(); let bind_result = self.bind_expr(expr); // We should properly decrement global count after a successful binding - // Since the subsequent `unset flag` operation relies on global counting + // Since the subsequent probe operation in `bind_column` or + // `bind_parameter` relies on global counting self.udf_context.decr_global_count(); - // Only the top-most sql udf binding should unset the flag - // Otherwise the subsequent binding may not be able to - // find the corresponding context, consider the following example: - // e.g., `select add_wrapper(a INT, b INT) returns int language sql as 'select add(a, b) + a';` - // The inner `add` should not unset the flag, otherwise the `a` will be treated as - // a normal column, which is then invalid in this context. - if self.udf_context.global_count() == 0 { - self.unset_udf_binding_flag(); - } - // Restore context information for subsequent binding self.udf_context.update_context(stashed_udf_context); diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index 7da1518ceb552..c7e17e79e6425 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -386,11 +386,13 @@ impl Binder { fn bind_parameter(&mut self, index: u64) -> Result { // Special check for sql udf - // Note: This is specific to anonymous sql udf, since the + // Note: This is specific to sql udf with unnamed parameters, since the // parameters will be parsed and treated as `Parameter`. // For detailed explanation, consider checking `bind_column`. - if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) { - return Ok(expr.clone()); + if self.udf_context.global_count() != 0 { + if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) { + return Ok(expr.clone()); + } } Ok(Parameter::new(index, self.param_types.clone()).into()) @@ -469,6 +471,60 @@ impl Binder { Ok(func_call.into()) } + /// The optimization check for the following case-when expression pattern + /// e.g., select case 1 when (...) then (...) else (...) end; + fn check_constant_case_when_optimization( + &mut self, + conditions: Vec, + results_expr: Vec, + operand: Option>, + fallback: Option, + constant_case_when_eval_inputs: &mut Vec, + ) -> bool { + // The operand value to be compared later + let operand_value; + + if let Some(operand) = operand { + let Ok(operand) = self.bind_expr_inner(*operand) else { + return false; + }; + if !operand.is_const() { + return false; + } + operand_value = operand; + } else { + return false; + } + + for (condition, result) in zip_eq_fast(conditions, results_expr) { + if let Expr::Value(_) = condition.clone() { + let Ok(res) = self.bind_expr_inner(condition.clone()) else { + return false; + }; + // Found a match + if res == operand_value { + constant_case_when_eval_inputs.push(result); + return true; + } + } else { + return false; + } + } + + // Otherwise this will eventually go through fallback arm + debug_assert!( + constant_case_when_eval_inputs.is_empty(), + "expect `inputs` to be empty" + ); + + let Some(fallback) = fallback else { + return false; + }; + + constant_case_when_eval_inputs.push(fallback); + true + } + /// The helper function to check if the current case-when /// expression in `bind_case` could be optimized /// into `ConstantLookupExpression` @@ -491,6 +547,12 @@ impl Binder { let Ok(operand) = self.bind_expr_inner(*operand) else { return false; }; + // This optimization should be done in subsequent optimization phase + // if the operand is const + // e.g., select case 1 when 1 then 114514 else 1919810 end; + if operand.is_const() { + return false; + } constant_lookup_inputs.push(operand); } else { return false; @@ -504,7 +566,7 @@ impl Binder { constant_lookup_inputs.push(input); } else { // If at least one condition is not in the simple form / not constant, - // we can NOT do the subsequent optimization then + // we can NOT do the subsequent optimization pass return false; } @@ -536,6 +598,26 @@ impl Binder { .transpose()?; let mut constant_lookup_inputs = Vec::new(); + let mut constant_case_when_eval_inputs = Vec::new(); + + let constant_case_when_flag = self.check_constant_case_when_optimization( + conditions.clone(), + results_expr.clone(), + operand.clone(), + else_result_expr.clone(), + &mut constant_case_when_eval_inputs, + ); + + if constant_case_when_flag { + // Here we reuse the `ConstantLookup` as the `FunctionCall` + // to avoid creating a dummy `ConstCaseWhenEval` expression type + // since we do not need to go through backend + return Ok(FunctionCall::new( + ExprType::ConstantLookup, + constant_case_when_eval_inputs, + )? + .into()); + } // See if the case-when expression can be optimized let optimize_flag = self.check_bind_case_optimization( diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index fc6e825eab979..083d5d990c91e 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -122,10 +122,6 @@ pub struct Binder { /// The sql udf context that will be used during binding phase udf_context: UdfContext, - - /// Udf binding flag, used to distinguish between - /// columns and named parameters during sql udf binding - udf_binding_flag: bool, } #[derive(Clone, Debug, Default)] @@ -509,14 +505,6 @@ impl Binder { pub fn udf_context_mut(&mut self) -> &mut UdfContext { &mut self.udf_context } - - pub fn set_udf_binding_flag(&mut self) { - self.udf_binding_flag = true; - } - - pub fn unset_udf_binding_flag(&mut self) { - self.udf_binding_flag = false; - } } /// The column name stored in [`BindContext`] for a column without an alias. diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index 346cbda0effd7..5b8d474813d6d 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -188,10 +188,8 @@ pub async fn handle_create_sql_function( arg_names.clone(), )); - binder.set_udf_binding_flag(); - // Need to set the initial global count to 1 - // otherwise the flag will be unset during the semantic check + // otherwise the context will not be probed during the semantic check binder.udf_context_mut().incr_global_count(); if let Ok(expr) = UdfContext::extract_udf_expression(ast) {