Skip to content

Commit

Permalink
HashJoin partial batch emitting
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Nov 1, 2023
1 parent d24228a commit 3d5e612
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 117 deletions.
285 changes: 190 additions & 95 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ use futures::{ready, Stream, StreamExt, TryStreamExt};

type JoinLeftData = (JoinHashMap, RecordBatch, MemoryReservation);

/// Tuple representing last matched probe-build side indices for partial join output
type JoinStreamState = (usize, usize);

/// Join execution plan executes partitions in parallel and combines them into a set of
/// partitions.
///
Expand Down Expand Up @@ -465,6 +468,8 @@ impl ExecutionPlan for HashJoinExec {
}
};

let batch_size = context.session_config().batch_size();

let reservation = MemoryConsumer::new(format!("HashJoinStream[{partition}]"))
.register(context.memory_pool());

Expand All @@ -487,6 +492,9 @@ impl ExecutionPlan for HashJoinExec {
null_equals_null: self.null_equals_null,
is_exhausted: false,
reservation,
batch_size,
join_stream_state: None,
probe_batch: None,
}))
}

Expand Down Expand Up @@ -682,6 +690,12 @@ struct HashJoinStream {
null_equals_null: bool,
/// Memory reservation
reservation: MemoryReservation,
/// Batch size
batch_size: usize,
/// Probe side index
join_stream_state: Option<JoinStreamState>,
/// Probe batch
probe_batch: Option<RecordBatch>,
}

impl RecordBatchStream for HashJoinStream {
Expand Down Expand Up @@ -734,7 +748,9 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
filter: Option<&JoinFilter>,
build_side: JoinSide,
deleted_offset: Option<usize>,
) -> Result<(UInt64Array, UInt32Array)> {
output_limit: usize,
initial_state: Option<JoinStreamState>,
) -> Result<(UInt64Array, UInt32Array, Option<JoinStreamState>)> {
let keys_values = probe_on
.iter()
.map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())))
Expand Down Expand Up @@ -783,16 +799,41 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
// With this approach, the lexicographic order on both the probe side and the build side is preserved.
let hash_map = build_hashmap.get_map();
let next_chain = build_hashmap.get_list();
for (row, hash_value) in hash_values.iter().enumerate().rev() {

let mut output_tuples = 0 as usize;
let mut last_join_stream_state = None;

let (initial_probe, initial_build) = if let Some(state) = initial_state {
(state.0, state.1)
} else {
(0, 0)
};

'probe: for (row, hash_value) in hash_values.iter().enumerate().skip(initial_probe)
{
let index = if initial_state.is_some() && row == initial_probe {
let next = next_chain[initial_build as usize];
if next == 0 {
continue;
}
Some(next)
} else if let Some((_, index)) = hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) {
Some(*index)
} else {
None
};
// Get the hash and find it in the build index

// For every item on the build and probe we check if it matches
// This possibly contains rows with hash collisions,
// So we have to check here whether rows are equal or not
if let Some((_, index)) =
hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash)
if let Some(index) = index
{
let mut i = *index - 1;
if index == 0 {
continue;
}
let mut i = index - 1;

loop {
let build_row_value = if let Some(offset) = deleted_offset {
// This arguments means that we prune the next index way before here.
Expand All @@ -806,8 +847,17 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
};
build_indices.append(build_row_value);
probe_indices.append(row as u32);

output_tuples += 1;

if output_tuples >= output_limit {
last_join_stream_state = Some((row, i as usize));
break 'probe;
}

// Follow the chain to get the next index value
let next = next_chain[build_row_value as usize];

if next == 0 {
// end of list
break;
Expand All @@ -816,9 +866,18 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
}
}
}
// Reversing both sets of indices
build_indices.as_slice_mut().reverse();
probe_indices.as_slice_mut().reverse();

// finalizing state -- if probe_batch scanned & no more records left on build-side -- return None
last_join_stream_state = match last_join_stream_state {
Some(state) => {
if state.0 == probe_batch.num_rows() - 1 && state.1 == 0 {
None
} else {
last_join_stream_state
}
}
_ => last_join_stream_state,
};

let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None);
let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None);
Expand All @@ -837,13 +896,15 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
(left, right)
};

equal_rows_arr(
let matched_indices = equal_rows_arr(
&left,
&right,
&build_join_values,
&keys_values,
null_equals_null,
)
)?;

return Ok((matched_indices.0, matched_indices.1, last_join_stream_state));
}

// version of eq_dyn supporting equality on null arrays
Expand Down Expand Up @@ -942,107 +1003,139 @@ impl HashJoinStream {
}
});
let mut hashes_buffer = vec![];
self.right
.poll_next_unpin(cx)
.map(|maybe_batch| match maybe_batch {
// one right batch in the join loop
Some(Ok(batch)) => {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
let timer = self.join_metrics.join_time.timer();

// get the matched two indices for the on condition
let left_right_indices = build_equal_condition_join_indices(
&left_data.0,
&left_data.1,
&batch,
&self.on_left,
&self.on_right,
&self.random_state,
self.null_equals_null,
&mut hashes_buffer,
self.filter.as_ref(),
JoinSide::Left,
None,
);

let result = match left_right_indices {
Ok((left_side, right_side)) => {
// set the left bitmap
// and only left, full, left semi, left anti need the left bitmap
if need_produce_result_in_final(self.join_type) {
left_side.iter().flatten().for_each(|x| {
visited_left_side.set_bit(x as usize, true);
});
}

// adjust the two side indices base on the join type
let (left_side, right_side) = adjust_indices_by_join_type(
left_side,
right_side,
batch.num_rows(),
self.join_type,
);

let result = build_batch_from_indices(
&self.schema,
&left_data.1,
&batch,
&left_side,
&right_side,
&self.column_indices,
JoinSide::Left,
);
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
Some(result)
}
Err(err) => Some(exec_err!(
"Fail to build join indices in HashJoinExec, error:{err}"
)),
};
timer.done();
result
// Fetch next probe batch
if self.probe_batch.is_none() {
match self.right.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(batch))) => {
self.probe_batch = Some(batch);
}
Poll::Ready(None) => {
self.probe_batch = None;
}
None => {
let timer = self.join_metrics.join_time.timer();
if need_produce_result_in_final(self.join_type) && !self.is_exhausted
{
// use the global left bitmap to produce the left indices and right indices
let (left_side, right_side) = get_final_indices_from_bit_map(
visited_left_side,
Poll::Ready(Some(err)) => return Poll::Ready(Some(err)),
Poll::Pending => return Poll::Pending,
}
}

let output_batch = match &self.probe_batch {
// one right batch in the join loop
Some(batch) => {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
let timer = self.join_metrics.join_time.timer();

// get the matched two indices for the on condition
let left_right_indices = build_equal_condition_join_indices(
&left_data.0,
&left_data.1,
batch,
&self.on_left,
&self.on_right,
&self.random_state,
self.null_equals_null,
&mut hashes_buffer,
self.filter.as_ref(),
JoinSide::Left,
None,
self.batch_size,
self.join_stream_state,
);

let result = match left_right_indices {
Ok((left_side, right_side, next_state)) => {
// println!("left indices: {:?}", left_side);
// println!("right indices: {:?}", right_side);
// println!("next state: {:?}", next_state);
// println!("curr state: {:?}", self.join_stream_state);
// set the left bitmap
// and only left, full, left semi, left anti need the left bitmap
if need_produce_result_in_final(self.join_type) {
left_side.iter().flatten().for_each(|x| {
visited_left_side.set_bit(x as usize, true);
});
}

// adjust the two side indices base on the join type
// as next_state (if some) always contains last joined probe tuple
// due to this fact if join_stream_state is some - it also contains matched once index - no need to adjust it
let adjust_range = match (self.join_stream_state, next_state) {
(None, None) => 0..batch.num_rows(),
(None, Some((range_end, _))) => 0..range_end+1,
(Some((range_start, _)), None) => range_start+1..batch.num_rows(),
(Some((range_start, _)), Some((range_end, _))) => range_start+1..range_end+1,
};

let (left_side, right_side) = adjust_indices_by_join_type(
left_side,
right_side,
adjust_range,
self.join_type,
);
let empty_right_batch =
RecordBatch::new_empty(self.right.schema());
// use the left and right indices to produce the batch result

let result = build_batch_from_indices(
&self.schema,
&left_data.1,
&empty_right_batch,
batch,
&left_side,
&right_side,
&self.column_indices,
JoinSide::Left,
);
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());

if let Ok(ref batch) = result {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
if next_state.is_none() {
self.probe_batch = None;
};
self.join_stream_state = next_state;

self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
}
timer.done();
self.is_exhausted = true;
Some(result)
} else {
// end of the join loop
None
}
Err(err) => Some(exec_err!(
"Fail to build join indices in HashJoinExec, error:{err}"
)),
};

timer.done();
result
}
None => {
let timer = self.join_metrics.join_time.timer();
if need_produce_result_in_final(self.join_type) && !self.is_exhausted {
// use the global left bitmap to produce the left indices and right indices
let (left_side, right_side) =
get_final_indices_from_bit_map(visited_left_side, self.join_type);
let empty_right_batch = RecordBatch::new_empty(self.right.schema());
// use the left and right indices to produce the batch result
let result = build_batch_from_indices(
&self.schema,
&left_data.1,
&empty_right_batch,
&left_side,
&right_side,
&self.column_indices,
JoinSide::Left,
);

if let Ok(ref batch) = result {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());

self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
}
timer.done();
self.is_exhausted = true;
Some(result)
} else {
// end of the join loop
None
}
Some(err) => Some(err),
})
}
};

Poll::Ready(output_batch)
}
}

Expand Down Expand Up @@ -2406,7 +2499,7 @@ mod tests {
},
left,
);
let (l, r) = build_equal_condition_join_indices(
let (l, r, _) = build_equal_condition_join_indices(
&left_data.0,
&left_data.1,
&right,
Expand All @@ -2418,6 +2511,8 @@ mod tests {
None,
JoinSide::Left,
None,
64,
None,
)?;

let mut left_ids = UInt64Builder::with_capacity(0);
Expand Down
Loading

0 comments on commit 3d5e612

Please sign in to comment.