From e1de1857a6b96c1070d1273cfcebd8eef52a62c3 Mon Sep 17 00:00:00 2001 From: xxchan Date: Wed, 14 Aug 2024 13:33:53 +0800 Subject: [PATCH] feat: support `map_access` (#17986) Signed-off-by: xxchan --- e2e_test/batch/types/map.slt.part | 6 ++ proto/expr.proto | 1 + src/common/src/array/list_array.rs | 18 ++++++ src/common/src/array/map_array.rs | 4 ++ src/expr/impl/src/scalar/array.rs | 63 ++++++++++++++++++- src/expr/impl/src/scalar/array_positions.rs | 5 +- src/expr/impl/src/scalar/case.rs | 4 +- src/expr/impl/src/scalar/cast.rs | 8 +-- src/expr/impl/src/scalar/coalesce.rs | 2 +- src/expr/impl/src/scalar/external/iceberg.rs | 2 +- src/expr/impl/src/scalar/field.rs | 2 +- src/expr/impl/src/scalar/jsonb_record.rs | 7 ++- src/expr/macro/src/gen.rs | 7 ++- .../binder/expr/function/builtin_scalar.rs | 1 + src/frontend/src/binder/mod.rs | 4 +- src/frontend/src/expr/literal.rs | 7 ++- src/frontend/src/expr/mod.rs | 2 +- src/frontend/src/expr/pure.rs | 3 +- src/frontend/src/expr/type_inference/cast.rs | 8 ++- src/frontend/src/expr/type_inference/func.rs | 18 ++++++ src/frontend/src/expr/type_inference/mod.rs | 3 +- .../src/optimizer/plan_expr_visitor/strong.rs | 1 + 22 files changed, 149 insertions(+), 27 deletions(-) diff --git a/e2e_test/batch/types/map.slt.part b/e2e_test/batch/types/map.slt.part index 251a5798fd00a..bcdc92103e936 100644 --- a/e2e_test/batch/types/map.slt.part +++ b/e2e_test/batch/types/map.slt.part @@ -116,5 +116,11 @@ select * from t; {"a":1,"b":2,"c":3} NULL NULL NULL NULL {"a":1,"b":2,"c":3} {"1":t,"2":f,"3":t} {"a":{"a1":a2},"b":{"b1":b2}} {"{\"a\":1,\"b\":2,\"c\":3}","{\"d\":4,\"e\":5,\"f\":6}"} ("{""a"":(1),""b"":(2),""c"":(3)}") +query ????? rowsort +select to_jsonb(m1), to_jsonb(m2), to_jsonb(m3), to_jsonb(l), to_jsonb(s) from t; +---- +{"a": 1.0, "b": 2.0, "c": 3.0} null null null null +{"a": 1.0, "b": 2.0, "c": 3.0} {"1": true, "2": false, "3": true} {"a": {"a1": "a2"}, "b": {"b1": "b2"}} [{"a": 1, "b": 2, "c": 3}, {"d": 4, "e": 5, "f": 6}] {"m": {"a": {"x": 1}, "b": {"x": 2}, "c": {"x": 3}}} + statement ok drop table t; diff --git a/proto/expr.proto b/proto/expr.proto index 1531984291028..0f543d3514e3b 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -285,6 +285,7 @@ message ExprNode { // Map functions MAP_FROM_ENTRIES = 700; + MAP_ACCESS = 701; // Non-pure functions below (> 1000) // ------------------------ diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index e7d4d780ea8f5..c30229852c0aa 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -597,6 +597,24 @@ impl<'a> ListRef<'a> { _ => None, } } + + /// # Panics + /// Panics if the list is not a map's internal representation (See [`super::MapArray`]). + pub(super) fn as_map_kv(self) -> (ListRef<'a>, ListRef<'a>) { + let (k, v) = self.array.as_struct().fields().collect_tuple().unwrap(); + ( + ListRef { + array: k, + start: self.start, + end: self.end, + }, + ListRef { + array: v, + start: self.start, + end: self.end, + }, + ) + } } impl PartialEq for ListRef<'_> { diff --git a/src/common/src/array/map_array.rs b/src/common/src/array/map_array.rs index 2534deb40518b..6e9c819a14638 100644 --- a/src/common/src/array/map_array.rs +++ b/src/common/src/array/map_array.rs @@ -268,6 +268,10 @@ mod scalar { pub fn into_inner(self) -> ListRef<'a> { self.0 } + + pub fn into_kv(self) -> (ListRef<'a>, ListRef<'a>) { + self.0.as_map_kv() + } } impl Scalar for MapValue { diff --git a/src/expr/impl/src/scalar/array.rs b/src/expr/impl/src/scalar/array.rs index 48ee281b63c2d..08de9714ce058 100644 --- a/src/expr/impl/src/scalar/array.rs +++ b/src/expr/impl/src/scalar/array.rs @@ -14,16 +14,20 @@ use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::row::Row; -use risingwave_common::types::{DataType, ListRef, MapType, MapValue, ToOwnedDatum}; +use risingwave_common::types::{ + DataType, ListRef, MapRef, MapType, MapValue, ScalarRefImpl, ToOwnedDatum, +}; use risingwave_expr::expr::Context; use risingwave_expr::{function, ExprError}; -#[function("array(...) -> anyarray", type_infer = "panic")] +use super::array_positions::array_position; + +#[function("array(...) -> anyarray", type_infer = "unreachable")] fn array(row: impl Row, ctx: &Context) -> ListValue { ListValue::from_datum_iter(ctx.return_type.as_list(), row.iter()) } -#[function("row(...) -> struct", type_infer = "panic")] +#[function("row(...) -> struct", type_infer = "unreachable")] fn row_(row: impl Row) -> StructValue { StructValue::new(row.iter().map(|d| d.to_owned_datum()).collect()) } @@ -54,6 +58,59 @@ fn map(key: ListRef<'_>, value: ListRef<'_>) -> Result { MapValue::try_from_kv(key.to_owned(), value.to_owned()).map_err(ExprError::Custom) } +/// # Example +/// +/// ```slt +/// query T +/// select map_access(map_from_entries(array[1,2,3], array[100,200,300]), 3); +/// ---- +/// 300 +/// +/// query T +/// select map_access(map_from_entries(array[1,2,3], array[100,200,300]), '3'); +/// ---- +/// 300 +/// +/// query error +/// select map_access(map_from_entries(array[1,2,3], array[100,200,300]), 1.0); +/// ---- +/// db error: ERROR: Failed to run the query +/// +/// Caused by these errors (recent errors listed first): +/// 1: Failed to bind expression: map_access(map_from_entries(ARRAY[1, 2, 3], ARRAY[100, 200, 300]), 1.0) +/// 2: Bind error: Cannot access numeric in map(integer,integer) +/// +/// +/// query T +/// select map_access(map_from_entries(array['a','b','c'], array[1,2,3]), 'a'); +/// ---- +/// 1 +/// +/// query T +/// select map_access(map_from_entries(array['a','b','c'], array[1,2,3]), 'd'); +/// ---- +/// NULL +/// +/// query T +/// select map_access(map_from_entries(array['a','b','c'], array[1,2,3]), null); +/// ---- +/// NULL +/// ``` +#[function("map_access(anymap, any) -> any")] +fn map_access<'a>( + map: MapRef<'a>, + key: ScalarRefImpl<'_>, +) -> Result>, ExprError> { + // FIXME: DatumRef in return value is not support by the macro yet. + + let (keys, values) = map.into_kv(); + let idx = array_position(keys, Some(key))?; + match idx { + Some(idx) => Ok(values.get((idx - 1) as usize).unwrap()), + None => Ok(None), + } +} + #[cfg(test)] mod tests { use risingwave_common::array::DataChunk; diff --git a/src/expr/impl/src/scalar/array_positions.rs b/src/expr/impl/src/scalar/array_positions.rs index cbae53c001439..22c5f67d40e0e 100644 --- a/src/expr/impl/src/scalar/array_positions.rs +++ b/src/expr/impl/src/scalar/array_positions.rs @@ -66,7 +66,10 @@ use risingwave_expr::{function, ExprError, Result}; /// 2 /// ``` #[function("array_position(anyarray, any) -> int4")] -fn array_position(array: ListRef<'_>, element: Option>) -> Result> { +pub(super) fn array_position( + array: ListRef<'_>, + element: Option>, +) -> Result> { array_position_common(array, element, 0) } diff --git a/src/expr/impl/src/scalar/case.rs b/src/expr/impl/src/scalar/case.rs index f7fb9d89ef41b..1c92e76ce4e30 100644 --- a/src/expr/impl/src/scalar/case.rs +++ b/src/expr/impl/src/scalar/case.rs @@ -208,7 +208,7 @@ impl Expression for ConstantLookupExpression { } } -#[build_function("constant_lookup(...) -> any", type_infer = "panic")] +#[build_function("constant_lookup(...) -> any", type_infer = "unreachable")] fn build_constant_lookup_expr( return_type: DataType, children: Vec, @@ -249,7 +249,7 @@ fn build_constant_lookup_expr( ))) } -#[build_function("case(...) -> any", type_infer = "panic")] +#[build_function("case(...) -> any", type_infer = "unreachable")] fn build_case_expr( return_type: DataType, children: Vec, diff --git a/src/expr/impl/src/scalar/cast.rs b/src/expr/impl/src/scalar/cast.rs index 2010a1876b0a8..e0dd1a8bb3fc8 100644 --- a/src/expr/impl/src/scalar/cast.rs +++ b/src/expr/impl/src/scalar/cast.rs @@ -189,13 +189,13 @@ pub fn str_to_bytea(elem: &str) -> Result> { cast::str_to_bytea(elem).map_err(|err| ExprError::Parse(err.into())) } -#[function("cast(varchar) -> anyarray", type_infer = "panic")] +#[function("cast(varchar) -> anyarray", type_infer = "unreachable")] fn str_to_list(input: &str, ctx: &Context) -> Result { ListValue::from_str(input, &ctx.return_type).map_err(|err| ExprError::Parse(err.into())) } /// Cast array with `source_elem_type` into array with `target_elem_type` by casting each element. -#[function("cast(anyarray) -> anyarray", type_infer = "panic")] +#[function("cast(anyarray) -> anyarray", type_infer = "unreachable")] fn list_cast(input: ListRef<'_>, ctx: &Context) -> Result { let cast = build_func( PbType::Cast, @@ -213,7 +213,7 @@ fn list_cast(input: ListRef<'_>, ctx: &Context) -> Result { } /// Cast struct of `source_elem_type` to `target_elem_type` by casting each element. -#[function("cast(struct) -> struct", type_infer = "panic")] +#[function("cast(struct) -> struct", type_infer = "unreachable")] fn struct_cast(input: StructRef<'_>, ctx: &Context) -> Result { let fields = (input.iter_fields_ref()) .zip_eq_fast(ctx.arg_types[0].as_struct().types()) @@ -242,7 +242,7 @@ fn struct_cast(input: StructRef<'_>, ctx: &Context) -> Result { } /// Cast array with `source_elem_type` into array with `target_elem_type` by casting each element. -#[function("cast(anymap) -> anymap", type_infer = "panic")] +#[function("cast(anymap) -> anymap", type_infer = "unreachable")] fn map_cast(map: MapRef<'_>, ctx: &Context) -> Result { let new_ctx = Context { arg_types: vec![ctx.arg_types[0].clone().as_map().clone().into_list()], diff --git a/src/expr/impl/src/scalar/coalesce.rs b/src/expr/impl/src/scalar/coalesce.rs index af3d753867559..6176a54a23d16 100644 --- a/src/expr/impl/src/scalar/coalesce.rs +++ b/src/expr/impl/src/scalar/coalesce.rs @@ -74,7 +74,7 @@ impl Expression for CoalesceExpression { } } -#[build_function("coalesce(...) -> any", type_infer = "panic")] +#[build_function("coalesce(...) -> any", type_infer = "unreachable")] fn build(return_type: DataType, children: Vec) -> Result { Ok(Box::new(CoalesceExpression { return_type, diff --git a/src/expr/impl/src/scalar/external/iceberg.rs b/src/expr/impl/src/scalar/external/iceberg.rs index 902545d01c25d..5fbc9b003305a 100644 --- a/src/expr/impl/src/scalar/external/iceberg.rs +++ b/src/expr/impl/src/scalar/external/iceberg.rs @@ -75,7 +75,7 @@ impl risingwave_expr::expr::Expression for IcebergTransform { } } -#[build_function("iceberg_transform(varchar, any) -> any", type_infer = "panic")] +#[build_function("iceberg_transform(varchar, any) -> any", type_infer = "unreachable")] fn build(return_type: DataType, mut children: Vec) -> Result { let transform_type = { let datum = children[0].eval_const()?.unwrap(); diff --git a/src/expr/impl/src/scalar/field.rs b/src/expr/impl/src/scalar/field.rs index 1d26fe9c85dbb..681b4ab6caacf 100644 --- a/src/expr/impl/src/scalar/field.rs +++ b/src/expr/impl/src/scalar/field.rs @@ -54,7 +54,7 @@ impl Expression for FieldExpression { } } -#[build_function("field(struct, int4) -> any", type_infer = "panic")] +#[build_function("field(struct, int4) -> any", type_infer = "unreachable")] fn build(return_type: DataType, children: Vec) -> Result { // Field `func_call_node` have 2 child nodes, the first is Field `FuncCall` or // `InputRef`, the second is i32 `Literal`. diff --git a/src/expr/impl/src/scalar/jsonb_record.rs b/src/expr/impl/src/scalar/jsonb_record.rs index b1d399d35a5f9..b85feb9190d2a 100644 --- a/src/expr/impl/src/scalar/jsonb_record.rs +++ b/src/expr/impl/src/scalar/jsonb_record.rs @@ -115,7 +115,7 @@ fn jsonb_populate_recordset<'a>( /// ---- /// 1 [1,2,3] {1,2,3} NULL (123,"a b c") /// ``` -#[function("jsonb_to_record(jsonb) -> struct", type_infer = "panic")] +#[function("jsonb_to_record(jsonb) -> struct", type_infer = "unreachable")] fn jsonb_to_record(jsonb: JsonbRef<'_>, ctx: &Context) -> Result { let output_type = ctx.return_type.as_struct(); jsonb.to_struct(output_type).map_err(parse_err) @@ -135,7 +135,10 @@ fn jsonb_to_record(jsonb: JsonbRef<'_>, ctx: &Context) -> Result { /// 1 foo /// 2 NULL /// ``` -#[function("jsonb_to_recordset(jsonb) -> setof struct", type_infer = "panic")] +#[function( + "jsonb_to_recordset(jsonb) -> setof struct", + type_infer = "unreachable" +)] fn jsonb_to_recordset<'a>( jsonb: JsonbRef<'a>, ctx: &'a Context, diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index 9b25f4b2557b7..ce5c8a884abdf 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -86,9 +86,10 @@ impl FunctionAttr { /// Generate the type infer function: `fn(&[DataType]) -> Result` fn generate_type_infer_fn(&self) -> Result { if let Some(func) = &self.type_infer { - // XXX: should this be called "placeholder" or "unreachable"? - if func == "panic" { - return Ok(quote! { |_| panic!("type inference function is not implemented") }); + if func == "unreachable" { + return Ok( + quote! { |_| unreachable!("type inference for this function should be specially handled in frontend, and should not call sig.type_infer") }, + ); } // use the user defined type inference function return Ok(func.parse().unwrap()); diff --git a/src/frontend/src/binder/expr/function/builtin_scalar.rs b/src/frontend/src/binder/expr/function/builtin_scalar.rs index 3987334b89ced..ed10b808b8bdc 100644 --- a/src/frontend/src/binder/expr/function/builtin_scalar.rs +++ b/src/frontend/src/binder/expr/function/builtin_scalar.rs @@ -392,6 +392,7 @@ impl Binder { ("jsonb_set", raw_call(ExprType::JsonbSet)), // map ("map_from_entries", raw_call(ExprType::MapFromEntries)), + ("map_access",raw_call(ExprType::MapAccess)), // Functions that return a constant value ("pi", pi()), // greatest and least diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 7cd9032890091..58aff26bd811d 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -260,8 +260,8 @@ impl UdfContext { } } -/// `ParameterTypes` is used to record the types of the parameters during binding. It works -/// following the rules: +/// `ParameterTypes` is used to record the types of the parameters during binding prepared stataments. +/// It works by following the rules: /// 1. At the beginning, it contains the user specified parameters type. /// 2. When the binder encounters a parameter, it will record it as unknown(call `record_new_param`) /// if it didn't exist in `ParameterTypes`. diff --git a/src/frontend/src/expr/literal.rs b/src/frontend/src/expr/literal.rs index 29ac0948b6c0b..ed76d1a9bf998 100644 --- a/src/frontend/src/expr/literal.rs +++ b/src/frontend/src/expr/literal.rs @@ -70,7 +70,12 @@ impl std::fmt::Debug for Literal { impl Literal { pub fn new(data: Datum, data_type: DataType) -> Self { - assert!(literal_type_match(&data_type, data.as_ref())); + assert!( + literal_type_match(&data_type, data.as_ref()), + "data_type: {:?}, data: {:?}", + data_type, + data + ); Literal { data, data_type: Some(data_type), diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index d42317b00f10b..73becd7bc86c8 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -68,7 +68,7 @@ pub use subquery::{Subquery, SubqueryKind}; pub use table_function::{TableFunction, TableFunctionType}; pub use type_inference::{ align_types, cast_map_array, cast_ok, cast_sigs, infer_some_all, infer_type, infer_type_name, - infer_type_with_sigmap, least_restrictive, CastContext, CastSig, FuncSign, + infer_type_with_sigmap, CastContext, CastSig, FuncSign, }; pub use user_defined_function::UserDefinedFunction; pub use utils::*; diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 90aa65d8e549d..59f087672417c 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -250,7 +250,8 @@ impl ExprVisitor for ImpureAnalyzer { | Type::InetAton | Type::QuoteLiteral | Type::QuoteNullable - | Type::MapFromEntries => + | Type::MapFromEntries + | Type::MapAccess => // expression output is deterministic(same result for the same input) { func_call diff --git a/src/frontend/src/expr/type_inference/cast.rs b/src/frontend/src/expr/type_inference/cast.rs index 2a62fc679a359..51441c3f70c5b 100644 --- a/src/frontend/src/expr/type_inference/cast.rs +++ b/src/frontend/src/expr/type_inference/cast.rs @@ -29,7 +29,10 @@ use crate::expr::{Expr as _, ExprImpl, InputRef, Literal}; /// /// If you also need to cast them to this type, and there are more than 2 exprs, check out /// [`align_types`]. -pub fn least_restrictive(lhs: DataType, rhs: DataType) -> std::result::Result { +/// +/// Note: be careful that literal strings are considered untyped. +/// e.g., `align_types(1, '1')` will be `Int32`, but `least_restrictive(Int32, Varchar)` will return error. +fn least_restrictive(lhs: DataType, rhs: DataType) -> std::result::Result { if lhs == rhs { Ok(lhs) } else if cast_ok(&lhs, &rhs, CastContext::Implicit) { @@ -81,6 +84,7 @@ pub fn align_array_and_element( element_indices: &[usize], inputs: &mut [ExprImpl], ) -> std::result::Result { + tracing::trace!(?inputs, "align_array_and_element begin"); let mut dummy_element = match inputs[array_idx].is_untyped() { // when array is unknown type, make an unknown typed value (e.g. null) true => ExprImpl::from(Literal::new_untyped(None)), @@ -106,7 +110,7 @@ pub fn align_array_and_element( // elements are already casted by `align_types`, we cast the array argument here inputs[array_idx].cast_implicit_mut(array_type.clone())?; - + tracing::trace!(?inputs, "align_array_and_element done"); Ok(array_type) } diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 1797770cc7611..746460e2b6363 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -31,6 +31,7 @@ use crate::expr::{cast_ok, is_row_function, Expr as _, ExprImpl, ExprType, Funct /// is not supported on backend. /// /// It also mutates the `inputs` by adding necessary casts. +#[tracing::instrument(level = "trace", skip(sig_map))] pub fn infer_type_with_sigmap( func_name: FuncName, inputs: &mut [ExprImpl], @@ -65,6 +66,7 @@ pub fn infer_type_with_sigmap( }) .collect_vec(); let sig = infer_type_name(sig_map, func_name, &actuals)?; + tracing::trace!(?actuals, ?sig, "infer_type_name"); // add implicit casts to inputs for (expr, t) in inputs.iter_mut().zip_eq_fast(&sig.inputs_type) { @@ -82,6 +84,7 @@ pub fn infer_type_with_sigmap( let input_types = inputs.iter().map(|expr| expr.return_type()).collect_vec(); let return_type = (sig.type_infer)(&input_types)?; + tracing::trace!(?input_types, ?return_type, "finished type inference"); Ok(return_type) } @@ -608,6 +611,21 @@ fn infer_type_for_special( _ => Ok(None), } } + ExprType::MapAccess => { + ensure_arity!("map_access", | inputs | == 2); + let map_type = inputs[0].return_type().into_map(); + // We do not align the map's key type with the input type here, but cast the latter to the former instead. + // e.g., for {1:'a'}[1.0], if we align them, we will get "numeric" as the key type, which violates the map type's restriction. + match inputs[1].cast_implicit_mut(map_type.key().clone()) { + Ok(()) => Ok(Some(map_type.value().clone())), + Err(_) => Err(ErrorCode::BindError(format!( + "Cannot access {} in {}", + inputs[1].return_type(), + inputs[0].return_type(), + )) + .into()), + } + } ExprType::Vnode => { ensure_arity!("vnode", 1 <= | inputs |); Ok(Some(VirtualNode::RW_TYPE)) diff --git a/src/frontend/src/expr/type_inference/mod.rs b/src/frontend/src/expr/type_inference/mod.rs index 5f191a898614c..2845f05ec0dae 100644 --- a/src/frontend/src/expr/type_inference/mod.rs +++ b/src/frontend/src/expr/type_inference/mod.rs @@ -18,7 +18,6 @@ mod cast; mod func; pub use cast::{ - align_types, cast_map_array, cast_ok, cast_ok_base, cast_sigs, least_restrictive, CastContext, - CastSig, + align_types, cast_map_array, cast_ok, cast_ok_base, cast_sigs, CastContext, CastSig, }; pub use func::{infer_some_all, infer_type, infer_type_name, infer_type_with_sigmap, FuncSign}; diff --git a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs index d744f1ba14a14..5e9a3ce05392c 100644 --- a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs +++ b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs @@ -292,6 +292,7 @@ impl Strong { | ExprType::JsonbToRecord | ExprType::JsonbSet | ExprType::MapFromEntries + | ExprType::MapAccess | ExprType::Vnode | ExprType::TestPaidTier | ExprType::Proctime