diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index 35057f62f7740..3d7ec1110fbac 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -54,6 +54,9 @@ use crate::types::*; use crate::util::iter_util::ZipEqFast; /// Defines how to convert RisingWave arrays to Arrow arrays. +/// +/// This trait allows for customized conversion logic for different external systems using Arrow. +/// The default implementation is based on the `From` implemented in this mod. pub trait ToArrow { /// Converts RisingWave `DataChunk` to Arrow `RecordBatch` with specified schema. /// @@ -767,7 +770,7 @@ converts!(IntervalArray, arrow_array::IntervalMonthDayNanoArray, @map); converts!(SerialArray, arrow_array::Int64Array, @map); /// Converts RisingWave value from and into Arrow value. -pub trait FromIntoArrow { +trait FromIntoArrow { /// The corresponding element type in the Arrow array. type ArrowType; fn from_arrow(value: Self::ArrowType) -> Self; diff --git a/src/common/src/array/proto_reader.rs b/src/common/src/array/proto_reader.rs index 073ad0b3de7ba..7c3b05437770c 100644 --- a/src/common/src/array/proto_reader.rs +++ b/src/common/src/array/proto_reader.rs @@ -26,6 +26,7 @@ impl ArrayImpl { pub fn from_protobuf(array: &PbArray, cardinality: usize) -> ArrayResult { use crate::array::value_reader::*; let array = match array.array_type() { + PbArrayType::Unspecified => unreachable!(), PbArrayType::Int16 => read_numeric_array::(array, cardinality)?, PbArrayType::Int32 => read_numeric_array::(array, cardinality)?, PbArrayType::Int64 => read_numeric_array::(array, cardinality)?, @@ -49,7 +50,6 @@ impl ArrayImpl { PbArrayType::Jsonb => JsonbArray::from_protobuf(array)?, PbArrayType::Struct => StructArray::from_protobuf(array)?, PbArrayType::List => ListArray::from_protobuf(array)?, - PbArrayType::Unspecified => unreachable!(), PbArrayType::Bytea => { read_string_array::(array, cardinality)? } diff --git a/src/common/src/types/macros.rs b/src/common/src/types/macros.rs index 35f106aafdffd..520e4ab8f45ee 100644 --- a/src/common/src/types/macros.rs +++ b/src/common/src/types/macros.rs @@ -39,6 +39,7 @@ macro_rules! for_all_variants { ($macro:ident $(, $x:tt)*) => { $macro! { $($x, )* + //data_type variant_name suffix_name scalar scalar_ref array builder { Int16, Int16, int16, i16, i16, $crate::array::I16Array, $crate::array::I16ArrayBuilder }, { Int32, Int32, int32, i32, i32, $crate::array::I32Array, $crate::array::I32ArrayBuilder }, { Int64, Int64, int64, i64, i64, $crate::array::I64Array, $crate::array::I64ArrayBuilder }, diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 3b02b8c38d020..91bebde846f0b 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -179,67 +179,39 @@ impl std::str::FromStr for Box { impl ZeroHeapSize for DataType {} -impl DataTypeName { - pub fn is_scalar(&self) -> bool { - match self { - DataTypeName::Boolean - | DataTypeName::Int16 - | DataTypeName::Int32 - | DataTypeName::Int64 - | DataTypeName::Int256 - | DataTypeName::Serial - | DataTypeName::Decimal - | DataTypeName::Float32 - | DataTypeName::Float64 - | DataTypeName::Varchar - | DataTypeName::Date - | DataTypeName::Timestamp - | DataTypeName::Timestamptz - | DataTypeName::Time - | DataTypeName::Bytea - | DataTypeName::Jsonb - | DataTypeName::Interval => true, - - DataTypeName::Struct | DataTypeName::List => false, - } - } +impl TryFrom for DataType { + type Error = &'static str; - pub fn to_type(self) -> Option { - let t = match self { - DataTypeName::Boolean => DataType::Boolean, - DataTypeName::Int16 => DataType::Int16, - DataTypeName::Int32 => DataType::Int32, - DataTypeName::Int64 => DataType::Int64, - DataTypeName::Int256 => DataType::Int256, - DataTypeName::Serial => DataType::Serial, - DataTypeName::Decimal => DataType::Decimal, - DataTypeName::Float32 => DataType::Float32, - DataTypeName::Float64 => DataType::Float64, - DataTypeName::Varchar => DataType::Varchar, - DataTypeName::Bytea => DataType::Bytea, - DataTypeName::Date => DataType::Date, - DataTypeName::Timestamp => DataType::Timestamp, - DataTypeName::Timestamptz => DataType::Timestamptz, - DataTypeName::Time => DataType::Time, - DataTypeName::Interval => DataType::Interval, - DataTypeName::Jsonb => DataType::Jsonb, + fn try_from(type_name: DataTypeName) -> Result { + match type_name { + DataTypeName::Boolean => Ok(DataType::Boolean), + DataTypeName::Int16 => Ok(DataType::Int16), + DataTypeName::Int32 => Ok(DataType::Int32), + DataTypeName::Int64 => Ok(DataType::Int64), + DataTypeName::Int256 => Ok(DataType::Int256), + DataTypeName::Serial => Ok(DataType::Serial), + DataTypeName::Decimal => Ok(DataType::Decimal), + DataTypeName::Float32 => Ok(DataType::Float32), + DataTypeName::Float64 => Ok(DataType::Float64), + DataTypeName::Varchar => Ok(DataType::Varchar), + DataTypeName::Bytea => Ok(DataType::Bytea), + DataTypeName::Date => Ok(DataType::Date), + DataTypeName::Timestamp => Ok(DataType::Timestamp), + DataTypeName::Timestamptz => Ok(DataType::Timestamptz), + DataTypeName::Time => Ok(DataType::Time), + DataTypeName::Interval => Ok(DataType::Interval), + DataTypeName::Jsonb => Ok(DataType::Jsonb), DataTypeName::Struct | DataTypeName::List => { - return None; + Err("Functions returning struct or list can not be inferred. Please use `FunctionCall::new_unchecked`.") } - }; - Some(t) - } -} - -impl From for DataType { - fn from(type_name: DataTypeName) -> Self { - type_name.to_type().unwrap_or_else(|| panic!("Functions returning struct or list can not be inferred. Please use `FunctionCall::new_unchecked`.")) + } } } impl From<&PbDataType> for DataType { fn from(proto: &PbDataType) -> DataType { match proto.get_type_name().expect("missing type field") { + PbTypeName::TypeUnspecified => unreachable!(), PbTypeName::Int16 => DataType::Int16, PbTypeName::Int32 => DataType::Int32, PbTypeName::Int64 => DataType::Int64, @@ -265,7 +237,6 @@ impl From<&PbDataType> for DataType { // The first (and only) item is the list element type. Box::new((&proto.field_type[0]).into()), ), - PbTypeName::TypeUnspecified => unreachable!(), PbTypeName::Int256 => DataType::Int256, } } @@ -337,27 +308,7 @@ impl DataType { } pub fn prost_type_name(&self) -> PbTypeName { - match self { - DataType::Int16 => PbTypeName::Int16, - DataType::Int32 => PbTypeName::Int32, - DataType::Int64 => PbTypeName::Int64, - DataType::Int256 => PbTypeName::Int256, - DataType::Serial => PbTypeName::Serial, - DataType::Float32 => PbTypeName::Float, - DataType::Float64 => PbTypeName::Double, - DataType::Boolean => PbTypeName::Boolean, - DataType::Varchar => PbTypeName::Varchar, - DataType::Date => PbTypeName::Date, - DataType::Time => PbTypeName::Time, - DataType::Timestamp => PbTypeName::Timestamp, - DataType::Timestamptz => PbTypeName::Timestamptz, - DataType::Decimal => PbTypeName::Decimal, - DataType::Interval => PbTypeName::Interval, - DataType::Jsonb => PbTypeName::Jsonb, - DataType::Struct { .. } => PbTypeName::Struct, - DataType::List { .. } => PbTypeName::List, - DataType::Bytea => PbTypeName::Bytea, - } + DataTypeName::from(self).into() } pub fn to_protobuf(&self) -> PbDataType { @@ -374,7 +325,23 @@ impl DataType { DataType::List(datatype) => { pb.field_type = vec![datatype.to_protobuf()]; } - _ => {} + DataType::Boolean + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal + | DataType::Date + | DataType::Varchar + | DataType::Time + | DataType::Timestamp + | DataType::Timestamptz + | DataType::Interval + | DataType::Bytea + | DataType::Jsonb + | DataType::Serial + | DataType::Int256 => (), } pb } @@ -392,10 +359,6 @@ impl DataType { ) } - pub fn is_scalar(&self) -> bool { - DataTypeName::from(self).is_scalar() - } - pub fn is_array(&self) -> bool { matches!(self, DataType::List(_)) } @@ -440,37 +403,6 @@ impl DataType { } } - /// WARNING: Currently this should only be used in `WatermarkFilterExecutor`. Please be careful - /// if you want to use this. - pub fn min_value(&self) -> ScalarImpl { - match self { - DataType::Int16 => ScalarImpl::Int16(i16::MIN), - DataType::Int32 => ScalarImpl::Int32(i32::MIN), - DataType::Int64 => ScalarImpl::Int64(i64::MIN), - DataType::Int256 => ScalarImpl::Int256(Int256::min_value()), - DataType::Serial => ScalarImpl::Serial(Serial::from(i64::MIN)), - DataType::Float32 => ScalarImpl::Float32(F32::neg_infinity()), - DataType::Float64 => ScalarImpl::Float64(F64::neg_infinity()), - DataType::Boolean => ScalarImpl::Bool(false), - DataType::Varchar => ScalarImpl::Utf8("".into()), - DataType::Bytea => ScalarImpl::Bytea("".to_string().into_bytes().into()), - DataType::Date => ScalarImpl::Date(Date::MIN), - DataType::Time => ScalarImpl::Time(Time::from_hms_uncheck(0, 0, 0)), - DataType::Timestamp => ScalarImpl::Timestamp(Timestamp::MIN), - DataType::Timestamptz => ScalarImpl::Timestamptz(Timestamptz::MIN), - DataType::Decimal => ScalarImpl::Decimal(Decimal::NegativeInf), - DataType::Interval => ScalarImpl::Interval(Interval::MIN), - DataType::Jsonb => ScalarImpl::Jsonb(JsonbVal::null()), // NOT `min` #7981 - DataType::Struct(data_types) => ScalarImpl::Struct(StructValue::new( - data_types - .types() - .map(|data_type| Some(data_type.min_value())) - .collect_vec(), - )), - DataType::List(data_type) => ScalarImpl::List(ListValue::empty(data_type)), - } - } - /// Return a new type that removes the outer list. /// /// ``` @@ -513,28 +445,32 @@ impl From for PbDataType { } } -/// Common trait bounds of scalar and scalar reference types. -/// -/// NOTE(rc): `Hash` is not in the trait bound list, it's implemented as [`ScalarRef::hash_scalar`]. -pub trait ScalarBounds = Debug - + Send - + Sync - + Clone - + PartialEq - + Eq - // in default ascending order - + PartialOrd - + Ord - + TryFrom - // `ScalarImpl`/`ScalarRefImpl` - + Into; +mod private { + use super::*; + + /// Common trait bounds of scalar and scalar reference types. + /// + /// NOTE(rc): `Hash` is not in the trait bound list, it's implemented as [`ScalarRef::hash_scalar`]. + pub trait ScalarBounds = Debug + + Send + + Sync + + Clone + + PartialEq + + Eq + // in default ascending order + + PartialOrd + + Ord + + TryFrom + // `ScalarImpl`/`ScalarRefImpl` + + Into; +} /// `Scalar` is a trait over all possible owned types in the evaluation /// framework. /// /// `Scalar` is reciprocal to `ScalarRef`. Use `as_scalar_ref` to get a /// reference which has the same lifetime as `self`. -pub trait Scalar: ScalarBounds + 'static { +pub trait Scalar: private::ScalarBounds + 'static { /// Type for reference of `Scalar` type ScalarRefType<'a>: ScalarRef<'a, ScalarType = Self> + 'a where @@ -548,17 +484,12 @@ pub trait Scalar: ScalarBounds + 'static { } } -/// Convert an `Option` to corresponding `Option`. -pub fn option_as_scalar_ref(scalar: &Option) -> Option> { - scalar.as_ref().map(|x| x.as_scalar_ref()) -} - /// `ScalarRef` is a trait over all possible references in the evaluation /// framework. /// /// `ScalarRef` is reciprocal to `Scalar`. Use `to_owned_scalar` to get an /// owned scalar. -pub trait ScalarRef<'a>: ScalarBounds> + 'a + Copy { +pub trait ScalarRef<'a>: private::ScalarBounds> + 'a + Copy { /// `ScalarType` is the owned type of current `ScalarRef`. type ScalarType: Scalar = Self>; @@ -653,6 +584,9 @@ impl ToDatumRef for DatumRef<'_> { } /// To make sure there is `as_scalar_ref` for all scalar ref types. +/// See +/// +/// This is used by the expr macro. pub trait SelfAsScalarRef { fn as_scalar_ref(&self) -> Self; } @@ -1021,7 +955,7 @@ impl ScalarRefImpl<'_> { } impl ScalarImpl { - /// Serialize the scalar. + /// Serialize the scalar into the `memcomparable` format. pub fn serialize( &self, ser: &mut memcomparable::Serializer, @@ -1029,7 +963,7 @@ impl ScalarImpl { self.as_scalar_ref_impl().serialize(ser) } - /// Deserialize the scalar. + /// Deserialize the scalar from the `memcomparable` format. pub fn deserialize( ty: &DataType, de: &mut memcomparable::Deserializer, diff --git a/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs b/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs index 11bcabcde0f69..d5b1332c25b3e 100644 --- a/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs +++ b/src/frontend/src/catalog/system_catalog/pg_catalog/pg_cast.rs @@ -38,8 +38,8 @@ fn read_pg_cast(_: &SysCatalogReaderImpl) -> Vec { .enumerate() .map(|(idx, (src, target, ctx))| PgCast { oid: idx as i32, - castsource: DataType::from(*src).to_oid(), - casttarget: DataType::from(*target).to_oid(), + castsource: DataType::try_from(*src).unwrap().to_oid(), + casttarget: DataType::try_from(*target).unwrap().to_oid(), castcontext: ctx.to_string(), }) .collect()