Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(streaming): improve hash join error message #14515

Merged
merged 4 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/common/src/row/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ impl<R: Row> Hash for Project<'_, R> {
}
}

impl<R: Row> std::fmt::Display for Project<'_, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
yuhao-su marked this conversation as resolved.
Show resolved Hide resolved
write!(f, "[")?;
for (i, datum) in self.iter().enumerate() {
if i != self.indices.len() - 1 {
write!(f, "{:?}, ", datum)?;
} else {
write!(f, "{:?}", datum)?;
}
}

write!(f, "]")
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
45 changes: 26 additions & 19 deletions src/stream/src/executor/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,29 +480,32 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
let state_order_key_indices_l = state_table_l.pk_indices();
let state_order_key_indices_r = state_table_r.pk_indices();

let join_key_indices_l = params_l.join_key_indices;
let join_key_indices_r = params_r.join_key_indices;
let state_join_key_indices_l = params_l.join_key_indices;
let state_join_key_indices_r = params_r.join_key_indices;

let degree_pk_indices_l = (join_key_indices_l.len()
..join_key_indices_l.len() + params_l.deduped_pk_indices.len())
let degree_join_key_indices_l = (0..state_join_key_indices_l.len()).collect_vec();
let degree_join_key_indices_r = (0..state_join_key_indices_r.len()).collect_vec();

let degree_pk_indices_l = (state_join_key_indices_l.len()
..state_join_key_indices_l.len() + params_l.deduped_pk_indices.len())
.collect_vec();
let degree_pk_indices_r = (join_key_indices_r.len()
..join_key_indices_r.len() + params_r.deduped_pk_indices.len())
let degree_pk_indices_r = (state_join_key_indices_r.len()
..state_join_key_indices_r.len() + params_r.deduped_pk_indices.len())
.collect_vec();

// If pk is contained in join key.
let pk_contained_in_jk_l = is_subset(state_pk_indices_l, join_key_indices_l.clone());
let pk_contained_in_jk_r = is_subset(state_pk_indices_r, join_key_indices_r.clone());
let pk_contained_in_jk_l = is_subset(state_pk_indices_l, state_join_key_indices_l.clone());
let pk_contained_in_jk_r = is_subset(state_pk_indices_r, state_join_key_indices_r.clone());

// check whether join key contains pk in both side
let append_only_optimize = is_append_only && pk_contained_in_jk_l && pk_contained_in_jk_r;

let join_key_data_types_l = join_key_indices_l
let join_key_data_types_l = state_join_key_indices_l
.iter()
.map(|idx| state_all_data_types_l[*idx].clone())
.collect_vec();

let join_key_data_types_r = join_key_indices_r
let join_key_data_types_r = state_join_key_indices_r
.iter()
.map(|idx| state_all_data_types_r[*idx].clone())
.collect_vec();
Expand Down Expand Up @@ -609,9 +612,11 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
ht: JoinHashMap::new(
watermark_epoch.clone(),
join_key_data_types_l,
state_join_key_indices_l.clone(),
state_all_data_types_l.clone(),
state_table_l,
params_l.deduped_pk_indices,
degree_join_key_indices_l,
degree_all_data_types_l,
degree_state_table_l,
degree_pk_indices_l,
Expand All @@ -623,7 +628,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
ctx.fragment_id,
"left",
),
join_key_indices: join_key_indices_l,
join_key_indices: state_join_key_indices_l,
all_data_types: state_all_data_types_l,
i2o_mapping: left_to_output,
i2o_mapping_indexed: l2o_indexed,
Expand All @@ -637,9 +642,11 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
ht: JoinHashMap::new(
watermark_epoch,
join_key_data_types_r,
state_join_key_indices_r.clone(),
state_all_data_types_r.clone(),
state_table_r,
params_r.deduped_pk_indices,
degree_join_key_indices_r,
degree_all_data_types_r,
degree_state_table_r,
degree_pk_indices_r,
Expand All @@ -651,7 +658,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
ctx.fragment_id,
"right",
),
join_key_indices: join_key_indices_r,
join_key_indices: state_join_key_indices_r,
all_data_types: state_all_data_types_r,
start_pos: side_l_column_n,
i2o_mapping: right_to_output,
Expand Down Expand Up @@ -1144,14 +1151,14 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
side_match.ht.update_state(key, matched_rows);
for matched_row in matched_rows_to_clean {
if side_match.need_degree_table {
side_match.ht.delete(key, matched_row);
side_match.ht.delete(key, matched_row)?;
} else {
side_match.ht.delete_row(key, matched_row.row);
side_match.ht.delete_row(key, matched_row.row)?;
}
}

if append_only_optimize && let Some(row) = append_only_matched_row {
side_match.ht.delete(key, row);
side_match.ht.delete(key, row)?;
} else if side_update.need_degree_table {
side_update
.ht
Expand Down Expand Up @@ -1243,18 +1250,18 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
side_match.ht.update_state(key, matched_rows);
for matched_row in matched_rows_to_clean {
if side_match.need_degree_table {
side_match.ht.delete(key, matched_row);
side_match.ht.delete(key, matched_row)?;
} else {
side_match.ht.delete_row(key, matched_row.row);
side_match.ht.delete_row(key, matched_row.row)?;
}
}

if append_only_optimize {
unreachable!();
} else if side_update.need_degree_table {
side_update.ht.delete(key, JoinRow::new(row, degree));
side_update.ht.delete(key, JoinRow::new(row, degree))?;
} else {
side_update.ht.delete_row(key, row);
side_update.ht.delete_row(key, row)?;
};
} else {
// We do not store row which violates null-safe bitmap.
Expand Down
26 changes: 21 additions & 5 deletions src/stream/src/executor/managed_state/join/join_entry_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use risingwave_common::estimate_size::KvSize;
use thiserror::Error;

use super::*;

Expand All @@ -36,19 +37,34 @@ impl EstimateSize for JoinEntryState {
}
}

#[derive(Error, Debug)]
pub enum JoinEntryError {
#[error("double inserting a join state entry")]
OccupiedError,
#[error("removing a join state entry but it is not in the cache")]
RemoveError,
}

yuhao-su marked this conversation as resolved.
Show resolved Hide resolved
impl JoinEntryState {
/// Insert into the cache.
pub fn insert(&mut self, key: PkType, value: StateValueType) {
pub fn insert(
&mut self,
key: PkType,
value: StateValueType,
) -> Result<&mut StateValueType, JoinEntryError> {
self.kv_heap_size.add(&key, &value);
self.cached.try_insert(key, value).unwrap();
self.cached
.try_insert(key, value)
.map_err(|_| JoinEntryError::OccupiedError)
}

/// Delete from the cache.
pub fn remove(&mut self, pk: PkType) {
pub fn remove(&mut self, pk: PkType) -> Result<(), JoinEntryError> {
if let Some(value) = self.cached.remove(&pk) {
self.kv_heap_size.sub(&pk, &value);
Ok(())
} else {
panic!("pk {:?} should be in the cache", pk);
Err(JoinEntryError::RemoveError)
}
}

Expand Down Expand Up @@ -98,7 +114,7 @@ mod tests {
// Pk is only a `i64` here, so encoding method does not matter.
let pk = OwnedRow::new(pk).project(&value_indices).value_serialize();
let join_row = JoinRow { row, degree: 0 };
managed_state.insert(pk, join_row.encode());
managed_state.insert(pk, join_row.encode()).unwrap();
}
}

Expand Down
63 changes: 50 additions & 13 deletions src/stream/src/executor/managed_state/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::alloc::Global;
use std::ops::{Bound, Deref, DerefMut};
use std::sync::Arc;

use anyhow::Context;
use futures::future::try_join;
use futures::StreamExt;
use futures_async_stream::for_await;
Expand Down Expand Up @@ -267,7 +268,10 @@ pub struct JoinHashMap<K: HashKey, S: StateStore> {
}

struct TableInner<S: StateStore> {
/// Indices of the (cache) pk in a state row
pk_indices: Vec<usize>,
/// Indices of the join key in a state row
join_key_indices: Vec<usize>,
// This should be identical to the pk in state table.
order_key_indices: Vec<usize>,
// This should be identical to the data types in table schema.
Expand All @@ -276,15 +280,28 @@ struct TableInner<S: StateStore> {
pub(crate) table: StateTable<S>,
}

impl<S: StateStore> TableInner<S> {
fn error_context(&self, row: &impl Row) -> String {
let pk = row.project(&self.pk_indices);
let jk = row.project(&self.join_key_indices);
format!(
"join key: {jk}, pk: {pk}, row: {row:?}, state_table_id: {}",
self.table.table_id()
)
}
}

impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
/// Create a [`JoinHashMap`] with the given LRU capacity.
#[allow(clippy::too_many_arguments)]
pub fn new(
watermark_epoch: AtomicU64Ref,
join_key_data_types: Vec<DataType>,
state_join_key_indices: Vec<usize>,
state_all_data_types: Vec<DataType>,
state_table: StateTable<S>,
state_pk_indices: Vec<usize>,
degree_join_key_indices: Vec<usize>,
degree_all_data_types: Vec<DataType>,
degree_table: StateTable<S>,
degree_pk_indices: Vec<usize>,
Expand All @@ -311,13 +328,15 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
let degree_table_id = degree_table.table_id();
let state = TableInner {
pk_indices: state_pk_indices,
join_key_indices: state_join_key_indices,
order_key_indices: state_table.pk_indices().to_vec(),
all_data_types: state_all_data_types,
table: state_table,
};

let degree_state = TableInner {
pk_indices: degree_pk_indices,
join_key_indices: degree_join_key_indices,
order_key_indices: degree_table.pk_indices().to_vec(),
all_data_types: degree_all_data_types,
table: degree_table,
Expand Down Expand Up @@ -445,10 +464,12 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
let degree_i64 = degree_row
.datum_at(degree_row.len() - 1)
.expect("degree should not be NULL");
entry_state.insert(
pk,
JoinRow::new(row.into_owned_row(), degree_i64.into_int64() as u64).encode(),
);
entry_state
.insert(
pk,
JoinRow::new(row.into_owned_row(), degree_i64.into_int64() as u64).encode(),
)
.expect("duplicated pk");
yuhao-su marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
Expand All @@ -466,7 +487,9 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
.as_ref()
.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer);
entry_state.insert(pk, JoinRow::new(row.into_owned_row(), 0).encode());
entry_state
.insert(pk, JoinRow::new(row.into_owned_row(), 0).encode())
.expect("duplicated pk");
}
};

Expand Down Expand Up @@ -498,12 +521,16 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
if self.inner.contains(key) {
// Update cache
let mut entry = self.inner.get_mut(key).unwrap();
entry.insert(pk, value.encode());
entry
.insert(pk, value.encode())
.with_context(|| self.state.error_context(&value.row))?;
} else if self.pk_contained_in_jk {
// Refill cache when the join key exist in neither cache or storage.
self.metrics.insert_cache_miss_count += 1;
let mut state = JoinEntryState::default();
state.insert(pk, value.encode());
state
.insert(pk, value.encode())
.with_context(|| self.state.error_context(&value.row))?;
self.update_state(key, state.into());
}

Expand All @@ -528,12 +555,16 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
if self.inner.contains(key) {
// Update cache
let mut entry = self.inner.get_mut(key).unwrap();
entry.insert(pk, join_row.encode());
entry
.insert(pk, join_row.encode())
.with_context(|| self.state.error_context(&value))?;
} else if self.pk_contained_in_jk {
// Refill cache when the join key exist in neither cache or storage.
self.metrics.insert_cache_miss_count += 1;
let mut state = JoinEntryState::default();
state.insert(pk, join_row.encode());
state
.insert(pk, join_row.encode())
.with_context(|| self.state.error_context(&value))?;
self.update_state(key, state.into());
}

Expand All @@ -543,32 +574,38 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
}

/// Delete a join row
pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) {
pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
if let Some(mut entry) = self.inner.get_mut(key) {
let pk = (&value.row)
.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer);
entry.remove(pk);
entry
.remove(pk)
.with_context(|| self.state.error_context(&value.row))?;
}

// If no cache maintained, only update the state table.
let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
self.state.table.delete(row);
self.degree_state.table.delete(degree);
Ok(())
}

/// Delete a row
/// Used when the side does not need to update degree.
pub fn delete_row(&mut self, key: &K, value: impl Row) {
pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
if let Some(mut entry) = self.inner.get_mut(key) {
let pk = (&value)
.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer);
entry.remove(pk);
entry
.remove(pk)
.with_context(|| self.state.error_context(&value))?;
}

// If no cache maintained, only update the state table.
self.state.table.delete(value);
Ok(())
}

/// Update a [`JoinEntryState`] into the hash table.
Expand Down
Loading