Skip to content

Commit

Permalink
fix: array contains does not support nested types (#13290)
Browse files Browse the repository at this point in the history
Co-authored-by: thexia <[email protected]>
  • Loading branch information
thexiay and thexia authored Nov 8, 2023
1 parent 7b3f8fc commit 8fae5b5
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 7 deletions.
14 changes: 10 additions & 4 deletions src/expr/impl/src/scalar/array_contain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,25 @@ use risingwave_expr::function;
/// f
///
/// query I
/// select array[1,2,3] @> NULL;
/// select array[[[1,2],[3,4]],[[5,6],[7,8]]] @> array[2,3];
/// ----
/// t
///
/// query I
/// select array[1,2,3] @> null;
/// ----
/// NULL
///
/// query I
/// select NULL @> array[3,4];
/// select null @> array[3,4];
/// ----
/// NULL
/// ```
#[function("array_contains(anyarray, anyarray) -> boolean")]
fn array_contains(left: ListRef<'_>, right: ListRef<'_>) -> bool {
let set: HashSet<_> = left.iter().collect();
right.iter().all(|item| set.contains(&item))
let flatten = left.flatten();
let set: HashSet<_> = flatten.iter().collect();
right.flatten().iter().all(|item| set.contains(&item))
}

#[function("array_contained(anyarray, anyarray) -> boolean")]
Expand Down
8 changes: 8 additions & 0 deletions src/frontend/planner_test/tests/testdata/input/array.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@
sql: select array[1,2] @> array[2,3];
expected_outputs:
- logical_plan
- name: array_contains(int[][], int[]) -> bool
sql: select array[[1,2]] @> array[2,3];
expected_outputs:
- logical_plan
- name: array_contains(int[], int) -> bool
sql: select array[1,2] @> 2;
expected_outputs:
Expand All @@ -157,6 +161,10 @@
sql: select array[2,3] @> array['1'];
expected_outputs:
- binder_error
- name: array_contains(int[][], varchar[][]) -> bool
sql: select array[array[1,2]] @> array[array['2','3']];
expected_outputs:
- binder_error
- name: any contains(null, null) -> bool
sql: select '{}' @> '{}';
expected_outputs:
Expand Down
14 changes: 13 additions & 1 deletion src/frontend/planner_test/tests/testdata/output/array.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@
logical_plan: |-
LogicalProject { exprs: [(Array(1:Int32, 2:Int32) @> Array(2:Int32, 3:Int32)) as $expr1] }
└─LogicalValues { rows: [[]], schema: Schema { fields: [] } }
- name: array_contains(int[][], int[]) -> bool
sql: select array[[1,2]] @> array[2,3];
logical_plan: |-
LogicalProject { exprs: [(Array(Array(1:Int32, 2:Int32)) @> Array(2:Int32, 3:Int32)) as $expr1] }
└─LogicalValues { rows: [[]], schema: Schema { fields: [] } }
- name: array_contains(int[], int) -> bool
sql: select array[1,2] @> 2;
binder_error: |-
Expand All @@ -200,7 +205,14 @@
Bind error: failed to bind expression: ARRAY[2, 3] @> ARRAY['1']
Caused by:
Bind error: types List(Int32) and List(Varchar) cannot be matched
Bind error: Cannot array_contains unnested type integer to unnested type character varying
- name: array_contains(int[][], varchar[][]) -> bool
sql: select array[array[1,2]] @> array[array['2','3']];
binder_error: |-
Bind error: failed to bind expression: ARRAY[ARRAY[1, 2]] @> ARRAY[ARRAY['2', '3']]
Caused by:
Bind error: Cannot array_contains unnested type integer to unnested type character varying
- name: any contains(null, null) -> bool
sql: select '{}' @> '{}';
binder_error: |-
Expand Down
26 changes: 24 additions & 2 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,30 @@ fn infer_type_for_special(
}
ExprType::ArrayContains | ExprType::ArrayContained => {
ensure_arity!("array_contains/array_contained", | inputs | == 2);
align_types(inputs.iter_mut())?;
Ok(Some(DataType::Boolean))
let left_type = (!inputs[0].is_untyped()).then(|| inputs[0].return_type());
let right_type = (!inputs[1].is_untyped()).then(|| inputs[1].return_type());
match (left_type, right_type) {
(None, Some(DataType::List(_))) | (Some(DataType::List(_)), None) => {
align_types(inputs.iter_mut())?;
Ok(Some(DataType::Boolean))
}
(Some(DataType::List(left)), Some(DataType::List(right))) => {
// cannot directly cast, find unnest type and judge if they are same type
let left = left.unnest_list();
let right = right.unnest_list();
if left.equals_datatype(right) {
Ok(Some(DataType::Boolean))
} else {
Err(ErrorCode::BindError(format!(
"Cannot array_contains unnested type {} to unnested type {}",
left, right
))
.into())
}
}
// any other condition cannot determine polymorphic type
_ => Ok(None),
}
}
ExprType::Vnode => {
ensure_arity!("vnode", 1 <= | inputs |);
Expand Down

0 comments on commit 8fae5b5

Please sign in to comment.