From b1f3cf66686cb3f14cd57554e6c3288057856384 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 7 May 2024 15:56:07 +0800 Subject: [PATCH] introduce NewUdfArrowConvert Signed-off-by: Runji Wang --- src/common/src/array/arrow/arrow_impl.rs | 91 +++++++++++++++---- .../arrow/{arrow_default.rs => arrow_udf.rs} | 69 +++++++++++++- src/common/src/array/arrow/mod.rs | 4 +- src/expr/impl/src/scalar/external/iceberg.rs | 7 +- 4 files changed, 146 insertions(+), 25 deletions(-) rename src/common/src/array/arrow/{arrow_default.rs => arrow_udf.rs} (61%) diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index 5c1f8ac45fba3..514d3b299769c 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -201,20 +201,20 @@ pub trait ToArrow { Ok(Arc::new(arrow_array::BinaryArray::from(array))) } - // Decimal values are stored as ASCII text representation in a large binary array. + // Decimal values are stored as ASCII text representation in a string array. #[inline] fn decimal_to_arrow( &self, _data_type: &arrow_schema::DataType, array: &DecimalArray, ) -> Result { - Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) + Ok(Arc::new(arrow_array::StringArray::from(array))) } - // JSON values are stored as text representation in a large string array. + // JSON values are stored as text representation in a string array. #[inline] fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result { - Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + Ok(Arc::new(arrow_array::StringArray::from(array))) } #[inline] @@ -366,7 +366,8 @@ pub trait ToArrow { #[inline] fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field { - arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true) + arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true) + .with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into()) } #[inline] @@ -376,7 +377,8 @@ pub trait ToArrow { #[inline] fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field { - arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true) + arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true) + .with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into()) } #[inline] @@ -414,8 +416,8 @@ pub trait FromArrow { /// Converts Arrow `RecordBatch` to RisingWave `DataChunk`. fn from_record_batch(&self, batch: &arrow_array::RecordBatch) -> Result { let mut columns = Vec::with_capacity(batch.num_columns()); - for array in batch.columns() { - let column = Arc::new(self.from_array(array)?); + for (array, field) in batch.columns().iter().zip_eq_fast(batch.schema().fields()) { + let column = Arc::new(self.from_array(field, array)?); columns.push(column); } Ok(DataChunk::new(columns, batch.num_rows())) @@ -472,30 +474,44 @@ pub trait FromArrow { /// Converts Arrow `LargeUtf8` type to RisingWave data type. fn from_large_utf8(&self) -> Result { - Ok(DataType::Jsonb) + Ok(DataType::Varchar) } /// Converts Arrow `LargeBinary` type to RisingWave data type. fn from_large_binary(&self) -> Result { - Ok(DataType::Decimal) + Ok(DataType::Bytea) } /// Converts Arrow extension type to RisingWave `DataType`. fn from_extension_type( &self, type_name: &str, - _physical_type: &arrow_schema::DataType, + physical_type: &arrow_schema::DataType, ) -> Result { - Err(ArrayError::from_arrow(format!( - "unsupported extension type: {type_name:?}" - ))) + match (type_name, physical_type) { + ("arrowudf.decimal", arrow_schema::DataType::Utf8) => Ok(DataType::Decimal), + ("arrowudf.json", arrow_schema::DataType::Utf8) => Ok(DataType::Jsonb), + _ => Err(ArrayError::from_arrow(format!( + "unsupported extension type: {type_name:?}" + ))), + } } /// Converts Arrow `Array` to RisingWave `ArrayImpl`. - fn from_array(&self, array: &arrow_array::ArrayRef) -> Result { + fn from_array( + &self, + field: &arrow_schema::Field, + array: &arrow_array::ArrayRef, + ) -> Result { use arrow_schema::DataType::*; use arrow_schema::IntervalUnit::*; use arrow_schema::TimeUnit::*; + + // extension type + if let Some(type_name) = field.metadata().get("ARROW:extension:name") { + return self.from_extension_array(type_name, array); + } + match array.data_type() { Boolean => self.from_bool_array(array.as_any().downcast_ref().unwrap()), Int16 => self.from_int16_array(array.as_any().downcast_ref().unwrap()), @@ -524,6 +540,37 @@ pub trait FromArrow { } } + /// Converts Arrow extension array to RisingWave `ArrayImpl`. + fn from_extension_array( + &self, + type_name: &str, + array: &arrow_array::ArrayRef, + ) -> Result { + match type_name { + "arrowudf.decimal" => { + let array: &arrow_array::StringArray = + array.as_any().downcast_ref().ok_or_else(|| { + ArrayError::from_arrow( + "expected string array for `arrowudf.decimal`".to_string(), + ) + })?; + Ok(ArrayImpl::Decimal(array.try_into()?)) + } + "arrowudf.json" => { + let array: &arrow_array::StringArray = + array.as_any().downcast_ref().ok_or_else(|| { + ArrayError::from_arrow( + "expected string array for `arrowudf.json`".to_string(), + ) + })?; + Ok(ArrayImpl::Jsonb(array.try_into()?)) + } + _ => Err(ArrayError::from_arrow(format!( + "unsupported extension type: {type_name:?}" + ))), + } + } + fn from_bool_array(&self, array: &arrow_array::BooleanArray) -> Result { Ok(ArrayImpl::Bool(array.into())) } @@ -598,20 +645,23 @@ pub trait FromArrow { &self, array: &arrow_array::LargeStringArray, ) -> Result { - Ok(ArrayImpl::Jsonb(array.try_into()?)) + Ok(ArrayImpl::Utf8(array.into())) } fn from_large_binary_array( &self, array: &arrow_array::LargeBinaryArray, ) -> Result { - Ok(ArrayImpl::Decimal(array.try_into()?)) + Ok(ArrayImpl::Bytea(array.into())) } fn from_list_array(&self, array: &arrow_array::ListArray) -> Result { use arrow_array::Array; + let arrow_schema::DataType::List(field) = array.data_type() else { + panic!("nested field types cannot be determined."); + }; Ok(ArrayImpl::List(ListArray { - value: Box::new(self.from_array(array.values())?), + value: Box::new(self.from_array(field, array.values())?), bitmap: match array.nulls() { Some(nulls) => nulls.iter().collect(), None => Bitmap::ones(array.len()), @@ -630,7 +680,8 @@ pub trait FromArrow { array .columns() .iter() - .map(|a| self.from_array(a).map(Arc::new)) + .zip_eq_fast(fields) + .map(|(array, field)| self.from_array(field, array).map(Arc::new)) .try_collect()?, (0..array.len()).map(|i| array.is_valid(i)).collect(), ))) @@ -703,7 +754,9 @@ converts!(I64Array, arrow_array::Int64Array); converts!(F32Array, arrow_array::Float32Array, @map); converts!(F64Array, arrow_array::Float64Array, @map); converts!(BytesArray, arrow_array::BinaryArray); +converts!(BytesArray, arrow_array::LargeBinaryArray); converts!(Utf8Array, arrow_array::StringArray); +converts!(Utf8Array, arrow_array::LargeStringArray); converts!(DateArray, arrow_array::Date32Array, @map); converts!(TimeArray, arrow_array::Time64MicrosecondArray, @map); converts!(TimestampArray, arrow_array::TimestampMicrosecondArray, @map); diff --git a/src/common/src/array/arrow/arrow_default.rs b/src/common/src/array/arrow/arrow_udf.rs similarity index 61% rename from src/common/src/array/arrow/arrow_default.rs rename to src/common/src/array/arrow/arrow_udf.rs index b2867d4fdf583..e2f9e39ad385a 100644 --- a/src/common/src/array/arrow/arrow_default.rs +++ b/src/common/src/array/arrow/arrow_udf.rs @@ -18,17 +18,78 @@ //! //! The corresponding version of arrow is currently used by `udf` and `iceberg` sink. +use std::sync::Arc; + pub use arrow_impl::{FromArrow, ToArrow}; use {arrow_array, arrow_buffer, arrow_cast, arrow_schema}; +use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray}; + #[expect(clippy::duplicate_mod)] #[path = "./arrow_impl.rs"] mod arrow_impl; +/// Arrow conversion for the current version of UDF. This is in use but will be deprecated soon. +/// +/// In the current version of UDF protocol, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types. pub struct UdfArrowConvert; -impl ToArrow for UdfArrowConvert {} -impl FromArrow for UdfArrowConvert {} +impl ToArrow for UdfArrowConvert { + // Decimal values are stored as ASCII text representation in a large binary array. + fn decimal_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &DecimalArray, + ) -> Result { + Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) + } + + // JSON values are stored as text representation in a large string array. + fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result { + Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + } + + fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true) + } + + fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true) + } +} + +impl FromArrow for UdfArrowConvert { + fn from_large_utf8(&self) -> Result { + Ok(DataType::Jsonb) + } + + fn from_large_binary(&self) -> Result { + Ok(DataType::Decimal) + } + + fn from_large_utf8_array( + &self, + array: &arrow_array::LargeStringArray, + ) -> Result { + Ok(ArrayImpl::Jsonb(array.try_into()?)) + } + + fn from_large_binary_array( + &self, + array: &arrow_array::LargeBinaryArray, + ) -> Result { + Ok(ArrayImpl::Decimal(array.try_into()?)) + } +} + +/// Arrow conversion for the next version of UDF. This is unused for now. +/// +/// In the next version of UDF protocol, decimal and jsonb types will be mapped to Arrow extension types. +/// See . +pub struct NewUdfArrowConvert; + +impl ToArrow for NewUdfArrowConvert {} +impl FromArrow for NewUdfArrowConvert {} #[cfg(test)] mod tests { @@ -108,7 +169,9 @@ mod tests { let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]); let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true); let arrow = UdfArrowConvert.list_to_arrow(&data_type, &array).unwrap(); - let rw_array = UdfArrowConvert.from_array(&arrow).unwrap(); + let rw_array = UdfArrowConvert + .from_list_array(arrow.as_any().downcast_ref().unwrap()) + .unwrap(); assert_eq!(rw_array.as_list(), &array); } } diff --git a/src/common/src/array/arrow/mod.rs b/src/common/src/array/arrow/mod.rs index cb726721c867b..67490b22315a1 100644 --- a/src/common/src/array/arrow/mod.rs +++ b/src/common/src/array/arrow/mod.rs @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod arrow_default; mod arrow_deltalake; mod arrow_iceberg; +mod arrow_udf; -pub use arrow_default::{FromArrow, ToArrow, UdfArrowConvert}; pub use arrow_deltalake::DeltaLakeConvert; pub use arrow_iceberg::IcebergArrowConvert; +pub use arrow_udf::{FromArrow, ToArrow, UdfArrowConvert}; diff --git a/src/expr/impl/src/scalar/external/iceberg.rs b/src/expr/impl/src/scalar/external/iceberg.rs index 2194d8b1355be..ea39ea7ef989d 100644 --- a/src/expr/impl/src/scalar/external/iceberg.rs +++ b/src/expr/impl/src/scalar/external/iceberg.rs @@ -35,6 +35,7 @@ pub struct IcebergTransform { child: BoxedExpression, transform: BoxedTransformFunction, input_arrow_type: arrow_schema::DataType, + output_arrow_field: arrow_schema::Field, return_type: DataType, } @@ -61,7 +62,9 @@ impl risingwave_expr::expr::Expression for IcebergTransform { // Transform let res_array = self.transform.transform(arrow_array).unwrap(); // Convert back to array ref and return it - Ok(Arc::new(IcebergArrowConvert.from_array(&res_array)?)) + Ok(Arc::new( + IcebergArrowConvert.from_array(&self.output_arrow_field, &res_array)?, + )) } async fn eval_row(&self, _row: &OwnedRow) -> Result { @@ -96,6 +99,7 @@ fn build(return_type: DataType, mut children: Vec) -> Result) -> Result