Skip to content

Commit

Permalink
support any type
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 committed Sep 13, 2023
1 parent 1a7adac commit ebcc575
Show file tree
Hide file tree
Showing 22 changed files with 156 additions and 127 deletions.
6 changes: 6 additions & 0 deletions src/common/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,12 @@ impl ToText for ListRef<'_> {
}
}

impl<'a> From<&'a ListValue> for ListRef<'a> {
fn from(val: &'a ListValue) -> Self {
ListRef::ValueRef { val }
}
}

#[cfg(test)]
mod tests {
use more_asserts::{assert_gt, assert_lt};
Expand Down
20 changes: 20 additions & 0 deletions src/common/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,14 @@ impl DataType {
DataTypeName::from(self).is_scalar()
}

pub fn is_array(&self) -> bool {
matches!(self, DataType::List(_))
}

pub fn is_struct(&self) -> bool {
matches!(self, DataType::Struct(_))
}

pub fn is_int(&self) -> bool {
matches!(self, DataType::Int16 | DataType::Int32 | DataType::Int64)
}
Expand Down Expand Up @@ -956,6 +964,18 @@ impl ScalarImpl {
}
}

impl From<ScalarRefImpl<'_>> for ScalarImpl {
fn from(scalar_ref: ScalarRefImpl<'_>) -> Self {
scalar_ref.into_scalar_impl()
}
}

impl<'a> From<&'a ScalarImpl> for ScalarRefImpl<'a> {
fn from(scalar: &'a ScalarImpl) -> Self {
scalar.as_scalar_ref_impl()
}
}

impl ScalarImpl {
/// Converts [`ScalarImpl`] to [`ScalarRefImpl`]
pub fn as_scalar_ref_impl(&self) -> ScalarRefImpl<'_> {
Expand Down
16 changes: 8 additions & 8 deletions src/expr/macro/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ impl FunctionAttr {
Ok(func.parse().unwrap())
} else {
if matches!(self.ret.as_str(), "any" | "anyarray" | "struct") {
return Err(Error::new(
Span::call_site(),
format!("type inference function is required for {}", self.ret),
));
return Ok(quote! { |_| todo!("type infer") });
// return Err(Error::new(
// Span::call_site(),
// format!("type inference function is required for {}", self.ret),
// ));
}
let ty = data_type(&self.ret);
Ok(quote! { |_| Ok(#ty) })
Expand Down Expand Up @@ -537,11 +538,9 @@ impl FunctionAttr {
.enumerate()
.map(|(i, arg)| {
let array = format_ident!("a{i}");
let variant: TokenStream2 = types::variant(arg).parse().unwrap();
let array_type: TokenStream2 = types::array_type(arg).parse().unwrap();
quote! {
let ArrayImpl::#variant(#array) = &**input.column_at(#i) else {
bail!("input type mismatch. expect: {}", stringify!(#variant));
};
let #array: &#array_type = input.column_at(#i).as_ref().into();
}
})
.collect_vec();
Expand Down Expand Up @@ -939,6 +938,7 @@ fn match_type(ty: &str) -> TokenStream2 {
match ty {
"any" => quote! { MatchType::Any },
"anyarray" => quote! { MatchType::AnyArray },
"struct" => quote! { MatchType::AnyStruct },
_ => {
let datatype = data_type(ty);
quote! { MatchType::Exact(#datatype) }
Expand Down
58 changes: 26 additions & 32 deletions src/expr/macro/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,60 +14,53 @@

//! This module provides utility functions for SQL data type conversion and manipulation.
// name data type variant array type owned type ref type primitive
// name data type array type owned type ref type primitive
const TYPE_MATRIX: &str = "
boolean Boolean Bool BoolArray bool bool _
int16 Int16 Int16 I16Array i16 i16 y
int32 Int32 Int32 I32Array i32 i32 y
int64 Int64 Int64 I64Array i64 i64 y
int256 Int256 Int256 Int256Array Int256 Int256Ref<'_> _
float32 Float32 Float32 F32Array F32 F32 y
float64 Float64 Float64 F64Array F64 F64 y
decimal Decimal Decimal DecimalArray Decimal Decimal y
serial Serial Serial SerialArray Serial Serial y
date Date Date DateArray Date Date y
time Time Time TimeArray Time Time y
timestamp Timestamp Timestamp TimestampArray Timestamp Timestamp y
timestamptz Timestamptz Timestamptz TimestamptzArray Timestamptz Timestamptz y
interval Interval Interval IntervalArray Interval Interval y
varchar Varchar Utf8 Utf8Array Box<str> &str _
bytea Bytea Bytea BytesArray Box<[u8]> &[u8] _
jsonb Jsonb Jsonb JsonbArray JsonbVal JsonbRef<'_> _
anyarray List List ListArray ListValue ListRef<'_> _
struct Struct Struct StructArray StructValue StructRef<'_> _
boolean Boolean BoolArray bool bool _
int16 Int16 I16Array i16 i16 y
int32 Int32 I32Array i32 i32 y
int64 Int64 I64Array i64 i64 y
int256 Int256 Int256Array Int256 Int256Ref<'_> _
float32 Float32 F32Array F32 F32 y
float64 Float64 F64Array F64 F64 y
decimal Decimal DecimalArray Decimal Decimal y
serial Serial SerialArray Serial Serial y
date Date DateArray Date Date y
time Time TimeArray Time Time y
timestamp Timestamp TimestampArray Timestamp Timestamp y
timestamptz Timestamptz TimestamptzArray Timestamptz Timestamptz y
interval Interval IntervalArray Interval Interval y
varchar Varchar Utf8Array Box<str> &str _
bytea Bytea BytesArray Box<[u8]> &[u8] _
jsonb Jsonb JsonbArray JsonbVal JsonbRef<'_> _
anyarray List ListArray ListValue ListRef<'_> _
struct Struct StructArray StructValue StructRef<'_> _
any ??? ArrayImpl ScalarImpl ScalarRefImpl<'_> _
";

/// Maps a data type to its corresponding data type name.
pub fn data_type(ty: &str) -> &str {
lookup_matrix(ty, 1)
}

/// Maps a data type to its corresponding variant name.
pub fn variant(ty: &str) -> &str {
lookup_matrix(ty, 2)
}

/// Maps a data type to its corresponding array type name.
pub fn array_type(ty: &str) -> &str {
if ty == "any" {
return "ArrayImpl";
}
lookup_matrix(ty, 3)
lookup_matrix(ty, 2)
}

/// Maps a data type to its corresponding `Scalar` type name.
pub fn owned_type(ty: &str) -> &str {
lookup_matrix(ty, 4)
lookup_matrix(ty, 3)
}

/// Maps a data type to its corresponding `ScalarRef` type name.
pub fn ref_type(ty: &str) -> &str {
lookup_matrix(ty, 5)
lookup_matrix(ty, 4)
}

/// Checks if a data type is primitive.
pub fn is_primitive(ty: &str) -> bool {
lookup_matrix(ty, 6) == "y"
lookup_matrix(ty, 5) == "y"
}

fn lookup_matrix(mut ty: &str, idx: usize) -> &str {
Expand All @@ -94,6 +87,7 @@ pub fn expand_type_wildcard(ty: &str) -> Vec<&str> {
.trim()
.lines()
.map(|l| l.split_whitespace().next().unwrap())
.filter(|l| *l != "any")
.collect(),
"*int" => vec!["int16", "int32", "int64"],
"*numeric" => vec!["decimal"],
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/agg/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use risingwave_common::array::ListValue;
use risingwave_common::types::{Datum, ScalarRef};
use risingwave_common::types::{Datum, ScalarRefImpl, ToOwnedDatum};
use risingwave_expr_macro::aggregate;

#[aggregate("array_agg(any) -> anyarray")]
Expand Down
1 change: 0 additions & 1 deletion src/expr/src/agg/jsonb_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::bail;
use risingwave_common::types::JsonbVal;
use risingwave_expr_macro::aggregate;
use serde_json::Value;
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/agg/mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use super::{AggStateDyn, AggregateFunction, AggregateState, BoxedAggregateFuncti
use crate::agg::AggCall;
use crate::Result;

#[build_aggregate("mode(*) -> auto")]
#[build_aggregate("mode(any) -> any")]
fn build(agg: &AggCall) -> Result<BoxedAggregateFunction> {
Ok(Box::new(Mode {
return_type: agg.return_type.clone(),
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/agg/percentile_disc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ use crate::Result;
/// statement ok
/// drop table t;
/// ```
#[build_aggregate("percentile_disc(*) -> auto")]
#[build_aggregate("percentile_disc(any) -> any")]
fn build(agg: &AggCall) -> Result<BoxedAggregateFunction> {
let fractions = agg.direct_args[0]
.literal()
Expand Down
1 change: 0 additions & 1 deletion src/expr/src/agg/string_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::bail;
use risingwave_expr_macro::aggregate;

#[aggregate("string_agg(varchar, varchar) -> varchar")]
Expand Down
28 changes: 28 additions & 0 deletions src/expr/src/sig/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ impl FuncSign {
}
}

/// Returns true if the function is a scalar function.
pub const fn is_scalar(&self) -> bool {
matches!(self.name, FuncName::Scalar(_))
}

/// Returns true if the function is a aggregate function.
pub const fn is_aggregate(&self) -> bool {
matches!(self.name, FuncName::Aggregate(_))
}

pub fn build_scalar(
&self,
return_type: DataType,
Expand Down Expand Up @@ -220,6 +230,20 @@ impl FuncName {
const fn is_table(&self) -> bool {
matches!(self, Self::Table(_))
}

pub fn as_scalar(&self) -> ScalarFunctionType {
match self {
Self::Scalar(ty) => *ty,
_ => panic!("Expected a scalar function"),
}
}

pub fn as_aggregate(&self) -> AggregateFunctionType {
match self {
Self::Aggregate(ty) => *ty,
_ => panic!("Expected an aggregate function"),
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
Expand All @@ -230,6 +254,8 @@ pub enum MatchType {
Any,
/// Any array data type
AnyArray,
/// Any struct type
AnyStruct,
}

impl From<DataType> for MatchType {
Expand All @@ -244,6 +270,7 @@ impl std::fmt::Display for MatchType {
Self::Exact(dt) => write!(f, "{}", dt),
Self::Any => write!(f, "any"),
Self::AnyArray => write!(f, "anyarray"),
Self::AnyStruct => write!(f, "anystruct"),
}
}
}
Expand All @@ -255,6 +282,7 @@ impl MatchType {
Self::Exact(ty) => ty == dt,
Self::Any => true,
Self::AnyArray => dt.is_array(),
Self::AnyStruct => dt.is_struct(),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/table_function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use itertools::Itertools;
use risingwave_common::array::{Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk};
use risingwave_common::types::{DataType, DataTypeName, DatumRef};
use risingwave_common::types::{DataType, DatumRef};
use risingwave_pb::expr::project_set_select_item::SelectItem;
use risingwave_pb::expr::table_function::PbType;
use risingwave_pb::expr::{PbProjectSetSelectItem, PbTableFunction};
Expand Down
49 changes: 22 additions & 27 deletions src/expr/src/vector_op/array_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@
// limitations under the License.

use risingwave_common::array::ListRef;
use risingwave_common::types::{Scalar, ToOwnedDatum};
use risingwave_common::types::ScalarRefImpl;
use risingwave_expr_macro::function;

use crate::Result;

#[function("array_access(anyarray, int32) -> any")]
fn array_access(list: ListRef<'_>, index: i32) -> DatumRef<'_> {
#[function(
"array_access(anyarray, int32) -> any",
type_infer = "|args| Ok(args[0].as_list().clone())"
)]
fn array_access(list: ListRef<'_>, index: i32) -> Option<ScalarRefImpl<'_>> {
// index must be greater than 0 following a one-based numbering convention for arrays
if index < 1 {
return None;
}
// returns `NULL` if index is out of bounds
list.elem_at(index as usize - 1).flatten();
list.elem_at(index as usize - 1).flatten()
}

#[cfg(test)]
Expand All @@ -45,10 +46,10 @@ mod tests {
]);
let l1 = ListRef::ValueRef { val: &v1 };

assert_eq!(array_access::<i32>(l1, 1).unwrap(), Some(1));
assert_eq!(array_access::<i32>(l1, -1).unwrap(), None);
assert_eq!(array_access::<i32>(l1, 0).unwrap(), None);
assert_eq!(array_access::<i32>(l1, 4).unwrap(), None);
assert_eq!(array_access(l1, 1), Some(1.into()));
assert_eq!(array_access(l1, -1), None);
assert_eq!(array_access(l1, 0), None);
assert_eq!(array_access(l1, 4), None);
}

#[test]
Expand All @@ -69,18 +70,9 @@ mod tests {
let l2 = ListRef::ValueRef { val: &v2 };
let l3 = ListRef::ValueRef { val: &v3 };

assert_eq!(
array_access::<Box<str>>(l1, 1).unwrap(),
Some("来自".into())
);
assert_eq!(
array_access::<Box<str>>(l2, 2).unwrap(),
Some("荷兰".into())
);
assert_eq!(
array_access::<Box<str>>(l3, 3).unwrap(),
Some("的爱".into())
);
assert_eq!(array_access(l1, 1), Some("来自".into()));
assert_eq!(array_access(l2, 2), Some("荷兰".into()));
assert_eq!(array_access(l3, 3), Some("的爱".into()));
}

#[test]
Expand All @@ -97,11 +89,14 @@ mod tests {
]);
let l = ListRef::ValueRef { val: &v };
assert_eq!(
array_access::<ListValue>(l, 1).unwrap(),
Some(ListValue::new(vec![
Some(ScalarImpl::Utf8("foo".into())),
Some(ScalarImpl::Utf8("bar".into())),
]))
array_access(l, 1),
Some(
ListRef::from(&ListValue::new(vec![
Some(ScalarImpl::Utf8("foo".into())),
Some(ScalarImpl::Utf8("bar".into())),
]))
.into()
)
);
}
}
5 changes: 4 additions & 1 deletion src/expr/src/vector_op/array_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ use risingwave_expr_macro::function;
/// select array_distinct(null);
/// ```
#[function("array_distinct(anyarray) -> anyarray")]
#[function(
"array_distinct(anyarray) -> anyarray",
type_infer = "|args| Ok(args[0].clone())"
)]
pub fn array_distinct(list: ListRef<'_>) -> ListValue {
ListValue::new(list.iter().unique().map(|x| x.to_owned_datum()).collect())
}
Expand Down
Loading

0 comments on commit ebcc575

Please sign in to comment.