Skip to content

Commit

Permalink
refactor(optimizer): record error contexts when casting structs
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao committed Nov 20, 2024
1 parent c1162ab commit 4e686e0
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 47 deletions.
53 changes: 27 additions & 26 deletions src/frontend/src/expr/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

use itertools::Itertools;
use risingwave_common::catalog::Schema;
use risingwave_common::error::{bail, def_anyhow_newtype};
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_common::util::iter_util::ZipEqFast;
use thiserror::Error;
use thiserror_ext::AsReport;

use super::{cast_ok, infer_some_all, infer_type, CastContext, Expr, ExprImpl, Literal};
use super::type_inference::cast;
use super::{infer_some_all, infer_type, CastContext, Expr, ExprImpl, Literal};
use crate::error::{ErrorCode, Result as RwResult};
use crate::expr::{ExprDisplay, ExprType, ExprVisitor, ImpureAnalyzer};

Expand Down Expand Up @@ -144,22 +145,23 @@ impl FunctionCall {
// else when eager parsing fails, just proceed as normal.
// Some callers are not ready to handle `'a'::int` error here.
}

let source = child.return_type();
if source == target {
Ok(())
// Casting from unknown is allowed in all context. And PostgreSQL actually does the parsing
// in frontend.
} else if child.is_untyped() || cast_ok(&source, &target, allows) {
// Always Ok below. Safe to mutate `child`.
let owned = std::mem::replace(child, ExprImpl::literal_bool(false));
*child = Self::new_unchecked(ExprType::Cast, vec![owned], target).into();
Ok(())
return Ok(());
}

if child.is_untyped() {
// Casting from unknown is allowed in all context. And PostgreSQL actually does the parsing
// in frontend.
} else {
Err(CastError(format!(
"cannot cast type \"{}\" to \"{}\" in {:?} context",
source, target, allows
)))
cast(&source, &target, allows)?;
}

// Always Ok below. Safe to mutate `child`.
let owned = std::mem::replace(child, ExprImpl::literal_bool(false));
*child = Self::new_unchecked(ExprType::Cast, vec![owned], target).into();
Ok(())
}

/// Cast a `ROW` expression to the target type. We intentionally disallow casting arbitrary
Expand All @@ -170,13 +172,13 @@ impl FunctionCall {
target_type: DataType,
allows: CastContext,
) -> Result<(), CastError> {
// Can only cast to a struct type.
let DataType::Struct(t) = &target_type else {
return Err(CastError(format!(
"cannot cast type \"{}\" to \"{}\" in {:?} context",
func.return_type(),
bail!(
"cannot cast type \"{}\" to \"{}\"",
func.return_type(), // typically "record"
target_type,
allows
)));
);
};
match t.len().cmp(&func.inputs.len()) {
std::cmp::Ordering::Equal => {
Expand All @@ -189,10 +191,8 @@ impl FunctionCall {
func.return_type = target_type;
Ok(())
}
std::cmp::Ordering::Less => Err(CastError("Input has too few columns.".to_string())),
std::cmp::Ordering::Greater => {
Err(CastError("Input has too many columns.".to_string()))
}
std::cmp::Ordering::Less => bail!("input has too few columns"),
std::cmp::Ordering::Greater => bail!("input has too many columns"),
}
}

Expand Down Expand Up @@ -423,9 +423,10 @@ pub fn is_row_function(expr: &ExprImpl) -> bool {
false
}

#[derive(Debug, Error)]
#[error("{0}")]
pub struct CastError(pub(super) String);
def_anyhow_newtype! {
pub CastError,
}
pub type CastResult<T = ()> = Result<T, CastError>;

impl From<CastError> for ErrorCode {
fn from(value: CastError) -> Self {
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ impl ExprImpl {
))),
DataType::Int32 => Ok(self),
dt if dt.is_int() => Ok(self.cast_explicit(DataType::Int32)?),
_ => Err(CastError("Unsupported input type".to_string())),
_ => bail!("unsupported input type"),
}
}

Expand Down
79 changes: 60 additions & 19 deletions src/frontend/src/expr/type_inference/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
use std::collections::BTreeMap;
use std::sync::LazyLock;

use anyhow::Context;
use itertools::Itertools as _;
use parse_display::Display;
use risingwave_common::error::bail;
use risingwave_common::types::{DataType, DataTypeName};
use risingwave_common::util::iter_util::ZipEqFast;

use crate::error::ErrorCode;
use crate::expr::function_call::{CastError, CastResult};
use crate::expr::{Expr as _, ExprImpl, InputRef, Literal};

/// Find the least restrictive type. Used by `VALUES`, `CASE`, `UNION`, etc.
Expand Down Expand Up @@ -114,12 +117,45 @@ pub fn align_array_and_element(
Ok(array_type)
}

fn canmeh(ok: bool) -> CastResult {
if ok {
Ok(())
} else {
bail!("")
}
}
fn cannot() -> CastResult {
canmeh(false)
}

pub fn cast(source: &DataType, target: &DataType, allows: CastContext) -> Result<(), CastError> {
macro_rules! any {
($f:ident) => {
source.$f() || target.$f()
};
}

if any!(is_struct) {
cast_ok_struct(source, target, allows)
} else if any!(is_array) {
cast_ok_array(source, target, allows)
} else if any!(is_map) {
cast_ok_map(source, target, allows)
} else {
canmeh(cast_ok_base(source, target, allows))
}
.with_context(|| {
format!(
"cannot cast type \"{}\" to \"{}\" in {:?} context",
source, target, allows
)
})
.map_err(Into::into)
}

/// Checks whether casting from `source` to `target` is ok in `allows` context.
pub fn cast_ok(source: &DataType, target: &DataType, allows: CastContext) -> bool {
cast_ok_struct(source, target, allows)
|| cast_ok_array(source, target, allows)
|| cast_ok_map(source, target, allows)
|| cast_ok_base(source, target, allows)
cast(source, target, allows).is_ok()
}

/// Checks whether casting from `source` to `target` is ok in `allows` context.
Expand All @@ -128,52 +164,57 @@ pub fn cast_ok_base(source: &DataType, target: &DataType, allows: CastContext) -
matches!(CAST_MAP.get(&(source.into(), target.into())), Some(context) if *context <= allows)
}

fn cast_ok_struct(source: &DataType, target: &DataType, allows: CastContext) -> bool {
fn cast_ok_struct(source: &DataType, target: &DataType, allows: CastContext) -> CastResult {
match (source, target) {
(DataType::Struct(lty), DataType::Struct(rty)) => {
if lty.is_empty() || rty.is_empty() {
unreachable!("record type should be already processed at this point");
}
if lty.len() != rty.len() {
// only cast structs of the same length
return false;
bail!("cannot cast structs of different lengths");
}
// ... and all fields are castable
lty.types()
.zip_eq_fast(rty.types())
.all(|(src, dst)| src == dst || cast_ok(src, dst, allows))
.try_for_each(|(src, dst)| {
if src == dst {
Ok(())
} else {
cast(src, dst, allows)
}
})
}
// The automatic casts to string types are treated as assignment casts, while the automatic
// casts from string types are explicit-only.
// https://www.postgresql.org/docs/14/sql-createcast.html#id-1.9.3.58.7.4
(DataType::Varchar, DataType::Struct(_)) => CastContext::Explicit <= allows,
(DataType::Struct(_), DataType::Varchar) => CastContext::Assign <= allows,
_ => false,
(DataType::Varchar, DataType::Struct(_)) => canmeh(CastContext::Explicit <= allows),
(DataType::Struct(_), DataType::Varchar) => canmeh(CastContext::Assign <= allows),
_ => cannot(),
}
}

fn cast_ok_array(source: &DataType, target: &DataType, allows: CastContext) -> bool {
fn cast_ok_array(source: &DataType, target: &DataType, allows: CastContext) -> CastResult {
match (source, target) {
(DataType::List(source_elem), DataType::List(target_elem)) => {
cast_ok(source_elem, target_elem, allows)
cast(source_elem, target_elem, allows)
}
// The automatic casts to string types are treated as assignment casts, while the automatic
// casts from string types are explicit-only.
// https://www.postgresql.org/docs/14/sql-createcast.html#id-1.9.3.58.7.4
(DataType::Varchar, DataType::List(_)) => CastContext::Explicit <= allows,
(DataType::List(_), DataType::Varchar) => CastContext::Assign <= allows,
_ => false,
(DataType::Varchar, DataType::List(_)) => canmeh(CastContext::Explicit <= allows),
(DataType::List(_), DataType::Varchar) => canmeh(CastContext::Assign <= allows),
_ => cannot(),
}
}

fn cast_ok_map(source: &DataType, target: &DataType, allows: CastContext) -> bool {
fn cast_ok_map(source: &DataType, target: &DataType, allows: CastContext) -> CastResult {
match (source, target) {
(DataType::Map(source_elem), DataType::Map(target_elem)) => cast_ok(
(DataType::Map(source_elem), DataType::Map(target_elem)) => cast(
&source_elem.clone().into_list(),
&target_elem.clone().into_list(),
allows,
),
_ => false,
_ => cannot(),
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/expr/type_inference/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@
mod cast;
mod func;
pub use cast::{
align_types, cast_map_array, cast_ok, cast_ok_base, cast_sigs, CastContext, CastSig,
align_types, cast_map_array, cast_ok, cast, cast_ok_base, cast_sigs, CastContext, CastSig,
};
pub use func::{infer_some_all, infer_type, infer_type_name, infer_type_with_sigmap, FuncSign};

0 comments on commit 4e686e0

Please sign in to comment.