Skip to content

Commit

Permalink
StreamJoinStateResult to StatefulStreamResult
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Dec 12, 2023
1 parent 349b6f5 commit 130c1ff
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 55 deletions.
20 changes: 10 additions & 10 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use crate::{
joins::utils::{
adjust_right_output_partitioning, build_join_schema, check_join_is_valid,
estimate_join_statistics, partitioned_join_output_partitioning,
BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, StreamJoinStateResult,
BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, StatefulStreamResult,
},
metrics::{ExecutionPlanMetricsSet, MetricsSet},
DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
Expand Down Expand Up @@ -1185,7 +1185,7 @@ impl HashJoinStream {
fn collect_build_side(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<StreamJoinStateResult<Option<RecordBatch>>>> {
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let build_timer = self.join_metrics.build_time.timer();
// build hash table from left (build) side, if not yet done
let left_data = ready!(self
Expand Down Expand Up @@ -1225,7 +1225,7 @@ impl HashJoinStream {
visited_left_side,
});

Poll::Ready(Ok(StreamJoinStateResult::Continue))
Poll::Ready(Ok(StatefulStreamResult::Continue))
}

/// Fetches next batch from probe-side
Expand All @@ -1235,7 +1235,7 @@ impl HashJoinStream {
fn fetch_probe_batch(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<StreamJoinStateResult<Option<RecordBatch>>>> {
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
match ready!(self.right.poll_next_unpin(cx)) {
None => {
self.state = HashJoinStreamState::ExhaustedProbeSide;
Expand All @@ -1249,15 +1249,15 @@ impl HashJoinStream {
Some(Err(err)) => return Poll::Ready(Err(err)),
};

Poll::Ready(Ok(StreamJoinStateResult::Continue))
Poll::Ready(Ok(StatefulStreamResult::Continue))
}

/// Joins current probe batch with build-side data and produces batch with matched output
///
/// Updates state to `FetchProbeBatch`
fn process_probe_batch(
&mut self,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
let state = self.state.try_as_process_probe_batch()?;
let build_side = self.build_side_state.try_into_ready_mut()?;

Expand Down Expand Up @@ -1320,21 +1320,21 @@ impl HashJoinStream {

self.state = HashJoinStreamState::FetchProbeBatch;

Ok(StreamJoinStateResult::Ready(Some(result?)))
Ok(StatefulStreamResult::Ready(Some(result?)))
}

/// Processes unmatched build-side rows for certain join types and produces output batch
///
/// Updates state to `Completed`
fn process_unmatched_build_batch(
&mut self,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
let timer = self.join_metrics.join_time.timer();

if !need_produce_result_in_final(self.join_type) {
self.state = HashJoinStreamState::Completed;

return Ok(StreamJoinStateResult::Continue);
return Ok(StatefulStreamResult::Continue);
}

let build_side = self.build_side_state.try_into_ready()?;
Expand Down Expand Up @@ -1365,7 +1365,7 @@ impl HashJoinStream {

self.state = HashJoinStreamState::Completed;

Ok(StreamJoinStateResult::Ready(Some(result?)))
Ok(StatefulStreamResult::Ready(Some(result?)))
}
}

Expand Down
50 changes: 25 additions & 25 deletions datafusion/physical-plan/src/joins/stream_join_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::sync::Arc;
use std::task::{Context, Poll};
use std::usize;

use crate::joins::utils::{JoinFilter, JoinHashMapType, StreamJoinStateResult};
use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult};
use crate::{handle_async_state, handle_state};

use arrow::compute::concat_batches;
Expand Down Expand Up @@ -751,10 +751,10 @@ pub trait EagerJoinStream {
///
/// # Returns
///
/// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state result after pulling the batch.
/// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
async fn fetch_next_from_right_stream(
&mut self,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
match self.right_stream().next().await {
Some(Ok(batch)) => {
self.set_state(EagerJoinStreamState::PullLeft);
Expand All @@ -763,7 +763,7 @@ pub trait EagerJoinStream {
Some(Err(e)) => Err(e),
None => {
self.set_state(EagerJoinStreamState::RightExhausted);
Ok(StreamJoinStateResult::Continue)
Ok(StatefulStreamResult::Continue)
}
}
}
Expand All @@ -776,10 +776,10 @@ pub trait EagerJoinStream {
///
/// # Returns
///
/// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state result after pulling the batch.
/// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
async fn fetch_next_from_left_stream(
&mut self,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
match self.left_stream().next().await {
Some(Ok(batch)) => {
self.set_state(EagerJoinStreamState::PullRight);
Expand All @@ -788,7 +788,7 @@ pub trait EagerJoinStream {
Some(Err(e)) => Err(e),
None => {
self.set_state(EagerJoinStreamState::LeftExhausted);
Ok(StreamJoinStateResult::Continue)
Ok(StatefulStreamResult::Continue)
}
}
}
Expand All @@ -802,18 +802,18 @@ pub trait EagerJoinStream {
///
/// # Returns
///
/// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
/// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
async fn handle_right_stream_end(
&mut self,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
match self.left_stream().next().await {
Some(Ok(batch)) => self.process_batch_after_right_end(batch),
Some(Err(e)) => Err(e),
None => {
self.set_state(EagerJoinStreamState::BothExhausted {
final_result: false,
});
Ok(StreamJoinStateResult::Continue)
Ok(StatefulStreamResult::Continue)
}
}
}
Expand All @@ -827,18 +827,18 @@ pub trait EagerJoinStream {
///
/// # Returns
///
/// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
/// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
async fn handle_left_stream_end(
&mut self,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
match self.right_stream().next().await {
Some(Ok(batch)) => self.process_batch_after_left_end(batch),
Some(Err(e)) => Err(e),
None => {
self.set_state(EagerJoinStreamState::BothExhausted {
final_result: false,
});
Ok(StreamJoinStateResult::Continue)
Ok(StatefulStreamResult::Continue)
}
}
}
Expand All @@ -851,10 +851,10 @@ pub trait EagerJoinStream {
///
/// # Returns
///
/// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state result after both streams are exhausted.
/// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after both streams are exhausted.
fn prepare_for_final_results_after_exhaustion(
&mut self,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.set_state(EagerJoinStreamState::BothExhausted { final_result: true });
self.process_batches_before_finalization()
}
Expand All @@ -867,11 +867,11 @@ pub trait EagerJoinStream {
///
/// # Returns
///
/// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state result after processing the batch.
/// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after processing the batch.
fn process_batch_from_right(
&mut self,
batch: RecordBatch,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
) -> Result<StatefulStreamResult<Option<RecordBatch>>>;

/// Handles a pulled batch from the left stream.
///
Expand All @@ -881,11 +881,11 @@ pub trait EagerJoinStream {
///
/// # Returns
///
/// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state result after processing the batch.
/// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after processing the batch.
fn process_batch_from_left(
&mut self,
batch: RecordBatch,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
) -> Result<StatefulStreamResult<Option<RecordBatch>>>;

/// Handles the situation when only the left stream is exhausted.
///
Expand All @@ -895,11 +895,11 @@ pub trait EagerJoinStream {
///
/// # Returns
///
/// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state result after the left stream is exhausted.
/// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after the left stream is exhausted.
fn process_batch_after_left_end(
&mut self,
right_batch: RecordBatch,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
) -> Result<StatefulStreamResult<Option<RecordBatch>>>;

/// Handles the situation when only the right stream is exhausted.
///
Expand All @@ -909,20 +909,20 @@ pub trait EagerJoinStream {
///
/// # Returns
///
/// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state result after the right stream is exhausted.
/// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after the right stream is exhausted.
fn process_batch_after_right_end(
&mut self,
left_batch: RecordBatch,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
) -> Result<StatefulStreamResult<Option<RecordBatch>>>;

/// Handles the final state after both streams are exhausted.
///
/// # Returns
///
/// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The final state result after processing.
/// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The final state result after processing.
fn process_batches_before_finalization(
&mut self,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
) -> Result<StatefulStreamResult<Option<RecordBatch>>>;

/// Provides mutable access to the right stream.
///
Expand Down
24 changes: 12 additions & 12 deletions datafusion/physical-plan/src/joins/symmetric_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use crate::joins::stream_join_utils::{
use crate::joins::utils::{
build_batch_from_indices, build_join_schema, check_join_is_valid,
partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex, JoinFilter,
JoinOn, StreamJoinStateResult,
JoinOn, StatefulStreamResult,
};
use crate::{
expressions::{Column, PhysicalSortExpr},
Expand Down Expand Up @@ -1014,48 +1014,48 @@ impl EagerJoinStream for SymmetricHashJoinStream {
fn process_batch_from_right(
&mut self,
batch: RecordBatch,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.perform_join_for_given_side(batch, JoinSide::Right)
.map(|maybe_batch| {
if maybe_batch.is_some() {
StreamJoinStateResult::Ready(maybe_batch)
StatefulStreamResult::Ready(maybe_batch)
} else {
StreamJoinStateResult::Continue
StatefulStreamResult::Continue
}
})
}

fn process_batch_from_left(
&mut self,
batch: RecordBatch,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.perform_join_for_given_side(batch, JoinSide::Left)
.map(|maybe_batch| {
if maybe_batch.is_some() {
StreamJoinStateResult::Ready(maybe_batch)
StatefulStreamResult::Ready(maybe_batch)
} else {
StreamJoinStateResult::Continue
StatefulStreamResult::Continue
}
})
}

fn process_batch_after_left_end(
&mut self,
right_batch: RecordBatch,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.process_batch_from_right(right_batch)
}

fn process_batch_after_right_end(
&mut self,
left_batch: RecordBatch,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.process_batch_from_left(left_batch)
}

fn process_batches_before_finalization(
&mut self,
) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
// Get the left side results:
let left_result = build_side_determined_results(
&self.left,
Expand Down Expand Up @@ -1083,9 +1083,9 @@ impl EagerJoinStream for SymmetricHashJoinStream {
// Update the metrics:
self.metrics.output_batches.add(1);
self.metrics.output_rows.add(batch.num_rows());
return Ok(StreamJoinStateResult::Ready(result));
return Ok(StatefulStreamResult::Ready(result));
}
Ok(StreamJoinStateResult::Continue)
Ok(StatefulStreamResult::Continue)
}

fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
Expand Down
16 changes: 8 additions & 8 deletions datafusion/physical-plan/src/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1296,29 +1296,29 @@ pub fn prepare_sorted_exprs(

/// The `handle_state` macro is designed to process the result of a state-changing
/// operation, encountered e.g. in implementations of `EagerJoinStream`. It
/// operates on a `StreamJoinStateResult` by matching its variants and executing
/// operates on a `StatefulStreamResult` by matching its variants and executing
/// corresponding actions. This macro is used to streamline code that deals with
/// state transitions, reducing boilerplate and improving readability.
///
/// # Cases
///
/// - `Ok(StreamJoinStateResult::Continue)`: Continues the loop, indicating the
/// - `Ok(StatefulStreamResult::Continue)`: Continues the loop, indicating the
/// stream join operation should proceed to the next step.
/// - `Ok(StreamJoinStateResult::Ready(result))`: Returns a `Poll::Ready` with the
/// - `Ok(StatefulStreamResult::Ready(result))`: Returns a `Poll::Ready` with the
/// result, either yielding a value or indicating the stream is awaiting more
/// data.
/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue
/// during the stream join operation.
///
/// # Arguments
///
/// * `$match_case`: An expression that evaluates to a `Result<StreamJoinStateResult<_>>`.
/// * `$match_case`: An expression that evaluates to a `Result<StatefulStreamResult<_>>`.
#[macro_export]
macro_rules! handle_state {
($match_case:expr) => {
match $match_case {
Ok(StreamJoinStateResult::Continue) => continue,
Ok(StreamJoinStateResult::Ready(result)) => {
Ok(StatefulStreamResult::Continue) => continue,
Ok(StatefulStreamResult::Ready(result)) => {
Poll::Ready(Ok(result).transpose())
}
Err(e) => Poll::Ready(Some(Err(e))),
Expand All @@ -1335,7 +1335,7 @@ macro_rules! handle_state {
/// # Arguments
///
/// * `$state_func`: An async function or future that returns a
/// `Result<StreamJoinStateResult<_>>`.
/// `Result<StatefulStreamResult<_>>`.
/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`.
///
#[macro_export]
Expand All @@ -1356,7 +1356,7 @@ macro_rules! handle_async_state {
/// processing or more data. When this variant is returned, it typically means that the
/// current invocation of the state did not produce a final result, and the operation
/// should be invoked again later with more data and possibly with a different state.
pub enum StreamJoinStateResult<T> {
pub enum StatefulStreamResult<T> {
Ready(T),
Continue,
}
Expand Down

0 comments on commit 130c1ff

Please sign in to comment.