Skip to content

Commit

Permalink
remove udf_binding_flag
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhseh committed Feb 4, 2024
1 parent 82f2ab2 commit 449465c
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl Binder {
// to the name of the defined sql udf parameters stored in `udf_context`.
// If so, we will treat this bind as an special bind, the actual expression
// stored in `udf_context` will then be bound instead of binding the non-existing column.
if self.udf_binding_flag {
if self.udf_context.global_count() != 0 {
if let Some(expr) = self.udf_context.get_expr(&column_name) {
return Ok(expr.clone());
} else {
Expand Down
14 changes: 2 additions & 12 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,23 +239,13 @@ impl Binder {
}

if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
self.set_udf_binding_flag();
let bind_result = self.bind_expr(expr);

// We should properly decrement global count after a successful binding
// Since the subsequent `unset flag` operation relies on global counting
// Since the subsequent probe operation in `bind_column` or
// `bind_parameter` relies on global counting
self.udf_context.decr_global_count();

// Only the top-most sql udf binding should unset the flag
// Otherwise the subsequent binding may not be able to
// find the corresponding context, consider the following example:
// e.g., `select add_wrapper(a INT, b INT) returns int language sql as 'select add(a, b) + a';`
// The inner `add` should not unset the flag, otherwise the `a` will be treated as
// a normal column, which is then invalid in this context.
if self.udf_context.global_count() == 0 {
self.unset_udf_binding_flag();
}

// Restore context information for subsequent binding
self.udf_context.update_context(stashed_udf_context);

Expand Down
90 changes: 86 additions & 4 deletions src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,13 @@ impl Binder {

fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
// Special check for sql udf
// Note: This is specific to anonymous sql udf, since the
// Note: This is specific to sql udf with unnamed parameters, since the
// parameters will be parsed and treated as `Parameter`.
// For detailed explanation, consider checking `bind_column`.
if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
return Ok(expr.clone());
if self.udf_context.global_count() != 0 {
if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
return Ok(expr.clone());
}
}

Ok(Parameter::new(index, self.param_types.clone()).into())
Expand Down Expand Up @@ -469,6 +471,60 @@ impl Binder {
Ok(func_call.into())
}

/// The optimization check for the following case-when expression pattern
/// e.g., select case 1 when (...) then (...) else (...) end;
fn check_constant_case_when_optimization(
&mut self,
conditions: Vec<Expr>,
results_expr: Vec<ExprImpl>,
operand: Option<Box<Expr>>,
fallback: Option<ExprImpl>,
constant_case_when_eval_inputs: &mut Vec<ExprImpl>,
) -> bool {
// The operand value to be compared later
let operand_value;

if let Some(operand) = operand {
let Ok(operand) = self.bind_expr_inner(*operand) else {
return false;
};
if !operand.is_const() {
return false;
}
operand_value = operand;
} else {
return false;
}

for (condition, result) in zip_eq_fast(conditions, results_expr) {
if let Expr::Value(_) = condition.clone() {
let Ok(res) = self.bind_expr_inner(condition.clone()) else {
return false;
};
// Found a match
if res == operand_value {
constant_case_when_eval_inputs.push(result);
return true;
}
} else {
return false;
}
}

// Otherwise this will eventually go through fallback arm
debug_assert!(
constant_case_when_eval_inputs.is_empty(),
"expect `inputs` to be empty"
);

let Some(fallback) = fallback else {
return false;
};

constant_case_when_eval_inputs.push(fallback);
true
}

/// The helper function to check if the current case-when
/// expression in `bind_case` could be optimized
/// into `ConstantLookupExpression`
Expand All @@ -491,6 +547,12 @@ impl Binder {
let Ok(operand) = self.bind_expr_inner(*operand) else {
return false;
};
// This optimization should be done in subsequent optimization phase
// if the operand is const
// e.g., select case 1 when 1 then 114514 else 1919810 end;
if operand.is_const() {
return false;
}
constant_lookup_inputs.push(operand);
} else {
return false;
Expand All @@ -504,7 +566,7 @@ impl Binder {
constant_lookup_inputs.push(input);
} else {
// If at least one condition is not in the simple form / not constant,
// we can NOT do the subsequent optimization then
// we can NOT do the subsequent optimization pass
return false;
}

Expand Down Expand Up @@ -536,6 +598,26 @@ impl Binder {
.transpose()?;

let mut constant_lookup_inputs = Vec::new();
let mut constant_case_when_eval_inputs = Vec::new();

let constant_case_when_flag = self.check_constant_case_when_optimization(
conditions.clone(),
results_expr.clone(),
operand.clone(),
else_result_expr.clone(),
&mut constant_case_when_eval_inputs,
);

if constant_case_when_flag {
// Here we reuse the `ConstantLookup` as the `FunctionCall`
// to avoid creating a dummy `ConstCaseWhenEval` expression type
// since we do not need to go through backend
return Ok(FunctionCall::new(
ExprType::ConstantLookup,
constant_case_when_eval_inputs,
)?
.into());
}

// See if the case-when expression can be optimized
let optimize_flag = self.check_bind_case_optimization(
Expand Down
12 changes: 0 additions & 12 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ pub struct Binder {

/// The sql udf context that will be used during binding phase
udf_context: UdfContext,

/// Udf binding flag, used to distinguish between
/// columns and named parameters during sql udf binding
udf_binding_flag: bool,
}

#[derive(Clone, Debug, Default)]
Expand Down Expand Up @@ -509,14 +505,6 @@ impl Binder {
pub fn udf_context_mut(&mut self) -> &mut UdfContext {
&mut self.udf_context
}

pub fn set_udf_binding_flag(&mut self) {
self.udf_binding_flag = true;
}

pub fn unset_udf_binding_flag(&mut self) {
self.udf_binding_flag = false;
}
}

/// The column name stored in [`BindContext`] for a column without an alias.
Expand Down
4 changes: 1 addition & 3 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,8 @@ pub async fn handle_create_sql_function(
arg_names.clone(),
));

binder.set_udf_binding_flag();

// Need to set the initial global count to 1
// otherwise the flag will be unset during the semantic check
// otherwise the context will not be probed during the semantic check
binder.udf_context_mut().incr_global_count();

if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
Expand Down

0 comments on commit 449465c

Please sign in to comment.