From 4695ad1239b1c160228dd7bf6f473634f57c9834 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 27 Dec 2023 17:57:22 +0800 Subject: [PATCH] fix(expr): fix `SOME/ALL/ANY` expression (#14221) Signed-off-by: Runji Wang --- src/common/src/array/data_chunk.rs | 2 - src/common/src/types/mod.rs | 10 ++- src/expr/core/src/expr/expr_some_all.rs | 85 ++++++++++++++++++++++++- 3 files changed, 91 insertions(+), 6 deletions(-) diff --git a/src/common/src/array/data_chunk.rs b/src/common/src/array/data_chunk.rs index 90c2560cadcb2..a1ef272ec0eca 100644 --- a/src/common/src/array/data_chunk.rs +++ b/src/common/src/array/data_chunk.rs @@ -760,8 +760,6 @@ impl DataChunkTestExt for DataChunk { { let datum = match val_str { "." => None, - "t" => Some(true.into()), - "f" => Some(false.into()), "(empty)" => Some("".into()), _ => Some(ScalarImpl::from_text(val_str.as_bytes(), ty).unwrap()), }; diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 03b477fa09e57..d8bcae757d530 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -874,7 +874,7 @@ impl ScalarImpl { let res = match data_type { DataType::Varchar => Self::Utf8(str.to_string().into()), DataType::Boolean => { - Self::Bool(bool::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + Self::Bool(str_to_bool(str).map_err(|_| FromSqlError::from_text(str))?) } DataType::Int16 => { Self::Int16(i16::from_str(str).map_err(|_| FromSqlError::from_text(str))?) @@ -931,7 +931,13 @@ impl ScalarImpl { } let mut builder = elem_type.create_array_builder(0); for s in str[1..str.len() - 1].split(',') { - builder.append(Some(Self::from_text(s.trim().as_bytes(), elem_type)?)); + if s.is_empty() { + continue; + } else if s.eq_ignore_ascii_case("null") { + builder.append_null(); + } else { + builder.append(Some(Self::from_text(s.trim().as_bytes(), elem_type)?)); + } } Self::List(ListValue::new(builder.finish())) } diff --git a/src/expr/core/src/expr/expr_some_all.rs b/src/expr/core/src/expr/expr_some_all.rs index 2cc84a0be79fd..8b7bd4c9667d7 100644 --- a/src/expr/core/src/expr/expr_some_all.rs +++ b/src/expr/core/src/expr/expr_some_all.rs @@ -49,6 +49,8 @@ impl SomeAllExpression { } } + // Notice that this function may not exhaust the iterator, + // so never pass an iterator created `by_ref`. fn resolve_bools(&self, bools: impl Iterator>) -> Option { match self.expr_type { Type::Some => { @@ -160,12 +162,17 @@ impl Expression for SomeAllExpression { ); let func_results = self.func.eval(&data_chunk).await?; - let mut func_results_iter = func_results.as_bool().iter(); + let bools = func_results.as_bool(); + let mut offset = 0; Ok(Arc::new( num_array .into_iter() .map(|num| match num { - Some(num) => self.resolve_bools(func_results_iter.by_ref().take(num)), + Some(num) => { + let range = offset..offset + num; + offset += num; + self.resolve_bools(range.map(|i| bools.value_at(i))) + } None => None, }) .collect::() @@ -262,3 +269,77 @@ impl Build for SomeAllExpression { )) } } + +#[cfg(test)] +mod tests { + use risingwave_common::array::DataChunk; + use risingwave_common::row::Row; + use risingwave_common::test_prelude::DataChunkTestExt; + use risingwave_common::types::ToOwnedDatum; + use risingwave_common::util::iter_util::ZipEqDebug; + use risingwave_expr::expr::build_from_pretty; + + use super::*; + + #[tokio::test] + async fn test_some() { + let expr = SomeAllExpression::new( + build_from_pretty("0:int4"), + build_from_pretty("$0:boolean"), + Type::Some, + build_from_pretty("$1:boolean"), + ); + let (input, expected) = DataChunk::from_pretty( + "B[] B + . . + {} f + {NULL} . + {NULL,f} . + {NULL,t} t + {t,f} t + {f,t} t", // <- regression test for #14214 + ) + .split_column_at(1); + + // test eval + let output = expr.eval(&input).await.unwrap(); + assert_eq!(&output, expected.column_at(0)); + + // test eval_row + for (row, expected) in input.rows().zip_eq_debug(expected.rows()) { + let result = expr.eval_row(&row.to_owned_row()).await.unwrap(); + assert_eq!(result, expected.datum_at(0).to_owned_datum()); + } + } + + #[tokio::test] + async fn test_all() { + let expr = SomeAllExpression::new( + build_from_pretty("0:int4"), + build_from_pretty("$0:boolean"), + Type::All, + build_from_pretty("$1:boolean"), + ); + let (input, expected) = DataChunk::from_pretty( + "B[] B + . . + {} t + {NULL} . + {NULL,t} . + {NULL,f} f + {f,f} f + {t} t", // <- regression test for #14214 + ) + .split_column_at(1); + + // test eval + let output = expr.eval(&input).await.unwrap(); + assert_eq!(&output, expected.column_at(0)); + + // test eval_row + for (row, expected) in input.rows().zip_eq_debug(expected.rows()) { + let result = expr.eval_row(&row.to_owned_row()).await.unwrap(); + assert_eq!(result, expected.datum_at(0).to_owned_datum()); + } + } +}