Skip to content

Commit

Permalink
auto type inference for any and anyarray
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 committed Sep 13, 2023
1 parent 9a0be69 commit f8099f8
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 38 deletions.
40 changes: 31 additions & 9 deletions src/expr/macro/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,40 @@ impl FunctionAttr {
/// Generate the type infer function.
fn generate_type_infer_fn(&self) -> Result<TokenStream2> {
if let Some(func) = &self.type_infer {
Ok(func.parse().unwrap())
} else {
if matches!(self.ret.as_str(), "any" | "anyarray" | "struct") {
return Ok(quote! { |_| todo!("type infer") });
// return Err(Error::new(
// Span::call_site(),
// format!("type inference function is required for {}", self.ret),
// ));
if func == "panic" {
return Ok(quote! { |_| panic!("type inference function is not implemented") });
}
// use the user defined type inference function
return Ok(func.parse().unwrap());
} else if self.ret == "any" {
// TODO: if there are multiple "any", they should be the same type
if let Some(i) = self.args.iter().position(|t| t == "any") {
// infer as the type of "any" argument
return Ok(quote! { |args| Ok(args[#i].clone()) });
}
if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
// infer as the element type of "anyarray" argument
return Ok(quote! { |args| Ok(args[#i].as_list().clone()) });
}
} else if self.ret == "anyarray" {
if let Some(i) = self.args.iter().position(|t| t == "anyarray") {
// infer as the type of "anyarray" argument
return Ok(quote! { |args| Ok(args[#i].clone()) });
}
} else if self.ret == "struct" {
if let Some(i) = self.args.iter().position(|t| t == "struct") {
// infer as the type of "struct" argument
return Ok(quote! { |args| Ok(args[#i].clone()) });
}
} else {
// the return type is fixed
let ty = data_type(&self.ret);
Ok(quote! { |_| Ok(#ty) })
return Ok(quote! { |_| Ok(#ty) });
}
Err(Error::new(
Span::call_site(),
"type inference function is required",
))
}

/// Generate a descriptor of the scalar or table function.
Expand Down
4 changes: 2 additions & 2 deletions src/expr/src/vector_op/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ use risingwave_common::row::Row;
use risingwave_common::types::ToOwnedDatum;
use risingwave_expr_macro::function;

#[function("array(...) -> anyarray")]
#[function("array(...) -> anyarray", type_infer = "panic")]
fn array(row: impl Row) -> ListValue {
ListValue::new(row.iter().map(|d| d.to_owned_datum()).collect())
}

#[function("row(...) -> struct")]
#[function("row(...) -> struct", type_infer = "panic")]
fn row_(row: impl Row) -> StructValue {
StructValue::new(row.iter().map(|d| d.to_owned_datum()).collect())
}
5 changes: 1 addition & 4 deletions src/expr/src/vector_op/array_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ use risingwave_common::array::ListRef;
use risingwave_common::types::ScalarRefImpl;
use risingwave_expr_macro::function;

#[function(
"array_access(anyarray, int32) -> any",
type_infer = "|args| Ok(args[0].as_list().clone())"
)]
#[function("array_access(anyarray, int32) -> any")]
fn array_access(list: ListRef<'_>, index: i32) -> Option<ScalarRefImpl<'_>> {
// index must be greater than 0 following a one-based numbering convention for arrays
if index < 1 {
Expand Down
5 changes: 1 addition & 4 deletions src/expr/src/vector_op/array_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ use risingwave_expr_macro::function;
/// select array_distinct(null);
/// ```
#[function(
"array_distinct(anyarray) -> anyarray",
type_infer = "|args| Ok(args[0].clone())"
)]
#[function("array_distinct(anyarray) -> anyarray")]
pub fn array_distinct(list: ListRef<'_>) -> ListValue {
ListValue::new(list.iter().unique().map(|x| x.to_owned_datum()).collect())
}
Expand Down
5 changes: 1 addition & 4 deletions src/expr/src/vector_op/array_range_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ use risingwave_expr_macro::function;

/// If the case is `array[1,2,3][:2]`, then start will be 0 set by the frontend
/// If the case is `array[1,2,3][1:]`, then end will be `i32::MAX` set by the frontend
#[function(
"array_range_access(anyarray, int32, int32) -> anyarray",
type_infer = "|args| Ok(args[0].clone())"
)]
#[function("array_range_access(anyarray, int32, int32) -> anyarray")]
pub fn array_range_access(list: ListRef<'_>, start: i32, end: i32) -> Option<ListValue> {
let mut data = vec![];
let list_all_values = list.iter();
Expand Down
5 changes: 1 addition & 4 deletions src/expr/src/vector_op/array_remove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ use risingwave_expr_macro::function;
/// statement error
/// select array_remove(ARRAY[array[1],array[2],array[3],array[2],null], array[true]);
/// ```
#[function(
"array_remove(anyarray, any) -> anyarray",
type_infer = "|args| Ok(args[0].clone())"
)]
#[function("array_remove(anyarray, any) -> anyarray")]
fn array_remove(array: Option<ListRef<'_>>, elem: Option<ScalarRefImpl<'_>>) -> Option<ListValue> {
Some(ListValue::new(
array?
Expand Down
5 changes: 1 addition & 4 deletions src/expr/src/vector_op/array_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ use risingwave_expr_macro::function;
/// statement error
/// select array_replace(array[array[array[0, 1], array[2, 3]], array[array[4, 5], array[6, 7]]], array[4, 5], array[8, 9]);
/// ```
#[function(
"array_replace(anyarray, any, any) -> anyarray",
type_infer = "|args| Ok(args[0].clone())"
)]
#[function("array_replace(anyarray, any, any) -> anyarray")]
fn array_replace(
array: Option<ListRef<'_>>,
elem_from: Option<ScalarRefImpl<'_>>,
Expand Down
6 changes: 3 additions & 3 deletions src/expr/src/vector_op/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ fn unnest(input: &str) -> Result<Vec<&str>> {
Ok(items)
}

#[function("cast(varchar) -> anyarray")]
#[function("cast(varchar) -> anyarray", type_infer = "panic")]
fn str_to_list(input: &str, ctx: &Context) -> Result<ListValue> {
let cast = build_func(
PbType::Cast,
Expand All @@ -334,7 +334,7 @@ fn str_to_list(input: &str, ctx: &Context) -> Result<ListValue> {
}

/// Cast array with `source_elem_type` into array with `target_elem_type` by casting each element.
#[function("cast(anyarray) -> anyarray")]
#[function("cast(anyarray) -> anyarray", type_infer = "panic")]
fn list_cast(input: ListRef<'_>, ctx: &Context) -> Result<ListValue> {
let cast = build_func(
PbType::Cast,
Expand All @@ -355,7 +355,7 @@ fn list_cast(input: ListRef<'_>, ctx: &Context) -> Result<ListValue> {
}

/// Cast struct of `source_elem_type` to `target_elem_type` by casting each element.
#[function("cast(struct) -> struct")]
#[function("cast(struct) -> struct", type_infer = "panic")]
fn struct_cast(input: StructRef<'_>, ctx: &Context) -> Result<StructValue> {
let fields = (input.iter_fields_ref())
.zip_eq_fast(ctx.arg_types[0].as_struct().types())
Expand Down
5 changes: 1 addition & 4 deletions src/expr/src/vector_op/trim_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,7 @@ use crate::Result;
/// statement error
/// select trim_array(array[1,2,3,4,5,null], true);
/// ```
#[function(
"trim_array(anyarray, int32) -> anyarray",
type_infer = "|args| Ok(args[0].clone())"
)]
#[function("trim_array(anyarray, int32) -> anyarray")]
fn trim_array(array: ListRef<'_>, n: i32) -> Result<ListValue> {
let values = array.iter();
let len_to_trim: usize = n.try_into().map_err(|_| ExprError::InvalidParam {
Expand Down

0 comments on commit f8099f8

Please sign in to comment.