Skip to content

Commit

Permalink
feat: support map_access (#17986)
Browse files Browse the repository at this point in the history
Signed-off-by: xxchan <[email protected]>
  • Loading branch information
xxchan authored Aug 14, 2024
1 parent 8a1ffc6 commit e1de185
Show file tree
Hide file tree
Showing 22 changed files with 149 additions and 27 deletions.
6 changes: 6 additions & 0 deletions e2e_test/batch/types/map.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,11 @@ select * from t;
{"a":1,"b":2,"c":3} NULL NULL NULL NULL
{"a":1,"b":2,"c":3} {"1":t,"2":f,"3":t} {"a":{"a1":a2},"b":{"b1":b2}} {"{\"a\":1,\"b\":2,\"c\":3}","{\"d\":4,\"e\":5,\"f\":6}"} ("{""a"":(1),""b"":(2),""c"":(3)}")

query ????? rowsort
select to_jsonb(m1), to_jsonb(m2), to_jsonb(m3), to_jsonb(l), to_jsonb(s) from t;
----
{"a": 1.0, "b": 2.0, "c": 3.0} null null null null
{"a": 1.0, "b": 2.0, "c": 3.0} {"1": true, "2": false, "3": true} {"a": {"a1": "a2"}, "b": {"b1": "b2"}} [{"a": 1, "b": 2, "c": 3}, {"d": 4, "e": 5, "f": 6}] {"m": {"a": {"x": 1}, "b": {"x": 2}, "c": {"x": 3}}}

statement ok
drop table t;
1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ message ExprNode {

// Map functions
MAP_FROM_ENTRIES = 700;
MAP_ACCESS = 701;

// Non-pure functions below (> 1000)
// ------------------------
Expand Down
18 changes: 18 additions & 0 deletions src/common/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,24 @@ impl<'a> ListRef<'a> {
_ => None,
}
}

/// # Panics
/// Panics if the list is not a map's internal representation (See [`super::MapArray`]).
pub(super) fn as_map_kv(self) -> (ListRef<'a>, ListRef<'a>) {
let (k, v) = self.array.as_struct().fields().collect_tuple().unwrap();
(
ListRef {
array: k,
start: self.start,
end: self.end,
},
ListRef {
array: v,
start: self.start,
end: self.end,
},
)
}
}

impl PartialEq for ListRef<'_> {
Expand Down
4 changes: 4 additions & 0 deletions src/common/src/array/map_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ mod scalar {
pub fn into_inner(self) -> ListRef<'a> {
self.0
}

pub fn into_kv(self) -> (ListRef<'a>, ListRef<'a>) {
self.0.as_map_kv()
}
}

impl Scalar for MapValue {
Expand Down
63 changes: 60 additions & 3 deletions src/expr/impl/src/scalar/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,20 @@

use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::row::Row;
use risingwave_common::types::{DataType, ListRef, MapType, MapValue, ToOwnedDatum};
use risingwave_common::types::{
DataType, ListRef, MapRef, MapType, MapValue, ScalarRefImpl, ToOwnedDatum,
};
use risingwave_expr::expr::Context;
use risingwave_expr::{function, ExprError};

#[function("array(...) -> anyarray", type_infer = "panic")]
use super::array_positions::array_position;

#[function("array(...) -> anyarray", type_infer = "unreachable")]
fn array(row: impl Row, ctx: &Context) -> ListValue {
ListValue::from_datum_iter(ctx.return_type.as_list(), row.iter())
}

#[function("row(...) -> struct", type_infer = "panic")]
#[function("row(...) -> struct", type_infer = "unreachable")]
fn row_(row: impl Row) -> StructValue {
StructValue::new(row.iter().map(|d| d.to_owned_datum()).collect())
}
Expand Down Expand Up @@ -54,6 +58,59 @@ fn map(key: ListRef<'_>, value: ListRef<'_>) -> Result<MapValue, ExprError> {
MapValue::try_from_kv(key.to_owned(), value.to_owned()).map_err(ExprError::Custom)
}

/// # Example
///
/// ```slt
/// query T
/// select map_access(map_from_entries(array[1,2,3], array[100,200,300]), 3);
/// ----
/// 300
///
/// query T
/// select map_access(map_from_entries(array[1,2,3], array[100,200,300]), '3');
/// ----
/// 300
///
/// query error
/// select map_access(map_from_entries(array[1,2,3], array[100,200,300]), 1.0);
/// ----
/// db error: ERROR: Failed to run the query
///
/// Caused by these errors (recent errors listed first):
/// 1: Failed to bind expression: map_access(map_from_entries(ARRAY[1, 2, 3], ARRAY[100, 200, 300]), 1.0)
/// 2: Bind error: Cannot access numeric in map(integer,integer)
///
///
/// query T
/// select map_access(map_from_entries(array['a','b','c'], array[1,2,3]), 'a');
/// ----
/// 1
///
/// query T
/// select map_access(map_from_entries(array['a','b','c'], array[1,2,3]), 'd');
/// ----
/// NULL
///
/// query T
/// select map_access(map_from_entries(array['a','b','c'], array[1,2,3]), null);
/// ----
/// NULL
/// ```
#[function("map_access(anymap, any) -> any")]
fn map_access<'a>(
map: MapRef<'a>,
key: ScalarRefImpl<'_>,
) -> Result<Option<ScalarRefImpl<'a>>, ExprError> {
// FIXME: DatumRef in return value is not support by the macro yet.

let (keys, values) = map.into_kv();
let idx = array_position(keys, Some(key))?;
match idx {
Some(idx) => Ok(values.get((idx - 1) as usize).unwrap()),
None => Ok(None),
}
}

#[cfg(test)]
mod tests {
use risingwave_common::array::DataChunk;
Expand Down
5 changes: 4 additions & 1 deletion src/expr/impl/src/scalar/array_positions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ use risingwave_expr::{function, ExprError, Result};
/// 2
/// ```
#[function("array_position(anyarray, any) -> int4")]
fn array_position(array: ListRef<'_>, element: Option<ScalarRefImpl<'_>>) -> Result<Option<i32>> {
pub(super) fn array_position(
array: ListRef<'_>,
element: Option<ScalarRefImpl<'_>>,
) -> Result<Option<i32>> {
array_position_common(array, element, 0)
}

Expand Down
4 changes: 2 additions & 2 deletions src/expr/impl/src/scalar/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl Expression for ConstantLookupExpression {
}
}

#[build_function("constant_lookup(...) -> any", type_infer = "panic")]
#[build_function("constant_lookup(...) -> any", type_infer = "unreachable")]
fn build_constant_lookup_expr(
return_type: DataType,
children: Vec<BoxedExpression>,
Expand Down Expand Up @@ -249,7 +249,7 @@ fn build_constant_lookup_expr(
)))
}

#[build_function("case(...) -> any", type_infer = "panic")]
#[build_function("case(...) -> any", type_infer = "unreachable")]
fn build_case_expr(
return_type: DataType,
children: Vec<BoxedExpression>,
Expand Down
8 changes: 4 additions & 4 deletions src/expr/impl/src/scalar/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,13 @@ pub fn str_to_bytea(elem: &str) -> Result<Box<[u8]>> {
cast::str_to_bytea(elem).map_err(|err| ExprError::Parse(err.into()))
}

#[function("cast(varchar) -> anyarray", type_infer = "panic")]
#[function("cast(varchar) -> anyarray", type_infer = "unreachable")]
fn str_to_list(input: &str, ctx: &Context) -> Result<ListValue> {
ListValue::from_str(input, &ctx.return_type).map_err(|err| ExprError::Parse(err.into()))
}

/// Cast array with `source_elem_type` into array with `target_elem_type` by casting each element.
#[function("cast(anyarray) -> anyarray", type_infer = "panic")]
#[function("cast(anyarray) -> anyarray", type_infer = "unreachable")]
fn list_cast(input: ListRef<'_>, ctx: &Context) -> Result<ListValue> {
let cast = build_func(
PbType::Cast,
Expand All @@ -213,7 +213,7 @@ fn list_cast(input: ListRef<'_>, ctx: &Context) -> Result<ListValue> {
}

/// Cast struct of `source_elem_type` to `target_elem_type` by casting each element.
#[function("cast(struct) -> struct", type_infer = "panic")]
#[function("cast(struct) -> struct", type_infer = "unreachable")]
fn struct_cast(input: StructRef<'_>, ctx: &Context) -> Result<StructValue> {
let fields = (input.iter_fields_ref())
.zip_eq_fast(ctx.arg_types[0].as_struct().types())
Expand Down Expand Up @@ -242,7 +242,7 @@ fn struct_cast(input: StructRef<'_>, ctx: &Context) -> Result<StructValue> {
}

/// Cast array with `source_elem_type` into array with `target_elem_type` by casting each element.
#[function("cast(anymap) -> anymap", type_infer = "panic")]
#[function("cast(anymap) -> anymap", type_infer = "unreachable")]
fn map_cast(map: MapRef<'_>, ctx: &Context) -> Result<MapValue> {
let new_ctx = Context {
arg_types: vec![ctx.arg_types[0].clone().as_map().clone().into_list()],
Expand Down
2 changes: 1 addition & 1 deletion src/expr/impl/src/scalar/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl Expression for CoalesceExpression {
}
}

#[build_function("coalesce(...) -> any", type_infer = "panic")]
#[build_function("coalesce(...) -> any", type_infer = "unreachable")]
fn build(return_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
Ok(Box::new(CoalesceExpression {
return_type,
Expand Down
2 changes: 1 addition & 1 deletion src/expr/impl/src/scalar/external/iceberg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl risingwave_expr::expr::Expression for IcebergTransform {
}
}

#[build_function("iceberg_transform(varchar, any) -> any", type_infer = "panic")]
#[build_function("iceberg_transform(varchar, any) -> any", type_infer = "unreachable")]
fn build(return_type: DataType, mut children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
let transform_type = {
let datum = children[0].eval_const()?.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/expr/impl/src/scalar/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl Expression for FieldExpression {
}
}

#[build_function("field(struct, int4) -> any", type_infer = "panic")]
#[build_function("field(struct, int4) -> any", type_infer = "unreachable")]
fn build(return_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
// Field `func_call_node` have 2 child nodes, the first is Field `FuncCall` or
// `InputRef`, the second is i32 `Literal`.
Expand Down
7 changes: 5 additions & 2 deletions src/expr/impl/src/scalar/jsonb_record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ fn jsonb_populate_recordset<'a>(
/// ----
/// 1 [1,2,3] {1,2,3} NULL (123,"a b c")
/// ```
#[function("jsonb_to_record(jsonb) -> struct", type_infer = "panic")]
#[function("jsonb_to_record(jsonb) -> struct", type_infer = "unreachable")]
fn jsonb_to_record(jsonb: JsonbRef<'_>, ctx: &Context) -> Result<StructValue> {
let output_type = ctx.return_type.as_struct();
jsonb.to_struct(output_type).map_err(parse_err)
Expand All @@ -135,7 +135,10 @@ fn jsonb_to_record(jsonb: JsonbRef<'_>, ctx: &Context) -> Result<StructValue> {
/// 1 foo
/// 2 NULL
/// ```
#[function("jsonb_to_recordset(jsonb) -> setof struct", type_infer = "panic")]
#[function(
"jsonb_to_recordset(jsonb) -> setof struct",
type_infer = "unreachable"
)]
fn jsonb_to_recordset<'a>(
jsonb: JsonbRef<'a>,
ctx: &'a Context,
Expand Down
7 changes: 4 additions & 3 deletions src/expr/macro/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ impl FunctionAttr {
/// Generate the type infer function: `fn(&[DataType]) -> Result<DataType>`
fn generate_type_infer_fn(&self) -> Result<TokenStream2> {
if let Some(func) = &self.type_infer {
// XXX: should this be called "placeholder" or "unreachable"?
if func == "panic" {
return Ok(quote! { |_| panic!("type inference function is not implemented") });
if func == "unreachable" {
return Ok(
quote! { |_| unreachable!("type inference for this function should be specially handled in frontend, and should not call sig.type_infer") },
);
}
// use the user defined type inference function
return Ok(func.parse().unwrap());
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/binder/expr/function/builtin_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ impl Binder {
("jsonb_set", raw_call(ExprType::JsonbSet)),
// map
("map_from_entries", raw_call(ExprType::MapFromEntries)),
("map_access",raw_call(ExprType::MapAccess)),
// Functions that return a constant value
("pi", pi()),
// greatest and least
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ impl UdfContext {
}
}

/// `ParameterTypes` is used to record the types of the parameters during binding. It works
/// following the rules:
/// `ParameterTypes` is used to record the types of the parameters during binding prepared stataments.
/// It works by following the rules:
/// 1. At the beginning, it contains the user specified parameters type.
/// 2. When the binder encounters a parameter, it will record it as unknown(call `record_new_param`)
/// if it didn't exist in `ParameterTypes`.
Expand Down
7 changes: 6 additions & 1 deletion src/frontend/src/expr/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ impl std::fmt::Debug for Literal {

impl Literal {
pub fn new(data: Datum, data_type: DataType) -> Self {
assert!(literal_type_match(&data_type, data.as_ref()));
assert!(
literal_type_match(&data_type, data.as_ref()),
"data_type: {:?}, data: {:?}",
data_type,
data
);
Literal {
data,
data_type: Some(data_type),
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub use subquery::{Subquery, SubqueryKind};
pub use table_function::{TableFunction, TableFunctionType};
pub use type_inference::{
align_types, cast_map_array, cast_ok, cast_sigs, infer_some_all, infer_type, infer_type_name,
infer_type_with_sigmap, least_restrictive, CastContext, CastSig, FuncSign,
infer_type_with_sigmap, CastContext, CastSig, FuncSign,
};
pub use user_defined_function::UserDefinedFunction;
pub use utils::*;
Expand Down
3 changes: 2 additions & 1 deletion src/frontend/src/expr/pure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ impl ExprVisitor for ImpureAnalyzer {
| Type::InetAton
| Type::QuoteLiteral
| Type::QuoteNullable
| Type::MapFromEntries =>
| Type::MapFromEntries
| Type::MapAccess =>
// expression output is deterministic(same result for the same input)
{
func_call
Expand Down
8 changes: 6 additions & 2 deletions src/frontend/src/expr/type_inference/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ use crate::expr::{Expr as _, ExprImpl, InputRef, Literal};
///
/// If you also need to cast them to this type, and there are more than 2 exprs, check out
/// [`align_types`].
pub fn least_restrictive(lhs: DataType, rhs: DataType) -> std::result::Result<DataType, ErrorCode> {
///
/// Note: be careful that literal strings are considered untyped.
/// e.g., `align_types(1, '1')` will be `Int32`, but `least_restrictive(Int32, Varchar)` will return error.
fn least_restrictive(lhs: DataType, rhs: DataType) -> std::result::Result<DataType, ErrorCode> {
if lhs == rhs {
Ok(lhs)
} else if cast_ok(&lhs, &rhs, CastContext::Implicit) {
Expand Down Expand Up @@ -81,6 +84,7 @@ pub fn align_array_and_element(
element_indices: &[usize],
inputs: &mut [ExprImpl],
) -> std::result::Result<DataType, ErrorCode> {
tracing::trace!(?inputs, "align_array_and_element begin");
let mut dummy_element = match inputs[array_idx].is_untyped() {
// when array is unknown type, make an unknown typed value (e.g. null)
true => ExprImpl::from(Literal::new_untyped(None)),
Expand All @@ -106,7 +110,7 @@ pub fn align_array_and_element(

// elements are already casted by `align_types`, we cast the array argument here
inputs[array_idx].cast_implicit_mut(array_type.clone())?;

tracing::trace!(?inputs, "align_array_and_element done");
Ok(array_type)
}

Expand Down
18 changes: 18 additions & 0 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::expr::{cast_ok, is_row_function, Expr as _, ExprImpl, ExprType, Funct
/// is not supported on backend.
///
/// It also mutates the `inputs` by adding necessary casts.
#[tracing::instrument(level = "trace", skip(sig_map))]
pub fn infer_type_with_sigmap(
func_name: FuncName,
inputs: &mut [ExprImpl],
Expand Down Expand Up @@ -65,6 +66,7 @@ pub fn infer_type_with_sigmap(
})
.collect_vec();
let sig = infer_type_name(sig_map, func_name, &actuals)?;
tracing::trace!(?actuals, ?sig, "infer_type_name");

// add implicit casts to inputs
for (expr, t) in inputs.iter_mut().zip_eq_fast(&sig.inputs_type) {
Expand All @@ -82,6 +84,7 @@ pub fn infer_type_with_sigmap(

let input_types = inputs.iter().map(|expr| expr.return_type()).collect_vec();
let return_type = (sig.type_infer)(&input_types)?;
tracing::trace!(?input_types, ?return_type, "finished type inference");
Ok(return_type)
}

Expand Down Expand Up @@ -608,6 +611,21 @@ fn infer_type_for_special(
_ => Ok(None),
}
}
ExprType::MapAccess => {
ensure_arity!("map_access", | inputs | == 2);
let map_type = inputs[0].return_type().into_map();
// We do not align the map's key type with the input type here, but cast the latter to the former instead.
// e.g., for {1:'a'}[1.0], if we align them, we will get "numeric" as the key type, which violates the map type's restriction.
match inputs[1].cast_implicit_mut(map_type.key().clone()) {
Ok(()) => Ok(Some(map_type.value().clone())),
Err(_) => Err(ErrorCode::BindError(format!(
"Cannot access {} in {}",
inputs[1].return_type(),
inputs[0].return_type(),
))
.into()),
}
}
ExprType::Vnode => {
ensure_arity!("vnode", 1 <= | inputs |);
Ok(Some(VirtualNode::RW_TYPE))
Expand Down
Loading

0 comments on commit e1de185

Please sign in to comment.