diff --git a/src/common/src/array/arrow/arrow_default.rs b/src/common/src/array/arrow/arrow_default.rs index 7944c687827f4..5d04527b354ba 100644 --- a/src/common/src/array/arrow/arrow_default.rs +++ b/src/common/src/array/arrow/arrow_default.rs @@ -18,7 +18,11 @@ //! //! The corresponding version of arrow is currently used by `udf` and `iceberg` sink. -pub use arrow_impl::to_record_batch_with_schema; +#![allow(unused_imports)] +pub use arrow_impl::{ + to_record_batch_with_schema, ToArrowArrayConvert, ToArrowArrayWithTypeConvert, + ToArrowTypeConvert, +}; use {arrow_array, arrow_buffer, arrow_cast, arrow_schema}; #[expect(clippy::duplicate_mod)] diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index 39991bacc48d8..f22ee55e9eece 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -124,6 +124,497 @@ impl TryFrom<&arrow_array::RecordBatch> for DataChunk { } } +/// Provides the default conversion logic for RisingWave array to Arrow array with type info. +pub trait ToArrowArrayWithTypeConvert { + fn to_arrow_with_type( + &self, + data_type: &arrow_schema::DataType, + array: &ArrayImpl, + ) -> Result { + match array { + ArrayImpl::Int16(array) => self.int16_to_arrow(data_type, array), + ArrayImpl::Int32(array) => self.int32_to_arrow(data_type, array), + ArrayImpl::Int64(array) => self.int64_to_arrow(data_type, array), + ArrayImpl::Float32(array) => self.float32_to_arrow(data_type, array), + ArrayImpl::Float64(array) => self.float64_to_arrow(data_type, array), + ArrayImpl::Utf8(array) => self.utf8_to_arrow(data_type, array), + ArrayImpl::Bool(array) => self.bool_to_arrow(data_type, array), + ArrayImpl::Decimal(array) => self.decimal_to_arrow(data_type, array), + ArrayImpl::Int256(array) => self.int256_to_arrow(data_type, array), + ArrayImpl::Date(array) => self.date_to_arrow(data_type, array), + ArrayImpl::Timestamp(array) => self.timestamp_to_arrow(data_type, array), + ArrayImpl::Timestamptz(array) => self.timestamptz_to_arrow(data_type, array), + ArrayImpl::Time(array) => self.time_to_arrow(data_type, array), + ArrayImpl::Interval(array) => self.interval_to_arrow(data_type, array), + ArrayImpl::Struct(array) => self.struct_to_arrow(data_type, array), + ArrayImpl::List(array) => self.list_to_arrow(data_type, array), + ArrayImpl::Bytea(array) => self.bytea_to_arrow(data_type, array), + ArrayImpl::Jsonb(array) => self.jsonb_to_arrow(data_type, array), + ArrayImpl::Serial(array) => self.serial_to_arrow(data_type, array), + } + } + + #[inline] + fn int16_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &I16Array, + ) -> Result { + Ok(Arc::new(arrow_array::Int16Array::from(array))) + } + + #[inline] + fn int32_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &I32Array, + ) -> Result { + Ok(Arc::new(arrow_array::Int32Array::from(array))) + } + + #[inline] + fn int64_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &I64Array, + ) -> Result { + Ok(Arc::new(arrow_array::Int64Array::from(array))) + } + + #[inline] + fn float32_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &F32Array, + ) -> Result { + Ok(Arc::new(arrow_array::Float32Array::from(array))) + } + + #[inline] + fn float64_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &F64Array, + ) -> Result { + Ok(Arc::new(arrow_array::Float64Array::from(array))) + } + + #[inline] + fn utf8_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &Utf8Array, + ) -> Result { + Ok(Arc::new(arrow_array::StringArray::from(array))) + } + + #[inline] + fn bool_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &BoolArray, + ) -> Result { + Ok(Arc::new(arrow_array::BooleanArray::from(array))) + } + + // Decimal values are stored as ASCII text representation in a large binary array. + #[inline] + fn decimal_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &DecimalArray, + ) -> Result { + Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) + } + + #[inline] + fn int256_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &Int256Array, + ) -> Result { + Ok(Arc::new(arrow_array::Decimal256Array::from(array))) + } + + #[inline] + fn date_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &DateArray, + ) -> Result { + Ok(Arc::new(arrow_array::Date32Array::from(array))) + } + + #[inline] + fn timestamp_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &TimestampArray, + ) -> Result { + Ok(Arc::new(arrow_array::TimestampMicrosecondArray::from( + array, + ))) + } + + #[inline] + fn timestamptz_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &TimestamptzArray, + ) -> Result { + Ok(Arc::new( + arrow_array::TimestampMicrosecondArray::from(array).with_timezone_utc(), + )) + } + + #[inline] + fn time_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &TimeArray, + ) -> Result { + Ok(Arc::new(arrow_array::Time64MicrosecondArray::from(array))) + } + + #[inline] + fn interval_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &IntervalArray, + ) -> Result { + Ok(Arc::new(arrow_array::IntervalMonthDayNanoArray::from( + array, + ))) + } + + #[inline] + fn struct_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &StructArray, + ) -> Result { + Ok(Arc::new(arrow_array::StructArray::try_from(array)?)) + } + + #[inline] + fn list_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &ListArray, + ) -> Result { + Ok(Arc::new(arrow_array::ListArray::try_from(array)?)) + } + + #[inline] + fn bytea_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &BytesArray, + ) -> Result { + Ok(Arc::new(arrow_array::BinaryArray::from(array))) + } + + // JSON values are stored as text representation in a large string array. + #[inline] + fn jsonb_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &JsonbArray, + ) -> Result { + Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + } + + #[inline] + fn serial_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + _array: &SerialArray, + ) -> Result { + todo!("serial type is not supported to convert to arrow") + } +} + +/// Provides the default conversion logic for RisingWave array to Arrow array with type info. +pub trait ToArrowArrayConvert { + fn to_arrow(&self, array: &ArrayImpl) -> Result { + match array { + ArrayImpl::Int16(array) => self.int16_to_arrow(array), + ArrayImpl::Int32(array) => self.int32_to_arrow(array), + ArrayImpl::Int64(array) => self.int64_to_arrow(array), + ArrayImpl::Float32(array) => self.float32_to_arrow(array), + ArrayImpl::Float64(array) => self.float64_to_arrow(array), + ArrayImpl::Utf8(array) => self.utf8_to_arrow(array), + ArrayImpl::Bool(array) => self.bool_to_arrow(array), + ArrayImpl::Decimal(array) => self.decimal_to_arrow(array), + ArrayImpl::Int256(array) => self.int256_to_arrow(array), + ArrayImpl::Date(array) => self.date_to_arrow(array), + ArrayImpl::Timestamp(array) => self.timestamp_to_arrow(array), + ArrayImpl::Timestamptz(array) => self.timestamptz_to_arrow(array), + ArrayImpl::Time(array) => self.time_to_arrow(array), + ArrayImpl::Interval(array) => self.interval_to_arrow(array), + ArrayImpl::Struct(array) => self.struct_to_arrow(array), + ArrayImpl::List(array) => self.list_to_arrow(array), + ArrayImpl::Bytea(array) => self.bytea_to_arrow(array), + ArrayImpl::Jsonb(array) => self.jsonb_to_arrow(array), + ArrayImpl::Serial(array) => self.serial_to_arrow(array), + } + } + + #[inline] + fn int16_to_arrow(&self, array: &I16Array) -> Result { + Ok(Arc::new(arrow_array::Int16Array::from(array))) + } + + #[inline] + fn int32_to_arrow(&self, array: &I32Array) -> Result { + Ok(Arc::new(arrow_array::Int32Array::from(array))) + } + + #[inline] + fn int64_to_arrow(&self, array: &I64Array) -> Result { + Ok(Arc::new(arrow_array::Int64Array::from(array))) + } + + #[inline] + fn float32_to_arrow(&self, array: &F32Array) -> Result { + Ok(Arc::new(arrow_array::Float32Array::from(array))) + } + + #[inline] + fn float64_to_arrow(&self, array: &F64Array) -> Result { + Ok(Arc::new(arrow_array::Float64Array::from(array))) + } + + #[inline] + fn utf8_to_arrow(&self, array: &Utf8Array) -> Result { + Ok(Arc::new(arrow_array::StringArray::from(array))) + } + + #[inline] + fn bool_to_arrow(&self, array: &BoolArray) -> Result { + Ok(Arc::new(arrow_array::BooleanArray::from(array))) + } + + // Decimal values are stored as ASCII text representation in a large binary array. + #[inline] + fn decimal_to_arrow(&self, array: &DecimalArray) -> Result { + Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) + } + + #[inline] + fn int256_to_arrow(&self, array: &Int256Array) -> Result { + Ok(Arc::new(arrow_array::Decimal256Array::from(array))) + } + + #[inline] + fn date_to_arrow(&self, array: &DateArray) -> Result { + Ok(Arc::new(arrow_array::Date32Array::from(array))) + } + + #[inline] + fn timestamp_to_arrow( + &self, + array: &TimestampArray, + ) -> Result { + Ok(Arc::new(arrow_array::TimestampMicrosecondArray::from( + array, + ))) + } + + #[inline] + fn timestamptz_to_arrow( + &self, + array: &TimestamptzArray, + ) -> Result { + Ok(Arc::new( + arrow_array::TimestampMicrosecondArray::from(array).with_timezone_utc(), + )) + } + + #[inline] + fn time_to_arrow(&self, array: &TimeArray) -> Result { + Ok(Arc::new(arrow_array::Time64MicrosecondArray::from(array))) + } + + #[inline] + fn interval_to_arrow( + &self, + array: &IntervalArray, + ) -> Result { + Ok(Arc::new(arrow_array::IntervalMonthDayNanoArray::from( + array, + ))) + } + + #[inline] + fn struct_to_arrow(&self, array: &StructArray) -> Result { + Ok(Arc::new(arrow_array::StructArray::try_from(array)?)) + } + + #[inline] + fn list_to_arrow(&self, array: &ListArray) -> Result { + Ok(Arc::new(arrow_array::ListArray::try_from(array)?)) + } + + #[inline] + fn bytea_to_arrow(&self, array: &BytesArray) -> Result { + Ok(Arc::new(arrow_array::BinaryArray::from(array))) + } + + // JSON values are stored as text representation in a large string array. + #[inline] + fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result { + Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + } + + #[inline] + fn serial_to_arrow(&self, _array: &SerialArray) -> Result { + todo!("serial type is not supported to convert to arrow") + } +} + +pub trait ToArrowTypeConvert { + fn to_arrow_type(&self, value: &DataType) -> Result { + match value { + // using the inline function + DataType::Boolean => Ok(self.bool_type_to_arrow()), + DataType::Int16 => Ok(self.int16_type_to_arrow()), + DataType::Int32 => Ok(self.int32_type_to_arrow()), + DataType::Int64 => Ok(self.int64_type_to_arrow()), + DataType::Int256 => Ok(self.int256_type_to_arrow()), + DataType::Float32 => Ok(self.float32_type_to_arrow()), + DataType::Float64 => Ok(self.float64_type_to_arrow()), + DataType::Date => Ok(self.date_type_to_arrow()), + DataType::Timestamp => Ok(self.timestamp_type_to_arrow()), + DataType::Timestamptz => Ok(self.timestamptz_type_to_arrow()), + DataType::Time => Ok(self.time_type_to_arrow()), + DataType::Interval => Ok(self.interval_type_to_arrow()), + DataType::Varchar => Ok(self.varchar_type_to_arrow()), + DataType::Jsonb => Ok(self.jsonb_type_to_arrow()), + DataType::Bytea => Ok(self.bytea_type_to_arrow()), + DataType::Decimal => Ok(self.decimal_type_to_arrow()), + DataType::Serial => Ok(self.serial_type_to_arrow()), + DataType::Struct(fields) => self.struct_type_to_arrow(fields), + DataType::List(datatype) => self.list_type_to_arrow(datatype), + } + } + + #[inline] + fn bool_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Boolean + } + + #[inline] + fn int32_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Int32 + } + + #[inline] + fn int64_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Int64 + } + + // generate function for each type for me using inline + #[inline] + fn int16_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Int16 + } + + #[inline] + fn int256_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Decimal256(arrow_schema::DECIMAL256_MAX_PRECISION, 0) + } + + #[inline] + fn float32_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Float32 + } + + #[inline] + fn float64_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Float64 + } + + #[inline] + fn date_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Date32 + } + + #[inline] + fn timestamp_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None) + } + + #[inline] + fn timestamptz_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Timestamp( + arrow_schema::TimeUnit::Microsecond, + Some("+00:00".into()), + ) + } + + #[inline] + fn time_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Time64(arrow_schema::TimeUnit::Microsecond) + } + + #[inline] + fn interval_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::MonthDayNano) + } + + #[inline] + fn varchar_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Utf8 + } + + #[inline] + fn jsonb_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::LargeUtf8 + } + + #[inline] + fn bytea_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Binary + } + + #[inline] + fn decimal_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::LargeBinary + } + + #[inline] + fn serial_type_to_arrow(&self) -> arrow_schema::DataType { + todo!("serial type is not supported to convert to arrow") + } + + #[inline] + fn list_type_to_arrow( + &self, + datatype: &DataType, + ) -> Result { + Ok(arrow_schema::DataType::List(Arc::new( + arrow_schema::Field::new("item", datatype.try_into()?, true), + ))) + } + + #[inline] + fn struct_type_to_arrow( + &self, + fields: &StructType, + ) -> Result { + Ok(arrow_schema::DataType::Struct( + fields + .iter() + .map(|(name, ty)| Ok(arrow_schema::Field::new(name, ty.try_into()?, true))) + .try_collect::<_, _, ArrayError>()?, + )) + } +} + +struct DefaultArrowConvert; +impl ToArrowArrayConvert for DefaultArrowConvert {} + /// Implement bi-directional `From` between `ArrayImpl` and `arrow_array::ArrayRef`. macro_rules! converts_generic { ($({ $ArrowType:ty, $ArrowPattern:pat, $ArrayImplPattern:path }),*) => { @@ -131,11 +622,7 @@ macro_rules! converts_generic { impl TryFrom<&ArrayImpl> for arrow_array::ArrayRef { type Error = ArrayError; fn try_from(array: &ArrayImpl) -> Result { - match array { - $($ArrayImplPattern(a) => Ok(Arc::new(<$ArrowType>::try_from(a)?)),)* - ArrayImpl::Timestamptz(a) => Ok(Arc::new(arrow_array::TimestampMicrosecondArray::try_from(a)?. with_timezone_utc())), - _ => todo!("unsupported array"), - } + DefaultArrowConvert{}.to_arrow(array) } } // Arrow array -> RisingWave array @@ -256,45 +743,15 @@ impl From for DataType { } } +struct DefaultArrowTypeConvert; + +impl ToArrowTypeConvert for DefaultArrowTypeConvert {} + impl TryFrom<&DataType> for arrow_schema::DataType { type Error = ArrayError; fn try_from(value: &DataType) -> Result { - match value { - DataType::Boolean => Ok(Self::Boolean), - DataType::Int16 => Ok(Self::Int16), - DataType::Int32 => Ok(Self::Int32), - DataType::Int64 => Ok(Self::Int64), - DataType::Int256 => Ok(Self::Decimal256(arrow_schema::DECIMAL256_MAX_PRECISION, 0)), - DataType::Float32 => Ok(Self::Float32), - DataType::Float64 => Ok(Self::Float64), - DataType::Date => Ok(Self::Date32), - DataType::Timestamp => Ok(Self::Timestamp(arrow_schema::TimeUnit::Microsecond, None)), - DataType::Timestamptz => Ok(Self::Timestamp( - arrow_schema::TimeUnit::Microsecond, - Some("+00:00".into()), - )), - DataType::Time => Ok(Self::Time64(arrow_schema::TimeUnit::Microsecond)), - DataType::Interval => Ok(Self::Interval(arrow_schema::IntervalUnit::MonthDayNano)), - DataType::Varchar => Ok(Self::Utf8), - DataType::Jsonb => Ok(Self::LargeUtf8), - DataType::Bytea => Ok(Self::Binary), - DataType::Decimal => Ok(Self::LargeBinary), - DataType::Struct(struct_type) => Ok(Self::Struct( - struct_type - .iter() - .map(|(name, ty)| Ok(arrow_schema::Field::new(name, ty.try_into()?, true))) - .try_collect::<_, _, ArrayError>()?, - )), - DataType::List(datatype) => Ok(Self::List(Arc::new(arrow_schema::Field::new( - "item", - datatype.as_ref().try_into()?, - true, - )))), - DataType::Serial => Err(ArrayError::to_arrow( - "Serial type is not supported to convert to arrow", - )), - } + DefaultArrowTypeConvert {}.to_arrow_type(value) } } @@ -525,17 +982,21 @@ impl TryFrom<&arrow_array::Decimal128Array> for DecimalArray { } let from_arrow = |value| { const NAN: i128 = i128::MIN + 1; - match value { + let res = match value { NAN => Decimal::NaN, i128::MAX => Decimal::PositiveInf, i128::MIN => Decimal::NegativeInf, - _ => Decimal::Normalized(rust_decimal::Decimal::from_i128_with_scale( - value, - array.scale() as u32, - )), - } + _ => Decimal::Normalized( + rust_decimal::Decimal::try_from_i128_with_scale(value, array.scale() as u32) + .map_err(ArrayError::internal)?, + ), + }; + Ok(res) }; - Ok(array.iter().map(|o| o.map(from_arrow)).collect()) + array + .iter() + .map(|o| o.map(from_arrow).transpose()) + .collect::>() } }