Skip to content

Commit

Permalink
fix(expr): fix SOME/ALL/ANY expression (#14221)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored and wangrunji0408 committed Dec 27, 2023
1 parent 393f17f commit 4260598
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 6 deletions.
2 changes: 0 additions & 2 deletions src/common/src/array/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,6 @@ impl DataChunkTestExt for DataChunk {
{
let datum = match val_str {
"." => None,
"t" => Some(true.into()),
"f" => Some(false.into()),
"(empty)" => Some("".into()),
_ => Some(ScalarImpl::from_text(val_str.as_bytes(), ty).unwrap()),
};
Expand Down
10 changes: 8 additions & 2 deletions src/common/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ impl ScalarImpl {
let res = match data_type {
DataType::Varchar => Self::Utf8(str.to_string().into()),
DataType::Boolean => {
Self::Bool(bool::from_str(str).map_err(|_| FromSqlError::from_text(str))?)
Self::Bool(str_to_bool(str).map_err(|_| FromSqlError::from_text(str))?)
}
DataType::Int16 => {
Self::Int16(i16::from_str(str).map_err(|_| FromSqlError::from_text(str))?)
Expand Down Expand Up @@ -921,7 +921,13 @@ impl ScalarImpl {
}
let mut builder = elem_type.create_array_builder(0);
for s in str[1..str.len() - 1].split(',') {
builder.append(Some(Self::from_text(s.trim().as_bytes(), elem_type)?));
if s.is_empty() {
continue;
} else if s.eq_ignore_ascii_case("null") {
builder.append_null();
} else {
builder.append(Some(Self::from_text(s.trim().as_bytes(), elem_type)?));
}
}
Self::List(ListValue::new(builder.finish()))
}
Expand Down
85 changes: 83 additions & 2 deletions src/expr/core/src/expr/expr_some_all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ impl SomeAllExpression {
}
}

// Notice that this function may not exhaust the iterator,
// so never pass an iterator created `by_ref`.
fn resolve_bools(&self, bools: impl Iterator<Item = Option<bool>>) -> Option<bool> {
match self.expr_type {
Type::Some => {
Expand Down Expand Up @@ -160,12 +162,17 @@ impl Expression for SomeAllExpression {
);

let func_results = self.func.eval(&data_chunk).await?;
let mut func_results_iter = func_results.as_bool().iter();
let bools = func_results.as_bool();
let mut offset = 0;
Ok(Arc::new(
num_array
.into_iter()
.map(|num| match num {
Some(num) => self.resolve_bools(func_results_iter.by_ref().take(num)),
Some(num) => {
let range = offset..offset + num;
offset += num;
self.resolve_bools(range.map(|i| bools.value_at(i)))
}
None => None,
})
.collect::<BoolArray>()
Expand Down Expand Up @@ -262,3 +269,77 @@ impl Build for SomeAllExpression {
))
}
}

#[cfg(test)]
mod tests {
use risingwave_common::array::DataChunk;
use risingwave_common::row::Row;
use risingwave_common::test_prelude::DataChunkTestExt;
use risingwave_common::types::ToOwnedDatum;
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_expr::expr::build_from_pretty;

use super::*;

#[tokio::test]
async fn test_some() {
let expr = SomeAllExpression::new(
build_from_pretty("0:int4"),
build_from_pretty("$0:boolean"),
Type::Some,
build_from_pretty("$1:boolean"),
);
let (input, expected) = DataChunk::from_pretty(
"B[] B
. .
{} f
{NULL} .
{NULL,f} .
{NULL,t} t
{t,f} t
{f,t} t", // <- regression test for #14214
)
.split_column_at(1);

// test eval
let output = expr.eval(&input).await.unwrap();
assert_eq!(&output, expected.column_at(0));

// test eval_row
for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
assert_eq!(result, expected.datum_at(0).to_owned_datum());
}
}

#[tokio::test]
async fn test_all() {
let expr = SomeAllExpression::new(
build_from_pretty("0:int4"),
build_from_pretty("$0:boolean"),
Type::All,
build_from_pretty("$1:boolean"),
);
let (input, expected) = DataChunk::from_pretty(
"B[] B
. .
{} t
{NULL} .
{NULL,t} .
{NULL,f} f
{f,f} f
{t} t", // <- regression test for #14214
)
.split_column_at(1);

// test eval
let output = expr.eval(&input).await.unwrap();
assert_eq!(&output, expected.column_at(0));

// test eval_row
for (row, expected) in input.rows().zip_eq_debug(expected.rows()) {
let result = expr.eval_row(&row.to_owned_row()).await.unwrap();
assert_eq!(result, expected.datum_at(0).to_owned_datum());
}
}
}

0 comments on commit 4260598

Please sign in to comment.