diff --git a/proto/data.proto b/proto/data.proto index 06d223d142bf9..9bb15ebcc8d62 100644 --- a/proto/data.proto +++ b/proto/data.proto @@ -52,6 +52,7 @@ message DataType { JSONB = 18; SERIAL = 19; INT256 = 20; + MAP = 21; } TypeName type_name = 1; // Data length for char. @@ -102,6 +103,7 @@ enum ArrayType { JSONB = 16; SERIAL = 17; INT256 = 18; + MAP = 20; } message Array { diff --git a/proto/expr.proto b/proto/expr.proto index dedfa3f3cd3b7..a3ba11baafb27 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -282,6 +282,9 @@ message ExprNode { JSONB_TO_RECORD = 630; JSONB_SET = 631; + // Map functions + MAP_FROM_ENTRIES = 700; + // Non-pure functions below (> 1000) // ------------------------ // Internal functions diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index 2ecef10e7aa3f..bc8e4e6faef27 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -42,6 +42,8 @@ use std::fmt::Write; +use arrow_array::cast::AsArray; +use arrow_array_iceberg::array; use arrow_buffer::OffsetBuffer; use chrono::{DateTime, NaiveDateTime, NaiveTime}; use itertools::Itertools; @@ -113,6 +115,7 @@ pub trait ToArrow { ArrayImpl::Serial(array) => self.serial_to_arrow(array), ArrayImpl::List(array) => self.list_to_arrow(data_type, array), ArrayImpl::Struct(array) => self.struct_to_arrow(data_type, array), + ArrayImpl::Map(array) => self.map_to_arrow(data_type, array), }?; if arrow_array.data_type() != data_type { arrow_cast::cast(&arrow_array, data_type).map_err(ArrayError::to_arrow) @@ -267,6 +270,33 @@ pub trait ToArrow { ))) } + #[inline] + fn map_to_arrow( + &self, + data_type: &arrow_schema::DataType, + array: &MapArray, + ) -> Result { + let arrow_schema::DataType::Map(field, ordered) = data_type else { + return Err(ArrayError::to_arrow("Invalid map type")); + }; + if *ordered { + return Err(ArrayError::to_arrow("Sorted map is not supported")); + } + let values = self + .struct_to_arrow(field.data_type(), array.as_struct())? + .as_struct() + .clone(); + let offsets = OffsetBuffer::new(array.offsets().iter().map(|&o| o as i32).collect()); + let nulls = (!array.null_bitmap().all()).then(|| array.null_bitmap().into()); + Ok(Arc::new(arrow_array::MapArray::new( + field.clone(), + offsets, + values, + nulls, + *ordered, + ))) + } + /// Convert RisingWave data type to Arrow data type. /// /// This function returns a `Field` instead of `DataType` because some may be converted to @@ -297,6 +327,7 @@ pub trait ToArrow { DataType::Jsonb => return Ok(self.jsonb_type_to_arrow(name)), DataType::Struct(fields) => self.struct_type_to_arrow(fields)?, DataType::List(datatype) => self.list_type_to_arrow(datatype)?, + DataType::Map(datatype) => self.map_type_to_arrow(datatype)?, }; Ok(arrow_schema::Field::new(name, data_type, true)) } @@ -413,6 +444,20 @@ pub trait ToArrow { .try_collect::<_, _, ArrayError>()?, )) } + + #[inline] + fn map_type_to_arrow(&self, map_type: &MapType) -> Result { + let sorted = false; + let list_type = map_type.clone().into_list(); + Ok(arrow_schema::DataType::Map( + Arc::new(arrow_schema::Field::new( + "entries", + self.list_type_to_arrow(&list_type)?, + true, + )), + sorted, + )) + } } /// Defines how to convert Arrow arrays to RisingWave arrays. diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index 7fc1fdecee6fe..86a30e5adac84 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -29,6 +29,7 @@ use super::{ Array, ArrayBuilder, ArrayBuilderImpl, ArrayImpl, ArrayResult, BoolArray, PrimitiveArray, PrimitiveArrayItemType, RowRef, Utf8Array, }; +use crate::array::struct_array::{quote_if_need, PG_NEED_QUOTE_CHARS}; use crate::bitmap::{Bitmap, BitmapBuilder}; use crate::row::Row; use crate::types::{ @@ -56,6 +57,7 @@ impl ArrayBuilder for ListArrayBuilder { #[cfg(test)] fn new(capacity: usize) -> Self { + // TODO: deprecate this Self::with_type( capacity, // Default datatype @@ -500,6 +502,7 @@ impl From for ArrayImpl { } } +/// A slice of an array #[derive(Copy, Clone)] pub struct ListRef<'a> { array: &'a ArrayImpl, @@ -649,12 +652,7 @@ impl ToText for ListRef<'_> { let need_quote = !matches!(datum_ref, None | Some(ScalarRefImpl::List(_))) && (s.is_empty() || s.to_ascii_lowercase() == "null" - || s.contains([ - '"', '\\', '{', '}', ',', - // PostgreSQL `array_isspace` includes '\x0B' but rust - // [`char::is_ascii_whitespace`] does not. - ' ', '\t', '\n', '\r', '\x0B', '\x0C', - ])); + || s.contains(PG_NEED_QUOTE_CHARS)); if need_quote { f(&"\"")?; s.chars().try_for_each(|c| { diff --git a/src/common/src/array/map_array.rs b/src/common/src/array/map_array.rs new file mode 100644 index 0000000000000..b684480f0e236 --- /dev/null +++ b/src/common/src/array/map_array.rs @@ -0,0 +1,317 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Ordering; +use std::fmt::{self, Debug, Display}; + +use itertools::Itertools; +use risingwave_common_estimate_size::EstimateSize; +use risingwave_pb::data::{PbArray, PbArrayType}; + +use super::{ + Array, ArrayBuilder, ArrayImpl, ArrayResult, DatumRef, ListArray, ListArrayBuilder, ListRef, + ListValue, MapType, ScalarRefImpl, StructArray, +}; +use crate::bitmap::Bitmap; +use crate::types::{DataType, Scalar, ToText}; + +#[derive(Debug, Clone, EstimateSize)] +pub struct MapArrayBuilder { + inner: ListArrayBuilder, +} + +impl ArrayBuilder for MapArrayBuilder { + type ArrayType = MapArray; + + fn new(_capacity: usize) -> Self { + panic!("please use `MapArrayBuilder::with_type` instead"); + } + + // TODO: think how we really build a map array in exprs before implementing these methods. + + fn with_type(capacity: usize, ty: DataType) -> Self { + let inner = ListArrayBuilder::with_type(capacity, ty.into_map().into_list()); + Self { inner } + } + + fn append_n(&mut self, n: usize, value: Option>) { + self.inner.append_n(n, value.map(|v| v.0)); + } + + fn append_array(&mut self, other: &MapArray) { + self.inner.append_array(&other.inner); + } + + fn pop(&mut self) -> Option<()> { + self.inner.pop() + } + + fn len(&self) -> usize { + self.inner.len() + } + + fn finish(self) -> MapArray { + let inner = self.inner.finish(); + MapArray { inner } + } +} + +/// `MapArray` is physically just a `List>` array, but with some additional restrictions +/// on the `key`. +#[derive(Debug, Clone, Eq)] +pub struct MapArray { + pub(super) inner: ListArray, +} + +impl EstimateSize for MapArray { + fn estimated_heap_size(&self) -> usize { + self.inner.estimated_heap_size() + } +} + +impl Array for MapArray { + type Builder = MapArrayBuilder; + type OwnedItem = MapValue; + type RefItem<'a> = MapRef<'a>; + + unsafe fn raw_value_at_unchecked(&self, idx: usize) -> Self::RefItem<'_> { + let list = self.inner.raw_value_at_unchecked(idx); + MapRef(list) + } + + fn len(&self) -> usize { + self.inner.len() + } + + fn to_protobuf(&self) -> PbArray { + let mut array = self.inner.to_protobuf(); + array.array_type = PbArrayType::Map as i32; + array + } + + fn null_bitmap(&self) -> &Bitmap { + self.inner.null_bitmap() + } + + fn into_null_bitmap(self) -> Bitmap { + self.inner.into_null_bitmap() + } + + fn set_bitmap(&mut self, bitmap: Bitmap) { + self.inner.set_bitmap(bitmap) + } + + fn data_type(&self) -> DataType { + let list_value_type = self.inner.values().data_type(); + let struct_type = list_value_type.as_struct(); + if cfg!(debug_assertions) { + // the field names are not strictly enforced + itertools::assert_equal(struct_type.names(), ["key", "value"]); + } + let (key_t, value_t) = struct_type + .types() + .collect_tuple() + .expect("the struct in a map must contains exactly 2 fields"); + DataType::Map(MapType::from_kv(key_t.clone(), value_t.clone())) + } +} + +impl MapArray { + pub fn from_protobuf(array: &PbArray) -> ArrayResult { + // TODO: avoid the clone + let mut array = array.clone(); + array.array_type = PbArrayType::List as i32; + let inner = ListArray::from_protobuf(&array)?.into_list(); + Ok(Self { inner }.into()) + } + + /// Return the inner struct array of the list array. + pub fn as_struct(&self) -> &StructArray { + self.inner.values().as_struct() + } + + /// Returns the offsets of this map. + pub fn offsets(&self) -> &[u32] { + self.inner.offsets() + } +} + +impl FromIterator for MapArray { + fn from_iter>(iter: I) -> Self { + todo!() + // let mut iter = iter.into_iter(); + // let first = iter.next().expect("empty iterator"); + // let mut builder = MapArrayBuilder::with_type( + // iter.size_hint().0, + // DataType::Map(Box::new(first.data_type())), + // ); + // builder.append(Some(first.as_scalar_ref())); + // for v in iter { + // builder.append(Some(v.as_scalar_ref())); + // } + // builder.finish() + } +} + +#[derive(Clone, Eq, EstimateSize)] +pub struct MapValue(pub(crate) ListValue); + +mod cmp { + use super::*; + impl PartialEq for MapArray { + fn eq(&self, _other: &Self) -> bool { + unreachable!("map is not comparable. Such usage should be banned in frontend.") + } + } + + impl PartialEq for MapValue { + fn eq(&self, _other: &Self) -> bool { + unreachable!("map is not comparable. Such usage should be banned in frontend.") + } + } + + impl PartialOrd for MapValue { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for MapValue { + fn cmp(&self, _other: &Self) -> Ordering { + unreachable!("map is not comparable. Such usage should be banned in frontend.") + } + } +} + +impl Debug for MapValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.as_scalar_ref().fmt(f) + } +} + +impl Display for MapValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.as_scalar_ref().write(f) + } +} + +impl MapValue { + pub fn from_list_entries(list: ListValue) -> Self { + if cfg!(debug_assertions) { + // validates list type is valid + _ = MapType::from_list_entries(list.data_type()); + } + // TODO: validate the values is valid + MapValue(list) + } + + pub fn from_kv(keys: ListValue, values: ListValue) -> Self { + if cfg!(debug_assertions) { + assert_eq!( + keys.len(), + values.len(), + "keys: {keys:?}, values: {values:?}" + ); + let unique_keys = keys.iter().unique().collect_vec(); + assert!( + unique_keys.len() == keys.len(), + "non unique keys in map: {keys:?}" + ); + assert!(!unique_keys.contains(&None), "null key in map: {keys:?}"); + } + + let len = keys.len(); + let key_type = keys.data_type(); + let value_type = values.data_type(); + let struct_array = StructArray::new( + MapType::struct_type_for_map(key_type, value_type), + vec![keys.into_array().into_ref(), values.into_array().into_ref()], + Bitmap::ones(len), + ); + MapValue(ListValue::new(struct_array.into())) + } +} + +/// A map is just a slice of the underlying struct array. +/// +/// XXX: perhaps we can make it `MapRef<'a, 'b>(ListRef<'a>, ListRef<'b>);`. +/// Then we can build a map ref from 2 list refs without copying the data. +/// Currently it's impossible. +#[derive(Copy, Clone, Eq)] +pub struct MapRef<'a>(pub(crate) ListRef<'a>); + +impl<'a> MapRef<'a> { + /// Iterates over the elements of the map. + pub fn iter( + self, + ) -> impl DoubleEndedIterator + ExactSizeIterator, DatumRef<'a>)> + 'a + { + self.0.iter().map(|list_elem| { + let list_elem = list_elem.expect("the list element in map should not be null"); + let struct_ = list_elem.into_struct(); + let (k, v) = struct_ + .iter_fields_ref() + .next_tuple() + .expect("the struct in map should have exactly 2 fields"); + (k.expect("map key should not be null"), v) + }) + } +} + +impl PartialEq for MapRef<'_> { + fn eq(&self, _other: &Self) -> bool { + unreachable!("map is not comparable. Such usage should be banned in frontend.") + } +} + +impl Ord for MapRef<'_> { + fn cmp(&self, _other: &Self) -> Ordering { + unreachable!("map is not comparable. Such usage should be banned in frontend.") + } +} + +impl PartialOrd for MapRef<'_> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Debug for MapRef<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.0.iter()).finish() + } +} + +impl ToText for MapRef<'_> { + fn write(&self, f: &mut W) -> std::fmt::Result { + // Note: This is arbitrarily decided... + write!( + f, + "{{{}}}", + self.iter().format_with(",", |(key, value), f| { + let key = key.to_text(); + let value = value.to_text(); + // TODO: consider quote like list and struct + f(&format_args!("\"{}\":{}", key, value)) + }) + ) + } + + fn write_with_type(&self, ty: &DataType, f: &mut W) -> std::fmt::Result { + match ty { + DataType::Map { .. } => self.write(f), + _ => unreachable!(), + } + } +} diff --git a/src/common/src/array/mod.rs b/src/common/src/array/mod.rs index 89b3b06266786..132540387d590 100644 --- a/src/common/src/array/mod.rs +++ b/src/common/src/array/mod.rs @@ -26,6 +26,7 @@ pub mod interval_array; mod iterator; mod jsonb_array; pub mod list_array; +mod map_array; mod num256_array; mod primitive_array; mod proto_reader; @@ -53,6 +54,7 @@ pub use interval_array::{IntervalArray, IntervalArrayBuilder}; pub use iterator::ArrayIterator; pub use jsonb_array::{JsonbArray, JsonbArrayBuilder}; pub use list_array::{ListArray, ListArrayBuilder, ListRef, ListValue}; +pub use map_array::{MapArray, MapArrayBuilder, MapRef, MapValue}; use paste::paste; pub use primitive_array::{PrimitiveArray, PrimitiveArrayBuilder, PrimitiveArrayItemType}; use risingwave_common_estimate_size::EstimateSize; @@ -104,6 +106,7 @@ pub trait ArrayBuilder: Send + Sync + Sized + 'static { type ArrayType: Array; /// Create a new builder with `capacity`. + /// TODO: remove this function from the trait. Let it be methods of each concrete builders. fn new(capacity: usize) -> Self; /// # Panics @@ -138,6 +141,8 @@ pub trait ArrayBuilder: Send + Sync + Sized + 'static { /// # Returns /// /// Returns `None` if there is no elements in the builder. + /// + /// XXX: This seems useless. Perhaps we can delete it. fn pop(&mut self) -> Option<()>; /// Append an element in another array into builder. @@ -331,6 +336,8 @@ macro_rules! array_impl_enum { for_all_array_variants! { array_impl_enum } +// XXX: We can merge the From impl into impl_convert + impl From> for ArrayImpl { fn from(arr: PrimitiveArray) -> Self { T::erase_array_type(arr) @@ -379,6 +386,12 @@ impl From for ArrayImpl { } } +impl From for ArrayImpl { + fn from(arr: MapArray) -> Self { + Self::Map(arr) + } +} + /// `impl_convert` implements several conversions for `Array` and `ArrayBuilder`. /// * `ArrayImpl -> &Array` with `impl.as_int16()`. /// * `ArrayImpl -> Array` with `impl.into_int16()`. @@ -390,6 +403,9 @@ macro_rules! impl_convert { $( paste! { impl ArrayImpl { + /// # Panics + /// + /// Panics if type mismatches. pub fn [](&self) -> &$array { match self { Self::$variant_name(ref array) => array, @@ -397,6 +413,9 @@ macro_rules! impl_convert { } } + /// # Panics + /// + /// Panics if type mismatches. pub fn [](self) -> $array { match self { Self::$variant_name(array) => array, @@ -405,6 +424,7 @@ macro_rules! impl_convert { } } + // FIXME: panic in From here is not proper. impl <'a> From<&'a ArrayImpl> for &'a $array { fn from(array: &'a ArrayImpl) -> Self { match array { diff --git a/src/common/src/array/proto_reader.rs b/src/common/src/array/proto_reader.rs index aa296900190df..665b8f9e92758 100644 --- a/src/common/src/array/proto_reader.rs +++ b/src/common/src/array/proto_reader.rs @@ -43,6 +43,7 @@ impl ArrayImpl { PbArrayType::List => ListArray::from_protobuf(array)?, PbArrayType::Bytea => read_string_array::(array, cardinality)?, PbArrayType::Int256 => Int256Array::from_protobuf(array, cardinality)?, + PbArrayType::Map => MapArray::from_protobuf(array)?, }; Ok(array) } diff --git a/src/common/src/array/struct_array.rs b/src/common/src/array/struct_array.rs index 22aae00c84f4c..d8477a01d4d63 100644 --- a/src/common/src/array/struct_array.rs +++ b/src/common/src/array/struct_array.rs @@ -497,15 +497,23 @@ impl ToText for StructRef<'_> { } } -/// Double quote a string if it contains any special characters. -fn quote_if_need(input: &str, writer: &mut impl Write) -> std::fmt::Result { +pub const PG_NEED_QUOTE_CHARS: [char; 11] = [ + '"', '\\', '(', ')', ',', + // PostgreSQL `array_isspace` includes '\x0B' but rust + // [`char::is_ascii_whitespace`] does not. + ' ', '\t', '\n', '\r', '\x0B', '\x0C', +]; + +/// Double quote a string if it contains any special characters./// +pub fn quote_if_need(input: &str, writer: &mut impl Write) -> std::fmt::Result { + // Note: for struct here, 'null' as a string is not quoted, but for list it's quoted: + // ```sql + // select row('a','a b','null'), array['a','a b','null']; + // ---- + // (a,"a b",null) {a,"a b","null"} + // ``` if !input.is_empty() // non-empty - && !input.contains([ - '"', '\\', '(', ')', ',', - // PostgreSQL `array_isspace` includes '\x0B' but rust - // [`char::is_ascii_whitespace`] does not. - ' ', '\t', '\n', '\r', '\x0B', '\x0C', - ]) + && !input.contains(PG_NEED_QUOTE_CHARS) { return writer.write_str(input); } diff --git a/src/common/src/hash/key.rs b/src/common/src/hash/key.rs index e9f7e83ac9146..d08ed7c8c317c 100644 --- a/src/common/src/hash/key.rs +++ b/src/common/src/hash/key.rs @@ -33,7 +33,7 @@ use risingwave_common_estimate_size::EstimateSize; use smallbitset::Set64; use static_assertions::const_assert_eq; -use crate::array::{ListValue, StructValue}; +use crate::array::{ListValue, MapValue, StructValue}; use crate::types::{ DataType, Date, Decimal, Int256, Int256Ref, JsonbVal, Scalar, ScalarRef, ScalarRefImpl, Serial, Time, Timestamp, Timestamptz, F32, F64, @@ -627,6 +627,7 @@ impl_value_encoding_hash_key_serde!(JsonbVal); // use the memcmp encoding for safety. impl_memcmp_encoding_hash_key_serde!(StructValue); impl_memcmp_encoding_hash_key_serde!(ListValue); +impl_memcmp_encoding_hash_key_serde!(MapValue); #[cfg(test)] mod tests { diff --git a/src/common/src/test_utils/rand_array.rs b/src/common/src/test_utils/rand_array.rs index b7c1d3630b9b7..338539ee5b89c 100644 --- a/src/common/src/test_utils/rand_array.rs +++ b/src/common/src/test_utils/rand_array.rs @@ -24,7 +24,7 @@ use rand::prelude::Distribution; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; -use crate::array::{Array, ArrayBuilder, ArrayRef, ListValue, StructValue}; +use crate::array::{Array, ArrayBuilder, ArrayRef, ListValue, MapValue, StructValue}; use crate::types::{ Date, Decimal, Int256, Interval, JsonbVal, NativeType, Scalar, Serial, Time, Timestamp, Timestamptz, @@ -151,6 +151,12 @@ impl RandValue for ListValue { } } +impl RandValue for MapValue { + fn rand_value(rand: &mut R) -> Self { + todo!() + } +} + pub fn rand_array(rand: &mut R, size: usize, null_ratio: f64) -> A where A: Array, diff --git a/src/common/src/test_utils/rand_chunk.rs b/src/common/src/test_utils/rand_chunk.rs index 3e537fd9b6a49..9c604b6205cc3 100644 --- a/src/common/src/test_utils/rand_chunk.rs +++ b/src/common/src/test_utils/rand_chunk.rs @@ -43,10 +43,11 @@ pub fn gen_chunk(data_types: &[DataType], size: usize, seed: u64, null_ratio: f6 } DataType::Interval => seed_rand_array_ref::(size, seed, null_ratio), DataType::Int256 => seed_rand_array_ref::(size, seed, null_ratio), - DataType::Struct(_) | DataType::Bytea | DataType::Jsonb => { - todo!() - } - DataType::List(_) => { + DataType::Struct(_) + | DataType::Bytea + | DataType::Jsonb + | DataType::List(_) + | DataType::Map(_) => { todo!() } }); diff --git a/src/common/src/types/macros.rs b/src/common/src/types/macros.rs index 520e4ab8f45ee..1dd29156dd651 100644 --- a/src/common/src/types/macros.rs +++ b/src/common/src/types/macros.rs @@ -58,6 +58,7 @@ macro_rules! for_all_variants { { Serial, Serial, serial, $crate::types::Serial, $crate::types::Serial, $crate::array::SerialArray, $crate::array::SerialArrayBuilder }, { Struct, Struct, struct, $crate::types::StructValue, $crate::types::StructRef<'scalar>, $crate::array::StructArray, $crate::array::StructArrayBuilder }, { List, List, list, $crate::types::ListValue, $crate::types::ListRef<'scalar>, $crate::array::ListArray, $crate::array::ListArrayBuilder }, + { Map, Map, map, $crate::types::MapValue, $crate::types::MapRef<'scalar>, $crate::array::MapArray, $crate::array::MapArrayBuilder }, { Bytea, Bytea, bytea, Box<[u8]>, &'scalar [u8], $crate::array::BytesArray, $crate::array::BytesArrayBuilder } } }; diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index b13df14a93308..97aaed5427c05 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -37,7 +37,8 @@ use thiserror_ext::AsReport; use crate::array::{ ArrayBuilderImpl, ArrayError, ArrayResult, PrimitiveArrayItemType, NULL_VAL_FOR_HASH, }; -pub use crate::array::{ListRef, ListValue, StructRef, StructValue}; +// Complex type's value is based on the array +pub use crate::array::{ListRef, ListValue, MapRef, MapValue, StructRef, StructValue}; use crate::cast::{str_to_bool, str_to_bytea}; use crate::error::BoxedError; use crate::{ @@ -166,6 +167,110 @@ pub enum DataType { #[display("rw_int256")] #[from_str(regex = "(?i)^rw_int256$")] Int256, + // FIXME: what is the syntax for map? + #[display("map{0}")] + #[from_str(regex = "(?i)^map(?P<0>.+)$")] + Map(MapType), +} + +pub use map_type::MapType; +mod map_type { + use std::fmt::Formatter; + + use anyhow::*; + + use super::*; + // TODO: check the trait impls + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub struct MapType(Box<(DataType, DataType)>); + + impl From for DataType { + fn from(value: MapType) -> Self { + DataType::Map(value) + } + } + + impl MapType { + /// # Panics + /// Panics if the key type is not valid for a map. + pub fn from_kv(key: DataType, value: DataType) -> Self { + Self::debug_assert_key_type_valid(&key); + Self(Box::new((key, value))) + } + + /// # Panics + /// Panics if the key type is not valid for a map, or the + /// entries type is not a valid struct type. + pub fn from_list_entries(list_entries_type: DataType) -> Self { + let struct_type = list_entries_type.as_struct(); + let (k, v) = struct_type + .iter() + .collect_tuple() + .expect("the underlying struct for map must have exactly two fields"); + debug_assert!(k.0 == "key" && v.0 == "value", "k: {k:?}, v: {v:?}"); + Self::from_kv(k.1.clone(), v.1.clone()) + } + + pub fn struct_type_for_map(key_type: DataType, value_type: DataType) -> StructType { + MapType::debug_assert_key_type_valid(&key_type); + StructType::new(vec![("key", key_type), ("value", value_type)]) + } + + pub fn key(&self) -> &DataType { + &self.0 .0 + } + + pub fn value(&self) -> &DataType { + &self.0 .1 + } + + pub fn into_struct(self) -> StructType { + let (key, value) = *self.0; + Self::struct_type_for_map(key, value) + } + + pub fn into_list(self) -> DataType { + DataType::List(Box::new(DataType::Struct(self.into_struct()))) + } + + pub fn debug_assert_key_type_valid(data_type: &DataType) { + let valid = match data_type { + DataType::Int16 | DataType::Int32 | DataType::Int64 => true, + DataType::Varchar => true, + DataType::Boolean + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal + | DataType::Date + | DataType::Time + | DataType::Timestamp + | DataType::Timestamptz + | DataType::Interval + | DataType::Struct(_) + | DataType::List(_) + | DataType::Bytea + | DataType::Jsonb + | DataType::Serial + | DataType::Int256 + | DataType::Map(_) => false, + }; + debug_assert!(valid, "invalid map key type: {data_type}"); + } + } + + impl FromStr for MapType { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + todo!() + } + } + + impl std::fmt::Display for MapType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + todo!() + } + } } impl std::str::FromStr for Box { @@ -200,7 +305,7 @@ impl TryFrom for DataType { DataTypeName::Time => Ok(DataType::Time), DataTypeName::Interval => Ok(DataType::Interval), DataTypeName::Jsonb => Ok(DataType::Jsonb), - DataTypeName::Struct | DataTypeName::List => { + DataTypeName::Struct | DataTypeName::List | DataTypeName::Map => { Err("Functions returning struct or list can not be inferred. Please use `FunctionCall::new_unchecked`.") } } @@ -236,6 +341,12 @@ impl From<&PbDataType> for DataType { // The first (and only) item is the list element type. Box::new((&proto.field_type[0]).into()), ), + PbTypeName::Map => { + // Map is physically the same as a list. + // So the first (and only) item is the list element type. + let list_entries_type: DataType = (&proto.field_type[0]).into(); + DataType::Map(MapType::from_list_entries(list_entries_type)) + } PbTypeName::Int256 => DataType::Int256, } } @@ -263,6 +374,7 @@ impl From for PbTypeName { DataTypeName::Struct => PbTypeName::Struct, DataTypeName::List => PbTypeName::List, DataTypeName::Int256 => PbTypeName::Int256, + DataTypeName::Map => PbTypeName::Map, } } } @@ -324,6 +436,11 @@ impl DataType { DataType::List(datatype) => { pb.field_type = vec![datatype.to_protobuf()]; } + DataType::Map(datatype) => { + // Same as List> + pb.field_type = + vec![DataType::Struct(datatype.clone().into_struct()).to_protobuf()]; + } DataType::Boolean | DataType::Int16 | DataType::Int32 @@ -366,6 +483,10 @@ impl DataType { matches!(self, DataType::Struct(_)) } + pub fn is_map(&self) -> bool { + matches!(self, DataType::Map(_)) + } + pub fn is_int(&self) -> bool { matches!(self, DataType::Int16 | DataType::Int32 | DataType::Int64) } @@ -383,14 +504,32 @@ impl DataType { Self::Struct(StructType::from_parts(field_names, fields)) } + pub fn new_unnamed_struct(fields: Vec) -> Self { + Self::Struct(StructType::unnamed(fields)) + } + pub fn as_struct(&self) -> &StructType { match self { DataType::Struct(t) => t, - _ => panic!("expect struct type"), + t => panic!("expect struct type, got {t}"), + } + } + + pub fn as_map(&self) -> &MapType { + match self { + DataType::Map(t) => t, + t => panic!("expect map type, got {t}"), + } + } + + pub fn into_map(self) -> MapType { + match self { + DataType::Map(t) => t, + t => panic!("expect map type, got {t}"), } } - /// Returns the inner type of a list type. + /// Returns the inner element's type of a list type. /// /// # Panics /// @@ -398,11 +537,13 @@ impl DataType { pub fn as_list(&self) -> &DataType { match self { DataType::List(t) => t, - _ => panic!("expect list type"), + t => panic!("expect list type, got {t}"), } } - /// Return a new type that removes the outer list. + /// Return a new type that removes the outer list, and get the innermost element type. + /// + /// Use [`DataType::as_list`] if you only want the element type of a list. /// /// ``` /// use risingwave_common::types::DataType::*; @@ -447,6 +588,10 @@ impl From for PbDataType { mod private { use super::*; + // Note: put pub trait inside a private mod just makes the name private, + // The trait methods will still be publicly available... + // a.k.a. ["Voldemort type"](https://rust-lang.github.io/rfcs/2145-type-privacy.html#lint-3-voldemort-types-its-reachable-but-i-cant-name-it) + /// Common trait bounds of scalar and scalar reference types. /// /// NOTE(rc): `Hash` is not in the trait bound list, it's implemented as [`ScalarRef::hash_scalar`]. @@ -610,7 +755,7 @@ macro_rules! impl_self_as_scalar_ref { )* }; } -impl_self_as_scalar_ref! { &str, &[u8], Int256Ref<'_>, JsonbRef<'_>, ListRef<'_>, StructRef<'_>, ScalarRefImpl<'_> } +impl_self_as_scalar_ref! { &str, &[u8], Int256Ref<'_>, JsonbRef<'_>, ListRef<'_>, StructRef<'_>, ScalarRefImpl<'_>, MapRef<'_> } /// `for_all_native_types` includes all native variants of our scalar types. /// @@ -831,7 +976,7 @@ impl ScalarImpl { .ok_or_else(|| "invalid value of Jsonb".to_string())?, ), DataType::Int256 => Self::Int256(Int256::from_binary(bytes)?), - DataType::Struct(_) | DataType::List(_) => { + DataType::Struct(_) | DataType::List(_) | DataType::Map(_) => { return Err(format!("unsupported data type: {}", data_type).into()); } }; @@ -864,6 +1009,9 @@ impl ScalarImpl { DataType::Struct(_) => StructValue::from_str(s, data_type)?.into(), DataType::Jsonb => JsonbVal::from_str(s)?.into(), DataType::Bytea => str_to_bytea(s)?.into(), + DataType::Map(_) => { + todo!() + } }) } } @@ -930,7 +1078,8 @@ impl ScalarRefImpl<'_> { self.to_text_with_type(data_type) } - /// Serialize the scalar. + /// Serialize the scalar into the `memcomparable` format. + /// TODO: use serde? pub fn serialize( &self, ser: &mut memcomparable::Serializer, @@ -961,6 +1110,9 @@ impl ScalarRefImpl<'_> { Self::Jsonb(v) => v.memcmp_serialize(ser)?, Self::Struct(v) => v.memcmp_serialize(ser)?, Self::List(v) => v.memcmp_serialize(ser)?, + // Map should not be used as key. + // This should be banned in frontend and this branch should actually be unreachable. + Self::Map(_) => Err(memcomparable::Error::NotSupported("map"))?, }; Ok(()) } @@ -1015,6 +1167,7 @@ impl ScalarImpl { Ty::Jsonb => Self::Jsonb(JsonbVal::memcmp_deserialize(de)?), Ty::Struct(t) => StructValue::memcmp_deserialize(t.types(), de)?.to_scalar_value(), Ty::List(t) => ListValue::memcmp_deserialize(t, de)?.to_scalar_value(), + Ty::Map(_) => todo!(), }) } @@ -1194,6 +1347,7 @@ mod tests { ScalarImpl::List(ListValue::from_iter([233i64, 2333])), DataType::List(Box::new(DataType::Int64)), ), + DataTypeName::Map => todo!(), }; test(Some(scalar), data_type.clone()); diff --git a/src/common/src/types/postgres_type.rs b/src/common/src/types/postgres_type.rs index ae147e9c9660e..d85f08ed59cc3 100644 --- a/src/common/src/types/postgres_type.rs +++ b/src/common/src/types/postgres_type.rs @@ -54,6 +54,12 @@ pub struct UnsupportedOid(i32); /// Get type information compatible with Postgres type, such as oid, type length. impl DataType { + /// For a fixed-size type, typlen is the number of bytes in the internal representation of the type. + /// But for a variable-length type, typlen is negative. + /// -1 indicates a “varlena” type (one that has a length word), + /// -2 indicates a null-terminated C string. + /// + /// pub fn type_len(&self) -> i16 { macro_rules! impl_type_len { ($( { $enum:ident | $oid:literal | $oid_array:literal | $name:ident | $input:ident | $len:literal } )*) => { @@ -63,7 +69,7 @@ impl DataType { )* DataType::Serial => 8, DataType::Int256 => -1, - DataType::List(_) | DataType::Struct(_) => -1, + DataType::List(_) | DataType::Struct(_) | DataType::Map(_) => -1, } } } @@ -96,6 +102,7 @@ impl DataType { for_all_base_types! { impl_from_oid } } + /// Refer to [`Self::from_oid`] pub fn to_oid(&self) -> i32 { macro_rules! impl_to_oid { ($( { $enum:ident | $oid:literal | $oid_array:literal | $name:ident | $input:ident | $len:literal } )*) => { @@ -111,10 +118,14 @@ impl DataType { DataType::Serial => 1016, DataType::Struct(_) => -1, DataType::List { .. } => unreachable!("Never reach here!"), + DataType::Map(_) => 1304, } DataType::Serial => 20, + // XXX: what does the oid mean here? Why we don't have from_oid for them? DataType::Int256 => 1301, + DataType::Map(_) => 1303, // TODO: Support to give a new oid for custom struct type. #9434 + // 1043 is varchar DataType::Struct(_) => 1043, } } @@ -133,6 +144,7 @@ impl DataType { DataType::List(_) => "list", DataType::Serial => "serial", DataType::Int256 => "rw_int256", + DataType::Map(_) => "map", } } } diff --git a/src/common/src/types/scalar_impl.rs b/src/common/src/types/scalar_impl.rs index 43742f74c7b51..34cc9692079d4 100644 --- a/src/common/src/types/scalar_impl.rs +++ b/src/common/src/types/scalar_impl.rs @@ -91,6 +91,14 @@ impl Scalar for ListValue { } } +impl Scalar for MapValue { + type ScalarRefType<'a> = MapRef<'a>; + + fn as_scalar_ref(&self) -> MapRef<'_> { + MapRef(self.0.as_scalar_ref()) + } +} + /// Implement `ScalarRef` for `Box`. /// `Box` could be converted to `&str`. impl<'a> ScalarRef<'a> for &'a str { @@ -316,6 +324,18 @@ impl<'a> ScalarRef<'a> for ListRef<'a> { } } +impl<'a> ScalarRef<'a> for MapRef<'a> { + type ScalarType = MapValue; + + fn to_owned_scalar(&self) -> MapValue { + MapValue(self.0.to_owned_scalar()) + } + + fn hash_scalar(&self, _state: &mut H) { + unreachable!("map is not hashable. Such usage should be banned in frontend.") + } +} + impl ScalarImpl { pub fn get_ident(&self) -> &'static str { dispatch_scalar_variants!(self, [I = VARIANT_NAME], { I }) diff --git a/src/common/src/types/struct_type.rs b/src/common/src/types/struct_type.rs index a18f452af7a74..edc4b73311533 100644 --- a/src/common/src/types/struct_type.rs +++ b/src/common/src/types/struct_type.rs @@ -37,11 +37,11 @@ impl Debug for StructType { #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] struct StructTypeInner { - // Details about a struct type. There are 2 cases for a struct: - // 1. `field_names.len() == field_types.len()`: it represents a struct with named fields, - // e.g. `STRUCT`. - // 2. `field_names.len() == 0`: it represents a struct with unnamed fields, - // e.g. `ROW(1, 2)`. + /// Details about a struct type. There are 2 cases for a struct: + /// 1. `field_names.len() == field_types.len()`: it represents a struct with named fields, + /// e.g. `STRUCT`. + /// 2. `field_names.len() == 0`: it represents a struct with unnamed fields, + /// e.g. `ROW(1, 2)`. field_names: Box<[String]>, field_types: Box<[DataType]>, } @@ -71,6 +71,8 @@ impl StructType { } pub(super) fn from_parts(field_names: Vec, field_types: Vec) -> Self { + // TODO: enable this assertion + // debug_assert!(field_names.len() == field_types.len()); Self(Arc::new(StructTypeInner { field_types: field_types.into(), field_names: field_names.into(), diff --git a/src/common/src/types/to_binary.rs b/src/common/src/types/to_binary.rs index 56eea301f3f61..da7f75f0a2a3f 100644 --- a/src/common/src/types/to_binary.rs +++ b/src/common/src/types/to_binary.rs @@ -102,6 +102,7 @@ impl ToBinary for ScalarRefImpl<'_> { issue = 7949, "the pgwire extended-mode encoding for {ty} is unsupported" ), + ScalarRefImpl::Map(_) => todo!(), } } } diff --git a/src/common/src/types/to_sql.rs b/src/common/src/types/to_sql.rs index 3ece8a574c450..57aab11daf4d7 100644 --- a/src/common/src/types/to_sql.rs +++ b/src/common/src/types/to_sql.rs @@ -46,6 +46,7 @@ impl ToSql for ScalarImpl { ScalarImpl::Int256(_) | ScalarImpl::Struct(_) | ScalarImpl::List(_) => { bail_not_implemented!("the postgres encoding for {ty} is unsupported") } + ScalarImpl::Map(_) => todo!(), } } diff --git a/src/common/src/util/memcmp_encoding.rs b/src/common/src/util/memcmp_encoding.rs index 5a5ad598093af..c58f06f908520 100644 --- a/src/common/src/util/memcmp_encoding.rs +++ b/src/common/src/util/memcmp_encoding.rs @@ -170,6 +170,7 @@ fn calculate_encoded_size_inner( DataType::Varchar => deserializer.skip_bytes()?, DataType::Bytea => deserializer.skip_bytes()?, DataType::Int256 => Int256::MEMCMP_ENCODED_SIZE, + DataType::Map(_) => todo!(), }; // consume offset of fixed_type diff --git a/src/common/src/util/value_encoding/mod.rs b/src/common/src/util/value_encoding/mod.rs index a3da88911ad9a..44b56257ae938 100644 --- a/src/common/src/util/value_encoding/mod.rs +++ b/src/common/src/util/value_encoding/mod.rs @@ -13,7 +13,8 @@ // limitations under the License. //! Value encoding is an encoding format which converts the data into a binary form (not -//! memcomparable). +//! memcomparable, i.e., Key encoding). + use bytes::{Buf, BufMut}; use chrono::{Datelike, Timelike}; use either::{for_both, Either}; @@ -226,6 +227,7 @@ fn serialize_scalar(value: ScalarRefImpl<'_>, buf: &mut impl BufMut) { ScalarRefImpl::Jsonb(v) => serialize_str(&v.value_serialize(), buf), ScalarRefImpl::Struct(s) => serialize_struct(s, buf), ScalarRefImpl::List(v) => serialize_list(v, buf), + ScalarRefImpl::Map(m) => serialize_list(m.0, buf), } } @@ -251,6 +253,7 @@ fn estimate_serialize_scalar_size(value: ScalarRefImpl<'_>) -> usize { ScalarRefImpl::Jsonb(v) => v.capacity(), ScalarRefImpl::Struct(s) => estimate_serialize_struct_size(s), ScalarRefImpl::List(v) => estimate_serialize_list_size(v), + ScalarRefImpl::Map(_) => todo!(), } } @@ -354,6 +357,12 @@ fn deserialize_value(ty: &DataType, data: &mut impl Buf) -> Result { DataType::Struct(struct_def) => deserialize_struct(struct_def, data)?, DataType::Bytea => ScalarImpl::Bytea(deserialize_bytea(data).into()), DataType::List(item_type) => deserialize_list(item_type, data)?, + DataType::Map(map_type) => { + // FIXME: clone type everytime here is inefficient + let list = deserialize_list(&DataType::Struct(map_type.clone().into_struct()), data)? + .into_list(); + ScalarImpl::Map(MapValue::from_list_entries(list)) + } }) } diff --git a/src/connector/src/parser/mysql.rs b/src/connector/src/parser/mysql.rs index a28dddc9aa65a..fe9b77c643de7 100644 --- a/src/connector/src/parser/mysql.rs +++ b/src/connector/src/parser/mysql.rs @@ -127,8 +127,10 @@ pub fn mysql_row_to_owned_row(mysql_row: &mut MysqlRow, schema: &Schema) -> Owne | DataType::Struct(_) | DataType::List(_) | DataType::Int256 - | DataType::Serial => { + | DataType::Serial + | DataType::Map(_) => { // Interval, Struct, List, Int256 are not supported + // XXX: is this branch reachable? if let Ok(suppressed_count) = LOG_SUPPERSSER.check() { tracing::warn!(column = rw_field.name, ?rw_field.data_type, suppressed_count, "unsupported data type, set to null"); } diff --git a/src/connector/src/parser/postgres.rs b/src/connector/src/parser/postgres.rs index da17ea256ba3c..f55fe28f878f9 100644 --- a/src/connector/src/parser/postgres.rs +++ b/src/connector/src/parser/postgres.rs @@ -116,7 +116,8 @@ fn postgres_cell_to_scalar_impl( } } }, - DataType::Struct(_) | DataType::Serial => { + DataType::Struct(_) | DataType::Serial | DataType::Map(_) => { + // Is this branch reachable? // Struct and Serial are not supported tracing::warn!(name, ?data_type, "unsupported data type, set to null"); None diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index 04f3360b1a02a..c175ad77f8c4a 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -261,6 +261,7 @@ impl BigQuerySink { DataType::Int256 => Err(SinkError::BigQuery(anyhow::anyhow!( "Bigquery cannot support Int256" ))), + DataType::Map(_) => todo!(), } } @@ -310,6 +311,7 @@ impl BigQuerySink { "Bigquery cannot support Int256" ))) } + DataType::Map(_) => todo!(), }; Ok(tfs) } @@ -815,6 +817,7 @@ fn build_protobuf_field( "Don't support Float32 and Int256" ))) } + DataType::Map(_) => todo!(), } Ok((field, None)) } diff --git a/src/connector/src/sink/clickhouse.rs b/src/connector/src/sink/clickhouse.rs index ac4930460eced..3175014333569 100644 --- a/src/connector/src/sink/clickhouse.rs +++ b/src/connector/src/sink/clickhouse.rs @@ -399,6 +399,9 @@ impl ClickHouseSink { risingwave_common::types::DataType::Int256 => Err(SinkError::ClickHouse( "clickhouse can not support Int256".to_string(), )), + risingwave_common::types::DataType::Map(_) => Err(SinkError::ClickHouse( + "clickhouse can not support Map".to_string(), + )), }; if !is_match? { return Err(SinkError::ClickHouse(format!( @@ -941,6 +944,11 @@ impl ClickHouseFieldWithNull { "clickhouse can not support Bytea".to_string(), )) } + ScalarRefImpl::Map(_) => { + return Err(SinkError::ClickHouse( + "clickhouse can not support Map".to_string(), + )) + } }; let data = if clickhouse_schema_feature.can_null { vec![ClickHouseFieldWithNull::WithSome(data)] diff --git a/src/connector/src/sink/doris.rs b/src/connector/src/sink/doris.rs index 35c438a534992..f6b806df841f4 100644 --- a/src/connector/src/sink/doris.rs +++ b/src/connector/src/sink/doris.rs @@ -188,6 +188,9 @@ impl DorisSink { risingwave_common::types::DataType::Int256 => { Err(SinkError::Doris("doris can not support Int256".to_string())) } + risingwave_common::types::DataType::Map(_) => { + Err(SinkError::Doris("doris can not support Map".to_string())) + } } } } diff --git a/src/connector/src/sink/dynamodb.rs b/src/connector/src/sink/dynamodb.rs index 35b48c6e31faf..2df15f517ca0b 100644 --- a/src/connector/src/sink/dynamodb.rs +++ b/src/connector/src/sink/dynamodb.rs @@ -395,6 +395,7 @@ fn map_data_type( } AttributeValue::M(map) } + DataType::Map(_) => todo!(), }; Ok(attr) } diff --git a/src/connector/src/sink/encoder/avro.rs b/src/connector/src/sink/encoder/avro.rs index 8122126727298..4a2060f0a8c6c 100644 --- a/src/connector/src/sink/encoder/avro.rs +++ b/src/connector/src/sink/encoder/avro.rs @@ -454,6 +454,10 @@ fn encode_field( DataType::Int256 => { return no_match_err(); } + DataType::Map(_) => { + // TODO: + return no_match_err(); + } }; D::handle_union(value, opt_idx) diff --git a/src/connector/src/sink/encoder/json.rs b/src/connector/src/sink/encoder/json.rs index 3652f38bacbb2..6dc8809f42933 100644 --- a/src/connector/src/sink/encoder/json.rs +++ b/src/connector/src/sink/encoder/json.rs @@ -401,6 +401,7 @@ pub(crate) fn schema_type_mapping(rw_type: &DataType) -> &'static str { DataType::Jsonb => "string", DataType::Serial => "string", DataType::Int256 => "string", + DataType::Map(_) => "map", } } diff --git a/src/connector/src/sink/encoder/proto.rs b/src/connector/src/sink/encoder/proto.rs index a0e4d41dc58de..8046606b5690c 100644 --- a/src/connector/src/sink/encoder/proto.rs +++ b/src/connector/src/sink/encoder/proto.rs @@ -420,6 +420,10 @@ fn encode_field( DataType::Int256 => { return no_match_err(); } + DataType::Map(_) => { + // TODO: + return no_match_err(); + } }; Ok(value) diff --git a/src/connector/src/sink/formatter/debezium_json.rs b/src/connector/src/sink/formatter/debezium_json.rs index fd4813e78541a..9419b6aa5ecd9 100644 --- a/src/connector/src/sink/formatter/debezium_json.rs +++ b/src/connector/src/sink/formatter/debezium_json.rs @@ -311,6 +311,7 @@ pub(crate) fn field_to_json(field: &Field) -> Value { // we do the same here risingwave_common::types::DataType::Struct(_) => ("string", ""), risingwave_common::types::DataType::List { .. } => ("string", ""), + risingwave_common::types::DataType::Map(_) => todo!("map"), }; if name.is_empty() { diff --git a/src/connector/src/sink/remote.rs b/src/connector/src/sink/remote.rs index d0edded78d6db..c27e2b45e5285 100644 --- a/src/connector/src/sink/remote.rs +++ b/src/connector/src/sink/remote.rs @@ -211,7 +211,7 @@ async fn validate_remote_sink(param: &SinkParam, sink_name: &str) -> ConnectorRe ))) } }, - DataType::Serial | DataType::Int256 => Err(SinkError::Remote(anyhow!( + DataType::Serial | DataType::Int256 | DataType::Map(_) => Err(SinkError::Remote(anyhow!( "remote sink supports Int16, Int32, Int64, Float32, Float64, Boolean, Decimal, Time, Date, Interval, Jsonb, Timestamp, Timestamptz, Bytea, List and Varchar, (Es sink support Struct) got {:?}: {:?}", col.name, col.data_type, diff --git a/src/connector/src/sink/sqlserver.rs b/src/connector/src/sink/sqlserver.rs index 959513e38b349..701eb3eed51bd 100644 --- a/src/connector/src/sink/sqlserver.rs +++ b/src/connector/src/sink/sqlserver.rs @@ -550,6 +550,7 @@ fn bind_params( ScalarRefImpl::List(_) => return Err(data_type_not_supported("List")), ScalarRefImpl::Int256(_) => return Err(data_type_not_supported("Int256")), ScalarRefImpl::Serial(_) => return Err(data_type_not_supported("Serial")), + ScalarRefImpl::Map(_) => return Err(data_type_not_supported("Map")), }, None => match schema[col_idx].data_type { DataType::Boolean => { @@ -597,6 +598,7 @@ fn bind_params( DataType::Jsonb => return Err(data_type_not_supported("Jsonb")), DataType::Serial => return Err(data_type_not_supported("Serial")), DataType::Int256 => return Err(data_type_not_supported("Int256")), + DataType::Map(_) => return Err(data_type_not_supported("Map")), }, }; } @@ -630,6 +632,7 @@ fn check_data_type_compatibility(data_type: &DataType) -> Result<()> { DataType::Jsonb => Err(data_type_not_supported("Jsonb")), DataType::Serial => Err(data_type_not_supported("Serial")), DataType::Int256 => Err(data_type_not_supported("Int256")), + DataType::Map(_) => Err(data_type_not_supported("Map")), } } diff --git a/src/connector/src/sink/starrocks.rs b/src/connector/src/sink/starrocks.rs index 56c352bfb4a9e..26cfc32fbb894 100644 --- a/src/connector/src/sink/starrocks.rs +++ b/src/connector/src/sink/starrocks.rs @@ -246,6 +246,9 @@ impl StarrocksSink { risingwave_common::types::DataType::Int256 => Err(SinkError::Starrocks( "INT256 is not supported for Starrocks sink.".to_string(), )), + risingwave_common::types::DataType::Map(_) => Err(SinkError::Starrocks( + "MAP is not supported for Starrocks sink.".to_string(), + )), } } } diff --git a/src/expr/core/src/error.rs b/src/expr/core/src/error.rs index e02c5f4521cf5..4bceb284fbfd9 100644 --- a/src/expr/core/src/error.rs +++ b/src/expr/core/src/error.rs @@ -88,6 +88,7 @@ pub enum ExprError { #[error("More than one row returned by {0} used as an expression")] MaxOneRow(&'static str), + /// TODO: deprecate in favor of `Function` #[error(transparent)] Internal( #[from] @@ -111,6 +112,7 @@ pub enum ExprError { InvalidState(String), /// Function error message returned by UDF. + /// TODO: replace with `Function` #[error("{0}")] Custom(String), diff --git a/src/expr/core/src/sig/mod.rs b/src/expr/core/src/sig/mod.rs index 124a002f6519e..7be6c9df9936a 100644 --- a/src/expr/core/src/sig/mod.rs +++ b/src/expr/core/src/sig/mod.rs @@ -394,6 +394,8 @@ pub enum SigDataType { AnyArray, /// Accepts any struct data type AnyStruct, + /// TODO: not all type can be used as a map key. + AnyMap, } impl From for SigDataType { @@ -409,6 +411,7 @@ impl std::fmt::Display for SigDataType { Self::Any => write!(f, "any"), Self::AnyArray => write!(f, "anyarray"), Self::AnyStruct => write!(f, "anystruct"), + Self::AnyMap => write!(f, "anymap"), } } } @@ -421,6 +424,7 @@ impl SigDataType { Self::Any => true, Self::AnyArray => dt.is_array(), Self::AnyStruct => dt.is_struct(), + Self::AnyMap => dt.is_map(), } } diff --git a/src/expr/impl/src/scalar/array.rs b/src/expr/impl/src/scalar/array.rs index aaefd17bba07d..b1c00382041b9 100644 --- a/src/expr/impl/src/scalar/array.rs +++ b/src/expr/impl/src/scalar/array.rs @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use itertools::Itertools; use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::row::Row; -use risingwave_common::types::ToOwnedDatum; +use risingwave_common::types::{ListRef, MapType, MapValue, ToOwnedDatum}; use risingwave_expr::expr::Context; -use risingwave_expr::function; +use risingwave_expr::{function, ExprError}; #[function("array(...) -> anyarray", type_infer = "panic")] fn array(row: impl Row, ctx: &Context) -> ListValue { @@ -28,6 +29,36 @@ fn row_(row: impl Row) -> StructValue { StructValue::new(row.iter().map(|d| d.to_owned_datum()).collect()) } +/// # Example +/// +/// ```slt +/// query T +/// select map_from_entries(null::int[], [1,2,3]); +/// ---- +/// null +/// ``` +#[function( + "map_from_entries(anyarray, anyarray) -> anymap", + type_infer = "|args| Ok(MapType::from_kv(args[0].as_list().clone(), args[1].as_list().clone()).into())" +)] +fn map(key: ListRef<'_>, value: ListRef<'_>) -> Result { + // TODO: restrict key's type (where? in the macro?) + if key.len() != value.len() { + return Err(ExprError::InvalidParam { + name: "key", + reason: "Map keys and values have different length".into(), + }); + } + if key.iter().duplicates().next().is_some() { + return Err(ExprError::InvalidParam { + name: "key", + reason: "Map keys must be unique".into(), + }); + } + + Ok(MapValue::from_kv(key.to_owned(), value.to_owned())) +} + #[cfg(test)] mod tests { use risingwave_common::array::DataChunk; diff --git a/src/expr/impl/src/scalar/to_jsonb.rs b/src/expr/impl/src/scalar/to_jsonb.rs index bb381954cc76b..c11d4474dc43b 100644 --- a/src/expr/impl/src/scalar/to_jsonb.rs +++ b/src/expr/impl/src/scalar/to_jsonb.rs @@ -16,8 +16,8 @@ use std::fmt::Debug; use jsonbb::Builder; use risingwave_common::types::{ - DataType, Date, Decimal, Int256Ref, Interval, JsonbRef, JsonbVal, ListRef, ScalarRefImpl, - Serial, StructRef, Time, Timestamp, Timestamptz, ToText, F32, F64, + DataType, Date, Decimal, Int256Ref, Interval, JsonbRef, JsonbVal, ListRef, MapRef, + ScalarRefImpl, Serial, StructRef, Time, Timestamp, Timestamptz, ToText, F32, F64, }; use risingwave_common::util::iter_util::ZipEqDebug; use risingwave_expr::expr::Context; @@ -72,6 +72,7 @@ impl ToJsonb for ScalarRefImpl<'_> { Timestamptz(v) => v.add_to(ty, builder), Struct(v) => v.add_to(ty, builder), List(v) => v.add_to(ty, builder), + Map(v) => v.add_to(ty, builder), } } } @@ -227,6 +228,20 @@ impl ToJsonb for ListRef<'_> { } } +impl ToJsonb for MapRef<'_> { + fn add_to(self, data_type: &DataType, builder: &mut Builder) -> Result<()> { + let value_type = data_type.as_map().value(); + builder.begin_object(); + for (k, v) in self.iter() { + // XXX: is to_text here reasonable? + builder.add_string(&k.to_text()); + v.add_to(value_type, builder)?; + } + builder.end_object(); + Ok(()) + } +} + impl ToJsonb for StructRef<'_> { fn add_to(self, data_type: &DataType, builder: &mut Builder) -> Result<()> { builder.begin_object(); diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index 3494f52406193..3a82d5dd73025 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -83,9 +83,10 @@ impl FunctionAttr { attrs } - /// Generate the type infer function. + /// Generate the type infer function: `fn(&[DataType]) -> Result` fn generate_type_infer_fn(&self) -> Result { if let Some(func) = &self.type_infer { + // XXX: should this be called "placeholder" or "unreachable"? if func == "panic" { return Ok(quote! { |_| panic!("type inference function is not implemented") }); } @@ -115,6 +116,11 @@ impl FunctionAttr { // infer as the type of "struct" argument return Ok(quote! { |args| Ok(args[#i].clone()) }); } + } else if self.ret == "anymap" { + if let Some(i) = self.args.iter().position(|t| t == "anymap") { + // infer as the type of "anymap" argument + return Ok(quote! { |args| Ok(args[#i].clone()) }); + } } else { // the return type is fixed let ty = data_type(&self.ret); @@ -122,13 +128,17 @@ impl FunctionAttr { } Err(Error::new( Span::call_site(), - "type inference function is required", + "type inference function cannot be automatically derived. You should provide: `type_infer = \"|args| Ok(...)\"`", )) } - /// Generate a descriptor of the scalar or table function. + /// Generate a descriptor (`FuncSign`) of the scalar or table function. /// /// The types of arguments and return value should not contain wildcard. + /// + /// # Arguments + /// `build_fn`: whether the user provided a function is a build function. + /// (from the `#[build_function]` macro) pub fn generate_function_descriptor( &self, user_fn: &UserFunctionAttr, @@ -156,6 +166,7 @@ impl FunctionAttr { } else if self.rewritten { quote! { |_, _| Err(ExprError::UnsupportedFunction(#name.into())) } } else { + // This is the core logic for `#[function]` self.generate_build_scalar_function(user_fn, true)? }; let type_infer_fn = self.generate_type_infer_fn()?; @@ -1302,6 +1313,7 @@ fn sig_data_type(ty: &str) -> TokenStream2 { match ty { "any" => quote! { SigDataType::Any }, "anyarray" => quote! { SigDataType::AnyArray }, + "anymap" => quote! { SigDataType::AnyMap }, "struct" => quote! { SigDataType::AnyStruct }, _ if ty.starts_with("struct") && ty.contains("any") => quote! { SigDataType::AnyStruct }, _ => { @@ -1320,6 +1332,12 @@ fn data_type(ty: &str) -> TokenStream2 { return quote! { DataType::Struct(#ty.parse().expect("invalid struct type")) }; } let variant = format_ident!("{}", types::data_type(ty)); + // TODO: enable the check + // assert!( + // !matches!(ty, "any" | "anyarray" | "anymap" | "struct"), + // "{ty}, {variant}" + // ); + quote! { DataType::#variant } } diff --git a/src/expr/macro/src/lib.rs b/src/expr/macro/src/lib.rs index 3a905165c2ee2..8fd03e344db89 100644 --- a/src/expr/macro/src/lib.rs +++ b/src/expr/macro/src/lib.rs @@ -30,7 +30,7 @@ mod utils; /// Defining the RisingWave SQL function from a Rust function. /// -/// [Online version of this doc.](https://risingwavelabs.github.io/risingwave/risingwave_expr_macro/attr.function.html) +/// [Online version of this doc.](https://risingwavelabs.github.io/risingwave/rustdoc/risingwave_expr_macro/attr.function.html) /// /// # Table of Contents /// @@ -70,8 +70,8 @@ mod utils; /// name ( [arg_types],* [...] ) [ -> [setof] return_type ] /// ``` /// -/// Where `name` is the function name in `snake_case`, which must match the function name defined -/// in `prost`. +/// Where `name` is the function name in `snake_case`, which must match the function name (in `UPPER_CASE`) defined +/// in `proto/expr.proto`. /// /// `arg_types` is a comma-separated list of argument types. The allowed data types are listed in /// in the `name` column of the appendix's [type matrix]. Wildcards or `auto` can also be used, as @@ -98,7 +98,7 @@ mod utils; /// } /// ``` /// -/// ## Type Expansion +/// ## Type Expansion with `*` /// /// Types can be automatically expanded to multiple types using wildcards. Here are some examples: /// @@ -115,13 +115,17 @@ mod utils; /// #[function("cast(varchar) -> int64")] /// ``` /// -/// Please note the difference between `*` and `any`. `*` will generate a function for each type, +/// Please note the difference between `*` and `any`: `*` will generate a function for each type, /// whereas `any` will only generate one function with a dynamic data type `Scalar`. +/// This is similar to `impl T` and `dyn T` in Rust. The performance of using `*` would be much better than `any`. +/// But we do not always prefer `*` due to better performance. In some cases, using `any` is more convenient. +/// For example, in array functions, the element type of `ListValue` is `Scalar(Ref)Impl`. +/// It is unnecessary to convert it from/into various `T`. /// -/// ## Automatic Type Inference +/// ## Automatic Type Inference with `auto` /// /// Correspondingly, the return type can be denoted as `auto` to be automatically inferred based on -/// the input types. It will be inferred as the smallest type that can accommodate all input types. +/// the input types. It will be inferred as the _smallest type_ that can accommodate all input types. /// /// For example, `#[function("add(*int, *int) -> auto")]` will be expanded to: /// @@ -142,10 +146,10 @@ mod utils; /// #[function("neg(int64) -> int64")] /// ``` /// -/// ## Custom Type Inference Function +/// ## Custom Type Inference Function with `type_infer` /// /// A few functions might have a return type that dynamically changes based on the input argument -/// types, such as `unnest`. +/// types, such as `unnest`. This is mainly for composite types like `anyarray`, `struct`, and `anymap`. /// /// In such cases, the `type_infer` option can be used to specify a function to infer the return /// type based on the input argument types. Its function signature is @@ -163,7 +167,7 @@ mod utils; /// )] /// ``` /// -/// This type inference function will be invoked at the frontend. +/// This type inference function will be invoked at the frontend (`infer_type_with_sigmap`). /// /// # Rust Function Signature /// @@ -182,8 +186,9 @@ mod utils; /// /// ## Nullable Arguments /// -/// The functions above will only be called when all arguments are not null. If null arguments need -/// to be considered, the `Option` type can be used: +/// The functions above will only be called when all arguments are not null. +/// It will return null if any argument is null. +/// If null arguments need to be considered, the `Option` type can be used: /// /// ```ignore /// #[function("trim_array(anyarray, int32) -> anyarray")] @@ -192,11 +197,11 @@ mod utils; /// /// This function will be called when `n` is null, but not when `array` is null. /// -/// ## Return Value +/// ## Return `NULL`s and Errors /// /// Similarly, the return value type can be one of the following: /// -/// - `T`: Indicates that a non-null value is always returned, and errors will not occur. +/// - `T`: Indicates that a non-null value is always returned (for non-null inputs), and errors will not occur. /// - `Option`: Indicates that a null value may be returned, but errors will not occur. /// - `Result`: Indicates that an error may occur, but a null value will not be returned. /// - `Result>`: Indicates that a null value may be returned, and an error may also occur. @@ -419,6 +424,16 @@ pub fn function(attr: TokenStream, item: TokenStream) -> TokenStream { } } +/// Different from `#[function]`, which implements the `Expression` trait for a rust scalar function, +/// `#[build_function]` is used when you already implemented `Expression` manually. +/// +/// The expected input is a "build" function: +/// ```ignore +/// fn(data_type: DataType, children: Vec) -> Result +/// ``` +/// +/// It generates the function descriptor using the "build" function and +/// registers the description to the `FUNC_SIG_MAP`. #[proc_macro_attribute] pub fn build_function(attr: TokenStream, item: TokenStream) -> TokenStream { fn inner(attr: TokenStream, item: TokenStream) -> Result { diff --git a/src/expr/macro/src/types.rs b/src/expr/macro/src/types.rs index f2219a1c34bd6..4f07162d038a0 100644 --- a/src/expr/macro/src/types.rs +++ b/src/expr/macro/src/types.rs @@ -35,6 +35,7 @@ const TYPE_MATRIX: &str = " jsonb Jsonb JsonbArray JsonbVal JsonbRef<'_> _ anyarray List ListArray ListValue ListRef<'_> _ struct Struct StructArray StructValue StructRef<'_> _ + anymap Map MapArray MapValue MapRef<'_> _ any ??? ArrayImpl ScalarImpl ScalarRefImpl<'_> _ "; @@ -81,7 +82,7 @@ fn lookup_matrix(mut ty: &str, idx: usize) -> &str { None } }); - s.unwrap_or_else(|| panic!("unknown type: {}", ty)) + s.unwrap_or_else(|| panic!("failed to lookup type matrix: unknown type: {}", ty)) } /// Expands a type wildcard string into a list of concrete types. diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 897b43a2f3669..ee1858f22c71a 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -923,6 +923,9 @@ impl Binder { ) } + // XXX: can we unify this with FUNC_SIG_MAP? + // For raw_call here, it seems unnecessary to declare it again here. + // For some functions, we have validation logic here. Is it still useful now? static HANDLES: LazyLock> = LazyLock::new(|| { [ ( @@ -1180,6 +1183,8 @@ impl Binder { ("jsonb_path_query_array", raw_call(ExprType::JsonbPathQueryArray)), ("jsonb_path_query_first", raw_call(ExprType::JsonbPathQueryFirst)), ("jsonb_set", raw_call(ExprType::JsonbSet)), + // map + ("map_from_entries", raw_call(ExprType::MapFromEntries)), // Functions that return a constant value ("pi", pi()), // greatest and least @@ -1485,6 +1490,7 @@ impl Binder { return Ok(FunctionCall::new(func, inputs)?.into()); } + // Note: for raw_call, we only check name here. The type check is done later. match HANDLES.get(function_name) { Some(handle) => handle(self, inputs), None => { diff --git a/src/frontend/src/binder/expr/value.rs b/src/frontend/src/binder/expr/value.rs index e1fc78e884e02..5b69610f13bfe 100644 --- a/src/frontend/src/binder/expr/value.rs +++ b/src/frontend/src/binder/expr/value.rs @@ -212,7 +212,7 @@ impl Binder { .map(|e| self.bind_expr_inner(e)) .collect::>>()?; let data_type = - DataType::new_struct(exprs.iter().map(|e| e.return_type()).collect_vec(), vec![]); + DataType::new_unnamed_struct(exprs.iter().map(|e| e.return_type()).collect_vec()); let expr: ExprImpl = FunctionCall::new_unchecked(ExprType::Row, exprs, data_type).into(); Ok(expr) } diff --git a/src/frontend/src/expr/literal.rs b/src/frontend/src/expr/literal.rs index d44a1b859d289..8a1c9fe73d754 100644 --- a/src/frontend/src/expr/literal.rs +++ b/src/frontend/src/expr/literal.rs @@ -60,6 +60,7 @@ impl std::fmt::Debug for Literal { v.as_scalar_ref_impl().to_text_with_type(&data_type) ), DataType::List { .. } => write!(f, "{}", v.as_list().display_for_explain()), + DataType::Map(_) => todo!(), }, }?; write!(f, ":{:?}", data_type) diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 1882e70c12b93..6af8ab6cdd4d0 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -248,7 +248,8 @@ impl ExprVisitor for ImpureAnalyzer { | Type::InetNtoa | Type::InetAton | Type::QuoteLiteral - | Type::QuoteNullable => + | Type::QuoteNullable + | Type::MapFromEntries => // expression output is deterministic(same result for the same input) { func_call diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 0fecee8ab45c0..9cb22d87264a1 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -730,6 +730,8 @@ pub fn infer_type_name<'a>( }; if candidates.is_empty() { + // TODO: when type mismatches, show what are supported signatures for the + // function with the given name. bail_no_function!("{}", sig()); } diff --git a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs index ed24ba75b524a..b9395e1b8f32e 100644 --- a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs +++ b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs @@ -290,6 +290,7 @@ impl Strong { | ExprType::JsonbPopulateRecord | ExprType::JsonbToRecord | ExprType::JsonbSet + | ExprType::MapFromEntries | ExprType::Vnode | ExprType::TestPaidTier | ExprType::Proctime diff --git a/src/frontend/src/optimizer/rule/index_selection_rule.rs b/src/frontend/src/optimizer/rule/index_selection_rule.rs index e65b249379750..548fda7b92af4 100644 --- a/src/frontend/src/optimizer/rule/index_selection_rule.rs +++ b/src/frontend/src/optimizer/rule/index_selection_rule.rs @@ -746,7 +746,7 @@ impl<'a> TableScanIoEstimator<'a> { .sum::() } - pub fn estimate_data_type_size(data_type: &DataType) -> usize { + fn estimate_data_type_size(data_type: &DataType) -> usize { use std::mem::size_of; match data_type { @@ -769,6 +769,7 @@ impl<'a> TableScanIoEstimator<'a> { DataType::Jsonb => 20, DataType::Struct { .. } => 20, DataType::List { .. } => 20, + DataType::Map(_) => 20, } } diff --git a/src/tests/sqlsmith/src/sql_gen/types.rs b/src/tests/sqlsmith/src/sql_gen/types.rs index 141accc71abc2..62e5d6f63f642 100644 --- a/src/tests/sqlsmith/src/sql_gen/types.rs +++ b/src/tests/sqlsmith/src/sql_gen/types.rs @@ -53,6 +53,7 @@ pub(super) fn data_type_to_ast_data_type(data_type: &DataType) -> AstDataType { .collect(), ), DataType::List(ref typ) => AstDataType::Array(Box::new(data_type_to_ast_data_type(typ))), + DataType::Map(_) => todo!(), } }