Skip to content

Commit

Permalink
refactor(common): add MemcmpEncoded struct to represent memcmp enco…
Browse files Browse the repository at this point in the history
…ded data (#10319)

Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc authored Jun 14, 2023
1 parent ff91a4a commit ede3278
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 36 deletions.
6 changes: 3 additions & 3 deletions src/batch/src/executor/top_n.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use risingwave_common::estimate_size::EstimateSize;
use risingwave_common::memory::MemoryContext;
use risingwave_common::row::{OwnedRow, Row};
use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
use risingwave_common::util::memcmp_encoding::encode_chunk;
use risingwave_common::util::memcmp_encoding::{encode_chunk, MemcmpEncoded};
use risingwave_common::util::sort_util::ColumnOrder;
use risingwave_pb::batch_plan::plan_node::NodeBody;

Expand Down Expand Up @@ -200,7 +200,7 @@ impl TopNHeap {

#[derive(Clone, EstimateSize)]
pub struct HeapElem {
encoded_row: Vec<u8>,
encoded_row: MemcmpEncoded,
row: OwnedRow,
}

Expand All @@ -225,7 +225,7 @@ impl Ord for HeapElem {
}

impl HeapElem {
pub fn new(encoded_row: Vec<u8>, row: impl Row) -> Self {
pub fn new(encoded_row: MemcmpEncoded, row: impl Row) -> Self {
Self {
encoded_row,
row: row.into_owned_row(),
Expand Down
3 changes: 2 additions & 1 deletion src/common/benches/bench_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::types::{
DataType, Date, Datum, Interval, ScalarImpl, StructType, Time, Timestamp,
};
use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
use risingwave_common::util::sort_util::OrderType;
use risingwave_common::util::{memcmp_encoding, value_encoding};

Expand All @@ -42,7 +43,7 @@ impl Case {
}
}

fn key_serialization(datum: &Datum) -> Vec<u8> {
fn key_serialization(datum: &Datum) -> MemcmpEncoded {
let result = memcmp_encoding::encode_value(
datum.as_ref().map(ScalarImpl::as_scalar_ref_impl),
OrderType::default(),
Expand Down
109 changes: 95 additions & 14 deletions src/common/src/util/memcmp_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::ops::Deref;

use bytes::{Buf, BufMut};
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use super::iter_util::{ZipEqDebug, ZipEqFast};
use crate::array::{ArrayImpl, DataChunk};
use crate::estimate_size::EstimateSize;
use crate::row::{OwnedRow, Row};
use crate::types::{
DataType, Date, Datum, Int256, ScalarImpl, Serial, Time, Timestamp, ToDatumRef, F32, F64,
Expand Down Expand Up @@ -180,12 +183,83 @@ fn calculate_encoded_size_inner(
Ok(deserializer.position() - base_position)
}

pub fn encode_value(value: impl ToDatumRef, order: OrderType) -> memcomparable::Result<Vec<u8>> {
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, EstimateSize)]
pub struct MemcmpEncoded(Box<[u8]>);

impl MemcmpEncoded {
pub fn as_inner(&self) -> &[u8] {
&self.0
}

pub fn into_inner(self) -> Box<[u8]> {
self.0
}
}

impl AsRef<[u8]> for MemcmpEncoded {
fn as_ref(&self) -> &[u8] {
&self.0
}
}

impl Deref for MemcmpEncoded {
type Target = [u8];

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl IntoIterator for MemcmpEncoded {
type IntoIter = std::vec::IntoIter<Self::Item>;
type Item = u8;

fn into_iter(self) -> Self::IntoIter {
self.0.into_vec().into_iter()
}
}

impl FromIterator<u8> for MemcmpEncoded {
fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
Self(iter.into_iter().collect())
}
}

impl From<Vec<u8>> for MemcmpEncoded {
fn from(v: Vec<u8>) -> Self {
Self(v.into_boxed_slice())
}
}

impl From<Box<[u8]>> for MemcmpEncoded {
fn from(v: Box<[u8]>) -> Self {
Self(v)
}
}

impl From<MemcmpEncoded> for Vec<u8> {
fn from(v: MemcmpEncoded) -> Self {
v.0.into()
}
}

impl From<MemcmpEncoded> for Box<[u8]> {
fn from(v: MemcmpEncoded) -> Self {
v.0
}
}

/// Encode a datum into memcomparable format.
pub fn encode_value(
value: impl ToDatumRef,
order: OrderType,
) -> memcomparable::Result<MemcmpEncoded> {
let mut serializer = memcomparable::Serializer::new(vec![]);
serialize_datum(value, order, &mut serializer)?;
Ok(serializer.into_inner())
Ok(serializer.into_inner().into())
}

/// Decode a datum from memcomparable format.
pub fn decode_value(
ty: &DataType,
encoded_value: &[u8],
Expand All @@ -195,21 +269,23 @@ pub fn decode_value(
deserialize_datum(ty, order, &mut deserializer)
}

pub fn encode_array(array: &ArrayImpl, order: OrderType) -> memcomparable::Result<Vec<Vec<u8>>> {
/// Encode an array into memcomparable format.
pub fn encode_array(
array: &ArrayImpl,
order: OrderType,
) -> memcomparable::Result<Vec<MemcmpEncoded>> {
let mut data = Vec::with_capacity(array.len());
for datum in array.iter() {
data.push(encode_value(datum, order)?);
}
Ok(data)
}

/// This function is used to accelerate the comparison of tuples. It takes datachunk and
/// user-defined order as input, yield encoded binary string with order preserved for each tuple in
/// the datachunk.
/// Encode a chunk into memcomparable format.
pub fn encode_chunk(
chunk: &DataChunk,
column_orders: &[ColumnOrder],
) -> memcomparable::Result<Vec<Vec<u8>>> {
) -> memcomparable::Result<Vec<MemcmpEncoded>> {
let encoded_columns: Vec<_> = column_orders
.iter()
.map(|o| encode_array(chunk.column_at(o.column_index), o.order_type))
Expand All @@ -222,18 +298,22 @@ pub fn encode_chunk(
}
}

Ok(encoded_chunk)
Ok(encoded_chunk.into_iter().map(Into::into).collect())
}

/// Encode a row into memcomparable format.
pub fn encode_row(row: impl Row, order_types: &[OrderType]) -> memcomparable::Result<Vec<u8>> {
pub fn encode_row(
row: impl Row,
order_types: &[OrderType],
) -> memcomparable::Result<MemcmpEncoded> {
let mut serializer = memcomparable::Serializer::new(vec![]);
row.iter()
.zip_eq_debug(order_types)
.try_for_each(|(datum, order)| serialize_datum(datum, *order, &mut serializer))?;
Ok(serializer.into_inner())
Ok(serializer.into_inner().into())
}

/// Decode a row from memcomparable format.
pub fn decode_row(
encoded_row: &[u8],
data_types: &[DataType],
Expand All @@ -259,11 +339,12 @@ mod tests {
use crate::array::{DataChunk, ListValue, StructValue};
use crate::row::{OwnedRow, RowExt};
use crate::types::{DataType, FloatExt, ScalarImpl, F32};
use crate::util::iter_util::ZipEqFast;
use crate::util::sort_util::{ColumnOrder, OrderType};

#[test]
fn test_memcomparable() {
fn encode_num(num: Option<i32>, order_type: OrderType) -> Vec<u8> {
fn encode_num(num: Option<i32>, order_type: OrderType) -> MemcmpEncoded {
encode_value(num.map(ScalarImpl::from), order_type).unwrap()
}

Expand Down Expand Up @@ -465,11 +546,11 @@ mod tests {
use num_traits::*;
use rand::seq::SliceRandom;

fn serialize(f: F32) -> Vec<u8> {
fn serialize(f: F32) -> MemcmpEncoded {
encode_value(&Some(ScalarImpl::from(f)), OrderType::default()).unwrap()
}

fn deserialize(data: Vec<u8>) -> F32 {
fn deserialize(data: MemcmpEncoded) -> F32 {
decode_value(&DataType::Float32, &data, OrderType::default())
.unwrap()
.unwrap()
Expand Down Expand Up @@ -539,7 +620,7 @@ mod tests {
let concated_encoded_row1 = encoded_v10
.into_iter()
.chain(encoded_v11.into_iter())
.collect_vec();
.collect();
assert_eq!(encoded_row1, concated_encoded_row1);

let encoded_row2 = encode_row(row2.project(&order_col_indices), &order_types).unwrap();
Expand Down
5 changes: 3 additions & 2 deletions src/stream/src/executor/aggregation/agg_state_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ use risingwave_common::array::{ArrayImpl, Op};
use risingwave_common::buffer::Bitmap;
use risingwave_common::estimate_size::EstimateSize;
use risingwave_common::types::{Datum, DatumRef};
use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
use risingwave_common::util::row_serde::OrderedRowSerde;
use smallvec::SmallVec;

use super::minput_agg_impl::MInputAggregator;
use crate::common::cache::{StateCache, StateCacheFiller};

/// Cache key type.
type CacheKey = Vec<u8>;
type CacheKey = MemcmpEncoded;

// TODO(yuchao): May extract common logic here to `struct [Data/Stream]ChunkRef` if there's other
// usage in the future. https://github.com/risingwavelabs/risingwave/pull/5908#discussion_r1002896176
Expand Down Expand Up @@ -76,7 +77,7 @@ impl<'a> Iterator for StateCacheInputBatch<'a> {
.map(|col_idx| self.columns[*col_idx].value_at(self.idx)),
&mut key,
);
key
key.into()
};
let value = self
.arg_col_indices
Expand Down
2 changes: 1 addition & 1 deletion src/stream/src/executor/aggregation/minput.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ impl<S: StateStore> MaterializedInputState<S> {
.project(&self.state_table_order_col_indices),
&mut cache_key,
);
cache_key
cache_key.into()
};
let cache_value = self
.state_table_arg_col_indices
Expand Down
12 changes: 4 additions & 8 deletions src/stream/src/executor/over_window/eowc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@ use risingwave_common::estimate_size::EstimateSize;
use risingwave_common::row::{OwnedRow, Row, RowExt};
use risingwave_common::types::{DataType, ToDatumRef, ToOwnedDatum};
use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast};
use risingwave_common::util::memcmp_encoding;
use risingwave_common::util::memcmp_encoding::{self, MemcmpEncoded};
use risingwave_common::util::sort_util::OrderType;
use risingwave_common::{must_match, row};
use risingwave_expr::function::window::WindowFuncCall;
use risingwave_storage::store::PrefetchOptions;
use risingwave_storage::StateStore;

use super::state::{create_window_state, EstimatedVecDeque, WindowState};
use super::MemcmpEncoded;
use crate::cache::{new_unbounded, ManagedLruCache};
use crate::common::table::state_table::StateTable;
use crate::executor::over_window::state::{StateEvictHint, StateKey};
Expand Down Expand Up @@ -241,8 +240,7 @@ impl<S: StateStore> EowcOverWindowExecutor<S> {
let encoded_pk = memcmp_encoding::encode_row(
(&row).project(&this.input_pk_indices),
&vec![OrderType::ascending(); this.input_pk_indices.len()],
)?
.into_boxed_slice();
)?;
let key = StateKey {
order_key: order_key.into(),
encoded_pk,
Expand Down Expand Up @@ -292,8 +290,7 @@ impl<S: StateStore> EowcOverWindowExecutor<S> {
let encoded_partition_key = memcmp_encoding::encode_row(
&partition_key,
&vec![OrderType::ascending(); this.partition_key_indices.len()],
)?
.into_boxed_slice();
)?;

// Get the partition.
Self::ensure_key_in_cache(
Expand All @@ -316,8 +313,7 @@ impl<S: StateStore> EowcOverWindowExecutor<S> {
let encoded_pk = memcmp_encoding::encode_row(
input_row.project(&this.input_pk_indices),
&vec![OrderType::ascending(); this.input_pk_indices.len()],
)?
.into_boxed_slice();
)?;
let key = StateKey {
order_key: order_key.into(),
encoded_pk,
Expand Down
2 changes: 0 additions & 2 deletions src/stream/src/executor/over_window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,3 @@ mod eowc;
mod state;

pub use eowc::{EowcOverWindowExecutor, EowcOverWindowExecutorArgs};

type MemcmpEncoded = Box<[u8]>;
2 changes: 1 addition & 1 deletion src/stream/src/executor/over_window/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ use std::collections::{BTreeSet, VecDeque};
use educe::Educe;
use risingwave_common::estimate_size::{EstimateSize, KvSize};
use risingwave_common::types::{Datum, DefaultOrdered, ScalarImpl};
use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
use risingwave_expr::function::window::{WindowFuncCall, WindowFuncKind};
use smallvec::SmallVec;

use super::MemcmpEncoded;
use crate::executor::{StreamExecutorError, StreamExecutorResult};

mod buffer;
Expand Down
6 changes: 2 additions & 4 deletions src/stream/src/executor/sort_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use risingwave_common::row::{self, OwnedRow, Row, RowExt};
use risingwave_common::types::{
DefaultOrd, DefaultOrdered, ScalarImpl, ScalarRefImpl, ToOwnedDatum,
};
use risingwave_common::util::memcmp_encoding::MemcmpEncoded;
use risingwave_storage::row_serde::row_serde_util::deserialize_pk_with_vnode;
use risingwave_storage::store::PrefetchOptions;
use risingwave_storage::StateStore;
Expand All @@ -35,9 +36,6 @@ use super::{StreamExecutorError, StreamExecutorResult};
use crate::common::cache::{OrderedStateCache, StateCache, StateCacheFiller};
use crate::common::table::state_table::StateTable;

// TODO(rc): This should be a struct in `memcmp_encoding` module. See #8606.
type MemcmpEncoded = Box<[u8]>;

type CacheKey = (
DefaultOrdered<ScalarImpl>, // sort (watermark) column value
MemcmpEncoded, // memcmp-encoded pk
Expand All @@ -56,7 +54,7 @@ fn row_to_cache_key<S: StateStore>(
buffer_table
.pk_serde()
.serialize((&row).project(buffer_table.pk_indices()), &mut pk);
(timestamp_val.into(), pk.into_boxed_slice())
(timestamp_val.into(), pk.into())
}

/// [`SortBuffer`] is a common component that consume an unordered stream and produce an ordered
Expand Down

0 comments on commit ede3278

Please sign in to comment.