Skip to content

Commit

Permalink
add map_length/contains/cat/insert/delete
Browse files Browse the repository at this point in the history
  • Loading branch information
xxchan committed Aug 18, 2024
1 parent b92834f commit 0e6017f
Show file tree
Hide file tree
Showing 10 changed files with 240 additions and 43 deletions.
5 changes: 5 additions & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,11 @@ message ExprNode {
MAP_VALUES = 703;
MAP_ENTRIES = 704;
MAP_FROM_KEY_VALUES = 705;
MAP_LENGTH = 706;
MAP_CONTAINS = 707;
MAP_CAT = 708;
MAP_INSERT = 709;
MAP_DELETE = 710;

// Non-pure functions below (> 1000)
// ------------------------
Expand Down
4 changes: 2 additions & 2 deletions src/common/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,11 @@ impl ListValue {

/// Creates a new `ListValue` from an iterator of `Datum`.
pub fn from_datum_iter<T: ToDatumRef>(
datatype: &DataType,
elem_datatype: &DataType,
iter: impl IntoIterator<Item = T>,
) -> Self {
let iter = iter.into_iter();
let mut builder = datatype.create_array_builder(iter.size_hint().0);
let mut builder = elem_datatype.create_array_builder(iter.size_hint().0);
for datum in iter {
builder.append(datum);
}
Expand Down
59 changes: 59 additions & 0 deletions src/common/src/array/map_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use serde::Serializer;
use super::{
Array, ArrayBuilder, ArrayImpl, ArrayResult, DatumRef, DefaultOrdered, ListArray,
ListArrayBuilder, ListRef, ListValue, MapType, ScalarRef, ScalarRefImpl, StructArray,
StructRef,
};
use crate::bitmap::Bitmap;
use crate::types::{DataType, Scalar, ToText};
Expand Down Expand Up @@ -196,6 +197,7 @@ mod scalar {
use std::collections::HashSet;

use super::*;
use crate::array::{Datum, ScalarImpl, StructValue};

/// Refer to [`MapArray`] for the invariants of a map value.
#[derive(Clone, Eq, EstimateSize)]
Expand Down Expand Up @@ -267,6 +269,46 @@ mod scalar {
);
Ok(MapValue(ListValue::new(struct_array.into())))
}

/// # Panics
/// Panics if `m1` and `m2` have different types.
pub fn concat(m1: MapRef<'_>, m2: MapRef<'_>) -> Self {
debug_assert_eq!(m1.inner().data_type(), m2.inner().data_type());
let m2_keys = m2.keys();
let l = ListValue::from_datum_iter(
&m1.inner().data_type(),
m1.iter_struct()
.filter(|s| !m2_keys.contains(&s.field_at(0).expect("map key is not null")))
.chain(m2.iter_struct())
.map(|s| Some(ScalarRefImpl::Struct(s))),
);
Self::from_entries(l)
}

pub fn insert(m: MapRef<'_>, key: ScalarImpl, value: Datum) -> Self {
let l = ListValue::from_datum_iter(
&m.inner().data_type(),
m.iter_struct()
.filter(|s| {
key.as_scalar_ref_impl() != s.field_at(0).expect("map key is not null")
})
.chain(std::iter::once(
StructValue::new(vec![Some(key.clone()), value]).as_scalar_ref(),
))
.map(|s| Some(ScalarRefImpl::Struct(s))),
);
Self::from_entries(l)
}

pub fn delete(m: MapRef<'_>, key: ScalarRefImpl<'_>) -> Self {
let l = ListValue::from_datum_iter(
&m.inner().data_type(),
m.iter_struct()
.filter(|s| key != s.field_at(0).expect("map key is not null"))
.map(|s| Some(ScalarRefImpl::Struct(s))),
);
Self::from_entries(l)
}
}

impl<'a> MapRef<'a> {
Expand All @@ -287,6 +329,14 @@ mod scalar {
pub fn into_kv(self) -> (ListRef<'a>, ListRef<'a>) {
self.0.as_map_kv()
}

pub fn keys(&self) -> HashSet<ScalarRefImpl<'_>> {
self.iter().map(|(k, _v)| k).collect()
}

pub fn to_owned(self) -> MapValue {
MapValue(self.0.to_owned())
}
}

impl Scalar for MapValue {
Expand Down Expand Up @@ -394,6 +444,15 @@ impl<'a> MapRef<'a> {
})
}

pub fn iter_struct(
self,
) -> impl DoubleEndedIterator + ExactSizeIterator<Item = StructRef<'a>> + 'a {
self.inner().iter().map(|list_elem| {
let list_elem = list_elem.expect("the list element in map should not be null");
list_elem.into_struct()
})
}

pub fn iter_sorted(
self,
) -> impl DoubleEndedIterator + ExactSizeIterator<Item = (ScalarRefImpl<'a>, DatumRef<'a>)> + 'a
Expand Down
9 changes: 9 additions & 0 deletions src/common/src/array/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,15 @@ impl<'a> StructRef<'a> {
iter_fields_ref!(self, it, { Either::Left(it) }, { Either::Right(it) })
}

/// # Panics
/// Panics if the index is out of bounds.
pub fn field_at(&self, i: usize) -> DatumRef<'a> {
match self {
StructRef::Indexed { arr, idx } => arr.field_at(i).value_at(*idx),
StructRef::ValueRef { val } => val.fields[i].to_datum_ref(),
}
}

pub fn memcmp_serialize(
self,
serializer: &mut memcomparable::Serializer<impl BufMut>,
Expand Down
138 changes: 100 additions & 38 deletions src/expr/impl/src/scalar/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,37 +76,37 @@ fn map_from_entries(entries: ListRef<'_>) -> Result<MapValue, ExprError> {
///
/// ```slt
/// query T
/// select map_access(map(array[1,2,3], array[100,200,300]), 3);
/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), 3);
/// ----
/// 300
///
/// query T
/// select map_access(map(array[1,2,3], array[100,200,300]), '3');
/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), '3');
/// ----
/// 300
///
/// query error
/// select map_access(map(array[1,2,3], array[100,200,300]), 1.0);
/// select map_access(map_from_key_values(array[1,2,3], array[100,200,300]), 1.0);
/// ----
/// db error: ERROR: Failed to run the query
///
/// Caused by these errors (recent errors listed first):
/// 1: Failed to bind expression: map_access(map(ARRAY[1, 2, 3], ARRAY[100, 200, 300]), 1.0)
/// 2: Bind error: Cannot access numeric in map_from_key_values(integer,integer)
/// 1: Failed to bind expression: map_access(map_from_key_values(ARRAY[1, 2, 3], ARRAY[100, 200, 300]), 1.0)
/// 2: Bind error: Cannot access numeric in map(integer,integer)
///
///
/// query T
/// select map_access(map(array['a','b','c'], array[1,2,3]), 'a');
/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), 'a');
/// ----
/// 1
///
/// query T
/// select map_access(map(array['a','b','c'], array[1,2,3]), 'd');
/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), 'd');
/// ----
/// NULL
///
/// query T
/// select map_access(map(array['a','b','c'], array[1,2,3]), null);
/// select map_access(map_from_key_values(array['a','b','c'], array[1,2,3]), null);
/// ----
/// NULL
/// ```
Expand All @@ -125,58 +125,120 @@ fn map_access<'a>(
}
}

fn map_keys_type_infer(args: &[DataType]) -> Result<DataType, ExprError> {
Ok(DataType::List(args[0].as_map().key().clone().into()))
}

fn map_values_type_infer(args: &[DataType]) -> Result<DataType, ExprError> {
Ok(DataType::List(args[0].as_map().value().clone().into()))
/// ```slt
/// query T
/// select
/// map_contains(MAP{1:1}, 1),
/// map_contains(MAP{1:1}, 2),
/// map_contains(MAP{1:1}, NULL::varchar),
/// map_contains(MAP{1:1}, 1.0)
/// ----
/// t f NULL f
/// ```
#[function("map_contains(anymap, any) -> boolean")]
fn map_contains(map: MapRef<'_>, key: ScalarRefImpl<'_>) -> Result<bool, ExprError> {
let (keys, _values) = map.into_kv();
let idx = array_position(keys, Some(key))?;
Ok(idx.is_some())
}

fn map_entries_type_infer(args: &[DataType]) -> Result<DataType, ExprError> {
Ok(args[0].as_map().clone().into_list())
/// ```slt
/// query I
/// select
/// map_length(NULL::map(int,int)),
/// map_length(MAP {}::map(int,int)),
/// map_length(MAP {1:1,2:2}::map(int,int))
/// ----
/// NULL 0 2
/// ```
#[function("map_length(anymap) -> int4")]
fn map_length<T: TryFrom<usize>>(map: MapRef<'_>) -> Result<T, ExprError> {
map.inner()
.len()
.try_into()
.map_err(|_| ExprError::NumericOverflow)
}

/// # Example
/// If both `m1` and `m2` have a value with the same key, then the output map contains the value from `m2`.
///
/// ```slt
/// query I
/// select map_keys(map{'a':1, 'b':2})
/// query T
/// select map_cat(MAP{'a':1,'b':2},null::map(varchar,int));
/// ----
/// {a:1,b:2}
///
/// query T
/// select map_cat(MAP{'a':1,'b':2},MAP{'b':3,'c':4});
/// ----
/// {a:1,b:3,c:4}
///
/// # implicit type cast
/// query T
/// select map_cat(MAP{'a':1,'b':2},MAP{'b':3.0,'c':4.0});
/// ----
/// {a,b}
/// {a:1,b:3.0,c:4.0}
/// ```
#[function("map_keys(anymap) -> anyarray", type_infer = "map_keys_type_infer")]
fn map_keys(map: MapRef<'_>) -> ListRef<'_> {
map.into_kv().0
#[function("map_cat(anymap, anymap) -> anymap")]
fn map_cat(m1: Option<MapRef<'_>>, m2: Option<MapRef<'_>>) -> Result<Option<MapValue>, ExprError> {
match (m1, m2) {
(None, None) => Ok(None),
(Some(m), None) | (None, Some(m)) => Ok(Some(m.to_owned())),
(Some(m1), Some(m2)) => Ok(Some(MapValue::concat(m1, m2))),
}
}

/// Inserts a key-value pair into the map. If the key already exists, the value is updated.
///
/// # Example
///
/// ```slt
/// query I
/// select map_values(map{'a':1, 'b':2})
/// query T
/// select map_insert(map{'a':1, 'b':2}, 'c', 3);
/// ----
/// {a:1,b:2,c:3}
///
/// query T
/// select map_insert(map{'a':1, 'b':2}, 'b', 4);
/// ----
/// {1,2}
/// {a:1,b:4}
/// ```
#[function("map_values(anymap) -> anyarray", type_infer = "map_values_type_infer")]
fn map_values(map: MapRef<'_>) -> ListRef<'_> {
map.into_kv().1
///
/// TODO: support variadic arguments
#[function("map_insert(anymap, any, any) -> anymap")]
fn map_insert(
map: MapRef<'_>,
key: Option<ScalarRefImpl<'_>>,
value: Option<ScalarRefImpl<'_>>,
) -> MapValue {
let Some(key) = key else {
return map.to_owned();
};
MapValue::insert(map, key.into_scalar_impl(), value.to_owned_datum())
}

/// Deletes a key-value pair from the map.
///
/// # Example
///
/// ```slt
/// query I
/// select map_entries(map{'a':1, 'b':2})
/// query T
/// select map_delete(map{'a':1, 'b':2, 'c':3}, 'b');
/// ----
/// {a:1,c:3}
///
/// query T
/// select map_delete(map{'a':1, 'b':2, 'c':3}, 'd');
/// ----
/// {"(a,1)","(b,2)"}
/// {a:1,b:2,c:3}
/// ```
#[function(
"map_entries(anymap) -> anyarray",
type_infer = "map_entries_type_infer"
)]
fn map_entries(map: MapRef<'_>) -> ListRef<'_> {
map.into_inner()
///
/// TODO: support variadic arguments
#[function("map_delete(anymap, any) -> anymap")]
fn map_delete(map: MapRef<'_>, key: Option<ScalarRefImpl<'_>>) -> MapValue {
let Some(key) = key else {
return map.to_owned();
};
MapValue::delete(map, key)
}

#[cfg(test)]
Expand Down
5 changes: 5 additions & 0 deletions src/frontend/src/binder/expr/function/builtin_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,11 @@ impl Binder {
("map_values", raw_call(ExprType::MapValues)),
("map_entries", raw_call(ExprType::MapEntries)),
("map_from_key_values", raw_call(ExprType::MapFromKeyValues)),
("map_cat", raw_call(ExprType::MapCat)),
("map_contains", raw_call(ExprType::MapContains)),
("map_delete", raw_call(ExprType::MapDelete)),
("map_insert", raw_call(ExprType::MapInsert)),
("map_length", raw_call(ExprType::MapLength)),
// Functions that return a constant value
("pi", pi()),
// greatest and least
Expand Down
15 changes: 14 additions & 1 deletion src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use fixedbitset::FixedBitSet;
use futures::FutureExt;
use paste::paste;
use risingwave_common::array::ListValue;
use risingwave_common::types::{DataType, Datum, JsonbVal, Scalar, ScalarImpl};
use risingwave_common::types::{DataType, Datum, JsonbVal, MapType, Scalar, ScalarImpl};
use risingwave_expr::aggregate::PbAggKind;
use risingwave_expr::expr::build_from_prost;
use risingwave_pb::expr::expr_node::RexNode;
Expand Down Expand Up @@ -324,6 +324,19 @@ impl ExprImpl {
}
}

/// Ensure the return type of this expression is a map of some type.
pub fn try_into_map_type(&self) -> Result<MapType, ErrorCode> {
if self.is_untyped() {
return Err(ErrorCode::BindError(
"could not determine polymorphic type because input has type unknown".into(),
));
}
match self.return_type() {
DataType::Map(m) => Ok(m),
t => Err(ErrorCode::BindError(format!("expects map but got {t}"))),
}
}

/// Shorthand to enforce implicit cast to boolean
pub fn enforce_bool_clause(self, clause: &str) -> RwResult<ExprImpl> {
if self.is_untyped() {
Expand Down
7 changes: 6 additions & 1 deletion src/frontend/src/expr/pure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,12 @@ impl ExprVisitor for ImpureAnalyzer {
| Type::MapKeys
| Type::MapValues
| Type::MapEntries
| Type::MapFromKeyValues =>
| Type::MapFromKeyValues
| Type::MapCat
| Type::MapContains
| Type::MapDelete
| Type::MapInsert
| Type::MapLength =>
// expression output is deterministic(same result for the same input)
{
func_call
Expand Down
Loading

0 comments on commit 0e6017f

Please sign in to comment.