diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 1c2d8ece2f360..70437b9b10f6b 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -27,6 +27,7 @@ use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::PhysicalExprRef; +use itertools::Itertools; use rand::Rng; use datafusion::common::JoinSide; @@ -225,15 +226,13 @@ async fn test_semi_join_1k() { #[tokio::test] async fn test_semi_join_1k_filtered() { - // NLJ vs HJ gives wrong result - // Tracked in https://github.com/apache/datafusion/issues/11537 JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::HjSmj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -292,27 +291,6 @@ impl JoinFuzzTestCase { } } - fn column_indices(&self) -> Vec { - vec![ - ColumnIndex { - index: 0, - side: JoinSide::Left, - }, - ColumnIndex { - index: 1, - side: JoinSide::Left, - }, - ColumnIndex { - index: 0, - side: JoinSide::Right, - }, - ColumnIndex { - index: 1, - side: JoinSide::Right, - }, - ] - } - fn on_columns(&self) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { let schema1 = self.input1[0].schema(); let schema2 = self.input2[0].schema(); @@ -328,10 +306,20 @@ impl JoinFuzzTestCase { ] } + /// Helper function for building NLJoin filter, returning intermediate + /// schema as a union of origin filter intermediate schema and + /// on-condition schema fn intermediate_schema(&self) -> Schema { + let filter_schema = if let Some(filter) = self.join_filter() { + filter.schema().to_owned() + } else { + Schema::empty() + }; + let schema1 = self.input1[0].schema(); let schema2 = self.input2[0].schema(); - Schema::new(vec![ + + let on_schema = Schema::new(vec![ schema1 .field_with_name("a") .unwrap() @@ -344,7 +332,81 @@ impl JoinFuzzTestCase { .with_nullable(true), schema2.field_with_name("a").unwrap().to_owned(), schema2.field_with_name("b").unwrap().to_owned(), - ]) + ]); + + Schema::new( + filter_schema + .fields + .into_iter() + .cloned() + .chain(on_schema.fields.into_iter().cloned()) + .collect_vec(), + ) + } + + /// Helper function for building NLJoin filter, returns the union + /// of original filter expression and on-condition expression + fn composite_filter_expression(&self) -> PhysicalExprRef { + let (filter_expression, column_idx_offset) = + if let Some(filter) = self.join_filter() { + ( + filter.expression().to_owned(), + filter.schema().fields().len(), + ) + } else { + (Arc::new(Literal::new(ScalarValue::from(true))) as _, 0) + }; + + let equal_a = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", column_idx_offset)), + Operator::Eq, + Arc::new(Column::new("a", column_idx_offset + 2)), + )); + let equal_b = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", column_idx_offset + 1)), + Operator::Eq, + Arc::new(Column::new("b", column_idx_offset + 3)), + )); + let on_expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)); + + Arc::new(BinaryExpr::new( + filter_expression, + Operator::And, + on_expression, + )) + } + + /// Helper function for building NLJoin filter, returning the union + /// of original filter column indices and on-condition column indices. + /// Result must match intermediate schema. + fn column_indices(&self) -> Vec { + let mut column_indices = if let Some(filter) = self.join_filter() { + filter.column_indices().to_vec() + } else { + vec![] + }; + + let on_column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ]; + + column_indices.extend(on_column_indices); + column_indices } fn left_right(&self) -> (Arc, Arc) { @@ -400,26 +462,15 @@ impl JoinFuzzTestCase { fn nested_loop_join(&self) -> Arc { let (left, right) = self.left_right(); - // Nested loop join uses filter for joining records + let column_indices = self.column_indices(); let intermediate_schema = self.intermediate_schema(); + let expression = self.composite_filter_expression(); - let equal_a = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Eq, - Arc::new(Column::new("a", 2)), - )) as _; - let equal_b = Arc::new(BinaryExpr::new( - Arc::new(Column::new("b", 1)), - Operator::Eq, - Arc::new(Column::new("b", 3)), - )) as _; - let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _; - - let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema); + let filter = JoinFilter::new(expression, column_indices, intermediate_schema); Arc::new( - NestedLoopJoinExec::try_new(left, right, Some(on_filter), &self.join_type) + NestedLoopJoinExec::try_new(left, right, Some(filter), &self.join_type) .unwrap(), ) }