diff --git a/Cargo.lock b/Cargo.lock index 950a49e1b2296..d1de130adfc0b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7022,6 +7022,7 @@ dependencies = [ "arrow-schema", "async-trait", "auto_enums", + "auto_impl", "bitflags 2.4.0", "byteorder", "bytes", diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index 233945f94eeec..a9b1ef1c95d7f 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -22,6 +22,7 @@ arrow-cast = { workspace = true } arrow-schema = { workspace = true } async-trait = "0.1" auto_enums = "0.8" +auto_impl = "1" bitflags = "2" byteorder = "1" bytes = "1" diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 77738b0bb0bdc..946030b05ef76 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -558,6 +558,7 @@ impl ToOwnedDatum for DatumRef<'_> { } } +#[auto_impl::auto_impl(&)] pub trait ToDatumRef: PartialEq + Eq + Debug { /// Convert the datum to [`DatumRef`]. fn to_datum_ref(&self) -> DatumRef<'_>; @@ -569,12 +570,6 @@ impl ToDatumRef for Datum { self.as_ref().map(|d| d.as_scalar_ref_impl()) } } -impl ToDatumRef for &Datum { - #[inline(always)] - fn to_datum_ref(&self) -> DatumRef<'_> { - self.as_ref().map(|d| d.as_scalar_ref_impl()) - } -} impl ToDatumRef for Option<&ScalarImpl> { #[inline(always)] fn to_datum_ref(&self) -> DatumRef<'_> { diff --git a/src/common/src/util/memcmp_encoding.rs b/src/common/src/util/memcmp_encoding.rs index 8593071e18c71..58ad76900b081 100644 --- a/src/common/src/util/memcmp_encoding.rs +++ b/src/common/src/util/memcmp_encoding.rs @@ -430,14 +430,14 @@ mod tests { fn test_memcomparable_structs() { // NOTE: `NULL`s inside composite type values are always the largest. - let struct_none = None; - let struct_1 = Some( + let struct_none = Datum::None; + let struct_1 = Datum::Some( StructValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(2))]).into(), ); - let struct_2 = Some( + let struct_2 = Datum::Some( StructValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(3))]).into(), ); - let struct_3 = Some(StructValue::new(vec![Some(ScalarImpl::from(1)), None]).into()); + let struct_3 = Datum::Some(StructValue::new(vec![Some(ScalarImpl::from(1)), None]).into()); { // ASC NULLS FIRST (NULLS SMALLEST) @@ -489,12 +489,14 @@ mod tests { fn test_memcomparable_lists() { // NOTE: `NULL`s inside composite type values are always the largest. - let list_none = None; - let list_1 = - Some(ListValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(2))]).into()); - let list_2 = - Some(ListValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(3))]).into()); - let list_3 = Some(ListValue::new(vec![Some(ScalarImpl::from(1)), None]).into()); + let list_none = Datum::None; + let list_1 = Datum::Some( + ListValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(2))]).into(), + ); + let list_2 = Datum::Some( + ListValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(3))]).into(), + ); + let list_3 = Datum::Some(ListValue::new(vec![Some(ScalarImpl::from(1)), None]).into()); { // ASC NULLS FIRST (NULLS SMALLEST) diff --git a/src/common/src/util/value_encoding/mod.rs b/src/common/src/util/value_encoding/mod.rs index f1584ae799181..e3c4386f39a20 100644 --- a/src/common/src/util/value_encoding/mod.rs +++ b/src/common/src/util/value_encoding/mod.rs @@ -18,6 +18,7 @@ use bytes::{Buf, BufMut}; use chrono::{Datelike, Timelike}; use either::{for_both, Either}; use enum_as_inner::EnumAsInner; +use risingwave_pb::data::PbDatum; use crate::array::{ArrayImpl, ListRef, ListValue, StructRef, StructValue}; use crate::row::{Row, RowDeserializer as BasicDeserializer}; @@ -165,6 +166,24 @@ pub fn estimate_serialize_datum_size(datum_ref: impl ToDatumRef) -> usize { } } +#[easy_ext::ext(DatumFromProtoExt)] +impl Datum { + /// Create a datum from the protobuf representation with the given data type. + pub fn from_protobuf(proto: &PbDatum, data_type: &DataType) -> Result { + deserialize_datum(proto.body.as_slice(), data_type) + } +} + +#[easy_ext::ext(DatumToProtoExt)] +impl D { + /// Convert the datum to the protobuf representation. + pub fn to_protobuf(&self) -> PbDatum { + PbDatum { + body: serialize_datum(self), + } + } +} + /// Deserialize bytes into a datum (Not order guarantee, used in value encoding). pub fn deserialize_datum(mut data: impl Buf, ty: &DataType) -> Result { inner_deserialize_datum(&mut data, ty) diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index a89c1413efd81..2d6763130cc4a 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -20,9 +20,9 @@ use std::sync::Arc; use itertools::Itertools; use parse_display::{Display, FromStr}; use risingwave_common::bail; -use risingwave_common::types::DataType; +use risingwave_common::types::{DataType, Datum}; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; -use risingwave_common::util::value_encoding; +use risingwave_common::util::value_encoding::DatumFromProtoExt; use risingwave_pb::expr::agg_call::PbType; use risingwave_pb::expr::{PbAggCall, PbInputRef}; @@ -78,11 +78,7 @@ impl AggCall { let data_type = DataType::from(arg.get_type().unwrap()); LiteralExpression::new( data_type.clone(), - value_encoding::deserialize_datum( - arg.get_datum().unwrap().get_body().as_slice(), - &data_type, - ) - .unwrap(), + Datum::from_protobuf(arg.get_datum().unwrap(), &data_type).unwrap(), ) }) .collect_vec(); diff --git a/src/expr/core/src/expr/expr_field.rs b/src/expr/core/src/expr/expr_field.rs index eb4183edd2f57..a4101301308ed 100644 --- a/src/expr/core/src/expr/expr_field.rs +++ b/src/expr/core/src/expr/expr_field.rs @@ -16,7 +16,7 @@ use anyhow::anyhow; use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk}; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum, ScalarImpl}; -use risingwave_common::util::value_encoding::deserialize_datum; +use risingwave_common::util::value_encoding::DatumFromProtoExt; use risingwave_pb::expr::expr_node::{RexNode, Type}; use risingwave_pb::expr::ExprNode; @@ -89,7 +89,7 @@ impl Build for FieldExpression { let RexNode::Constant(value) = second.get_rex_node().unwrap() else { bail!("Expected Constant as 1st argument"); }; - let index = deserialize_datum(value.body.as_slice(), &DataType::Int32) + let index = Datum::from_protobuf(value, &DataType::Int32) .map_err(|e| anyhow!("Failed to deserialize i32, reason: {:?}", e))? .unwrap() .as_int32() diff --git a/src/expr/core/src/expr/expr_in.rs b/src/expr/core/src/expr/expr_in.rs index f599bd4a64e42..cbc5cd244b528 100644 --- a/src/expr/core/src/expr/expr_in.rs +++ b/src/expr/core/src/expr/expr_in.rs @@ -131,10 +131,10 @@ mod tests { use risingwave_common::array::DataChunk; use risingwave_common::row::OwnedRow; use risingwave_common::test_prelude::DataChunkTestExt; - use risingwave_common::types::{DataType, ScalarImpl}; - use risingwave_common::util::value_encoding::serialize_datum; + use risingwave_common::types::{DataType, Datum, ScalarImpl}; + use risingwave_common::util::value_encoding::DatumToProtoExt; use risingwave_pb::data::data_type::TypeName; - use risingwave_pb::data::{PbDataType, PbDatum}; + use risingwave_pb::data::PbDataType; use risingwave_pb::expr::expr_node::{RexNode, Type}; use risingwave_pb::expr::{ExprNode, FunctionCall}; @@ -158,9 +158,7 @@ mod tests { type_name: TypeName::Varchar as i32, ..Default::default() }), - rex_node: Some(RexNode::Constant(PbDatum { - body: serialize_datum(Some("ABC".into()).as_ref()), - })), + rex_node: Some(RexNode::Constant(Datum::Some("ABC".into()).to_protobuf())), }, ExprNode { function_type: Type::Unspecified as i32, @@ -168,9 +166,7 @@ mod tests { type_name: TypeName::Varchar as i32, ..Default::default() }), - rex_node: Some(RexNode::Constant(PbDatum { - body: serialize_datum(Some("def".into()).as_ref()), - })), + rex_node: Some(RexNode::Constant(Datum::Some("def".into()).to_protobuf())), }, ]; let mut in_children = vec![input_ref_expr_node]; diff --git a/src/expr/core/src/expr/expr_literal.rs b/src/expr/core/src/expr/expr_literal.rs index 4009fc346cc41..54202ba732d3e 100644 --- a/src/expr/core/src/expr/expr_literal.rs +++ b/src/expr/core/src/expr/expr_literal.rs @@ -15,7 +15,7 @@ use risingwave_common::array::DataChunk; use risingwave_common::row::OwnedRow; use risingwave_common::types::{literal_type_match, DataType, Datum}; -use risingwave_common::util::value_encoding::deserialize_datum; +use risingwave_common::util::value_encoding::DatumFromProtoExt; use risingwave_pb::expr::ExprNode; use super::{Build, ValueImpl}; @@ -74,9 +74,8 @@ impl Build for LiteralExpression { let prost_value = prost.get_rex_node().unwrap().as_constant().unwrap(); - // TODO: We need to unify these - let value = deserialize_datum( - prost_value.get_body().as_slice(), + let value = Datum::from_protobuf( + prost_value, &DataType::from(prost.get_return_type().unwrap()), ) .map_err(|e| ExprError::Internal(e.into()))?; @@ -92,7 +91,7 @@ mod tests { use risingwave_common::array::{I32Array, StructValue}; use risingwave_common::types::test_utils::IntervalTestExt; use risingwave_common::types::{Decimal, Interval, IntoOrdered, Scalar, ScalarImpl}; - use risingwave_common::util::value_encoding::serialize_datum; + use risingwave_common::util::value_encoding::{serialize_datum, DatumToProtoExt}; use risingwave_pb::data::data_type::{IntervalType, TypeName}; use risingwave_pb::data::{PbDataType, PbDatum}; use risingwave_pb::expr::expr_node::RexNode::{self, Constant}; @@ -108,7 +107,7 @@ mod tests { Some(2.into()), None, ]); - let body = serialize_datum(Some(value.clone().to_scalar_value()).as_ref()); + let pb_datum = Some(value.clone().to_scalar_value()).to_protobuf(); let expr = ExprNode { function_type: Type::Unspecified as i32, return_type: Some(PbDataType { @@ -129,7 +128,7 @@ mod tests { ], ..Default::default() }), - rex_node: Some(Constant(PbDatum { body })), + rex_node: Some(Constant(pb_datum)), }; let expr = LiteralExpression::build_for_test(&expr).unwrap(); assert_eq!(value.to_scalar_value(), expr.literal().unwrap()); diff --git a/src/expr/core/src/expr/test_utils.rs b/src/expr/core/src/expr/test_utils.rs index d276413c02678..56ebcdfddf784 100644 --- a/src/expr/core/src/expr/test_utils.rs +++ b/src/expr/core/src/expr/test_utils.rs @@ -18,9 +18,9 @@ use std::num::NonZeroUsize; use num_traits::CheckedSub; use risingwave_common::types::{DataType, Interval, ScalarImpl}; -use risingwave_common::util::value_encoding::serialize_datum; +use risingwave_common::util::value_encoding::DatumToProtoExt; use risingwave_pb::data::data_type::TypeName; -use risingwave_pb::data::{PbDataType, PbDatum}; +use risingwave_pb::data::PbDataType; use risingwave_pb::expr::expr_node::Type::Field; use risingwave_pb::expr::expr_node::{self, RexNode, Type}; use risingwave_pb::expr::{ExprNode, FunctionCall}; @@ -57,9 +57,9 @@ pub fn make_i32_literal(data: i32) -> ExprNode { type_name: TypeName::Int32 as i32, ..Default::default() }), - rex_node: Some(RexNode::Constant(PbDatum { - body: serialize_datum(Some(ScalarImpl::Int32(data)).as_ref()), - })), + rex_node: Some(RexNode::Constant( + Some(ScalarImpl::Int32(data)).to_protobuf(), + )), } } @@ -70,9 +70,9 @@ fn make_interval_literal(data: Interval) -> ExprNode { type_name: TypeName::Interval as i32, ..Default::default() }), - rex_node: Some(RexNode::Constant(PbDatum { - body: serialize_datum(Some(ScalarImpl::Interval(data)).as_ref()), - })), + rex_node: Some(RexNode::Constant( + Some(ScalarImpl::Interval(data)).to_protobuf(), + )), } } diff --git a/src/frontend/src/expr/literal.rs b/src/frontend/src/expr/literal.rs index bf4c95b2114d4..2882243f93170 100644 --- a/src/frontend/src/expr/literal.rs +++ b/src/frontend/src/expr/literal.rs @@ -14,8 +14,7 @@ use risingwave_common::array::list_array::display_for_explain; use risingwave_common::types::{literal_type_match, DataType, Datum, ToText}; -use risingwave_common::util::value_encoding::{deserialize_datum, serialize_datum}; -use risingwave_pb::data::PbDatum; +use risingwave_common::util::value_encoding::{DatumFromProtoExt, DatumToProtoExt}; use risingwave_pb::expr::expr_node::RexNode; use super::Expr; @@ -121,8 +120,7 @@ impl Expr for Literal { /// Convert a literal value (datum) into protobuf. pub fn literal_to_value_encoding(d: &Datum) -> RexNode { - let body = serialize_datum(d.as_ref()); - RexNode::Constant(PbDatum { body }) + RexNode::Constant(d.to_protobuf()) } /// Convert protobuf into a literal value (datum). @@ -132,7 +130,7 @@ fn value_encoding_to_literal( ) -> risingwave_common::error::Result { if let Some(rex_node) = proto { if let RexNode::Constant(prost_datum) = rex_node { - let datum = deserialize_datum(prost_datum.body.as_ref(), ty)?; + let datum = Datum::from_protobuf(prost_datum, ty)?; Ok(datum) } else { unreachable!() @@ -145,8 +143,8 @@ fn value_encoding_to_literal( #[cfg(test)] mod tests { use risingwave_common::array::{ListValue, StructValue}; - use risingwave_common::types::{DataType, ScalarImpl}; - use risingwave_common::util::value_encoding::deserialize_datum; + use risingwave_common::types::{DataType, Datum, ScalarImpl}; + use risingwave_common::util::value_encoding::DatumFromProtoExt; use risingwave_pb::expr::expr_node::RexNode; use crate::expr::literal::literal_to_value_encoding; @@ -161,8 +159,8 @@ mod tests { let data = Some(ScalarImpl::Struct(value.clone())); let node = literal_to_value_encoding(&data); if let RexNode::Constant(prost) = node { - let data2 = deserialize_datum( - prost.get_body().as_slice(), + let data2 = Datum::from_protobuf( + &prost, &DataType::new_struct( vec![DataType::Varchar, DataType::Int32, DataType::Int32], vec![], @@ -184,12 +182,9 @@ mod tests { let data = Some(ScalarImpl::List(value.clone())); let node = literal_to_value_encoding(&data); if let RexNode::Constant(prost) = node { - let data2 = deserialize_datum( - prost.get_body().as_slice(), - &DataType::List(Box::new(DataType::Varchar)), - ) - .unwrap() - .unwrap(); + let data2 = Datum::from_protobuf(&prost, &DataType::List(Box::new(DataType::Varchar))) + .unwrap() + .unwrap(); assert_eq!(ScalarImpl::List(value), data2); } } diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index e6392a0ba14e6..4db0ac0780f62 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -22,10 +22,9 @@ use risingwave_common::catalog::{Field, FieldDisplay, Schema}; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType}; -use risingwave_common::util::value_encoding; +use risingwave_common::util::value_encoding::DatumToProtoExt; use risingwave_expr::aggregate::{agg_kinds, AggKind}; use risingwave_expr::sig::agg::AGG_FUNC_SIG_MAP; -use risingwave_pb::data::PbDatum; use risingwave_pb::expr::{PbAggCall, PbConstant}; use risingwave_pb::stream_plan::{agg_call_state, AggCallState as AggCallStatePb}; @@ -715,9 +714,7 @@ impl PlanAggCall { .direct_args .iter() .map(|x| PbConstant { - datum: Some(PbDatum { - body: value_encoding::serialize_datum(x.get_data()), - }), + datum: Some(x.get_data().to_protobuf()), r#type: Some(x.return_type().to_protobuf()), }) .collect(), diff --git a/src/stream/src/executor/mod.rs b/src/stream/src/executor/mod.rs index 663f1c0236e85..fffa62f4794f8 100644 --- a/src/stream/src/executor/mod.rs +++ b/src/stream/src/executor/mod.rs @@ -26,13 +26,13 @@ use risingwave_common::array::StreamChunk; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::Schema; use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, DefaultOrd, ScalarImpl}; +use risingwave_common::types::{DataType, Datum, DefaultOrd, ScalarImpl}; use risingwave_common::util::epoch::{Epoch, EpochPair}; use risingwave_common::util::tracing::TracingContext; -use risingwave_common::util::value_encoding::{deserialize_datum, serialize_datum}; +use risingwave_common::util::value_encoding::{DatumFromProtoExt, DatumToProtoExt}; use risingwave_connector::source::SplitImpl; use risingwave_expr::expr::BoxedExpression; -use risingwave_pb::data::{PbDatum, PbEpoch}; +use risingwave_pb::data::PbEpoch; use risingwave_pb::expr::PbInputRef; use risingwave_pb::stream_plan::barrier::{BarrierKind, PbMutation}; use risingwave_pb::stream_plan::stream_message::StreamMessage; @@ -675,16 +675,14 @@ impl Watermark { index: self.col_idx as _, r#type: Some(self.data_type.to_protobuf()), }), - val: Some(PbDatum { - body: serialize_datum(Some(&self.val)), - }), + val: Some(&self.val).to_protobuf().into(), } } pub fn from_protobuf(prost: &PbWatermark) -> StreamExecutorResult { let col_ref = prost.get_column()?; let data_type = DataType::from(col_ref.get_type()?); - let val = deserialize_datum(prost.get_val()?.get_body().as_slice(), &data_type)? + let val = Datum::from_protobuf(prost.get_val()?, &data_type)? .expect("watermark value cannot be null"); Ok(Self::new(col_ref.get_index() as _, data_type, val)) }