Skip to content

Commit

Permalink
refactor(streaming): improve hash join error message (#14515)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhao-su authored and Little-Wallace committed Jan 20, 2024
1 parent 154e0e3 commit 042a94b
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 37 deletions.
4 changes: 4 additions & 0 deletions src/storage/src/table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ impl<T: AsRef<[u8]>> KeyedRow<T> {
self.vnode_prefixed_key.key_part()
}

pub fn row(&self) -> &OwnedRow {
&self.row
}

pub fn into_parts(self) -> (TableKey<T>, OwnedRow) {
(self.vnode_prefixed_key, self.row)
}
Expand Down
45 changes: 26 additions & 19 deletions src/stream/src/executor/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,29 +480,32 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
let state_order_key_indices_l = state_table_l.pk_indices();
let state_order_key_indices_r = state_table_r.pk_indices();

let join_key_indices_l = params_l.join_key_indices;
let join_key_indices_r = params_r.join_key_indices;
let state_join_key_indices_l = params_l.join_key_indices;
let state_join_key_indices_r = params_r.join_key_indices;

let degree_pk_indices_l = (join_key_indices_l.len()
..join_key_indices_l.len() + params_l.deduped_pk_indices.len())
let degree_join_key_indices_l = (0..state_join_key_indices_l.len()).collect_vec();
let degree_join_key_indices_r = (0..state_join_key_indices_r.len()).collect_vec();

let degree_pk_indices_l = (state_join_key_indices_l.len()
..state_join_key_indices_l.len() + params_l.deduped_pk_indices.len())
.collect_vec();
let degree_pk_indices_r = (join_key_indices_r.len()
..join_key_indices_r.len() + params_r.deduped_pk_indices.len())
let degree_pk_indices_r = (state_join_key_indices_r.len()
..state_join_key_indices_r.len() + params_r.deduped_pk_indices.len())
.collect_vec();

// If pk is contained in join key.
let pk_contained_in_jk_l = is_subset(state_pk_indices_l, join_key_indices_l.clone());
let pk_contained_in_jk_r = is_subset(state_pk_indices_r, join_key_indices_r.clone());
let pk_contained_in_jk_l = is_subset(state_pk_indices_l, state_join_key_indices_l.clone());
let pk_contained_in_jk_r = is_subset(state_pk_indices_r, state_join_key_indices_r.clone());

// check whether join key contains pk in both side
let append_only_optimize = is_append_only && pk_contained_in_jk_l && pk_contained_in_jk_r;

let join_key_data_types_l = join_key_indices_l
let join_key_data_types_l = state_join_key_indices_l
.iter()
.map(|idx| state_all_data_types_l[*idx].clone())
.collect_vec();

let join_key_data_types_r = join_key_indices_r
let join_key_data_types_r = state_join_key_indices_r
.iter()
.map(|idx| state_all_data_types_r[*idx].clone())
.collect_vec();
Expand Down Expand Up @@ -609,9 +612,11 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
ht: JoinHashMap::new(
watermark_epoch.clone(),
join_key_data_types_l,
state_join_key_indices_l.clone(),
state_all_data_types_l.clone(),
state_table_l,
params_l.deduped_pk_indices,
degree_join_key_indices_l,
degree_all_data_types_l,
degree_state_table_l,
degree_pk_indices_l,
Expand All @@ -623,7 +628,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
ctx.fragment_id,
"left",
),
join_key_indices: join_key_indices_l,
join_key_indices: state_join_key_indices_l,
all_data_types: state_all_data_types_l,
i2o_mapping: left_to_output,
i2o_mapping_indexed: l2o_indexed,
Expand All @@ -637,9 +642,11 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
ht: JoinHashMap::new(
watermark_epoch,
join_key_data_types_r,
state_join_key_indices_r.clone(),
state_all_data_types_r.clone(),
state_table_r,
params_r.deduped_pk_indices,
degree_join_key_indices_r,
degree_all_data_types_r,
degree_state_table_r,
degree_pk_indices_r,
Expand All @@ -651,7 +658,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
ctx.fragment_id,
"right",
),
join_key_indices: join_key_indices_r,
join_key_indices: state_join_key_indices_r,
all_data_types: state_all_data_types_r,
start_pos: side_l_column_n,
i2o_mapping: right_to_output,
Expand Down Expand Up @@ -1144,14 +1151,14 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
side_match.ht.update_state(key, matched_rows);
for matched_row in matched_rows_to_clean {
if side_match.need_degree_table {
side_match.ht.delete(key, matched_row);
side_match.ht.delete(key, matched_row)?;
} else {
side_match.ht.delete_row(key, matched_row.row);
side_match.ht.delete_row(key, matched_row.row)?;
}
}

if append_only_optimize && let Some(row) = append_only_matched_row {
side_match.ht.delete(key, row);
side_match.ht.delete(key, row)?;
} else if side_update.need_degree_table {
side_update
.ht
Expand Down Expand Up @@ -1243,18 +1250,18 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
side_match.ht.update_state(key, matched_rows);
for matched_row in matched_rows_to_clean {
if side_match.need_degree_table {
side_match.ht.delete(key, matched_row);
side_match.ht.delete(key, matched_row)?;
} else {
side_match.ht.delete_row(key, matched_row.row);
side_match.ht.delete_row(key, matched_row.row)?;
}
}

if append_only_optimize {
unreachable!();
} else if side_update.need_degree_table {
side_update.ht.delete(key, JoinRow::new(row, degree));
side_update.ht.delete(key, JoinRow::new(row, degree))?;
} else {
side_update.ht.delete_row(key, row);
side_update.ht.delete_row(key, row)?;
};
} else {
// We do not store row which violates null-safe bitmap.
Expand Down
26 changes: 21 additions & 5 deletions src/stream/src/executor/managed_state/join/join_entry_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use risingwave_common::estimate_size::KvSize;
use thiserror::Error;

use super::*;

Expand All @@ -36,19 +37,34 @@ impl EstimateSize for JoinEntryState {
}
}

#[derive(Error, Debug)]
pub enum JoinEntryError {
#[error("double inserting a join state entry")]
OccupiedError,
#[error("removing a join state entry but it is not in the cache")]
RemoveError,
}

impl JoinEntryState {
/// Insert into the cache.
pub fn insert(&mut self, key: PkType, value: StateValueType) {
pub fn insert(
&mut self,
key: PkType,
value: StateValueType,
) -> Result<&mut StateValueType, JoinEntryError> {
self.kv_heap_size.add(&key, &value);
self.cached.try_insert(key, value).unwrap();
self.cached
.try_insert(key, value)
.map_err(|_| JoinEntryError::OccupiedError)
}

/// Delete from the cache.
pub fn remove(&mut self, pk: PkType) {
pub fn remove(&mut self, pk: PkType) -> Result<(), JoinEntryError> {
if let Some(value) = self.cached.remove(&pk) {
self.kv_heap_size.sub(&pk, &value);
Ok(())
} else {
panic!("pk {:?} should be in the cache", pk);
Err(JoinEntryError::RemoveError)
}
}

Expand Down Expand Up @@ -98,7 +114,7 @@ mod tests {
// Pk is only a `i64` here, so encoding method does not matter.
let pk = OwnedRow::new(pk).project(&value_indices).value_serialize();
let join_row = JoinRow { row, degree: 0 };
managed_state.insert(pk, join_row.encode());
managed_state.insert(pk, join_row.encode()).unwrap();
}
}

Expand Down
66 changes: 53 additions & 13 deletions src/stream/src/executor/managed_state/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::alloc::Global;
use std::ops::{Bound, Deref, DerefMut};
use std::sync::Arc;

use anyhow::Context;
use futures::future::try_join;
use futures::StreamExt;
use futures_async_stream::for_await;
Expand Down Expand Up @@ -267,7 +268,10 @@ pub struct JoinHashMap<K: HashKey, S: StateStore> {
}

struct TableInner<S: StateStore> {
/// Indices of the (cache) pk in a state row
pk_indices: Vec<usize>,
/// Indices of the join key in a state row
join_key_indices: Vec<usize>,
// This should be identical to the pk in state table.
order_key_indices: Vec<usize>,
// This should be identical to the data types in table schema.
Expand All @@ -276,15 +280,31 @@ struct TableInner<S: StateStore> {
pub(crate) table: StateTable<S>,
}

impl<S: StateStore> TableInner<S> {
fn error_context(&self, row: &impl Row) -> String {
let pk = row.project(&self.pk_indices);
let jk = row.project(&self.join_key_indices);
format!(
"join key: {}, pk: {}, row: {}, state_table_id: {}",
jk.display(),
pk.display(),
row.display(),
self.table.table_id()
)
}
}

impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
/// Create a [`JoinHashMap`] with the given LRU capacity.
#[allow(clippy::too_many_arguments)]
pub fn new(
watermark_epoch: AtomicU64Ref,
join_key_data_types: Vec<DataType>,
state_join_key_indices: Vec<usize>,
state_all_data_types: Vec<DataType>,
state_table: StateTable<S>,
state_pk_indices: Vec<usize>,
degree_join_key_indices: Vec<usize>,
degree_all_data_types: Vec<DataType>,
degree_table: StateTable<S>,
degree_pk_indices: Vec<usize>,
Expand All @@ -311,13 +331,15 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
let degree_table_id = degree_table.table_id();
let state = TableInner {
pk_indices: state_pk_indices,
join_key_indices: state_join_key_indices,
order_key_indices: state_table.pk_indices().to_vec(),
all_data_types: state_all_data_types,
table: state_table,
};

let degree_state = TableInner {
pk_indices: degree_pk_indices,
join_key_indices: degree_join_key_indices,
order_key_indices: degree_table.pk_indices().to_vec(),
all_data_types: degree_all_data_types,
table: degree_table,
Expand Down Expand Up @@ -445,10 +467,12 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
let degree_i64 = degree_row
.datum_at(degree_row.len() - 1)
.expect("degree should not be NULL");
entry_state.insert(
pk,
JoinRow::new(row.into_owned_row(), degree_i64.into_int64() as u64).encode(),
);
entry_state
.insert(
pk,
JoinRow::new(row.row(), degree_i64.into_int64() as u64).encode(),
)
.with_context(|| self.state.error_context(row.row()))?;
}
} else {
let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) =
Expand All @@ -466,7 +490,9 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
.as_ref()
.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer);
entry_state.insert(pk, JoinRow::new(row.into_owned_row(), 0).encode());
entry_state
.insert(pk, JoinRow::new(row.row(), 0).encode())
.with_context(|| self.state.error_context(row.row()))?;
}
};

Expand Down Expand Up @@ -498,12 +524,16 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
if self.inner.contains(key) {
// Update cache
let mut entry = self.inner.get_mut(key).unwrap();
entry.insert(pk, value.encode());
entry
.insert(pk, value.encode())
.with_context(|| self.state.error_context(&value.row))?;
} else if self.pk_contained_in_jk {
// Refill cache when the join key exist in neither cache or storage.
self.metrics.insert_cache_miss_count += 1;
let mut state = JoinEntryState::default();
state.insert(pk, value.encode());
state
.insert(pk, value.encode())
.with_context(|| self.state.error_context(&value.row))?;
self.update_state(key, state.into());
}

Expand All @@ -528,12 +558,16 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
if self.inner.contains(key) {
// Update cache
let mut entry = self.inner.get_mut(key).unwrap();
entry.insert(pk, join_row.encode());
entry
.insert(pk, join_row.encode())
.with_context(|| self.state.error_context(&value))?;
} else if self.pk_contained_in_jk {
// Refill cache when the join key exist in neither cache or storage.
self.metrics.insert_cache_miss_count += 1;
let mut state = JoinEntryState::default();
state.insert(pk, join_row.encode());
state
.insert(pk, join_row.encode())
.with_context(|| self.state.error_context(&value))?;
self.update_state(key, state.into());
}

Expand All @@ -543,32 +577,38 @@ impl<K: HashKey, S: StateStore> JoinHashMap<K, S> {
}

/// Delete a join row
pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) {
pub fn delete(&mut self, key: &K, value: JoinRow<impl Row>) -> StreamExecutorResult<()> {
if let Some(mut entry) = self.inner.get_mut(key) {
let pk = (&value.row)
.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer);
entry.remove(pk);
entry
.remove(pk)
.with_context(|| self.state.error_context(&value.row))?;
}

// If no cache maintained, only update the state table.
let (row, degree) = value.to_table_rows(&self.state.order_key_indices);
self.state.table.delete(row);
self.degree_state.table.delete(degree);
Ok(())
}

/// Delete a row
/// Used when the side does not need to update degree.
pub fn delete_row(&mut self, key: &K, value: impl Row) {
pub fn delete_row(&mut self, key: &K, value: impl Row) -> StreamExecutorResult<()> {
if let Some(mut entry) = self.inner.get_mut(key) {
let pk = (&value)
.project(&self.state.pk_indices)
.memcmp_serialize(&self.pk_serializer);
entry.remove(pk);
entry
.remove(pk)
.with_context(|| self.state.error_context(&value))?;
}

// If no cache maintained, only update the state table.
self.state.table.delete(value);
Ok(())
}

/// Update a [`JoinEntryState`] into the hash table.
Expand Down

0 comments on commit 042a94b

Please sign in to comment.