diff --git a/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs b/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs index d5b1332c25b3..291743ea4ba2 100644 --- a/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs +++ b/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use itertools::Itertools; use risingwave_common::types::{DataType, Fields}; use risingwave_frontend_macro::system_catalog; use crate::catalog::system_catalog::SysCatalogReaderImpl; -use crate::expr::cast_map_array; +use crate::expr::CAST_TABLE; /// The catalog `pg_cast` stores data type conversion paths. /// Ref: [`https://www.postgresql.org/docs/current/catalog-pg-cast.html`] @@ -31,12 +32,11 @@ struct PgCast { #[system_catalog(table, "pg_catalog.pg_cast")] fn read_pg_cast(_: &SysCatalogReaderImpl) -> Vec { - let mut cast_array = cast_map_array(); - cast_array.sort(); - cast_array + CAST_TABLE .iter() + .sorted() .enumerate() - .map(|(idx, (src, target, ctx))| PgCast { + .map(|(idx, ((src, target), ctx))| PgCast { oid: idx as i32, castsource: DataType::try_from(*src).unwrap().to_oid(), casttarget: DataType::try_from(*target).unwrap().to_oid(), diff --git a/src/frontend/src/expr/function_call.rs b/src/frontend/src/expr/function_call.rs index a8a5598b7349..e6da13ccbbce 100644 --- a/src/frontend/src/expr/function_call.rs +++ b/src/frontend/src/expr/function_call.rs @@ -423,6 +423,7 @@ pub fn is_row_function(expr: &ExprImpl) -> bool { false } +/// A stack of error messages for the cast operation. #[derive(Error, Debug, Box, Macro)] #[thiserror_ext(newtype(name = CastError), macro(path = "crate::expr::function_call"))] #[error("{message}")] @@ -433,6 +434,7 @@ pub struct CastErrorInner { pub type CastResult = Result; +// TODO(error-handling): do not use report string but directly make it a source of `ErrorCode`. impl From for ErrorCode { fn from(value: CastError) -> Self { ErrorCode::BindError(value.to_report_string()) diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index cf1d0cc21879..96a90efd9db7 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -68,8 +68,8 @@ pub use session_timezone::{SessionTimezone, TimestamptzExprFinder}; pub use subquery::{Subquery, SubqueryKind}; pub use table_function::{TableFunction, TableFunctionType}; pub use type_inference::{ - align_types, cast_map_array, cast_ok, cast_sigs, infer_some_all, infer_type, infer_type_name, - infer_type_with_sigmap, CastContext, CastSig, FuncSign, + align_types, cast_ok, cast_sigs, infer_some_all, infer_type, infer_type_name, + infer_type_with_sigmap, CastContext, CastSig, FuncSign, CAST_TABLE, }; pub use user_defined_function::UserDefinedFunction; pub use utils::*; diff --git a/src/frontend/src/expr/type_inference/cast.rs b/src/frontend/src/expr/type_inference/cast.rs index 806d13f9fb30..ee804c811a8a 100644 --- a/src/frontend/src/expr/type_inference/cast.rs +++ b/src/frontend/src/expr/type_inference/cast.rs @@ -115,6 +115,8 @@ pub fn align_array_and_element( Ok(array_type) } +/// Returns `Ok` if `ok` is true, otherwise returns a placeholder [`CastError`] to be further +/// wrapped with a more informative context in [`cast`]. fn canmeh(ok: bool) -> CastResult { if ok { Ok(()) @@ -122,10 +124,13 @@ fn canmeh(ok: bool) -> CastResult { bail_cast_error!() } } +/// Equivalent to `canmeh(false)`. fn cannot() -> CastResult { canmeh(false) } +/// Checks whether casting from `source` to `target` is ok in `allows` context. +/// Returns an error if the cast is not possible. pub fn cast(source: &DataType, target: &DataType, allows: CastContext) -> Result<(), CastError> { macro_rules! any { ($f:ident) => { @@ -134,11 +139,11 @@ pub fn cast(source: &DataType, target: &DataType, allows: CastContext) -> Result } if any!(is_struct) { - cast_ok_struct(source, target, allows) + cast_struct(source, target, allows) } else if any!(is_array) { - cast_ok_array(source, target, allows) + cast_array(source, target, allows) } else if any!(is_map) { - cast_ok_map(source, target, allows) + cast_map(source, target, allows) } else { canmeh(cast_ok_base(source, target, allows)) } @@ -154,6 +159,8 @@ pub fn cast(source: &DataType, target: &DataType, allows: CastContext) -> Result } /// Checks whether casting from `source` to `target` is ok in `allows` context. +/// +/// Equivalent to `cast(..).is_ok()`, but [`cast`] may be preferred for its error messages. pub fn cast_ok(source: &DataType, target: &DataType, allows: CastContext) -> bool { cast(source, target, allows).is_ok() } @@ -161,10 +168,10 @@ pub fn cast_ok(source: &DataType, target: &DataType, allows: CastContext) -> boo /// Checks whether casting from `source` to `target` is ok in `allows` context. /// Both `source` and `target` must be base types, i.e. not struct or array. pub fn cast_ok_base(source: &DataType, target: &DataType, allows: CastContext) -> bool { - matches!(CAST_MAP.get(&(source.into(), target.into())), Some(context) if *context <= allows) + matches!(CAST_TABLE.get(&(source.into(), target.into())), Some(context) if *context <= allows) } -fn cast_ok_struct(source: &DataType, target: &DataType, allows: CastContext) -> CastResult { +fn cast_struct(source: &DataType, target: &DataType, allows: CastContext) -> CastResult { match (source, target) { (DataType::Struct(lty), DataType::Struct(rty)) => { if lty.is_empty() || rty.is_empty() { @@ -193,7 +200,7 @@ fn cast_ok_struct(source: &DataType, target: &DataType, allows: CastContext) -> } } -fn cast_ok_array(source: &DataType, target: &DataType, allows: CastContext) -> CastResult { +fn cast_array(source: &DataType, target: &DataType, allows: CastContext) -> CastResult { match (source, target) { (DataType::List(source_elem), DataType::List(target_elem)) => { cast(source_elem, target_elem, allows) @@ -207,7 +214,7 @@ fn cast_ok_array(source: &DataType, target: &DataType, allows: CastContext) -> C } } -fn cast_ok_map(source: &DataType, target: &DataType, allows: CastContext) -> CastResult { +fn cast_map(source: &DataType, target: &DataType, allows: CastContext) -> CastResult { match (source, target) { (DataType::Map(source_elem), DataType::Map(target_elem)) => cast( &source_elem.clone().into_list(), @@ -218,13 +225,6 @@ fn cast_ok_map(source: &DataType, target: &DataType, allows: CastContext) -> Cas } } -pub fn cast_map_array() -> Vec<(DataTypeName, DataTypeName, CastContext)> { - CAST_MAP - .iter() - .map(|((src, target), ctx)| (*src, *target, *ctx)) - .collect_vec() -} - #[derive(Clone, Debug)] pub struct CastSig { pub from_type: DataTypeName, @@ -245,10 +245,10 @@ pub enum CastContext { Explicit, } -pub type CastMap = BTreeMap<(DataTypeName, DataTypeName), CastContext>; +pub type CastTable = BTreeMap<(DataTypeName, DataTypeName), CastContext>; pub fn cast_sigs() -> impl Iterator { - CAST_MAP + CAST_TABLE .iter() .map(|((from_type, to_type), context)| CastSig { from_type: *from_type, @@ -257,7 +257,7 @@ pub fn cast_sigs() -> impl Iterator { }) } -pub static CAST_MAP: LazyLock = LazyLock::new(|| { +pub static CAST_TABLE: LazyLock = LazyLock::new(|| { // cast rules: // 1. implicit cast operations in PG are organized in 3 sequences, // with the reverse direction being assign cast operations. @@ -347,7 +347,7 @@ mod tests { fn test_cast_ok() { // With the help of a script we can obtain the 3 expected cast tables from PG. They are // slightly modified on same-type cast and from-string cast for reasons explained above in - // `build_cast_map`. + // `build_`. let actual = gen_cast_table(CastContext::Implicit); assert_eq!( diff --git a/src/frontend/src/expr/type_inference/mod.rs b/src/frontend/src/expr/type_inference/mod.rs index 7f10d3aec3f6..bc4e6477bf2d 100644 --- a/src/frontend/src/expr/type_inference/mod.rs +++ b/src/frontend/src/expr/type_inference/mod.rs @@ -18,6 +18,6 @@ mod cast; mod func; pub use cast::{ - align_types, cast_map_array, cast_ok, cast, cast_ok_base, cast_sigs, CastContext, CastSig, + align_types, cast, cast_ok, cast_ok_base, cast_sigs, CastContext, CastSig, CAST_TABLE, }; pub use func::{infer_some_all, infer_type, infer_type_name, infer_type_with_sigmap, FuncSign};