From ebaf213441007b8ca09df971d6eee6ac50ce88e8 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 19 Nov 2024 16:28:08 +0800 Subject: [PATCH] refactor(optimizer): record error contexts when casting structs Signed-off-by: Bugen Zhao --- src/frontend/src/expr/function_call.rs | 53 ++++++------- src/frontend/src/expr/mod.rs | 2 +- src/frontend/src/expr/type_inference/cast.rs | 79 +++++++++++++++----- src/frontend/src/expr/type_inference/mod.rs | 2 +- 4 files changed, 89 insertions(+), 47 deletions(-) diff --git a/src/frontend/src/expr/function_call.rs b/src/frontend/src/expr/function_call.rs index af1f84b321eb5..101bdea17c004 100644 --- a/src/frontend/src/expr/function_call.rs +++ b/src/frontend/src/expr/function_call.rs @@ -14,12 +14,13 @@ use itertools::Itertools; use risingwave_common::catalog::Schema; +use risingwave_common::error::{bail, def_anyhow_newtype}; use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_common::util::iter_util::ZipEqFast; -use thiserror::Error; use thiserror_ext::AsReport; -use super::{cast_ok, infer_some_all, infer_type, CastContext, Expr, ExprImpl, Literal}; +use super::type_inference::cast; +use super::{infer_some_all, infer_type, CastContext, Expr, ExprImpl, Literal}; use crate::error::{ErrorCode, Result as RwResult}; use crate::expr::{ExprDisplay, ExprType, ExprVisitor, ImpureAnalyzer}; @@ -144,22 +145,23 @@ impl FunctionCall { // else when eager parsing fails, just proceed as normal. // Some callers are not ready to handle `'a'::int` error here. } + let source = child.return_type(); if source == target { - Ok(()) - // Casting from unknown is allowed in all context. And PostgreSQL actually does the parsing - // in frontend. - } else if child.is_untyped() || cast_ok(&source, &target, allows) { - // Always Ok below. Safe to mutate `child`. - let owned = std::mem::replace(child, ExprImpl::literal_bool(false)); - *child = Self::new_unchecked(ExprType::Cast, vec![owned], target).into(); - Ok(()) + return Ok(()); + } + + if child.is_untyped() { + // Casting from unknown is allowed in all context. And PostgreSQL actually does the parsing + // in frontend. } else { - Err(CastError(format!( - "cannot cast type \"{}\" to \"{}\" in {:?} context", - source, target, allows - ))) + cast(&source, &target, allows)?; } + + // Always Ok below. Safe to mutate `child`. + let owned = std::mem::replace(child, ExprImpl::literal_bool(false)); + *child = Self::new_unchecked(ExprType::Cast, vec![owned], target).into(); + Ok(()) } /// Cast a `ROW` expression to the target type. We intentionally disallow casting arbitrary @@ -170,13 +172,13 @@ impl FunctionCall { target_type: DataType, allows: CastContext, ) -> Result<(), CastError> { + // Can only cast to a struct type. let DataType::Struct(t) = &target_type else { - return Err(CastError(format!( - "cannot cast type \"{}\" to \"{}\" in {:?} context", - func.return_type(), + bail!( + "cannot cast type \"{}\" to \"{}\"", + func.return_type(), // typically "record" target_type, - allows - ))); + ); }; match t.len().cmp(&func.inputs.len()) { std::cmp::Ordering::Equal => { @@ -189,10 +191,8 @@ impl FunctionCall { func.return_type = target_type; Ok(()) } - std::cmp::Ordering::Less => Err(CastError("Input has too few columns.".to_string())), - std::cmp::Ordering::Greater => { - Err(CastError("Input has too many columns.".to_string())) - } + std::cmp::Ordering::Less => bail!("input has too few columns"), + std::cmp::Ordering::Greater => bail!("input has too many columns"), } } @@ -423,9 +423,10 @@ pub fn is_row_function(expr: &ExprImpl) -> bool { false } -#[derive(Debug, Error)] -#[error("{0}")] -pub struct CastError(pub(super) String); +def_anyhow_newtype! { + pub CastError, +} +pub type CastResult = Result; impl From for ErrorCode { fn from(value: CastError) -> Self { diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index c7acdfa5c4a3c..bc36679b3547f 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -300,7 +300,7 @@ impl ExprImpl { ))), DataType::Int32 => Ok(self), dt if dt.is_int() => Ok(self.cast_explicit(DataType::Int32)?), - _ => Err(CastError("Unsupported input type".to_string())), + _ => bail!("unsupported input type"), } } diff --git a/src/frontend/src/expr/type_inference/cast.rs b/src/frontend/src/expr/type_inference/cast.rs index 51441c3f70c5b..66ad6969ae316 100644 --- a/src/frontend/src/expr/type_inference/cast.rs +++ b/src/frontend/src/expr/type_inference/cast.rs @@ -15,12 +15,15 @@ use std::collections::BTreeMap; use std::sync::LazyLock; +use anyhow::Context; use itertools::Itertools as _; use parse_display::Display; +use risingwave_common::error::bail; use risingwave_common::types::{DataType, DataTypeName}; use risingwave_common::util::iter_util::ZipEqFast; use crate::error::ErrorCode; +use crate::expr::function_call::{CastError, CastResult}; use crate::expr::{Expr as _, ExprImpl, InputRef, Literal}; /// Find the least restrictive type. Used by `VALUES`, `CASE`, `UNION`, etc. @@ -114,12 +117,45 @@ pub fn align_array_and_element( Ok(array_type) } +fn canmeh(ok: bool) -> CastResult { + if ok { + Ok(()) + } else { + bail!("") + } +} +fn cannot() -> CastResult { + canmeh(false) +} + +pub fn cast(source: &DataType, target: &DataType, allows: CastContext) -> Result<(), CastError> { + macro_rules! any { + ($f:ident) => { + source.$f() || target.$f() + }; + } + + if any!(is_struct) { + cast_ok_struct(source, target, allows) + } else if any!(is_array) { + cast_ok_array(source, target, allows) + } else if any!(is_map) { + cast_ok_map(source, target, allows) + } else { + canmeh(cast_ok_base(source, target, allows)) + } + .with_context(|| { + format!( + "cannot cast type \"{}\" to \"{}\" in {:?} context", + source, target, allows + ) + }) + .map_err(Into::into) +} + /// Checks whether casting from `source` to `target` is ok in `allows` context. pub fn cast_ok(source: &DataType, target: &DataType, allows: CastContext) -> bool { - cast_ok_struct(source, target, allows) - || cast_ok_array(source, target, allows) - || cast_ok_map(source, target, allows) - || cast_ok_base(source, target, allows) + cast(source, target, allows).is_ok() } /// Checks whether casting from `source` to `target` is ok in `allows` context. @@ -128,52 +164,57 @@ pub fn cast_ok_base(source: &DataType, target: &DataType, allows: CastContext) - matches!(CAST_MAP.get(&(source.into(), target.into())), Some(context) if *context <= allows) } -fn cast_ok_struct(source: &DataType, target: &DataType, allows: CastContext) -> bool { +fn cast_ok_struct(source: &DataType, target: &DataType, allows: CastContext) -> CastResult { match (source, target) { (DataType::Struct(lty), DataType::Struct(rty)) => { if lty.is_empty() || rty.is_empty() { unreachable!("record type should be already processed at this point"); } if lty.len() != rty.len() { - // only cast structs of the same length - return false; + bail!("cannot cast structs of different lengths"); } // ... and all fields are castable lty.types() .zip_eq_fast(rty.types()) - .all(|(src, dst)| src == dst || cast_ok(src, dst, allows)) + .try_for_each(|(src, dst)| { + if src == dst { + Ok(()) + } else { + cast(src, dst, allows) + } + }) } // The automatic casts to string types are treated as assignment casts, while the automatic // casts from string types are explicit-only. // https://www.postgresql.org/docs/14/sql-createcast.html#id-1.9.3.58.7.4 - (DataType::Varchar, DataType::Struct(_)) => CastContext::Explicit <= allows, - (DataType::Struct(_), DataType::Varchar) => CastContext::Assign <= allows, - _ => false, + (DataType::Varchar, DataType::Struct(_)) => canmeh(CastContext::Explicit <= allows), + (DataType::Struct(_), DataType::Varchar) => canmeh(CastContext::Assign <= allows), + _ => cannot(), } } -fn cast_ok_array(source: &DataType, target: &DataType, allows: CastContext) -> bool { +fn cast_ok_array(source: &DataType, target: &DataType, allows: CastContext) -> CastResult { match (source, target) { (DataType::List(source_elem), DataType::List(target_elem)) => { - cast_ok(source_elem, target_elem, allows) + cast(source_elem, target_elem, allows) } // The automatic casts to string types are treated as assignment casts, while the automatic // casts from string types are explicit-only. // https://www.postgresql.org/docs/14/sql-createcast.html#id-1.9.3.58.7.4 - (DataType::Varchar, DataType::List(_)) => CastContext::Explicit <= allows, - (DataType::List(_), DataType::Varchar) => CastContext::Assign <= allows, - _ => false, + (DataType::Varchar, DataType::List(_)) => canmeh(CastContext::Explicit <= allows), + (DataType::List(_), DataType::Varchar) => canmeh(CastContext::Assign <= allows), + _ => cannot(), } } -fn cast_ok_map(source: &DataType, target: &DataType, allows: CastContext) -> bool { +fn cast_ok_map(source: &DataType, target: &DataType, allows: CastContext) -> CastResult { match (source, target) { - (DataType::Map(source_elem), DataType::Map(target_elem)) => cast_ok( + (DataType::Map(source_elem), DataType::Map(target_elem)) => cast( &source_elem.clone().into_list(), &target_elem.clone().into_list(), allows, ), - _ => false, + _ => cannot(), } } diff --git a/src/frontend/src/expr/type_inference/mod.rs b/src/frontend/src/expr/type_inference/mod.rs index 2845f05ec0dae..7f10d3aec3f6a 100644 --- a/src/frontend/src/expr/type_inference/mod.rs +++ b/src/frontend/src/expr/type_inference/mod.rs @@ -18,6 +18,6 @@ mod cast; mod func; pub use cast::{ - align_types, cast_map_array, cast_ok, cast_ok_base, cast_sigs, CastContext, CastSig, + align_types, cast_map_array, cast_ok, cast, cast_ok_base, cast_sigs, CastContext, CastSig, }; pub use func::{infer_some_all, infer_type, infer_type_name, infer_type_with_sigmap, FuncSign};