From 24105cd2d5da024c65dc7c6d959dcd18bffc8bb2 Mon Sep 17 00:00:00 2001 From: xxchan Date: Mon, 19 Aug 2024 17:30:42 +0800 Subject: [PATCH] feat: support more map functions (#18073) --- e2e_test/batch/types/map.slt.part | 139 ++++++++++++--- proto/expr.proto | 9 + src/common/src/array/list_array.rs | 4 +- src/common/src/array/map_array.rs | 92 +++++++++- src/common/src/array/struct_array.rs | 9 + src/common/src/test_utils/rand_array.rs | 2 +- src/common/src/types/map_type.rs | 43 +++-- src/common/src/types/mod.rs | 8 +- src/common/src/util/value_encoding/mod.rs | 2 +- src/connector/codec/src/decoder/avro/mod.rs | 2 +- src/expr/impl/src/scalar/array.rs | 160 ++++++++++++++++-- src/expr/impl/src/scalar/cast.rs | 2 +- .../binder/expr/function/builtin_scalar.rs | 9 + src/frontend/src/binder/expr/mod.rs | 2 +- src/frontend/src/binder/expr/value.rs | 4 +- src/frontend/src/expr/mod.rs | 15 +- src/frontend/src/expr/pure.rs | 11 +- src/frontend/src/expr/type_inference/func.rs | 36 +++- .../src/optimizer/plan_expr_visitor/strong.rs | 9 + 19 files changed, 482 insertions(+), 76 deletions(-) diff --git a/e2e_test/batch/types/map.slt.part b/e2e_test/batch/types/map.slt.part index 5f68bcad22740..b4b4be7e5cba7 100644 --- a/e2e_test/batch/types/map.slt.part +++ b/e2e_test/batch/types/map.slt.part @@ -8,66 +8,66 @@ create table t (m map (float, float)); db error: ERROR: Failed to run the query Caused by: - invalid map key type: double precision + Bind error: invalid map key type: double precision query error -select map_from_entries(array[1.0,2.0,3.0], array[1,2,3]); +select map_from_key_values(array[1.0,2.0,3.0], array[1,2,3]); ---- db error: ERROR: Failed to run the query Caused by these errors (recent errors listed first): - 1: Failed to bind expression: map_from_entries(ARRAY[1.0, 2.0, 3.0], ARRAY[1, 2, 3]) + 1: Failed to bind expression: map_from_key_values(ARRAY[1.0, 2.0, 3.0], ARRAY[1, 2, 3]) 2: Expr error 3: invalid map key type: numeric query error -select map_from_entries(array[1,1,3], array[1,2,3]); +select map_from_key_values(array[1,1,3], array[1,2,3]); ---- db error: ERROR: Failed to run the query Caused by these errors (recent errors listed first): 1: Expr error - 2: error while evaluating expression `map_from_entries('{1,1,3}', '{1,2,3}')` + 2: error while evaluating expression `map_from_key_values('{1,1,3}', '{1,2,3}')` 3: map keys must be unique query ? -select map_from_entries(array[1,2,3], array[1,null,3]); +select map_from_key_values(array[1,2,3], array[1,null,3]); ---- {1:1,2:NULL,3:3} query error -select map_from_entries(array[1,null,3], array[1,2,3]); +select map_from_key_values(array[1,null,3], array[1,2,3]); ---- db error: ERROR: Failed to run the query Caused by these errors (recent errors listed first): 1: Expr error - 2: error while evaluating expression `map_from_entries('{1,NULL,3}', '{1,2,3}')` + 2: error while evaluating expression `map_from_key_values('{1,NULL,3}', '{1,2,3}')` 3: map keys must not be NULL query error -select map_from_entries(array[1,3], array[1,2,3]); +select map_from_key_values(array[1,3], array[1,2,3]); ---- db error: ERROR: Failed to run the query Caused by these errors (recent errors listed first): 1: Expr error - 2: error while evaluating expression `map_from_entries('{1,3}', '{1,2,3}')` + 2: error while evaluating expression `map_from_key_values('{1,3}', '{1,2,3}')` 3: map keys and values have different length query error -select map_from_entries(array[1,2], array[1,2]) = map_from_entries(array[2,1], array[2,1]); +select map_from_key_values(array[1,2], array[1,2]) = map_from_key_values(array[2,1], array[2,1]); ---- db error: ERROR: Failed to run the query Caused by these errors (recent errors listed first): - 1: Failed to bind expression: map_from_entries(ARRAY[1, 2], ARRAY[1, 2]) = map_from_entries(ARRAY[2, 1], ARRAY[2, 1]) + 1: Failed to bind expression: map_from_key_values(ARRAY[1, 2], ARRAY[1, 2]) = map_from_key_values(ARRAY[2, 1], ARRAY[2, 1]) 2: function equal(map(integer,integer), map(integer,integer)) does not exist @@ -83,32 +83,32 @@ create table t ( statement ok insert into t values ( - map_from_entries(array['a','b','c'], array[1.0,2.0,3.0]::float[]), - map_from_entries(array[1,2,3], array[true,false,true]), - map_from_entries(array['a','b'], + map_from_key_values(array['a','b','c'], array[1.0,2.0,3.0]::float[]), + map_from_key_values(array[1,2,3], array[true,false,true]), + map_from_key_values(array['a','b'], array[ - map_from_entries(array['a1'], array['a2']), - map_from_entries(array['b1'], array['b2']) + map_from_key_values(array['a1'], array['a2']), + map_from_key_values(array['b1'], array['b2']) ] ), array[ - map_from_entries(array['a','b','c'], array[1,2,3]), - map_from_entries(array['d','e','f'], array[4,5,6]) + map_from_key_values(array['a','b','c'], array[1,2,3]), + map_from_key_values(array['d','e','f'], array[4,5,6]) ], row( - map_from_entries(array['a','b','c'], array[row(1),row(2),row(3)]::struct[]) + map_from_key_values(array['a','b','c'], array[row(1),row(2),row(3)]::struct[]) ) ); # cast(map(character varying,integer)) -> map(character varying,double precision) query ? -select map_from_entries(array['a','b','c'], array[1,2,3])::map(varchar,float); +select map_from_key_values(array['a','b','c'], array[1,2,3])::map(varchar,float); ---- {a:1,b:2,c:3} statement ok -insert into t(m1) values (map_from_entries(array['a','b','c'], array[1,2,3])); +insert into t(m1) values (map_from_key_values(array['a','b','c'], array[1,2,3])); query ????? rowsort select * from t; @@ -144,7 +144,7 @@ db error: ERROR: Failed to run the query Caused by these errors (recent errors listed first): 1: Expr error - 2: error while evaluating expression `map_from_entries('{a,a}', '{1,2}')` + 2: error while evaluating expression `map_from_key_values('{a,a}', '{1,2}')` 3: map keys must be unique @@ -165,3 +165,96 @@ select MAP{1:'a',2:'b'}::MAP(VARCHAR,VARCHAR) ---- {} {1:a,2:b} + +query error +select map_from_entries(array[]); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Failed to bind expression: map_from_entries(ARRAY[]) + 2: Bind error: cannot determine type of empty array +HINT: Explicitly cast to the desired type, for example ARRAY[]::integer[]. + + +query error +select map_from_entries(array[]::int[]); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Failed to bind expression: map_from_entries(CAST(ARRAY[] AS INT[])) + 2: Expr error + 3: invalid map entries type, expected struct, got: integer + + +query error +select map_from_entries(array[]::struct[]); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Failed to bind expression: map_from_entries(CAST(ARRAY[] AS STRUCT[])) + 2: Expr error + 3: invalid map key type: double precision + + +query ? +select map_from_entries(array[]::struct[]); +---- +{} + + +query ? +select map_from_entries(array[row('a',1), row('b',2), row('c',3)]); +---- +{a:1,b:2,c:3} + + +query error +select map_from_entries(array[row('a',1), row('a',2), row('c',3)]); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Expr error + 2: error while evaluating expression `map_from_entries('{"(a,1)","(a,2)","(c,3)"}')` + 3: map keys must be unique + + +query error +select map_from_entries(array[row('a',1,2)]); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Failed to bind expression: map_from_entries(ARRAY[ROW('a', 1, 2)]) + 2: Expr error + 3: the underlying struct for map must have exactly two fields, got: StructType { field_names: [], field_types: [Varchar, Int32, Int32] } + + +query error +select map_from_entries(array[row(1.0,1)]); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Failed to bind expression: map_from_entries(ARRAY[ROW(1.0, 1)]) + 2: Expr error + 3: invalid map key type: numeric + + +query error +select map_from_entries(null); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Failed to bind expression: map_from_entries(NULL) + 2: Bind error: Cannot implicitly cast 'null:Varchar' to polymorphic type AnyArray + + +query ? +select map_from_entries(null::struct[]); +---- +NULL diff --git a/proto/expr.proto b/proto/expr.proto index 0f543d3514e3b..e5b5fb73ba8ff 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -286,6 +286,15 @@ message ExprNode { // Map functions MAP_FROM_ENTRIES = 700; MAP_ACCESS = 701; + MAP_KEYS = 702; + MAP_VALUES = 703; + MAP_ENTRIES = 704; + MAP_FROM_KEY_VALUES = 705; + MAP_LENGTH = 706; + MAP_CONTAINS = 707; + MAP_CAT = 708; + MAP_INSERT = 709; + MAP_DELETE = 710; // Non-pure functions below (> 1000) // ------------------------ diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index c30229852c0aa..745b1f6bbab05 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -372,11 +372,11 @@ impl ListValue { /// Creates a new `ListValue` from an iterator of `Datum`. pub fn from_datum_iter( - datatype: &DataType, + elem_datatype: &DataType, iter: impl IntoIterator, ) -> Self { let iter = iter.into_iter(); - let mut builder = datatype.create_array_builder(iter.size_hint().0); + let mut builder = elem_datatype.create_array_builder(iter.size_hint().0); for datum in iter { builder.append(datum); } diff --git a/src/common/src/array/map_array.rs b/src/common/src/array/map_array.rs index f0904211f5edc..2f0da9bbf816f 100644 --- a/src/common/src/array/map_array.rs +++ b/src/common/src/array/map_array.rs @@ -24,6 +24,7 @@ use serde::Serializer; use super::{ Array, ArrayBuilder, ArrayImpl, ArrayResult, DatumRef, DefaultOrdered, ListArray, ListArrayBuilder, ListRef, ListValue, MapType, ScalarRef, ScalarRefImpl, StructArray, + StructRef, }; use crate::bitmap::Bitmap; use crate::types::{DataType, Scalar, ToText}; @@ -162,7 +163,7 @@ impl Array for MapArray { fn data_type(&self) -> DataType { let list_value_type = self.inner.values().data_type(); - DataType::Map(MapType::from_list_entries(list_value_type)) + DataType::Map(MapType::from_entries(list_value_type)) } } @@ -193,7 +194,10 @@ pub use scalar::{MapRef, MapValue}; /// We only check the invariants in the constructors. /// After they are constructed, we assume the invariants holds. mod scalar { + use std::collections::HashSet; + use super::*; + use crate::array::{Datum, ScalarImpl, StructValue}; /// Refer to [`MapArray`] for the invariants of a map value. #[derive(Clone, Eq, EstimateSize)] @@ -221,20 +225,33 @@ mod scalar { /// # Panics /// Panics if [map invariants](`super::MapArray`) are violated. - pub fn from_list_entries(list: ListValue) -> Self { + pub fn from_entries(entries: ListValue) -> Self { + Self::try_from_entries(entries).unwrap() + } + + /// Returns error if [map invariants](`super::MapArray`) are violated. + pub fn try_from_entries(entries: ListValue) -> Result { // validates list type is valid - _ = MapType::from_list_entries(list.data_type()); - // TODO: validate the values is valid - MapValue(list) + let _ = MapType::try_from_entries(entries.data_type())?; + let mut keys = HashSet::with_capacity(entries.len()); + let struct_array = entries.into_array(); + for key in struct_array.as_struct().field_at(0).iter() { + let Some(key) = key else { + return Err("map keys must not be NULL".to_string()); + }; + if !keys.insert(key) { + return Err("map keys must be unique".to_string()); + } + } + Ok(MapValue(ListValue::new(struct_array))) } - /// # Panics - /// Panics if [map invariants](`super::MapArray`) are violated. + /// Returns error if [map invariants](`super::MapArray`) are violated. pub fn try_from_kv(key: ListValue, value: ListValue) -> Result { if key.len() != value.len() { return Err("map keys and values have different length".to_string()); } - let unique_keys = key.iter().unique().collect_vec(); + let unique_keys: HashSet<_> = key.iter().unique().collect(); if unique_keys.len() != key.len() { return Err("map keys must be unique".to_string()); } @@ -252,6 +269,46 @@ mod scalar { ); Ok(MapValue(ListValue::new(struct_array.into()))) } + + /// # Panics + /// Panics if `m1` and `m2` have different types. + pub fn concat(m1: MapRef<'_>, m2: MapRef<'_>) -> Self { + debug_assert_eq!(m1.inner().data_type(), m2.inner().data_type()); + let m2_keys = m2.keys(); + let l = ListValue::from_datum_iter( + &m1.inner().data_type(), + m1.iter_struct() + .filter(|s| !m2_keys.contains(&s.field_at(0).expect("map key is not null"))) + .chain(m2.iter_struct()) + .map(|s| Some(ScalarRefImpl::Struct(s))), + ); + Self::from_entries(l) + } + + pub fn insert(m: MapRef<'_>, key: ScalarImpl, value: Datum) -> Self { + let l = ListValue::from_datum_iter( + &m.inner().data_type(), + m.iter_struct() + .filter(|s| { + key.as_scalar_ref_impl() != s.field_at(0).expect("map key is not null") + }) + .chain(std::iter::once( + StructValue::new(vec![Some(key.clone()), value]).as_scalar_ref(), + )) + .map(|s| Some(ScalarRefImpl::Struct(s))), + ); + Self::from_entries(l) + } + + pub fn delete(m: MapRef<'_>, key: ScalarRefImpl<'_>) -> Self { + let l = ListValue::from_datum_iter( + &m.inner().data_type(), + m.iter_struct() + .filter(|s| key != s.field_at(0).expect("map key is not null")) + .map(|s| Some(ScalarRefImpl::Struct(s))), + ); + Self::from_entries(l) + } } impl<'a> MapRef<'a> { @@ -272,6 +329,14 @@ mod scalar { pub fn into_kv(self) -> (ListRef<'a>, ListRef<'a>) { self.0.as_map_kv() } + + pub fn keys(&self) -> HashSet> { + self.iter().map(|(k, _v)| k).collect() + } + + pub fn to_owned(self) -> MapValue { + MapValue(self.0.to_owned()) + } } impl Scalar for MapValue { @@ -379,6 +444,15 @@ impl<'a> MapRef<'a> { }) } + pub fn iter_struct( + self, + ) -> impl DoubleEndedIterator + ExactSizeIterator> + 'a { + self.inner().iter().map(|list_elem| { + let list_elem = list_elem.expect("the list element in map should not be null"); + list_elem.into_struct() + }) + } + pub fn iter_sorted( self, ) -> impl DoubleEndedIterator + ExactSizeIterator, DatumRef<'a>)> + 'a @@ -411,7 +485,7 @@ impl MapValue { deserializer: &mut memcomparable::Deserializer, ) -> memcomparable::Result { let list = ListValue::memcmp_deserialize(&datatype.clone().into_struct(), deserializer)?; - Ok(Self::from_list_entries(list)) + Ok(Self::from_entries(list)) } } diff --git a/src/common/src/array/struct_array.rs b/src/common/src/array/struct_array.rs index 9c3bd23653815..ebf224f581616 100644 --- a/src/common/src/array/struct_array.rs +++ b/src/common/src/array/struct_array.rs @@ -393,6 +393,15 @@ impl<'a> StructRef<'a> { iter_fields_ref!(self, it, { Either::Left(it) }, { Either::Right(it) }) } + /// # Panics + /// Panics if the index is out of bounds. + pub fn field_at(&self, i: usize) -> DatumRef<'a> { + match self { + StructRef::Indexed { arr, idx } => arr.field_at(i).value_at(*idx), + StructRef::ValueRef { val } => val.fields[i].to_datum_ref(), + } + } + pub fn memcmp_serialize( self, serializer: &mut memcomparable::Serializer, diff --git a/src/common/src/test_utils/rand_array.rs b/src/common/src/test_utils/rand_array.rs index f201b4de33843..33cf42bc403e8 100644 --- a/src/common/src/test_utils/rand_array.rs +++ b/src/common/src/test_utils/rand_array.rs @@ -154,7 +154,7 @@ impl RandValue for ListValue { impl RandValue for MapValue { fn rand_value(_rand: &mut R) -> Self { // dummy value - MapValue::from_list_entries(ListValue::empty(&DataType::Struct( + MapValue::from_entries(ListValue::empty(&DataType::Struct( MapType::struct_type_for_map(DataType::Varchar, DataType::Varchar), ))) } diff --git a/src/common/src/types/map_type.rs b/src/common/src/types/map_type.rs index 4d9ec3dc5f143..e0dae8d9bc102 100644 --- a/src/common/src/types/map_type.rs +++ b/src/common/src/types/map_type.rs @@ -36,26 +36,37 @@ impl MapType { Self(Box::new((key, value))) } - pub fn try_from_kv(key: DataType, value: DataType) -> Result { + pub fn try_from_kv(key: DataType, value: DataType) -> Result { Self::check_key_type_valid(&key)?; Ok(Self(Box::new((key, value)))) } + pub fn try_from_entries(list_entries_type: DataType) -> Result { + match list_entries_type { + DataType::Struct(s) => { + let Some((k, v)) = s.iter().collect_tuple() else { + return Err(format!( + "the underlying struct for map must have exactly two fields, got: {s:?}" + )); + }; + // the field names are not strictly enforced + // Currently this panics for SELECT * FROM t + // if cfg!(debug_assertions) { + // itertools::assert_equal(struct_type.names(), ["key", "value"]); + // } + Self::try_from_kv(k.1.clone(), v.1.clone()) + } + _ => Err(format!( + "invalid map entries type, expected struct, got: {list_entries_type}" + )), + } + } + /// # Panics /// Panics if the key type is not valid for a map, or the /// entries type is not a valid struct type. - pub fn from_list_entries(list_entries_type: DataType) -> Self { - let struct_type = list_entries_type.as_struct(); - let (k, v) = struct_type - .iter() - .collect_tuple() - .expect("the underlying struct for map must have exactly two fields"); - // the field names are not strictly enforced - // Currently this panics for SELECT * FROM t - // if cfg!(debug_assertions) { - // itertools::assert_equal(struct_type.names(), ["key", "value"]); - // } - Self::from_kv(k.1.clone(), v.1.clone()) + pub fn from_entries(list_entries_type: DataType) -> Self { + Self::try_from_entries(list_entries_type).unwrap() } /// # Panics @@ -89,7 +100,7 @@ impl MapType { /// /// Note that this isn't definitive. /// Just be conservative at the beginning, but not too restrictive (like only allowing strings). - pub fn check_key_type_valid(data_type: &DataType) -> anyhow::Result<()> { + pub fn check_key_type_valid(data_type: &DataType) -> Result<(), String> { let ok = match data_type { DataType::Int16 | DataType::Int32 | DataType::Int64 => true, DataType::Varchar => true, @@ -111,7 +122,7 @@ impl MapType { | DataType::Map(_) => false, }; if !ok { - Err(anyhow::anyhow!("invalid map key type: {data_type}")) + Err(format!("invalid map key type: {data_type}")) } else { Ok(()) } @@ -128,7 +139,7 @@ impl FromStr for MapType { if let Some((key, value)) = s[4..s.len() - 1].split(',').collect_tuple() { let key = key.parse().context("failed to parse map key type")?; let value = value.parse().context("failed to parse map value type")?; - MapType::try_from_kv(key, value) + MapType::try_from_kv(key, value).map_err(|e| anyhow::anyhow!(e)) } else { Err(anyhow::anyhow!("expect map(...,...)")) } diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index b86e70b85d8bf..1fe1f3958e33c 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -251,7 +251,7 @@ impl From<&PbDataType> for DataType { // Map is physically the same as a list. // So the first (and only) item is the list element type. let list_entries_type: DataType = (&proto.field_type[0]).into(); - DataType::Map(MapType::from_list_entries(list_entries_type)) + DataType::Map(MapType::from_entries(list_entries_type)) } PbTypeName::Int256 => DataType::Int256, } @@ -849,6 +849,12 @@ impl From for ScalarImpl { } } +impl From> for ScalarImpl { + fn from(list: ListRef<'_>) -> Self { + Self::List(list.to_owned_scalar()) + } +} + impl ScalarImpl { /// Creates a scalar from pgwire "BINARY" format. /// diff --git a/src/common/src/util/value_encoding/mod.rs b/src/common/src/util/value_encoding/mod.rs index 3b4167331cb7e..3fdb8078fdef4 100644 --- a/src/common/src/util/value_encoding/mod.rs +++ b/src/common/src/util/value_encoding/mod.rs @@ -360,7 +360,7 @@ fn deserialize_value(ty: &DataType, data: &mut impl Buf) -> Result { DataType::Map(map_type) => { // FIXME: clone type everytime here is inefficient let list = deserialize_list(&map_type.clone().into_struct(), data)?.into_list(); - ScalarImpl::Map(MapValue::from_list_entries(list)) + ScalarImpl::Map(MapValue::from_entries(list)) } }) } diff --git a/src/connector/codec/src/decoder/avro/mod.rs b/src/connector/codec/src/decoder/avro/mod.rs index 738535ec9410c..dc4dae49ca7c4 100644 --- a/src/connector/codec/src/decoder/avro/mod.rs +++ b/src/connector/codec/src/decoder/avro/mod.rs @@ -344,7 +344,7 @@ impl<'a> AvroParseOptions<'a> { ); } let list = ListValue::new(builder.finish()); - MapValue::from_list_entries(list).into() + MapValue::from_entries(list).into() } (_expected, _got) => Err(create_error())?, diff --git a/src/expr/impl/src/scalar/array.rs b/src/expr/impl/src/scalar/array.rs index cee7de36c7175..d5f53213bf277 100644 --- a/src/expr/impl/src/scalar/array.rs +++ b/src/expr/impl/src/scalar/array.rs @@ -32,8 +32,14 @@ fn row_(row: impl Row) -> StructValue { StructValue::new(row.iter().map(|d| d.to_owned_datum()).collect()) } -fn map_type_infer(args: &[DataType]) -> Result { - let map = MapType::try_from_kv(args[0].as_list().clone(), args[1].as_list().clone())?; +fn map_from_key_values_type_infer(args: &[DataType]) -> Result { + let map = MapType::try_from_kv(args[0].as_list().clone(), args[1].as_list().clone()) + .map_err(ExprError::Custom)?; + Ok(map.into()) +} + +fn map_from_entries_type_infer(args: &[DataType]) -> Result { + let map = MapType::try_from_entries(args[0].as_list().clone()).map_err(ExprError::Custom)?; Ok(map.into()) } @@ -41,62 +47,70 @@ fn map_type_infer(args: &[DataType]) -> Result { /// /// ```slt /// query T -/// select map_from_entries(null::int[], array[1,2,3]); +/// select map_from_key_values(null::int[], array[1,2,3]); /// ---- /// NULL /// /// query T -/// select map_from_entries(array['a','b','c'], array[1,2,3]); +/// select map_from_key_values(array['a','b','c'], array[1,2,3]); /// ---- /// {a:1,b:2,c:3} /// ``` #[function( - "map_from_entries(anyarray, anyarray) -> anymap", - type_infer = "map_type_infer" + "map_from_key_values(anyarray, anyarray) -> anymap", + type_infer = "map_from_key_values_type_infer" )] -fn map_from_entries(key: ListRef<'_>, value: ListRef<'_>) -> Result { +fn map_from_key_values(key: ListRef<'_>, value: ListRef<'_>) -> Result { MapValue::try_from_kv(key.to_owned(), value.to_owned()).map_err(ExprError::Custom) } +#[function( + "map_from_entries(anyarray) -> anymap", + type_infer = "map_from_entries_type_infer" +)] +fn map_from_entries(entries: ListRef<'_>) -> Result { + MapValue::try_from_entries(entries.to_owned()).map_err(ExprError::Custom) +} + /// # Example /// /// ```slt /// query T -/// select map_access(map_from_entries(array[1,2,3], array[100,200,300]), 3); +/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), 3); /// ---- /// 300 /// /// query T -/// select map_access(map_from_entries(array[1,2,3], array[100,200,300]), '3'); +/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), '3'); /// ---- /// 300 /// /// query error -/// select map_access(map_from_entries(array[1,2,3], array[100,200,300]), 1.0); +/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), 1.0); /// ---- /// db error: ERROR: Failed to run the query /// /// Caused by these errors (recent errors listed first): -/// 1: Failed to bind expression: map_access(map_from_entries(ARRAY[1, 2, 3], ARRAY[100, 200, 300]), 1.0) +/// 1: Failed to bind expression: map_access(map_from_key_values(ARRAY[1, 2, 3], ARRAY[100, 200, 300]), 1.0) /// 2: Bind error: Cannot access numeric in map(integer,integer) /// /// /// query T -/// select map_access(map_from_entries(array['a','b','c'], array[1,2,3]), 'a'); +/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), 'a'); /// ---- /// 1 /// /// query T -/// select map_access(map_from_entries(array['a','b','c'], array[1,2,3]), 'd'); +/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), 'd'); /// ---- /// NULL /// /// query T -/// select map_access(map_from_entries(array['a','b','c'], array[1,2,3]), null); +/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), null); /// ---- /// NULL /// ``` -#[function("map_access(anymap, any) -> any")] +#[function("map_access(anymap, any) -> any", type_infer = "unreachable")] fn map_access<'a>( map: MapRef<'a>, key: ScalarRefImpl<'_>, @@ -111,6 +125,122 @@ fn map_access<'a>( } } +/// ```slt +/// query T +/// select +/// map_contains(MAP{1:1}, 1), +/// map_contains(MAP{1:1}, 2), +/// map_contains(MAP{1:1}, NULL::varchar), +/// map_contains(MAP{1:1}, 1.0) +/// ---- +/// t f NULL f +/// ``` +#[function("map_contains(anymap, any) -> boolean")] +fn map_contains(map: MapRef<'_>, key: ScalarRefImpl<'_>) -> Result { + let (keys, _values) = map.into_kv(); + let idx = array_position(keys, Some(key))?; + Ok(idx.is_some()) +} + +/// ```slt +/// query I +/// select +/// map_length(NULL::map(int,int)), +/// map_length(MAP {}::map(int,int)), +/// map_length(MAP {1:1,2:2}::map(int,int)) +/// ---- +/// NULL 0 2 +/// ``` +#[function("map_length(anymap) -> int4")] +fn map_length>(map: MapRef<'_>) -> Result { + map.inner() + .len() + .try_into() + .map_err(|_| ExprError::NumericOverflow) +} + +/// If both `m1` and `m2` have a value with the same key, then the output map contains the value from `m2`. +/// +/// ```slt +/// query T +/// select map_cat(MAP{'a':1,'b':2},null::map(varchar,int)); +/// ---- +/// {a:1,b:2} +/// +/// query T +/// select map_cat(MAP{'a':1,'b':2},MAP{'b':3,'c':4}); +/// ---- +/// {a:1,b:3,c:4} +/// +/// # implicit type cast +/// query T +/// select map_cat(MAP{'a':1,'b':2},MAP{'b':3.0,'c':4.0}); +/// ---- +/// {a:1,b:3.0,c:4.0} +/// ``` +#[function("map_cat(anymap, anymap) -> anymap")] +fn map_cat(m1: Option>, m2: Option>) -> Result, ExprError> { + match (m1, m2) { + (None, None) => Ok(None), + (Some(m), None) | (None, Some(m)) => Ok(Some(m.to_owned())), + (Some(m1), Some(m2)) => Ok(Some(MapValue::concat(m1, m2))), + } +} + +/// Inserts a key-value pair into the map. If the key already exists, the value is updated. +/// +/// # Example +/// +/// ```slt +/// query T +/// select map_insert(map{'a':1, 'b':2}, 'c', 3); +/// ---- +/// {a:1,b:2,c:3} +/// +/// query T +/// select map_insert(map{'a':1, 'b':2}, 'b', 4); +/// ---- +/// {a:1,b:4} +/// ``` +/// +/// TODO: support variadic arguments +#[function("map_insert(anymap, any, any) -> anymap")] +fn map_insert( + map: MapRef<'_>, + key: Option>, + value: Option>, +) -> MapValue { + let Some(key) = key else { + return map.to_owned(); + }; + MapValue::insert(map, key.into_scalar_impl(), value.to_owned_datum()) +} + +/// Deletes a key-value pair from the map. +/// +/// # Example +/// +/// ```slt +/// query T +/// select map_delete(map{'a':1, 'b':2, 'c':3}, 'b'); +/// ---- +/// {a:1,c:3} +/// +/// query T +/// select map_delete(map{'a':1, 'b':2, 'c':3}, 'd'); +/// ---- +/// {a:1,b:2,c:3} +/// ``` +/// +/// TODO: support variadic arguments +#[function("map_delete(anymap, any) -> anymap")] +fn map_delete(map: MapRef<'_>, key: Option>) -> MapValue { + let Some(key) = key else { + return map.to_owned(); + }; + MapValue::delete(map, key) +} + #[cfg(test)] mod tests { use risingwave_common::array::DataChunk; diff --git a/src/expr/impl/src/scalar/cast.rs b/src/expr/impl/src/scalar/cast.rs index e0dd1a8bb3fc8..41c51d95445ec 100644 --- a/src/expr/impl/src/scalar/cast.rs +++ b/src/expr/impl/src/scalar/cast.rs @@ -249,7 +249,7 @@ fn map_cast(map: MapRef<'_>, ctx: &Context) -> Result { return_type: ctx.return_type.as_map().clone().into_list(), variadic: ctx.variadic, }; - list_cast(map.into_inner(), &new_ctx).map(MapValue::from_list_entries) + list_cast(map.into_inner(), &new_ctx).map(MapValue::from_entries) } #[cfg(test)] diff --git a/src/frontend/src/binder/expr/function/builtin_scalar.rs b/src/frontend/src/binder/expr/function/builtin_scalar.rs index 824f08cf36b73..73eb722b26011 100644 --- a/src/frontend/src/binder/expr/function/builtin_scalar.rs +++ b/src/frontend/src/binder/expr/function/builtin_scalar.rs @@ -402,6 +402,15 @@ impl Binder { // map ("map_from_entries", raw_call(ExprType::MapFromEntries)), ("map_access",raw_call(ExprType::MapAccess)), + ("map_keys", raw_call(ExprType::MapKeys)), + ("map_values", raw_call(ExprType::MapValues)), + ("map_entries", raw_call(ExprType::MapEntries)), + ("map_from_key_values", raw_call(ExprType::MapFromKeyValues)), + ("map_cat", raw_call(ExprType::MapCat)), + ("map_contains", raw_call(ExprType::MapContains)), + ("map_delete", raw_call(ExprType::MapDelete)), + ("map_insert", raw_call(ExprType::MapInsert)), + ("map_length", raw_call(ExprType::MapLength)), // Functions that return a constant value ("pi", pi()), // greatest and least diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index 3c127c7da7c40..85ed93c7dc0ca 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -1017,7 +1017,7 @@ pub fn bind_data_type(data_type: &AstDataType) -> Result { AstDataType::Map(kv) => { let key = bind_data_type(&kv.0)?; let value = bind_data_type(&kv.1)?; - DataType::Map(MapType::try_from_kv(key, value)?) + DataType::Map(MapType::try_from_kv(key, value).map_err(ErrorCode::BindError)?) } AstDataType::Custom(qualified_type_name) => { let idents = qualified_type_name diff --git a/src/frontend/src/binder/expr/value.rs b/src/frontend/src/binder/expr/value.rs index 711aa6bbb6979..961306408a43e 100644 --- a/src/frontend/src/binder/expr/value.rs +++ b/src/frontend/src/binder/expr/value.rs @@ -159,7 +159,7 @@ impl Binder { .into(); let expr: ExprImpl = FunctionCall::new_unchecked( - ExprType::MapFromEntries, + ExprType::MapFromKeyValues, vec![keys, values], DataType::Map(MapType::from_kv(key_type, value_type)), ) @@ -209,7 +209,7 @@ impl Binder { .into(); let expr: ExprImpl = FunctionCall::new_unchecked( - ExprType::MapFromEntries, + ExprType::MapFromKeyValues, vec![keys, values], DataType::Map(map_type), ) diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index 73becd7bc86c8..f650fa3cb521b 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -17,7 +17,7 @@ use fixedbitset::FixedBitSet; use futures::FutureExt; use paste::paste; use risingwave_common::array::ListValue; -use risingwave_common::types::{DataType, Datum, JsonbVal, Scalar, ScalarImpl}; +use risingwave_common::types::{DataType, Datum, JsonbVal, MapType, Scalar, ScalarImpl}; use risingwave_expr::aggregate::PbAggKind; use risingwave_expr::expr::build_from_prost; use risingwave_pb::expr::expr_node::RexNode; @@ -324,6 +324,19 @@ impl ExprImpl { } } + /// Ensure the return type of this expression is a map of some type. + pub fn try_into_map_type(&self) -> Result { + if self.is_untyped() { + return Err(ErrorCode::BindError( + "could not determine polymorphic type because input has type unknown".into(), + )); + } + match self.return_type() { + DataType::Map(m) => Ok(m), + t => Err(ErrorCode::BindError(format!("expects map but got {t}"))), + } + } + /// Shorthand to enforce implicit cast to boolean pub fn enforce_bool_clause(self, clause: &str) -> RwResult { if self.is_untyped() { diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 59f087672417c..3e6c83d8330fb 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -251,7 +251,16 @@ impl ExprVisitor for ImpureAnalyzer { | Type::QuoteLiteral | Type::QuoteNullable | Type::MapFromEntries - | Type::MapAccess => + | Type::MapAccess + | Type::MapKeys + | Type::MapValues + | Type::MapEntries + | Type::MapFromKeyValues + | Type::MapCat + | Type::MapContains + | Type::MapDelete + | Type::MapInsert + | Type::MapLength => // expression output is deterministic(same result for the same input) { func_call diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 746460e2b6363..9ed7530499921 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -613,7 +613,7 @@ fn infer_type_for_special( } ExprType::MapAccess => { ensure_arity!("map_access", | inputs | == 2); - let map_type = inputs[0].return_type().into_map(); + let map_type = inputs[0].try_into_map_type()?; // We do not align the map's key type with the input type here, but cast the latter to the former instead. // e.g., for {1:'a'}[1.0], if we align them, we will get "numeric" as the key type, which violates the map type's restriction. match inputs[1].cast_implicit_mut(map_type.key().clone()) { @@ -626,6 +626,40 @@ fn infer_type_for_special( .into()), } } + ExprType::MapCat => { + ensure_arity!("map_contains", | inputs | == 2); + Ok(Some(align_types(inputs.iter_mut())?)) + } + ExprType::MapInsert => { + ensure_arity!("map_insert", | inputs | == 3); + let map_type = inputs[0].try_into_map_type()?; + let rk = inputs[1].cast_implicit_mut(map_type.key().clone()); + let rv = inputs[2].cast_implicit_mut(map_type.value().clone()); + match (rk, rv) { + (Ok(()), Ok(())) => Ok(Some(map_type.into())), + _ => Err(ErrorCode::BindError(format!( + "Cannot insert ({},{}) to {}", + inputs[1].return_type(), + inputs[2].return_type(), + inputs[0].return_type(), + )) + .into()), + } + } + ExprType::MapDelete => { + ensure_arity!("map_delete", | inputs | == 2); + let map_type = inputs[0].try_into_map_type()?; + let rk = inputs[1].cast_implicit_mut(map_type.key().clone()); + match rk { + Ok(()) => Ok(Some(map_type.into())), + _ => Err(ErrorCode::BindError(format!( + "Cannot delete {} from {}", + inputs[1].return_type(), + inputs[0].return_type(), + )) + .into()), + } + } ExprType::Vnode => { ensure_arity!("vnode", 1 <= | inputs |); Ok(Some(VirtualNode::RW_TYPE)) diff --git a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs index 5e9a3ce05392c..2c14fc730877d 100644 --- a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs +++ b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs @@ -293,6 +293,15 @@ impl Strong { | ExprType::JsonbSet | ExprType::MapFromEntries | ExprType::MapAccess + | ExprType::MapKeys + | ExprType::MapValues + | ExprType::MapEntries + | ExprType::MapFromKeyValues + | ExprType::MapCat + | ExprType::MapContains + | ExprType::MapDelete + | ExprType::MapInsert + | ExprType::MapLength | ExprType::Vnode | ExprType::TestPaidTier | ExprType::Proctime