Skip to content

Commit

Permalink
refactor(common): implement Datum::[from|to]_protobuf (#12679)
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao authored Oct 8, 2023
1 parent d3ebe9f commit 8b18a7e
Show file tree
Hide file tree
Showing 13 changed files with 75 additions and 76 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ arrow-cast = { workspace = true }
arrow-schema = { workspace = true }
async-trait = "0.1"
auto_enums = "0.8"
auto_impl = "1"
bitflags = "2"
byteorder = "1"
bytes = "1"
Expand Down
7 changes: 1 addition & 6 deletions src/common/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ impl ToOwnedDatum for DatumRef<'_> {
}
}

#[auto_impl::auto_impl(&)]
pub trait ToDatumRef: PartialEq + Eq + Debug {
/// Convert the datum to [`DatumRef`].
fn to_datum_ref(&self) -> DatumRef<'_>;
Expand All @@ -569,12 +570,6 @@ impl ToDatumRef for Datum {
self.as_ref().map(|d| d.as_scalar_ref_impl())
}
}
impl ToDatumRef for &Datum {
#[inline(always)]
fn to_datum_ref(&self) -> DatumRef<'_> {
self.as_ref().map(|d| d.as_scalar_ref_impl())
}
}
impl ToDatumRef for Option<&ScalarImpl> {
#[inline(always)]
fn to_datum_ref(&self) -> DatumRef<'_> {
Expand Down
22 changes: 12 additions & 10 deletions src/common/src/util/memcmp_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,14 @@ mod tests {
fn test_memcomparable_structs() {
// NOTE: `NULL`s inside composite type values are always the largest.

let struct_none = None;
let struct_1 = Some(
let struct_none = Datum::None;
let struct_1 = Datum::Some(
StructValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(2))]).into(),
);
let struct_2 = Some(
let struct_2 = Datum::Some(
StructValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(3))]).into(),
);
let struct_3 = Some(StructValue::new(vec![Some(ScalarImpl::from(1)), None]).into());
let struct_3 = Datum::Some(StructValue::new(vec![Some(ScalarImpl::from(1)), None]).into());

{
// ASC NULLS FIRST (NULLS SMALLEST)
Expand Down Expand Up @@ -489,12 +489,14 @@ mod tests {
fn test_memcomparable_lists() {
// NOTE: `NULL`s inside composite type values are always the largest.

let list_none = None;
let list_1 =
Some(ListValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(2))]).into());
let list_2 =
Some(ListValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(3))]).into());
let list_3 = Some(ListValue::new(vec![Some(ScalarImpl::from(1)), None]).into());
let list_none = Datum::None;
let list_1 = Datum::Some(
ListValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(2))]).into(),
);
let list_2 = Datum::Some(
ListValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(3))]).into(),
);
let list_3 = Datum::Some(ListValue::new(vec![Some(ScalarImpl::from(1)), None]).into());

{
// ASC NULLS FIRST (NULLS SMALLEST)
Expand Down
19 changes: 19 additions & 0 deletions src/common/src/util/value_encoding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use bytes::{Buf, BufMut};
use chrono::{Datelike, Timelike};
use either::{for_both, Either};
use enum_as_inner::EnumAsInner;
use risingwave_pb::data::PbDatum;

use crate::array::{ArrayImpl, ListRef, ListValue, StructRef, StructValue};
use crate::row::{Row, RowDeserializer as BasicDeserializer};
Expand Down Expand Up @@ -165,6 +166,24 @@ pub fn estimate_serialize_datum_size(datum_ref: impl ToDatumRef) -> usize {
}
}

#[easy_ext::ext(DatumFromProtoExt)]
impl Datum {
/// Create a datum from the protobuf representation with the given data type.
pub fn from_protobuf(proto: &PbDatum, data_type: &DataType) -> Result<Datum> {
deserialize_datum(proto.body.as_slice(), data_type)
}
}

#[easy_ext::ext(DatumToProtoExt)]
impl<D: ToDatumRef> D {
/// Convert the datum to the protobuf representation.
pub fn to_protobuf(&self) -> PbDatum {
PbDatum {
body: serialize_datum(self),
}
}
}

/// Deserialize bytes into a datum (Not order guarantee, used in value encoding).
pub fn deserialize_datum(mut data: impl Buf, ty: &DataType) -> Result<Datum> {
inner_deserialize_datum(&mut data, ty)
Expand Down
10 changes: 3 additions & 7 deletions src/expr/core/src/aggregate/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ use std::sync::Arc;
use itertools::Itertools;
use parse_display::{Display, FromStr};
use risingwave_common::bail;
use risingwave_common::types::DataType;
use risingwave_common::types::{DataType, Datum};
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_common::util::value_encoding;
use risingwave_common::util::value_encoding::DatumFromProtoExt;
use risingwave_pb::expr::agg_call::PbType;
use risingwave_pb::expr::{PbAggCall, PbInputRef};

Expand Down Expand Up @@ -78,11 +78,7 @@ impl AggCall {
let data_type = DataType::from(arg.get_type().unwrap());
LiteralExpression::new(
data_type.clone(),
value_encoding::deserialize_datum(
arg.get_datum().unwrap().get_body().as_slice(),
&data_type,
)
.unwrap(),
Datum::from_protobuf(arg.get_datum().unwrap(), &data_type).unwrap(),
)
})
.collect_vec();
Expand Down
4 changes: 2 additions & 2 deletions src/expr/core/src/expr/expr_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use anyhow::anyhow;
use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk};
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum, ScalarImpl};
use risingwave_common::util::value_encoding::deserialize_datum;
use risingwave_common::util::value_encoding::DatumFromProtoExt;
use risingwave_pb::expr::expr_node::{RexNode, Type};
use risingwave_pb::expr::ExprNode;

Expand Down Expand Up @@ -89,7 +89,7 @@ impl Build for FieldExpression {
let RexNode::Constant(value) = second.get_rex_node().unwrap() else {
bail!("Expected Constant as 1st argument");
};
let index = deserialize_datum(value.body.as_slice(), &DataType::Int32)
let index = Datum::from_protobuf(value, &DataType::Int32)
.map_err(|e| anyhow!("Failed to deserialize i32, reason: {:?}", e))?
.unwrap()
.as_int32()
Expand Down
14 changes: 5 additions & 9 deletions src/expr/core/src/expr/expr_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ mod tests {
use risingwave_common::array::DataChunk;
use risingwave_common::row::OwnedRow;
use risingwave_common::test_prelude::DataChunkTestExt;
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_common::util::value_encoding::serialize_datum;
use risingwave_common::types::{DataType, Datum, ScalarImpl};
use risingwave_common::util::value_encoding::DatumToProtoExt;
use risingwave_pb::data::data_type::TypeName;
use risingwave_pb::data::{PbDataType, PbDatum};
use risingwave_pb::data::PbDataType;
use risingwave_pb::expr::expr_node::{RexNode, Type};
use risingwave_pb::expr::{ExprNode, FunctionCall};

Expand All @@ -158,19 +158,15 @@ mod tests {
type_name: TypeName::Varchar as i32,
..Default::default()
}),
rex_node: Some(RexNode::Constant(PbDatum {
body: serialize_datum(Some("ABC".into()).as_ref()),
})),
rex_node: Some(RexNode::Constant(Datum::Some("ABC".into()).to_protobuf())),
},
ExprNode {
function_type: Type::Unspecified as i32,
return_type: Some(PbDataType {
type_name: TypeName::Varchar as i32,
..Default::default()
}),
rex_node: Some(RexNode::Constant(PbDatum {
body: serialize_datum(Some("def".into()).as_ref()),
})),
rex_node: Some(RexNode::Constant(Datum::Some("def".into()).to_protobuf())),
},
];
let mut in_children = vec![input_ref_expr_node];
Expand Down
13 changes: 6 additions & 7 deletions src/expr/core/src/expr/expr_literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use risingwave_common::array::DataChunk;
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{literal_type_match, DataType, Datum};
use risingwave_common::util::value_encoding::deserialize_datum;
use risingwave_common::util::value_encoding::DatumFromProtoExt;
use risingwave_pb::expr::ExprNode;

use super::{Build, ValueImpl};
Expand Down Expand Up @@ -74,9 +74,8 @@ impl Build for LiteralExpression {

let prost_value = prost.get_rex_node().unwrap().as_constant().unwrap();

// TODO: We need to unify these
let value = deserialize_datum(
prost_value.get_body().as_slice(),
let value = Datum::from_protobuf(
prost_value,
&DataType::from(prost.get_return_type().unwrap()),
)
.map_err(|e| ExprError::Internal(e.into()))?;
Expand All @@ -92,7 +91,7 @@ mod tests {
use risingwave_common::array::{I32Array, StructValue};
use risingwave_common::types::test_utils::IntervalTestExt;
use risingwave_common::types::{Decimal, Interval, IntoOrdered, Scalar, ScalarImpl};
use risingwave_common::util::value_encoding::serialize_datum;
use risingwave_common::util::value_encoding::{serialize_datum, DatumToProtoExt};
use risingwave_pb::data::data_type::{IntervalType, TypeName};
use risingwave_pb::data::{PbDataType, PbDatum};
use risingwave_pb::expr::expr_node::RexNode::{self, Constant};
Expand All @@ -108,7 +107,7 @@ mod tests {
Some(2.into()),
None,
]);
let body = serialize_datum(Some(value.clone().to_scalar_value()).as_ref());
let pb_datum = Some(value.clone().to_scalar_value()).to_protobuf();
let expr = ExprNode {
function_type: Type::Unspecified as i32,
return_type: Some(PbDataType {
Expand All @@ -129,7 +128,7 @@ mod tests {
],
..Default::default()
}),
rex_node: Some(Constant(PbDatum { body })),
rex_node: Some(Constant(pb_datum)),
};
let expr = LiteralExpression::build_for_test(&expr).unwrap();
assert_eq!(value.to_scalar_value(), expr.literal().unwrap());
Expand Down
16 changes: 8 additions & 8 deletions src/expr/core/src/expr/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ use std::num::NonZeroUsize;

use num_traits::CheckedSub;
use risingwave_common::types::{DataType, Interval, ScalarImpl};
use risingwave_common::util::value_encoding::serialize_datum;
use risingwave_common::util::value_encoding::DatumToProtoExt;
use risingwave_pb::data::data_type::TypeName;
use risingwave_pb::data::{PbDataType, PbDatum};
use risingwave_pb::data::PbDataType;
use risingwave_pb::expr::expr_node::Type::Field;
use risingwave_pb::expr::expr_node::{self, RexNode, Type};
use risingwave_pb::expr::{ExprNode, FunctionCall};
Expand Down Expand Up @@ -57,9 +57,9 @@ pub fn make_i32_literal(data: i32) -> ExprNode {
type_name: TypeName::Int32 as i32,
..Default::default()
}),
rex_node: Some(RexNode::Constant(PbDatum {
body: serialize_datum(Some(ScalarImpl::Int32(data)).as_ref()),
})),
rex_node: Some(RexNode::Constant(
Some(ScalarImpl::Int32(data)).to_protobuf(),
)),
}
}

Expand All @@ -70,9 +70,9 @@ fn make_interval_literal(data: Interval) -> ExprNode {
type_name: TypeName::Interval as i32,
..Default::default()
}),
rex_node: Some(RexNode::Constant(PbDatum {
body: serialize_datum(Some(ScalarImpl::Interval(data)).as_ref()),
})),
rex_node: Some(RexNode::Constant(
Some(ScalarImpl::Interval(data)).to_protobuf(),
)),
}
}

Expand Down
25 changes: 10 additions & 15 deletions src/frontend/src/expr/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

use risingwave_common::array::list_array::display_for_explain;
use risingwave_common::types::{literal_type_match, DataType, Datum, ToText};
use risingwave_common::util::value_encoding::{deserialize_datum, serialize_datum};
use risingwave_pb::data::PbDatum;
use risingwave_common::util::value_encoding::{DatumFromProtoExt, DatumToProtoExt};
use risingwave_pb::expr::expr_node::RexNode;

use super::Expr;
Expand Down Expand Up @@ -121,8 +120,7 @@ impl Expr for Literal {

/// Convert a literal value (datum) into protobuf.
pub fn literal_to_value_encoding(d: &Datum) -> RexNode {
let body = serialize_datum(d.as_ref());
RexNode::Constant(PbDatum { body })
RexNode::Constant(d.to_protobuf())
}

/// Convert protobuf into a literal value (datum).
Expand All @@ -132,7 +130,7 @@ fn value_encoding_to_literal(
) -> risingwave_common::error::Result<Datum> {
if let Some(rex_node) = proto {
if let RexNode::Constant(prost_datum) = rex_node {
let datum = deserialize_datum(prost_datum.body.as_ref(), ty)?;
let datum = Datum::from_protobuf(prost_datum, ty)?;
Ok(datum)
} else {
unreachable!()
Expand All @@ -145,8 +143,8 @@ fn value_encoding_to_literal(
#[cfg(test)]
mod tests {
use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_common::util::value_encoding::deserialize_datum;
use risingwave_common::types::{DataType, Datum, ScalarImpl};
use risingwave_common::util::value_encoding::DatumFromProtoExt;
use risingwave_pb::expr::expr_node::RexNode;

use crate::expr::literal::literal_to_value_encoding;
Expand All @@ -161,8 +159,8 @@ mod tests {
let data = Some(ScalarImpl::Struct(value.clone()));
let node = literal_to_value_encoding(&data);
if let RexNode::Constant(prost) = node {
let data2 = deserialize_datum(
prost.get_body().as_slice(),
let data2 = Datum::from_protobuf(
&prost,
&DataType::new_struct(
vec![DataType::Varchar, DataType::Int32, DataType::Int32],
vec![],
Expand All @@ -184,12 +182,9 @@ mod tests {
let data = Some(ScalarImpl::List(value.clone()));
let node = literal_to_value_encoding(&data);
if let RexNode::Constant(prost) = node {
let data2 = deserialize_datum(
prost.get_body().as_slice(),
&DataType::List(Box::new(DataType::Varchar)),
)
.unwrap()
.unwrap();
let data2 = Datum::from_protobuf(&prost, &DataType::List(Box::new(DataType::Varchar)))
.unwrap()
.unwrap();
assert_eq!(ScalarImpl::List(value), data2);
}
}
Expand Down
7 changes: 2 additions & 5 deletions src/frontend/src/optimizer/plan_node/generic/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ use risingwave_common::catalog::{Field, FieldDisplay, Schema};
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType};
use risingwave_common::util::value_encoding;
use risingwave_common::util::value_encoding::DatumToProtoExt;
use risingwave_expr::aggregate::{agg_kinds, AggKind};
use risingwave_expr::sig::agg::AGG_FUNC_SIG_MAP;
use risingwave_pb::data::PbDatum;
use risingwave_pb::expr::{PbAggCall, PbConstant};
use risingwave_pb::stream_plan::{agg_call_state, AggCallState as AggCallStatePb};

Expand Down Expand Up @@ -715,9 +714,7 @@ impl PlanAggCall {
.direct_args
.iter()
.map(|x| PbConstant {
datum: Some(PbDatum {
body: value_encoding::serialize_datum(x.get_data()),
}),
datum: Some(x.get_data().to_protobuf()),
r#type: Some(x.return_type().to_protobuf()),
})
.collect(),
Expand Down
Loading

0 comments on commit 8b18a7e

Please sign in to comment.