Skip to content

Commit

Permalink
HashJoin partial batch emitting
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Oct 31, 2023
1 parent d24228a commit 5d22ed9
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 93 deletions.
261 changes: 169 additions & 92 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ use futures::{ready, Stream, StreamExt, TryStreamExt};

type JoinLeftData = (JoinHashMap, RecordBatch, MemoryReservation);

/// Tuple representing last matched probe-build side indices for partial join output
type JoinStreamState = (usize, usize);

/// Join execution plan executes partitions in parallel and combines them into a set of
/// partitions.
///
Expand Down Expand Up @@ -465,6 +468,8 @@ impl ExecutionPlan for HashJoinExec {
}
};

let batch_size = context.session_config().batch_size();

let reservation = MemoryConsumer::new(format!("HashJoinStream[{partition}]"))
.register(context.memory_pool());

Expand All @@ -487,6 +492,9 @@ impl ExecutionPlan for HashJoinExec {
null_equals_null: self.null_equals_null,
is_exhausted: false,
reservation,
batch_size,
join_stream_state: None,
probe_batch: None,
}))
}

Expand Down Expand Up @@ -682,6 +690,12 @@ struct HashJoinStream {
null_equals_null: bool,
/// Memory reservation
reservation: MemoryReservation,
/// Batch size
batch_size: usize,
/// Probe side index
join_stream_state: Option<JoinStreamState>,
/// Probe batch
probe_batch: Option<RecordBatch>,
}

impl RecordBatchStream for HashJoinStream {
Expand Down Expand Up @@ -734,7 +748,9 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
filter: Option<&JoinFilter>,
build_side: JoinSide,
deleted_offset: Option<usize>,
) -> Result<(UInt64Array, UInt32Array)> {
output_limit: usize,
initial_state: Option<JoinStreamState>,
) -> Result<(UInt64Array, UInt32Array, Option<JoinStreamState>)> {
let keys_values = probe_on
.iter()
.map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows())))
Expand Down Expand Up @@ -783,16 +799,39 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
// With this approach, the lexicographic order on both the probe side and the build side is preserved.
let hash_map = build_hashmap.get_map();
let next_chain = build_hashmap.get_list();
for (row, hash_value) in hash_values.iter().enumerate().rev() {

let mut output_tuples = 0 as usize;
let mut next_join_stream_state = None;

let (initial_probe, initial_build) = if let Some(state) = initial_state {
(state.0, state.1)
} else {
(probe_batch.num_rows(), 0)
};



'probe: for (row, hash_value) in hash_values.iter().take(initial_probe + 1).enumerate().rev()
{
let index = if initial_state.is_some() && row == initial_probe {
Some(initial_build as u64)
} else if let Some((_, index)) = hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash) {
Some(*index)
} else {
None
};
// Get the hash and find it in the build index

// For every item on the build and probe we check if it matches
// This possibly contains rows with hash collisions,
// So we have to check here whether rows are equal or not
if let Some((_, index)) =
hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash)
if let Some(index) = index
{
let mut i = *index - 1;
if index == 0 {
continue;
}
let mut i = index - 1;

loop {
let build_row_value = if let Some(offset) = deleted_offset {
// This arguments means that we prune the next index way before here.
Expand All @@ -806,8 +845,17 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
};
build_indices.append(build_row_value);
probe_indices.append(row as u32);

output_tuples += 1;

// Follow the chain to get the next index value
let next = next_chain[build_row_value as usize];

if output_tuples >= output_limit {
next_join_stream_state = Some((row, next as usize));
break 'probe;
}

if next == 0 {
// end of list
break;
Expand All @@ -816,6 +864,12 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
}
}
}

next_join_stream_state = match next_join_stream_state {
Some((0, 0)) => None,
_ => next_join_stream_state,
};

// Reversing both sets of indices
build_indices.as_slice_mut().reverse();
probe_indices.as_slice_mut().reverse();
Expand All @@ -837,13 +891,15 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
(left, right)
};

equal_rows_arr(
let matched_indices = equal_rows_arr(
&left,
&right,
&build_join_values,
&keys_values,
null_equals_null,
)
)?;

return Ok((matched_indices.0, matched_indices.1, next_join_stream_state));
}

// version of eq_dyn supporting equality on null arrays
Expand Down Expand Up @@ -942,107 +998,126 @@ impl HashJoinStream {
}
});
let mut hashes_buffer = vec![];
self.right
.poll_next_unpin(cx)
.map(|maybe_batch| match maybe_batch {
// one right batch in the join loop
Some(Ok(batch)) => {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
let timer = self.join_metrics.join_time.timer();

// get the matched two indices for the on condition
let left_right_indices = build_equal_condition_join_indices(
&left_data.0,
&left_data.1,
&batch,
&self.on_left,
&self.on_right,
&self.random_state,
self.null_equals_null,
&mut hashes_buffer,
self.filter.as_ref(),
JoinSide::Left,
None,
);

let result = match left_right_indices {
Ok((left_side, right_side)) => {
// set the left bitmap
// and only left, full, left semi, left anti need the left bitmap
if need_produce_result_in_final(self.join_type) {
left_side.iter().flatten().for_each(|x| {
visited_left_side.set_bit(x as usize, true);
});
}

// adjust the two side indices base on the join type
let (left_side, right_side) = adjust_indices_by_join_type(
left_side,
right_side,
batch.num_rows(),
self.join_type,
);

let result = build_batch_from_indices(
&self.schema,
&left_data.1,
&batch,
&left_side,
&right_side,
&self.column_indices,
JoinSide::Left,
);
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
Some(result)
}
Err(err) => Some(exec_err!(
"Fail to build join indices in HashJoinExec, error:{err}"
)),
};
timer.done();
result
// Fetch next probe batch
if self.probe_batch.is_none() {
match self.right.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(batch))) => {
self.probe_batch = Some(batch);
}
None => {
let timer = self.join_metrics.join_time.timer();
if need_produce_result_in_final(self.join_type) && !self.is_exhausted
{
// use the global left bitmap to produce the left indices and right indices
let (left_side, right_side) = get_final_indices_from_bit_map(
visited_left_side,
Poll::Ready(None) => {
self.probe_batch = None;
}
Poll::Ready(Some(err)) => return Poll::Ready(Some(err)),
Poll::Pending => return Poll::Pending,
}
}

let output_batch = match &self.probe_batch {
// one right batch in the join loop
Some(batch) => {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
let timer = self.join_metrics.join_time.timer();

// get the matched two indices for the on condition
let left_right_indices = build_equal_condition_join_indices(
&left_data.0,
&left_data.1,
batch,
&self.on_left,
&self.on_right,
&self.random_state,
self.null_equals_null,
&mut hashes_buffer,
self.filter.as_ref(),
JoinSide::Left,
None,
self.batch_size,
self.join_stream_state,
);

let result = match left_right_indices {
Ok((left_side, right_side, next_state)) => {
// set the left bitmap
// and only left, full, left semi, left anti need the left bitmap
if need_produce_result_in_final(self.join_type) {
left_side.iter().flatten().for_each(|x| {
visited_left_side.set_bit(x as usize, true);
});
}

// adjust the two side indices base on the join type
let (left_side, right_side) = adjust_indices_by_join_type(
left_side,
right_side,
batch.num_rows(),
self.join_type,
);
let empty_right_batch =
RecordBatch::new_empty(self.right.schema());
// use the left and right indices to produce the batch result

let result = build_batch_from_indices(
&self.schema,
&left_data.1,
&empty_right_batch,
batch,
&left_side,
&right_side,
&self.column_indices,
JoinSide::Left,
);
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());

if let Ok(ref batch) = result {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
if next_state.is_none() {
self.probe_batch = None;
};
self.join_stream_state = next_state;

self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
}
timer.done();
self.is_exhausted = true;
Some(result)
} else {
// end of the join loop
None
}
Err(err) => Some(exec_err!(
"Fail to build join indices in HashJoinExec, error:{err}"
)),
};

timer.done();
result
}
None => {
let timer = self.join_metrics.join_time.timer();
if need_produce_result_in_final(self.join_type) && !self.is_exhausted {
// use the global left bitmap to produce the left indices and right indices
let (left_side, right_side) =
get_final_indices_from_bit_map(visited_left_side, self.join_type);
let empty_right_batch = RecordBatch::new_empty(self.right.schema());
// use the left and right indices to produce the batch result
let result = build_batch_from_indices(
&self.schema,
&left_data.1,
&empty_right_batch,
&left_side,
&right_side,
&self.column_indices,
JoinSide::Left,
);

if let Ok(ref batch) = result {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());

self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
}
timer.done();
self.is_exhausted = true;
Some(result)
} else {
// end of the join loop
None
}
Some(err) => Some(err),
})
}
};

Poll::Ready(output_batch)
}
}

Expand Down Expand Up @@ -2406,7 +2481,7 @@ mod tests {
},
left,
);
let (l, r) = build_equal_condition_join_indices(
let (l, r, _) = build_equal_condition_join_indices(
&left_data.0,
&left_data.1,
&right,
Expand All @@ -2418,6 +2493,8 @@ mod tests {
None,
JoinSide::Left,
None,
64,
None,
)?;

let mut left_ids = UInt64Builder::with_capacity(0);
Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-plan/src/joins/symmetric_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ pub(crate) fn join_with_probe_batch(
if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
return Ok(None);
}
let (build_indices, probe_indices) = build_equal_condition_join_indices(
let (build_indices, probe_indices, _) = build_equal_condition_join_indices(
&build_hash_joiner.hashmap,
&build_hash_joiner.input_buffer,
probe_batch,
Expand All @@ -830,6 +830,8 @@ pub(crate) fn join_with_probe_batch(
filter,
build_hash_joiner.build_side,
Some(build_hash_joiner.deleted_offset),
usize::MAX,
None,
)?;
if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
record_visited_indices(
Expand Down

0 comments on commit 5d22ed9

Please sign in to comment.