From a487a0e132696e471f1a492757d67db036ae04e1 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Thu, 2 Nov 2023 11:28:58 +0300 Subject: [PATCH] batch splitting tests --- .../physical-plan/src/joins/hash_join.rs | 140 +++++++++++++++++- .../src/joins/nested_loop_join.rs | 6 +- datafusion/physical-plan/src/joins/utils.rs | 60 ++++---- 3 files changed, 174 insertions(+), 32 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index a20e643b2e4a3..265315d2555e7 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1266,7 +1266,9 @@ mod tests { use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue, + }; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::Literal; use hashbrown::raw::RawTable; @@ -2973,6 +2975,142 @@ mod tests { } } + #[tokio::test] + async fn join_splitted_batch() { + let left = build_table( + ("a1", &vec![1, 2, 3, 4]), + ("b1", &vec![1, 1, 1, 1]), + ("c1", &vec![0, 0, 0, 0]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40, 50]), + ("b2", &vec![1, 1, 1, 1, 1]), + ("c2", &vec![0, 0, 0, 0, 0]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::RightSemi, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + let expected_resultset_records = 20; + let common_result = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 4 | 1 | 0 | 10 | 1 | 0 |", + "| 3 | 1 | 0 | 10 | 1 | 0 |", + "| 2 | 1 | 0 | 10 | 1 | 0 |", + "| 1 | 1 | 0 | 10 | 1 | 0 |", + "| 4 | 1 | 0 | 20 | 1 | 0 |", + "| 3 | 1 | 0 | 20 | 1 | 0 |", + "| 2 | 1 | 0 | 20 | 1 | 0 |", + "| 1 | 1 | 0 | 20 | 1 | 0 |", + "| 4 | 1 | 0 | 30 | 1 | 0 |", + "| 3 | 1 | 0 | 30 | 1 | 0 |", + "| 2 | 1 | 0 | 30 | 1 | 0 |", + "| 1 | 1 | 0 | 30 | 1 | 0 |", + "| 4 | 1 | 0 | 40 | 1 | 0 |", + "| 3 | 1 | 0 | 40 | 1 | 0 |", + "| 2 | 1 | 0 | 40 | 1 | 0 |", + "| 1 | 1 | 0 | 40 | 1 | 0 |", + "| 4 | 1 | 0 | 50 | 1 | 0 |", + "| 3 | 1 | 0 | 50 | 1 | 0 |", + "| 2 | 1 | 0 | 50 | 1 | 0 |", + "| 1 | 1 | 0 | 50 | 1 | 0 |", + "+----+----+----+----+----+----+", + ]; + let left_batch = [ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 1 | 0 |", + "| 2 | 1 | 0 |", + "| 3 | 1 | 0 |", + "| 4 | 1 | 0 |", + "+----+----+----+", + ]; + let right_batch = [ + "+----+----+----+", + "| a2 | b2 | c2 |", + "+----+----+----+", + "| 10 | 1 | 0 |", + "| 20 | 1 | 0 |", + "| 30 | 1 | 0 |", + "| 40 | 1 | 0 |", + "| 50 | 1 | 0 |", + "+----+----+----+", + ]; + let right_empty = [ + "+----+----+----+", + "| a2 | b2 | c2 |", + "+----+----+----+", + "+----+----+----+", + ]; + let left_empty = [ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "+----+----+----+", + ]; + + // validation of partial join results output for different batch_size setting + for join_type in join_types { + for batch_size in (1..21).rev() { + let session_config = SessionConfig::default().with_batch_size(batch_size); + let task_ctx = TaskContext::default().with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + + let join = + join(left.clone(), right.clone(), on.clone(), &join_type, false) + .unwrap(); + + let stream = join.execute(0, task_ctx).unwrap(); + let batches = common::collect(stream).await.unwrap(); + + // For inner/right join expected batch count equals ceil_div result, + // as there is no need to append non-joined build side data. + // For other join types it'll be ceil_div + 1 -- for additional batch + // containing not visited build side rows (empty in this test case). + let expected_batch_count = match join_type { + JoinType::Inner + | JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti => { + (expected_resultset_records + batch_size - 1) / batch_size + } + _ => (expected_resultset_records + batch_size - 1) / batch_size + 1, + }; + assert_eq!( + batches.len(), + expected_batch_count, + "expected {} output batches for {} join with batch_size = {}", + expected_batch_count, + join_type, + batch_size + ); + + let expected = match join_type { + JoinType::RightSemi => right_batch.to_vec(), + JoinType::RightAnti => right_empty.to_vec(), + JoinType::LeftSemi => left_batch.to_vec(), + JoinType::LeftAnti => left_empty.to_vec(), + _ => common_result.to_vec(), + }; + assert_batches_eq!(expected, &batches); + } + } + } + #[tokio::test] async fn single_partition_join_overallocation() -> Result<()> { let left = build_table( diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 73fd5c1caec77..e629ab59278a6 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -648,20 +648,20 @@ fn adjust_indices_by_join_type( // matched // unmatched left row will be produced in this batch let left_unmatched_indices = - get_anti_u64_indices(count_left_batch, &left_indices); + get_anti_u64_indices(0..count_left_batch, &left_indices); // combine the matched and unmatched left result together append_left_indices(left_indices, right_indices, left_unmatched_indices) } JoinType::LeftSemi => { // need to remove the duplicated record in the left side - let left_indices = get_semi_u64_indices(count_left_batch, &left_indices); + let left_indices = get_semi_u64_indices(0..count_left_batch, &left_indices); // the right_indices will not be used later for the `left semi` join (left_indices, right_indices) } JoinType::LeftAnti => { // need to remove the duplicated record in the left side // get the anti index for the left side - let left_indices = get_anti_u64_indices(count_left_batch, &left_indices); + let left_indices = get_anti_u64_indices(0..count_left_batch, &left_indices); // the right_indices will not be used later for the `left anti` join (left_indices, right_indices) } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 01d43c2b51b10..8edeae6fef801 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -921,81 +921,85 @@ pub(crate) fn append_right_indices( /// Get unmatched and deduplicated indices for specified range of indices pub(crate) fn get_anti_indices( - rg: Range, + range: Range, input_indices: &UInt32Array, ) -> UInt32Array { - let mut bitmap = BooleanBufferBuilder::new(rg.len()); - bitmap.append_n(rg.len(), false); + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); input_indices .iter() .flatten() .map(|v| v as usize) - .filter(|v| rg.contains(v)) + .filter(|v| range.contains(v)) .for_each(|v| { - bitmap.set_bit(v - rg.start, true); + bitmap.set_bit(v - range.start, true); }); - let offset = rg.start; + let offset = range.start; // get the anti index - (rg).filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u32)) + (range).filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u32)) .collect::() } /// Get unmatched and deduplicated indices pub(crate) fn get_anti_u64_indices( - row_count: usize, + range: Range, input_indices: &UInt64Array, ) -> UInt64Array { - let mut bitmap = BooleanBufferBuilder::new(row_count); - bitmap.append_n(row_count, false); - input_indices.iter().flatten().for_each(|v| { - bitmap.set_bit(v as usize, true); + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); + input_indices.iter().flatten().map(|v| v as usize).filter(|v| range.contains(v)).for_each(|v| { + bitmap.set_bit(v - range.start, true); }); + let offset = range.start; + // get the anti index - (0..row_count) - .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u64)) + (range) + .filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u64)) .collect::() } /// Get matched and deduplicated indices for specified range of indices pub(crate) fn get_semi_indices( - rg: Range, + range: Range, input_indices: &UInt32Array, ) -> UInt32Array { - let mut bitmap = BooleanBufferBuilder::new(rg.len()); - bitmap.append_n(rg.len(), false); + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); input_indices .iter() .flatten() .map(|v| v as usize) - .filter(|v| rg.contains(v)) + .filter(|v| range.contains(v)) .for_each(|v| { - bitmap.set_bit(v - rg.start, true); + bitmap.set_bit(v - range.start, true); }); - let offset = rg.start; + let offset = range.start; // get the semi index - (rg).filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u32)) + (range).filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u32)) .collect::() } /// Get matched and deduplicated indices pub(crate) fn get_semi_u64_indices( - row_count: usize, + range: Range, input_indices: &UInt64Array, ) -> UInt64Array { - let mut bitmap = BooleanBufferBuilder::new(row_count); - bitmap.append_n(row_count, false); - input_indices.iter().flatten().for_each(|v| { - bitmap.set_bit(v as usize, true); + let mut bitmap = BooleanBufferBuilder::new(range.len()); + bitmap.append_n(range.len(), false); + input_indices.iter().flatten().map(|v| v as usize).filter(|v| range.contains(v)).for_each(|v| { + bitmap.set_bit(v - range.start, true); }); + let offset = range.start; + // get the semi index - (0..row_count) - .filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u64)) + (range) + .filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u64)) .collect::() }