diff --git a/src/batch/src/executor/iceberg_scan.rs b/src/batch/src/executor/iceberg_scan.rs index e20b1aaedb8fe..31abcc73a2aac 100644 --- a/src/batch/src/executor/iceberg_scan.rs +++ b/src/batch/src/executor/iceberg_scan.rs @@ -13,13 +13,13 @@ // limitations under the License. use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::Arc; use anyhow::anyhow; use arrow_array::RecordBatch; use futures_async_stream::try_stream; use futures_util::stream::StreamExt; use icelake::io::{FileScan, TableScan}; +use risingwave_common::array::arrow::{FromArrow, IcebergArrowConvert}; use risingwave_common::catalog::Schema; use risingwave_connector::sink::iceberg::IcebergConfig; @@ -150,12 +150,7 @@ impl IcebergScanExecutor { } fn record_batch_to_chunk(record_batch: RecordBatch) -> Result { - let mut columns = Vec::with_capacity(record_batch.num_columns()); - for array in record_batch.columns() { - let column = Arc::new(array.try_into()?); - columns.push(column); - } - Ok(DataChunk::new(columns, record_batch.num_rows())) + Ok(IcebergArrowConvert.from_record_batch(&record_batch)?) } } diff --git a/src/common/src/array/arrow/arrow_default.rs b/src/common/src/array/arrow/arrow_default.rs deleted file mode 100644 index 5d04527b354ba..0000000000000 --- a/src/common/src/array/arrow/arrow_default.rs +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! This is for arrow dependency named `arrow-xxx` such as `arrow-array` in the cargo workspace. -//! -//! This should the default arrow version to be used in our system. -//! -//! The corresponding version of arrow is currently used by `udf` and `iceberg` sink. - -#![allow(unused_imports)] -pub use arrow_impl::{ - to_record_batch_with_schema, ToArrowArrayConvert, ToArrowArrayWithTypeConvert, - ToArrowTypeConvert, -}; -use {arrow_array, arrow_buffer, arrow_cast, arrow_schema}; - -#[expect(clippy::duplicate_mod)] -#[path = "./arrow_impl.rs"] -mod arrow_impl; diff --git a/src/common/src/array/arrow/arrow_deltalake.rs b/src/common/src/array/arrow/arrow_deltalake.rs index c55cae305b07f..c9f4052e2036d 100644 --- a/src/common/src/array/arrow/arrow_deltalake.rs +++ b/src/common/src/array/arrow/arrow_deltalake.rs @@ -21,25 +21,29 @@ use std::ops::{Div, Mul}; use std::sync::Arc; use arrow_array::ArrayRef; -use arrow_schema::DataType; -use itertools::Itertools; use num_traits::abs; use { arrow_array_deltalake as arrow_array, arrow_buffer_deltalake as arrow_buffer, arrow_cast_deltalake as arrow_cast, arrow_schema_deltalake as arrow_schema, }; -use self::arrow_impl::ToArrowArrayWithTypeConvert; -use crate::array::arrow::arrow_deltalake::arrow_impl::FromIntoArrow; -use crate::array::{Array, ArrayError, ArrayImpl, DataChunk, Decimal, DecimalArray, ListArray}; -use crate::util::iter_util::ZipEqFast; +use self::arrow_impl::ToArrow; +use crate::array::{Array, ArrayError, DataChunk, Decimal, DecimalArray}; #[expect(clippy::duplicate_mod)] #[path = "./arrow_impl.rs"] mod arrow_impl; -struct DeltaLakeConvert; +pub struct DeltaLakeConvert; impl DeltaLakeConvert { + pub fn to_record_batch( + &self, + schema: arrow_schema::SchemaRef, + chunk: &DataChunk, + ) -> Result { + ToArrow::to_record_batch(self, schema, chunk) + } + fn decimal_to_i128(decimal: Decimal, precision: u8, max_scale: i8) -> Option { match decimal { crate::array::Decimal::Normalized(e) => { @@ -67,7 +71,7 @@ impl DeltaLakeConvert { } } -impl arrow_impl::ToArrowArrayWithTypeConvert for DeltaLakeConvert { +impl ToArrow for DeltaLakeConvert { fn decimal_to_arrow( &self, data_type: &arrow_schema::DataType, @@ -89,189 +93,6 @@ impl arrow_impl::ToArrowArrayWithTypeConvert for DeltaLakeConvert { .map_err(ArrayError::from_arrow)?; Ok(Arc::new(array) as ArrayRef) } - - #[inline] - fn list_to_arrow( - &self, - data_type: &arrow_schema::DataType, - array: &ListArray, - ) -> Result { - use arrow_array::builder::*; - fn build( - array: &ListArray, - a: &A, - builder: B, - mut append: F, - ) -> arrow_array::ListArray - where - A: Array, - B: arrow_array::builder::ArrayBuilder, - F: FnMut(&mut B, Option>), - { - let mut builder = ListBuilder::with_capacity(builder, a.len()); - for i in 0..array.len() { - for j in array.offsets[i]..array.offsets[i + 1] { - append(builder.values(), a.value_at(j as usize)); - } - builder.append(!array.is_null(i)); - } - builder.finish() - } - let inner_type = match data_type { - arrow_schema::DataType::List(inner) => inner.data_type(), - _ => return Err(ArrayError::to_arrow("Invalid list type")), - }; - let arr: arrow_array::ListArray = match &*array.value { - ArrayImpl::Int16(a) => build(array, a, Int16Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - ArrayImpl::Int32(a) => build(array, a, Int32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - ArrayImpl::Int64(a) => build(array, a, Int64Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - - ArrayImpl::Float32(a) => { - build(array, a, Float32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|f| f.0)) - }) - } - ArrayImpl::Float64(a) => { - build(array, a, Float64Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|f| f.0)) - }) - } - ArrayImpl::Utf8(a) => build( - array, - a, - StringBuilder::with_capacity(a.len(), a.data().len()), - |b, v| b.append_option(v), - ), - ArrayImpl::Int256(a) => build( - array, - a, - Decimal256Builder::with_capacity(a.len()).with_data_type( - arrow_schema::DataType::Decimal256(arrow_schema::DECIMAL256_MAX_PRECISION, 0), - ), - |b, v| b.append_option(v.map(Into::into)), - ), - ArrayImpl::Bool(a) => { - build(array, a, BooleanBuilder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }) - } - ArrayImpl::Decimal(a) => { - let (precision, max_scale) = match inner_type { - arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale), - _ => return Err(ArrayError::to_arrow("Invalid decimal type")), - }; - build( - array, - a, - Decimal128Builder::with_capacity(a.len()) - .with_data_type(DataType::Decimal128(precision, max_scale)), - |b, v| { - let v = v.and_then(|v| { - DeltaLakeConvert::decimal_to_i128(v, precision, max_scale) - }); - b.append_option(v); - }, - ) - } - ArrayImpl::Interval(a) => build( - array, - a, - IntervalMonthDayNanoBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Date(a) => build(array, a, Date32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|d| d.into_arrow())) - }), - ArrayImpl::Timestamp(a) => build( - array, - a, - TimestampMicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Timestamptz(a) => build( - array, - a, - TimestampMicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Time(a) => build( - array, - a, - Time64MicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Jsonb(a) => build( - array, - a, - LargeStringBuilder::with_capacity(a.len(), a.len() * 16), - |b, v| b.append_option(v.map(|j| j.to_string())), - ), - ArrayImpl::Serial(_) => todo!("list of serial"), - ArrayImpl::Struct(a) => { - let values = Arc::new(arrow_array::StructArray::try_from(a)?); - arrow_array::ListArray::new( - Arc::new(arrow_schema::Field::new( - "item", - a.data_type().try_into()?, - true, - )), - arrow_buffer::OffsetBuffer::new(arrow_buffer::ScalarBuffer::from( - array - .offsets() - .iter() - .map(|o| *o as i32) - .collect::>(), - )), - values, - Some(array.null_bitmap().into()), - ) - } - ArrayImpl::List(_) => todo!("list of list"), - ArrayImpl::Bytea(a) => build( - array, - a, - BinaryBuilder::with_capacity(a.len(), a.data().len()), - |b, v| b.append_option(v), - ), - }; - Ok(Arc::new(arr)) - } -} - -/// Converts RisingWave array to Arrow array with the schema. -/// This function will try to convert the array if the type is not same with the schema. -pub fn to_deltalake_record_batch_with_schema( - schema: arrow_schema::SchemaRef, - chunk: &DataChunk, -) -> Result { - if !chunk.is_compacted() { - let c = chunk.clone(); - return to_deltalake_record_batch_with_schema(schema, &c.compact()); - } - let columns: Vec<_> = chunk - .columns() - .iter() - .zip_eq_fast(schema.fields().iter()) - .map(|(column, field)| { - let column: arrow_array::ArrayRef = - DeltaLakeConvert.to_arrow_with_type(field.data_type(), column)?; - if column.data_type() == field.data_type() { - Ok(column) - } else { - arrow_cast::cast(&column, field.data_type()).map_err(ArrayError::from_arrow) - } - }) - .try_collect::<_, _, ArrayError>()?; - - let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity())); - arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts) - .map_err(ArrayError::to_arrow) } #[cfg(test)] @@ -283,6 +104,7 @@ mod test { use arrow_schema::Field; use {arrow_array_deltalake as arrow_array, arrow_schema_deltalake as arrow_schema}; + use crate::array::arrow::arrow_deltalake::DeltaLakeConvert; use crate::array::{ArrayImpl, Decimal, DecimalArray, ListArray, ListValue}; use crate::buffer::Bitmap; @@ -304,13 +126,14 @@ mod test { arrow_schema::DataType::List(Arc::new(Field::new( "test", arrow_schema::DataType::Decimal128(10, 0), - false, + true, ))), false, )]); - let record_batch = - super::to_deltalake_record_batch_with_schema(Arc::new(schema), &chunk).unwrap(); + let record_batch = DeltaLakeConvert + .to_record_batch(Arc::new(schema), &chunk) + .unwrap(); let expect_array = Arc::new( arrow_array::Decimal128Array::from(vec![ None, diff --git a/src/common/src/array/arrow/arrow_iceberg.rs b/src/common/src/array/arrow/arrow_iceberg.rs index 2dd7900da5da1..72a49cb349370 100644 --- a/src/common/src/array/arrow/arrow_iceberg.rs +++ b/src/common/src/array/arrow/arrow_iceberg.rs @@ -15,25 +15,22 @@ use std::ops::{Div, Mul}; use std::sync::Arc; -use arrow_array::{ArrayRef, StructArray}; -use arrow_schema::DataType; -use itertools::Itertools; +use arrow_array::ArrayRef; use num_traits::abs; -use super::{ToArrowArrayWithTypeConvert, ToArrowTypeConvert}; -use crate::array::{Array, ArrayError, DataChunk, DecimalArray}; -use crate::util::iter_util::ZipEqFast; +use super::{FromArrow, ToArrow}; +use crate::array::{Array, ArrayError, DecimalArray}; -struct IcebergArrowConvert; +pub struct IcebergArrowConvert; -impl ToArrowTypeConvert for IcebergArrowConvert { +impl ToArrow for IcebergArrowConvert { #[inline] - fn decimal_type_to_arrow(&self) -> arrow_schema::DataType { - arrow_schema::DataType::Decimal128(arrow_schema::DECIMAL128_MAX_PRECISION, 0) + fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + let data_type = + arrow_schema::DataType::Decimal128(arrow_schema::DECIMAL128_MAX_PRECISION, 0); + arrow_schema::Field::new(name, data_type, true) } -} -impl ToArrowArrayWithTypeConvert for IcebergArrowConvert { fn decimal_to_arrow( &self, data_type: &arrow_schema::DataType, @@ -85,63 +82,7 @@ impl ToArrowArrayWithTypeConvert for IcebergArrowConvert { } } -/// Converts RisingWave array to Arrow array with the schema. -/// The behavior is specified for iceberg: -/// For different struct type, try to use fields in schema to cast. -pub fn to_iceberg_record_batch_with_schema( - schema: arrow_schema::SchemaRef, - chunk: &DataChunk, -) -> Result { - if !chunk.is_compacted() { - let c = chunk.clone(); - return to_iceberg_record_batch_with_schema(schema, &c.compact()); - } - let columns: Vec<_> = chunk - .columns() - .iter() - .zip_eq_fast(schema.fields().iter()) - .map(|(column, field)| { - let column: arrow_array::ArrayRef = - IcebergArrowConvert {}.to_arrow_with_type(field.data_type(), column)?; - if column.data_type() == field.data_type() { - Ok(column) - } else if let DataType::Struct(actual) = column.data_type() - && let DataType::Struct(expect) = field.data_type() - { - // Special case for iceberg - if actual.len() != expect.len() { - return Err(ArrayError::to_arrow(format!( - "Struct field count mismatch, expect {}, actual {}", - expect.len(), - actual.len() - ))); - } - let column = column - .as_any() - .downcast_ref::() - .unwrap() - .clone(); - let (_, struct_columns, nullable) = column.into_parts(); - Ok(Arc::new( - StructArray::try_new(expect.clone(), struct_columns, nullable) - .map_err(ArrayError::from_arrow)?, - ) as ArrayRef) - } else { - arrow_cast::cast(&column, field.data_type()).map_err(ArrayError::from_arrow) - } - }) - .try_collect::<_, _, ArrayError>()?; - - let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity())); - arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts) - .map_err(ArrayError::to_arrow) -} - -pub fn iceberg_to_arrow_type( - data_type: &crate::array::DataType, -) -> Result { - IcebergArrowConvert {}.to_arrow_type(data_type) -} +impl FromArrow for IcebergArrowConvert {} #[cfg(test)] mod test { @@ -150,7 +91,7 @@ mod test { use arrow_array::ArrayRef; use crate::array::arrow::arrow_iceberg::IcebergArrowConvert; - use crate::array::arrow::ToArrowArrayWithTypeConvert; + use crate::array::arrow::ToArrow; use crate::array::{Decimal, DecimalArray}; #[test] diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index e426caa306d55..514d3b299769c 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -39,6 +39,7 @@ use std::fmt::Write; use std::sync::Arc; +use arrow_buffer::OffsetBuffer; use chrono::{NaiveDateTime, NaiveTime}; use itertools::Itertools; @@ -50,317 +51,76 @@ use crate::buffer::Bitmap; use crate::types::*; use crate::util::iter_util::ZipEqFast; -/// Converts RisingWave array to Arrow array with the schema. -/// This function will try to convert the array if the type is not same with the schema. -#[allow(dead_code)] -pub fn to_record_batch_with_schema( - schema: arrow_schema::SchemaRef, - chunk: &DataChunk, -) -> Result { - if !chunk.is_compacted() { - let c = chunk.clone(); - return to_record_batch_with_schema(schema, &c.compact()); - } - let columns: Vec<_> = chunk - .columns() - .iter() - .zip_eq_fast(schema.fields().iter()) - .map(|(column, field)| { - let column: arrow_array::ArrayRef = column.as_ref().try_into()?; - if column.data_type() == field.data_type() { - Ok(column) - } else { - arrow_cast::cast(&column, field.data_type()).map_err(ArrayError::from_arrow) - } - }) - .try_collect::<_, _, ArrayError>()?; - - let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity())); - arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts) - .map_err(ArrayError::to_arrow) -} - -// Implement bi-directional `From` between `DataChunk` and `arrow_array::RecordBatch`. -impl TryFrom<&DataChunk> for arrow_array::RecordBatch { - type Error = ArrayError; - - fn try_from(chunk: &DataChunk) -> Result { +/// Defines how to convert RisingWave arrays to Arrow arrays. +pub trait ToArrow { + /// Converts RisingWave `DataChunk` to Arrow `RecordBatch` with specified schema. + /// + /// This function will try to convert the array if the type is not same with the schema. + fn to_record_batch( + &self, + schema: arrow_schema::SchemaRef, + chunk: &DataChunk, + ) -> Result { + // compact the chunk if it's not compacted if !chunk.is_compacted() { let c = chunk.clone(); - return Self::try_from(&c.compact()); + return self.to_record_batch(schema, &c.compact()); } + + // convert each column to arrow array let columns: Vec<_> = chunk .columns() .iter() - .map(|column| column.as_ref().try_into()) - .try_collect::<_, _, Self::Error>()?; - - let fields: Vec<_> = columns - .iter() - .map(|array: &Arc| { - let nullable = array.null_count() > 0; - let data_type = array.data_type().clone(); - arrow_schema::Field::new("", data_type, nullable) - }) - .collect(); + .zip_eq_fast(schema.fields().iter()) + .map(|(column, field)| self.to_array(field.data_type(), column)) + .try_collect()?; - let schema = Arc::new(arrow_schema::Schema::new(fields)); + // create record batch let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity())); arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts) .map_err(ArrayError::to_arrow) } -} - -impl TryFrom<&arrow_array::RecordBatch> for DataChunk { - type Error = ArrayError; - - fn try_from(batch: &arrow_array::RecordBatch) -> Result { - let mut columns = Vec::with_capacity(batch.num_columns()); - for array in batch.columns() { - let column = Arc::new(array.try_into()?); - columns.push(column); - } - Ok(DataChunk::new(columns, batch.num_rows())) - } -} -/// Provides the default conversion logic for RisingWave array to Arrow array with type info. -pub trait ToArrowArrayWithTypeConvert { - fn to_arrow_with_type( + /// Converts RisingWave array to Arrow array. + fn to_array( &self, data_type: &arrow_schema::DataType, array: &ArrayImpl, ) -> Result { - match array { - ArrayImpl::Int16(array) => self.int16_to_arrow(data_type, array), - ArrayImpl::Int32(array) => self.int32_to_arrow(data_type, array), - ArrayImpl::Int64(array) => self.int64_to_arrow(data_type, array), - ArrayImpl::Float32(array) => self.float32_to_arrow(data_type, array), - ArrayImpl::Float64(array) => self.float64_to_arrow(data_type, array), - ArrayImpl::Utf8(array) => self.utf8_to_arrow(data_type, array), - ArrayImpl::Bool(array) => self.bool_to_arrow(data_type, array), - ArrayImpl::Decimal(array) => self.decimal_to_arrow(data_type, array), - ArrayImpl::Int256(array) => self.int256_to_arrow(data_type, array), - ArrayImpl::Date(array) => self.date_to_arrow(data_type, array), - ArrayImpl::Timestamp(array) => self.timestamp_to_arrow(data_type, array), - ArrayImpl::Timestamptz(array) => self.timestamptz_to_arrow(data_type, array), - ArrayImpl::Time(array) => self.time_to_arrow(data_type, array), - ArrayImpl::Interval(array) => self.interval_to_arrow(data_type, array), - ArrayImpl::Struct(array) => self.struct_to_arrow(data_type, array), - ArrayImpl::List(array) => self.list_to_arrow(data_type, array), - ArrayImpl::Bytea(array) => self.bytea_to_arrow(data_type, array), - ArrayImpl::Jsonb(array) => self.jsonb_to_arrow(data_type, array), - ArrayImpl::Serial(array) => self.serial_to_arrow(data_type, array), - } - } - - #[inline] - fn int16_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &I16Array, - ) -> Result { - Ok(Arc::new(arrow_array::Int16Array::from(array))) - } - - #[inline] - fn int32_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &I32Array, - ) -> Result { - Ok(Arc::new(arrow_array::Int32Array::from(array))) - } - - #[inline] - fn int64_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &I64Array, - ) -> Result { - Ok(Arc::new(arrow_array::Int64Array::from(array))) - } - - #[inline] - fn float32_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &F32Array, - ) -> Result { - Ok(Arc::new(arrow_array::Float32Array::from(array))) - } - - #[inline] - fn float64_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &F64Array, - ) -> Result { - Ok(Arc::new(arrow_array::Float64Array::from(array))) - } - - #[inline] - fn utf8_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &Utf8Array, - ) -> Result { - Ok(Arc::new(arrow_array::StringArray::from(array))) - } - - #[inline] - fn bool_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &BoolArray, - ) -> Result { - Ok(Arc::new(arrow_array::BooleanArray::from(array))) - } - - // Decimal values are stored as ASCII text representation in a large binary array. - #[inline] - fn decimal_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &DecimalArray, - ) -> Result { - Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) - } - - #[inline] - fn int256_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &Int256Array, - ) -> Result { - Ok(Arc::new(arrow_array::Decimal256Array::from(array))) - } - - #[inline] - fn date_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &DateArray, - ) -> Result { - Ok(Arc::new(arrow_array::Date32Array::from(array))) - } - - #[inline] - fn timestamp_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &TimestampArray, - ) -> Result { - Ok(Arc::new(arrow_array::TimestampMicrosecondArray::from( - array, - ))) - } - - #[inline] - fn timestamptz_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &TimestamptzArray, - ) -> Result { - Ok(Arc::new( - arrow_array::TimestampMicrosecondArray::from(array).with_timezone_utc(), - )) - } - - #[inline] - fn time_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &TimeArray, - ) -> Result { - Ok(Arc::new(arrow_array::Time64MicrosecondArray::from(array))) - } - - #[inline] - fn interval_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &IntervalArray, - ) -> Result { - Ok(Arc::new(arrow_array::IntervalMonthDayNanoArray::from( - array, - ))) - } - - #[inline] - fn struct_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &StructArray, - ) -> Result { - Ok(Arc::new(arrow_array::StructArray::try_from(array)?)) - } - - #[inline] - fn list_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &ListArray, - ) -> Result { - Ok(Arc::new(arrow_array::ListArray::try_from(array)?)) - } - - #[inline] - fn bytea_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &BytesArray, - ) -> Result { - Ok(Arc::new(arrow_array::BinaryArray::from(array))) - } - - // JSON values are stored as text representation in a large string array. - #[inline] - fn jsonb_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &JsonbArray, - ) -> Result { - Ok(Arc::new(arrow_array::LargeStringArray::from(array))) - } - - #[inline] - fn serial_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - _array: &SerialArray, - ) -> Result { - todo!("serial type is not supported to convert to arrow") - } -} - -/// Provides the default conversion logic for RisingWave array to Arrow array with type info. -pub trait ToArrowArrayConvert { - fn to_arrow(&self, array: &ArrayImpl) -> Result { - match array { + let arrow_array = match array { + ArrayImpl::Bool(array) => self.bool_to_arrow(array), ArrayImpl::Int16(array) => self.int16_to_arrow(array), ArrayImpl::Int32(array) => self.int32_to_arrow(array), ArrayImpl::Int64(array) => self.int64_to_arrow(array), + ArrayImpl::Int256(array) => self.int256_to_arrow(array), ArrayImpl::Float32(array) => self.float32_to_arrow(array), ArrayImpl::Float64(array) => self.float64_to_arrow(array), - ArrayImpl::Utf8(array) => self.utf8_to_arrow(array), - ArrayImpl::Bool(array) => self.bool_to_arrow(array), - ArrayImpl::Decimal(array) => self.decimal_to_arrow(array), - ArrayImpl::Int256(array) => self.int256_to_arrow(array), ArrayImpl::Date(array) => self.date_to_arrow(array), + ArrayImpl::Time(array) => self.time_to_arrow(array), ArrayImpl::Timestamp(array) => self.timestamp_to_arrow(array), ArrayImpl::Timestamptz(array) => self.timestamptz_to_arrow(array), - ArrayImpl::Time(array) => self.time_to_arrow(array), ArrayImpl::Interval(array) => self.interval_to_arrow(array), - ArrayImpl::Struct(array) => self.struct_to_arrow(array), - ArrayImpl::List(array) => self.list_to_arrow(array), + ArrayImpl::Utf8(array) => self.utf8_to_arrow(array), ArrayImpl::Bytea(array) => self.bytea_to_arrow(array), + ArrayImpl::Decimal(array) => self.decimal_to_arrow(data_type, array), ArrayImpl::Jsonb(array) => self.jsonb_to_arrow(array), ArrayImpl::Serial(array) => self.serial_to_arrow(array), + ArrayImpl::List(array) => self.list_to_arrow(data_type, array), + ArrayImpl::Struct(array) => self.struct_to_arrow(data_type, array), + }?; + if arrow_array.data_type() != data_type { + arrow_cast::cast(&arrow_array, data_type).map_err(ArrayError::to_arrow) + } else { + Ok(arrow_array) } } + #[inline] + fn bool_to_arrow(&self, array: &BoolArray) -> Result { + Ok(Arc::new(arrow_array::BooleanArray::from(array))) + } + #[inline] fn int16_to_arrow(&self, array: &I16Array) -> Result { Ok(Arc::new(arrow_array::Int16Array::from(array))) @@ -391,17 +151,6 @@ pub trait ToArrowArrayConvert { Ok(Arc::new(arrow_array::StringArray::from(array))) } - #[inline] - fn bool_to_arrow(&self, array: &BoolArray) -> Result { - Ok(Arc::new(arrow_array::BooleanArray::from(array))) - } - - // Decimal values are stored as ASCII text representation in a large binary array. - #[inline] - fn decimal_to_arrow(&self, array: &DecimalArray) -> Result { - Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) - } - #[inline] fn int256_to_arrow(&self, array: &Int256Array) -> Result { Ok(Arc::new(arrow_array::Decimal256Array::from(array))) @@ -448,56 +197,103 @@ pub trait ToArrowArrayConvert { } #[inline] - fn struct_to_arrow(&self, array: &StructArray) -> Result { - Ok(Arc::new(arrow_array::StructArray::try_from(array)?)) + fn bytea_to_arrow(&self, array: &BytesArray) -> Result { + Ok(Arc::new(arrow_array::BinaryArray::from(array))) } + // Decimal values are stored as ASCII text representation in a string array. #[inline] - fn list_to_arrow(&self, array: &ListArray) -> Result { - Ok(Arc::new(arrow_array::ListArray::try_from(array)?)) + fn decimal_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &DecimalArray, + ) -> Result { + Ok(Arc::new(arrow_array::StringArray::from(array))) } + // JSON values are stored as text representation in a string array. #[inline] - fn bytea_to_arrow(&self, array: &BytesArray) -> Result { - Ok(Arc::new(arrow_array::BinaryArray::from(array))) + fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result { + Ok(Arc::new(arrow_array::StringArray::from(array))) } - // JSON values are stored as text representation in a large string array. #[inline] - fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result { - Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + fn serial_to_arrow(&self, array: &SerialArray) -> Result { + Ok(Arc::new(arrow_array::Int64Array::from(array))) } #[inline] - fn serial_to_arrow(&self, _array: &SerialArray) -> Result { - todo!("serial type is not supported to convert to arrow") + fn list_to_arrow( + &self, + data_type: &arrow_schema::DataType, + array: &ListArray, + ) -> Result { + let arrow_schema::DataType::List(field) = data_type else { + return Err(ArrayError::to_arrow("Invalid list type")); + }; + let values = self.to_array(field.data_type(), array.values())?; + let offsets = OffsetBuffer::new(array.offsets().iter().map(|&o| o as i32).collect()); + let nulls = (!array.null_bitmap().all()).then(|| array.null_bitmap().into()); + Ok(Arc::new(arrow_array::ListArray::new( + field.clone(), + offsets, + values, + nulls, + ))) + } + + #[inline] + fn struct_to_arrow( + &self, + data_type: &arrow_schema::DataType, + array: &StructArray, + ) -> Result { + let arrow_schema::DataType::Struct(fields) = data_type else { + return Err(ArrayError::to_arrow("Invalid struct type")); + }; + Ok(Arc::new(arrow_array::StructArray::new( + fields.clone(), + array + .fields() + .zip_eq_fast(fields) + .map(|(arr, field)| self.to_array(field.data_type(), arr)) + .try_collect::<_, _, ArrayError>()?, + Some(array.null_bitmap().into()), + ))) } -} -pub trait ToArrowTypeConvert { - fn to_arrow_type(&self, value: &DataType) -> Result { - match value { + /// Convert RisingWave data type to Arrow data type. + /// + /// This function returns a `Field` instead of `DataType` because some may be converted to + /// extension types which require additional metadata in the field. + fn to_arrow_field( + &self, + name: &str, + value: &DataType, + ) -> Result { + let data_type = match value { // using the inline function - DataType::Boolean => Ok(self.bool_type_to_arrow()), - DataType::Int16 => Ok(self.int16_type_to_arrow()), - DataType::Int32 => Ok(self.int32_type_to_arrow()), - DataType::Int64 => Ok(self.int64_type_to_arrow()), - DataType::Int256 => Ok(self.int256_type_to_arrow()), - DataType::Float32 => Ok(self.float32_type_to_arrow()), - DataType::Float64 => Ok(self.float64_type_to_arrow()), - DataType::Date => Ok(self.date_type_to_arrow()), - DataType::Timestamp => Ok(self.timestamp_type_to_arrow()), - DataType::Timestamptz => Ok(self.timestamptz_type_to_arrow()), - DataType::Time => Ok(self.time_type_to_arrow()), - DataType::Interval => Ok(self.interval_type_to_arrow()), - DataType::Varchar => Ok(self.varchar_type_to_arrow()), - DataType::Jsonb => Ok(self.jsonb_type_to_arrow()), - DataType::Bytea => Ok(self.bytea_type_to_arrow()), - DataType::Decimal => Ok(self.decimal_type_to_arrow()), - DataType::Serial => Ok(self.serial_type_to_arrow()), - DataType::Struct(fields) => self.struct_type_to_arrow(fields), - DataType::List(datatype) => self.list_type_to_arrow(datatype), - } + DataType::Boolean => self.bool_type_to_arrow(), + DataType::Int16 => self.int16_type_to_arrow(), + DataType::Int32 => self.int32_type_to_arrow(), + DataType::Int64 => self.int64_type_to_arrow(), + DataType::Int256 => self.int256_type_to_arrow(), + DataType::Float32 => self.float32_type_to_arrow(), + DataType::Float64 => self.float64_type_to_arrow(), + DataType::Date => self.date_type_to_arrow(), + DataType::Time => self.time_type_to_arrow(), + DataType::Timestamp => self.timestamp_type_to_arrow(), + DataType::Timestamptz => self.timestamptz_type_to_arrow(), + DataType::Interval => self.interval_type_to_arrow(), + DataType::Varchar => self.varchar_type_to_arrow(), + DataType::Bytea => self.bytea_type_to_arrow(), + DataType::Serial => self.serial_type_to_arrow(), + DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)), + DataType::Jsonb => return Ok(self.jsonb_type_to_arrow(name)), + DataType::Struct(fields) => self.struct_type_to_arrow(fields)?, + DataType::List(datatype) => self.list_type_to_arrow(datatype)?, + }; + Ok(arrow_schema::Field::new(name, data_type, true)) } #[inline] @@ -505,6 +301,11 @@ pub trait ToArrowTypeConvert { arrow_schema::DataType::Boolean } + #[inline] + fn int16_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Int16 + } + #[inline] fn int32_type_to_arrow(&self) -> arrow_schema::DataType { arrow_schema::DataType::Int32 @@ -515,12 +316,6 @@ pub trait ToArrowTypeConvert { arrow_schema::DataType::Int64 } - // generate function for each type for me using inline - #[inline] - fn int16_type_to_arrow(&self) -> arrow_schema::DataType { - arrow_schema::DataType::Int16 - } - #[inline] fn int256_type_to_arrow(&self) -> arrow_schema::DataType { arrow_schema::DataType::Decimal256(arrow_schema::DECIMAL256_MAX_PRECISION, 0) @@ -541,6 +336,11 @@ pub trait ToArrowTypeConvert { arrow_schema::DataType::Date32 } + #[inline] + fn time_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Time64(arrow_schema::TimeUnit::Microsecond) + } + #[inline] fn timestamp_type_to_arrow(&self) -> arrow_schema::DataType { arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None) @@ -554,11 +354,6 @@ pub trait ToArrowTypeConvert { ) } - #[inline] - fn time_type_to_arrow(&self) -> arrow_schema::DataType { - arrow_schema::DataType::Time64(arrow_schema::TimeUnit::Microsecond) - } - #[inline] fn interval_type_to_arrow(&self) -> arrow_schema::DataType { arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::MonthDayNano) @@ -570,8 +365,9 @@ pub trait ToArrowTypeConvert { } #[inline] - fn jsonb_type_to_arrow(&self) -> arrow_schema::DataType { - arrow_schema::DataType::LargeUtf8 + fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true) + .with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into()) } #[inline] @@ -580,22 +376,23 @@ pub trait ToArrowTypeConvert { } #[inline] - fn decimal_type_to_arrow(&self) -> arrow_schema::DataType { - arrow_schema::DataType::LargeBinary + fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true) + .with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into()) } #[inline] fn serial_type_to_arrow(&self) -> arrow_schema::DataType { - todo!("serial type is not supported to convert to arrow") + arrow_schema::DataType::Int64 } #[inline] fn list_type_to_arrow( &self, - datatype: &DataType, + elem_type: &DataType, ) -> Result { Ok(arrow_schema::DataType::List(Arc::new( - arrow_schema::Field::new("item", datatype.try_into()?, true), + self.to_arrow_field("item", elem_type)?, ))) } @@ -607,160 +404,287 @@ pub trait ToArrowTypeConvert { Ok(arrow_schema::DataType::Struct( fields .iter() - .map(|(name, ty)| Ok(arrow_schema::Field::new(name, ty.try_into()?, true))) + .map(|(name, ty)| self.to_arrow_field(name, ty)) .try_collect::<_, _, ArrayError>()?, )) } } -struct DefaultArrowConvert; -impl ToArrowArrayConvert for DefaultArrowConvert {} - -/// Implement bi-directional `From` between `ArrayImpl` and `arrow_array::ArrayRef`. -macro_rules! converts_generic { - ($({ $ArrowType:ty, $ArrowPattern:pat, $ArrayImplPattern:path }),*) => { - // RisingWave array -> Arrow array - impl TryFrom<&ArrayImpl> for arrow_array::ArrayRef { - type Error = ArrayError; - fn try_from(array: &ArrayImpl) -> Result { - DefaultArrowConvert{}.to_arrow(array) - } +/// Defines how to convert Arrow arrays to RisingWave arrays. +#[allow(clippy::wrong_self_convention)] +pub trait FromArrow { + /// Converts Arrow `RecordBatch` to RisingWave `DataChunk`. + fn from_record_batch(&self, batch: &arrow_array::RecordBatch) -> Result { + let mut columns = Vec::with_capacity(batch.num_columns()); + for (array, field) in batch.columns().iter().zip_eq_fast(batch.schema().fields()) { + let column = Arc::new(self.from_array(field, array)?); + columns.push(column); } - // Arrow array -> RisingWave array - impl TryFrom<&arrow_array::ArrayRef> for ArrayImpl { - type Error = ArrayError; - fn try_from(array: &arrow_array::ArrayRef) -> Result { - use arrow_schema::DataType::*; - use arrow_schema::IntervalUnit::*; - use arrow_schema::TimeUnit::*; - match array.data_type() { - $($ArrowPattern => Ok($ArrayImplPattern( - array - .as_any() - .downcast_ref::<$ArrowType>() - .unwrap() - .try_into()?, - )),)* - Timestamp(Microsecond, Some(_)) => Ok(ArrayImpl::Timestamptz( - array - .as_any() - .downcast_ref::() - .unwrap() - .try_into()?, - )), - // This arrow decimal type is used by iceberg source to read iceberg decimal into RW decimal. - Decimal128(_, _) => Ok(ArrayImpl::Decimal( - array - .as_any() - .downcast_ref::() - .unwrap() - .try_into()?, - )), - t => Err(ArrayError::from_arrow(format!("unsupported data type: {t:?}"))), - } + Ok(DataChunk::new(columns, batch.num_rows())) + } + + /// Converts Arrow `Fields` to RisingWave `StructType`. + fn from_fields(&self, fields: &arrow_schema::Fields) -> Result { + Ok(StructType::new( + fields + .iter() + .map(|f| Ok((f.name().clone(), self.from_field(f)?))) + .try_collect::<_, _, ArrayError>()?, + )) + } + + /// Converts Arrow `Field` to RisingWave `DataType`. + fn from_field(&self, field: &arrow_schema::Field) -> Result { + use arrow_schema::DataType::*; + use arrow_schema::IntervalUnit::*; + use arrow_schema::TimeUnit::*; + + // extension type + if let Some(type_name) = field.metadata().get("ARROW:extension:name") { + return self.from_extension_type(type_name, field.data_type()); + } + + Ok(match field.data_type() { + Boolean => DataType::Boolean, + Int16 => DataType::Int16, + Int32 => DataType::Int32, + Int64 => DataType::Int64, + Float32 => DataType::Float32, + Float64 => DataType::Float64, + Decimal128(_, _) => DataType::Decimal, + Decimal256(_, _) => DataType::Int256, + Date32 => DataType::Date, + Time64(Microsecond) => DataType::Time, + Timestamp(Microsecond, None) => DataType::Timestamp, + Timestamp(Microsecond, Some(_)) => DataType::Timestamptz, + Interval(MonthDayNano) => DataType::Interval, + Utf8 => DataType::Varchar, + Binary => DataType::Bytea, + LargeUtf8 => self.from_large_utf8()?, + LargeBinary => self.from_large_binary()?, + List(field) => DataType::List(Box::new(self.from_field(field)?)), + Struct(fields) => DataType::Struct(self.from_fields(fields)?), + t => { + return Err(ArrayError::from_arrow(format!( + "unsupported arrow data type: {t:?}" + ))) } + }) + } + + /// Converts Arrow `LargeUtf8` type to RisingWave data type. + fn from_large_utf8(&self) -> Result { + Ok(DataType::Varchar) + } + + /// Converts Arrow `LargeBinary` type to RisingWave data type. + fn from_large_binary(&self) -> Result { + Ok(DataType::Bytea) + } + + /// Converts Arrow extension type to RisingWave `DataType`. + fn from_extension_type( + &self, + type_name: &str, + physical_type: &arrow_schema::DataType, + ) -> Result { + match (type_name, physical_type) { + ("arrowudf.decimal", arrow_schema::DataType::Utf8) => Ok(DataType::Decimal), + ("arrowudf.json", arrow_schema::DataType::Utf8) => Ok(DataType::Jsonb), + _ => Err(ArrayError::from_arrow(format!( + "unsupported extension type: {type_name:?}" + ))), } - }; -} -converts_generic! { - { arrow_array::Int16Array, Int16, ArrayImpl::Int16 }, - { arrow_array::Int32Array, Int32, ArrayImpl::Int32 }, - { arrow_array::Int64Array, Int64, ArrayImpl::Int64 }, - { arrow_array::Float32Array, Float32, ArrayImpl::Float32 }, - { arrow_array::Float64Array, Float64, ArrayImpl::Float64 }, - { arrow_array::StringArray, Utf8, ArrayImpl::Utf8 }, - { arrow_array::BooleanArray, Boolean, ArrayImpl::Bool }, - // Arrow doesn't have a data type to represent unconstrained numeric (`DECIMAL` in RisingWave and - // Postgres). So we pick a special type `LargeBinary` for it. - // Values stored in the array are the string representation of the decimal. e.g. b"1.234", b"+inf" - { arrow_array::LargeBinaryArray, LargeBinary, ArrayImpl::Decimal }, - { arrow_array::Decimal256Array, Decimal256(_, _), ArrayImpl::Int256 }, - { arrow_array::Date32Array, Date32, ArrayImpl::Date }, - { arrow_array::TimestampMicrosecondArray, Timestamp(Microsecond, None), ArrayImpl::Timestamp }, - { arrow_array::Time64MicrosecondArray, Time64(Microsecond), ArrayImpl::Time }, - { arrow_array::IntervalMonthDayNanoArray, Interval(MonthDayNano), ArrayImpl::Interval }, - { arrow_array::StructArray, Struct(_), ArrayImpl::Struct }, - { arrow_array::ListArray, List(_), ArrayImpl::List }, - { arrow_array::BinaryArray, Binary, ArrayImpl::Bytea }, - { arrow_array::LargeStringArray, LargeUtf8, ArrayImpl::Jsonb } // we use LargeUtf8 to represent Jsonb in arrow -} + } -// Arrow Datatype -> Risingwave Datatype -impl From<&arrow_schema::DataType> for DataType { - fn from(value: &arrow_schema::DataType) -> Self { + /// Converts Arrow `Array` to RisingWave `ArrayImpl`. + fn from_array( + &self, + field: &arrow_schema::Field, + array: &arrow_array::ArrayRef, + ) -> Result { use arrow_schema::DataType::*; use arrow_schema::IntervalUnit::*; use arrow_schema::TimeUnit::*; - match value { - Boolean => Self::Boolean, - Int16 => Self::Int16, - Int32 => Self::Int32, - Int64 => Self::Int64, - Float32 => Self::Float32, - Float64 => Self::Float64, - LargeBinary => Self::Decimal, - Decimal256(_, _) => Self::Int256, - Date32 => Self::Date, - Time64(Microsecond) => Self::Time, - Timestamp(Microsecond, None) => Self::Timestamp, - Timestamp(Microsecond, Some(_)) => Self::Timestamptz, - Interval(MonthDayNano) => Self::Interval, - Binary => Self::Bytea, - Utf8 => Self::Varchar, - LargeUtf8 => Self::Jsonb, - Struct(fields) => Self::Struct(fields.into()), - List(field) => Self::List(Box::new(field.data_type().into())), - Decimal128(_, _) => Self::Decimal, - _ => todo!("Unsupported arrow data type: {value:?}"), + + // extension type + if let Some(type_name) = field.metadata().get("ARROW:extension:name") { + return self.from_extension_array(type_name, array); + } + + match array.data_type() { + Boolean => self.from_bool_array(array.as_any().downcast_ref().unwrap()), + Int16 => self.from_int16_array(array.as_any().downcast_ref().unwrap()), + Int32 => self.from_int32_array(array.as_any().downcast_ref().unwrap()), + Int64 => self.from_int64_array(array.as_any().downcast_ref().unwrap()), + Decimal256(_, _) => self.from_int256_array(array.as_any().downcast_ref().unwrap()), + Float32 => self.from_float32_array(array.as_any().downcast_ref().unwrap()), + Float64 => self.from_float64_array(array.as_any().downcast_ref().unwrap()), + Date32 => self.from_date32_array(array.as_any().downcast_ref().unwrap()), + Time64(Microsecond) => self.from_time64us_array(array.as_any().downcast_ref().unwrap()), + Timestamp(Microsecond, _) => { + self.from_timestampus_array(array.as_any().downcast_ref().unwrap()) + } + Interval(MonthDayNano) => { + self.from_interval_array(array.as_any().downcast_ref().unwrap()) + } + Utf8 => self.from_utf8_array(array.as_any().downcast_ref().unwrap()), + Binary => self.from_binary_array(array.as_any().downcast_ref().unwrap()), + LargeUtf8 => self.from_large_utf8_array(array.as_any().downcast_ref().unwrap()), + LargeBinary => self.from_large_binary_array(array.as_any().downcast_ref().unwrap()), + List(_) => self.from_list_array(array.as_any().downcast_ref().unwrap()), + Struct(_) => self.from_struct_array(array.as_any().downcast_ref().unwrap()), + t => Err(ArrayError::from_arrow(format!( + "unsupported arrow data type: {t:?}", + ))), } } -} -impl From<&arrow_schema::Fields> for StructType { - fn from(fields: &arrow_schema::Fields) -> Self { - Self::new( - fields - .iter() - .map(|f| (f.name().clone(), f.data_type().into())) - .collect(), - ) + /// Converts Arrow extension array to RisingWave `ArrayImpl`. + fn from_extension_array( + &self, + type_name: &str, + array: &arrow_array::ArrayRef, + ) -> Result { + match type_name { + "arrowudf.decimal" => { + let array: &arrow_array::StringArray = + array.as_any().downcast_ref().ok_or_else(|| { + ArrayError::from_arrow( + "expected string array for `arrowudf.decimal`".to_string(), + ) + })?; + Ok(ArrayImpl::Decimal(array.try_into()?)) + } + "arrowudf.json" => { + let array: &arrow_array::StringArray = + array.as_any().downcast_ref().ok_or_else(|| { + ArrayError::from_arrow( + "expected string array for `arrowudf.json`".to_string(), + ) + })?; + Ok(ArrayImpl::Jsonb(array.try_into()?)) + } + _ => Err(ArrayError::from_arrow(format!( + "unsupported extension type: {type_name:?}" + ))), + } } -} -impl TryFrom<&StructType> for arrow_schema::Fields { - type Error = ArrayError; + fn from_bool_array(&self, array: &arrow_array::BooleanArray) -> Result { + Ok(ArrayImpl::Bool(array.into())) + } - fn try_from(struct_type: &StructType) -> Result { - struct_type - .iter() - .map(|(name, ty)| Ok(arrow_schema::Field::new(name, ty.try_into()?, true))) - .try_collect() + fn from_int16_array(&self, array: &arrow_array::Int16Array) -> Result { + Ok(ArrayImpl::Int16(array.into())) } -} -impl From for DataType { - fn from(value: arrow_schema::DataType) -> Self { - (&value).into() + fn from_int32_array(&self, array: &arrow_array::Int32Array) -> Result { + Ok(ArrayImpl::Int32(array.into())) } -} -struct DefaultArrowTypeConvert; + fn from_int64_array(&self, array: &arrow_array::Int64Array) -> Result { + Ok(ArrayImpl::Int64(array.into())) + } -impl ToArrowTypeConvert for DefaultArrowTypeConvert {} + fn from_int256_array( + &self, + array: &arrow_array::Decimal256Array, + ) -> Result { + Ok(ArrayImpl::Int256(array.into())) + } -impl TryFrom<&DataType> for arrow_schema::DataType { - type Error = ArrayError; + fn from_float32_array( + &self, + array: &arrow_array::Float32Array, + ) -> Result { + Ok(ArrayImpl::Float32(array.into())) + } - fn try_from(value: &DataType) -> Result { - DefaultArrowTypeConvert {}.to_arrow_type(value) + fn from_float64_array( + &self, + array: &arrow_array::Float64Array, + ) -> Result { + Ok(ArrayImpl::Float64(array.into())) } -} -impl TryFrom for arrow_schema::DataType { - type Error = ArrayError; + fn from_date32_array(&self, array: &arrow_array::Date32Array) -> Result { + Ok(ArrayImpl::Date(array.into())) + } + + fn from_time64us_array( + &self, + array: &arrow_array::Time64MicrosecondArray, + ) -> Result { + Ok(ArrayImpl::Time(array.into())) + } + + fn from_timestampus_array( + &self, + array: &arrow_array::TimestampMicrosecondArray, + ) -> Result { + Ok(ArrayImpl::Timestamp(array.into())) + } + + fn from_interval_array( + &self, + array: &arrow_array::IntervalMonthDayNanoArray, + ) -> Result { + Ok(ArrayImpl::Interval(array.into())) + } + + fn from_utf8_array(&self, array: &arrow_array::StringArray) -> Result { + Ok(ArrayImpl::Utf8(array.into())) + } + + fn from_binary_array(&self, array: &arrow_array::BinaryArray) -> Result { + Ok(ArrayImpl::Bytea(array.into())) + } + + fn from_large_utf8_array( + &self, + array: &arrow_array::LargeStringArray, + ) -> Result { + Ok(ArrayImpl::Utf8(array.into())) + } + + fn from_large_binary_array( + &self, + array: &arrow_array::LargeBinaryArray, + ) -> Result { + Ok(ArrayImpl::Bytea(array.into())) + } + + fn from_list_array(&self, array: &arrow_array::ListArray) -> Result { + use arrow_array::Array; + let arrow_schema::DataType::List(field) = array.data_type() else { + panic!("nested field types cannot be determined."); + }; + Ok(ArrayImpl::List(ListArray { + value: Box::new(self.from_array(field, array.values())?), + bitmap: match array.nulls() { + Some(nulls) => nulls.iter().collect(), + None => Bitmap::ones(array.len()), + }, + offsets: array.offsets().iter().map(|o| *o as u32).collect(), + })) + } - fn try_from(value: DataType) -> Result { - (&value).try_into() + fn from_struct_array(&self, array: &arrow_array::StructArray) -> Result { + use arrow_array::Array; + let arrow_schema::DataType::Struct(fields) = array.data_type() else { + panic!("nested field types cannot be determined."); + }; + Ok(ArrayImpl::Struct(StructArray::new( + self.from_fields(fields)?, + array + .columns() + .iter() + .zip_eq_fast(fields) + .map(|(array, field)| self.from_array(field, array).map(Arc::new)) + .try_collect()?, + (0..array.len()).map(|i| array.is_valid(i)).collect(), + ))) } } @@ -830,12 +754,15 @@ converts!(I64Array, arrow_array::Int64Array); converts!(F32Array, arrow_array::Float32Array, @map); converts!(F64Array, arrow_array::Float64Array, @map); converts!(BytesArray, arrow_array::BinaryArray); +converts!(BytesArray, arrow_array::LargeBinaryArray); converts!(Utf8Array, arrow_array::StringArray); +converts!(Utf8Array, arrow_array::LargeStringArray); converts!(DateArray, arrow_array::Date32Array, @map); converts!(TimeArray, arrow_array::Time64MicrosecondArray, @map); converts!(TimestampArray, arrow_array::TimestampMicrosecondArray, @map); converts!(TimestamptzArray, arrow_array::TimestampMicrosecondArray, @map); converts!(IntervalArray, arrow_array::IntervalMonthDayNanoArray, @map); +converts!(SerialArray, arrow_array::Int64Array, @map); /// Converts RisingWave value from and into Arrow value. pub trait FromIntoArrow { @@ -845,6 +772,18 @@ pub trait FromIntoArrow { fn into_arrow(self) -> Self::ArrowType; } +impl FromIntoArrow for Serial { + type ArrowType = i64; + + fn from_arrow(value: Self::ArrowType) -> Self { + value.into() + } + + fn into_arrow(self) -> Self::ArrowType { + self.into() + } +} + impl FromIntoArrow for F32 { type ArrowType = f32; @@ -973,6 +912,17 @@ impl From<&DecimalArray> for arrow_array::LargeBinaryArray { } } +impl From<&DecimalArray> for arrow_array::StringArray { + fn from(array: &DecimalArray) -> Self { + let mut builder = + arrow_array::builder::StringBuilder::with_capacity(array.len(), array.len() * 8); + for value in array.iter() { + builder.append_option(value.map(|d| d.to_string())); + } + builder.finish() + } +} + // This arrow decimal type is used by iceberg source to read iceberg decimal into RW decimal. impl TryFrom<&arrow_array::Decimal128Array> for DecimalArray { type Error = ArrayError; @@ -1020,6 +970,57 @@ impl TryFrom<&arrow_array::LargeBinaryArray> for DecimalArray { } } +impl TryFrom<&arrow_array::StringArray> for DecimalArray { + type Error = ArrayError; + + fn try_from(array: &arrow_array::StringArray) -> Result { + array + .iter() + .map(|o| { + o.map(|s| { + s.parse() + .map_err(|_| ArrayError::from_arrow(format!("invalid decimal: {s:?}"))) + }) + .transpose() + }) + .try_collect() + } +} + +impl From<&JsonbArray> for arrow_array::StringArray { + fn from(array: &JsonbArray) -> Self { + let mut builder = + arrow_array::builder::StringBuilder::with_capacity(array.len(), array.len() * 16); + for value in array.iter() { + match value { + Some(jsonb) => { + write!(&mut builder, "{}", jsonb).unwrap(); + builder.append_value(""); + } + None => builder.append_null(), + } + } + builder.finish() + } +} + +impl TryFrom<&arrow_array::StringArray> for JsonbArray { + type Error = ArrayError; + + fn try_from(array: &arrow_array::StringArray) -> Result { + array + .iter() + .map(|o| { + o.map(|s| { + s.parse() + .map_err(|_| ArrayError::from_arrow(format!("invalid json: {s}"))) + }) + .transpose() + }) + .try_collect() + } +} + impl From<&JsonbArray> for arrow_array::LargeStringArray { fn from(array: &JsonbArray) -> Self { let mut builder = @@ -1088,195 +1089,8 @@ impl From<&arrow_array::Decimal256Array> for Int256Array { } } -impl TryFrom<&ListArray> for arrow_array::ListArray { - type Error = ArrayError; - - fn try_from(array: &ListArray) -> Result { - use arrow_array::builder::*; - fn build( - array: &ListArray, - a: &A, - builder: B, - mut append: F, - ) -> arrow_array::ListArray - where - A: Array, - B: arrow_array::builder::ArrayBuilder, - F: FnMut(&mut B, Option>), - { - let mut builder = ListBuilder::with_capacity(builder, a.len()); - for i in 0..array.len() { - for j in array.offsets[i]..array.offsets[i + 1] { - append(builder.values(), a.value_at(j as usize)); - } - builder.append(!array.is_null(i)); - } - builder.finish() - } - Ok(match &*array.value { - ArrayImpl::Int16(a) => build(array, a, Int16Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - ArrayImpl::Int32(a) => build(array, a, Int32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - ArrayImpl::Int64(a) => build(array, a, Int64Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - - ArrayImpl::Float32(a) => { - build(array, a, Float32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|f| f.0)) - }) - } - ArrayImpl::Float64(a) => { - build(array, a, Float64Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|f| f.0)) - }) - } - ArrayImpl::Utf8(a) => build( - array, - a, - StringBuilder::with_capacity(a.len(), a.data().len()), - |b, v| b.append_option(v), - ), - ArrayImpl::Int256(a) => build( - array, - a, - Decimal256Builder::with_capacity(a.len()).with_data_type( - arrow_schema::DataType::Decimal256(arrow_schema::DECIMAL256_MAX_PRECISION, 0), - ), - |b, v| b.append_option(v.map(Into::into)), - ), - ArrayImpl::Bool(a) => { - build(array, a, BooleanBuilder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }) - } - ArrayImpl::Decimal(a) => build( - array, - a, - LargeBinaryBuilder::with_capacity(a.len(), a.len() * 8), - |b, v| b.append_option(v.map(|d| d.to_string())), - ), - ArrayImpl::Interval(a) => build( - array, - a, - IntervalMonthDayNanoBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Date(a) => build(array, a, Date32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|d| d.into_arrow())) - }), - ArrayImpl::Timestamp(a) => build( - array, - a, - TimestampMicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Timestamptz(a) => build( - array, - a, - TimestampMicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Time(a) => build( - array, - a, - Time64MicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Jsonb(a) => build( - array, - a, - LargeStringBuilder::with_capacity(a.len(), a.len() * 16), - |b, v| b.append_option(v.map(|j| j.to_string())), - ), - ArrayImpl::Serial(_) => todo!("list of serial"), - ArrayImpl::Struct(a) => { - let values = Arc::new(arrow_array::StructArray::try_from(a)?); - arrow_array::ListArray::new( - Arc::new(arrow_schema::Field::new( - "item", - a.data_type().try_into()?, - true, - )), - arrow_buffer::OffsetBuffer::new(arrow_buffer::ScalarBuffer::from( - array - .offsets() - .iter() - .map(|o| *o as i32) - .collect::>(), - )), - values, - Some(array.null_bitmap().into()), - ) - } - ArrayImpl::List(_) => todo!("list of list"), - ArrayImpl::Bytea(a) => build( - array, - a, - BinaryBuilder::with_capacity(a.len(), a.data().len()), - |b, v| b.append_option(v), - ), - }) - } -} - -impl TryFrom<&arrow_array::ListArray> for ListArray { - type Error = ArrayError; - - fn try_from(array: &arrow_array::ListArray) -> Result { - use arrow_array::Array; - Ok(ListArray { - value: Box::new(ArrayImpl::try_from(array.values())?), - bitmap: match array.nulls() { - Some(nulls) => nulls.iter().collect(), - None => Bitmap::ones(array.len()), - }, - offsets: array.offsets().iter().map(|o| *o as u32).collect(), - }) - } -} - -impl TryFrom<&StructArray> for arrow_array::StructArray { - type Error = ArrayError; - - fn try_from(array: &StructArray) -> Result { - Ok(arrow_array::StructArray::new( - array.data_type().as_struct().try_into()?, - array - .fields() - .map(|arr| arr.as_ref().try_into()) - .try_collect::<_, _, ArrayError>()?, - Some(array.null_bitmap().into()), - )) - } -} - -impl TryFrom<&arrow_array::StructArray> for StructArray { - type Error = ArrayError; - - fn try_from(array: &arrow_array::StructArray) -> Result { - use arrow_array::Array; - let arrow_schema::DataType::Struct(fields) = array.data_type() else { - panic!("nested field types cannot be determined."); - }; - Ok(StructArray::new( - fields.into(), - array - .columns() - .iter() - .map(|a| ArrayImpl::try_from(a).map(Arc::new)) - .try_collect()?, - (0..array.len()).map(|i| !array.is_null(i)).collect(), - )) - } -} - #[cfg(test)] mod tests { - use super::arrow_array::Array as _; use super::*; #[test] @@ -1293,6 +1107,20 @@ mod tests { assert_eq!(I16Array::from(&arrow), array); } + #[test] + fn i32() { + let array = I32Array::from_iter([None, Some(-7), Some(25)]); + let arrow = arrow_array::Int32Array::from(&array); + assert_eq!(I32Array::from(&arrow), array); + } + + #[test] + fn i64() { + let array = I64Array::from_iter([None, Some(-7), Some(25)]); + let arrow = arrow_array::Int64Array::from(&array); + assert_eq!(I64Array::from(&arrow), array); + } + #[test] fn f32() { let array = F32Array::from_iter([None, Some(-7.0), Some(25.0)]); @@ -1300,6 +1128,13 @@ mod tests { assert_eq!(F32Array::from(&arrow), array); } + #[test] + fn f64() { + let array = F64Array::from_iter([None, Some(-7.0), Some(25.0)]); + let arrow = arrow_array::Float64Array::from(&array); + assert_eq!(F64Array::from(&arrow), array); + } + #[test] fn date() { let array = DateArray::from_iter([ @@ -1352,6 +1187,13 @@ mod tests { assert_eq!(Utf8Array::from(&arrow), array); } + #[test] + fn binary() { + let array = BytesArray::from_iter([None, Some("array".as_bytes())]); + let arrow = arrow_array::BinaryArray::from(&array); + assert_eq!(BytesArray::from(&arrow), array); + } + #[test] fn decimal() { let array = DecimalArray::from_iter([ @@ -1364,6 +1206,9 @@ mod tests { ]); let arrow = arrow_array::LargeBinaryArray::from(&array); assert_eq!(DecimalArray::try_from(&arrow).unwrap(), array); + + let arrow = arrow_array::StringArray::from(&array); + assert_eq!(DecimalArray::try_from(&arrow).unwrap(), array); } #[test] @@ -1378,6 +1223,9 @@ mod tests { ]); let arrow = arrow_array::LargeStringArray::from(&array); assert_eq!(JsonbArray::try_from(&arrow).unwrap(), array); + + let arrow = arrow_array::StringArray::from(&array); + assert_eq!(JsonbArray::try_from(&arrow).unwrap(), array); } #[test] @@ -1403,62 +1251,4 @@ mod tests { let arrow = arrow_array::Decimal256Array::from(&array); assert_eq!(Int256Array::from(&arrow), array); } - - #[test] - fn struct_array() { - // Empty array - risingwave to arrow conversion. - let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0)); - assert_eq!( - arrow_array::StructArray::try_from(&test_arr).unwrap().len(), - 0 - ); - - // Empty array - arrow to risingwave conversion. - let test_arr_2 = arrow_array::StructArray::from(vec![]); - assert_eq!(StructArray::try_from(&test_arr_2).unwrap().len(), 0); - - // Struct array with primitive types. arrow to risingwave conversion. - let test_arrow_struct_array = arrow_array::StructArray::try_from(vec![ - ( - "a", - Arc::new(arrow_array::BooleanArray::from(vec![ - Some(false), - Some(false), - Some(true), - None, - ])) as arrow_array::ArrayRef, - ), - ( - "b", - Arc::new(arrow_array::Int32Array::from(vec![ - Some(42), - Some(28), - Some(19), - None, - ])) as arrow_array::ArrayRef, - ), - ]) - .unwrap(); - let actual_risingwave_struct_array = - StructArray::try_from(&test_arrow_struct_array).unwrap(); - let expected_risingwave_struct_array = StructArray::new( - StructType::new(vec![("a", DataType::Boolean), ("b", DataType::Int32)]), - vec![ - BoolArray::from_iter([Some(false), Some(false), Some(true), None]).into_ref(), - I32Array::from_iter([Some(42), Some(28), Some(19), None]).into_ref(), - ], - [true, true, true, true].into_iter().collect(), - ); - assert_eq!( - expected_risingwave_struct_array, - actual_risingwave_struct_array - ); - } - - #[test] - fn list() { - let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]); - let arrow = arrow_array::ListArray::try_from(&array).unwrap(); - assert_eq!(ListArray::try_from(&arrow).unwrap(), array); - } } diff --git a/src/common/src/array/arrow/arrow_udf.rs b/src/common/src/array/arrow/arrow_udf.rs new file mode 100644 index 0000000000000..e2f9e39ad385a --- /dev/null +++ b/src/common/src/array/arrow/arrow_udf.rs @@ -0,0 +1,177 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! This is for arrow dependency named `arrow-xxx` such as `arrow-array` in the cargo workspace. +//! +//! This should the default arrow version to be used in our system. +//! +//! The corresponding version of arrow is currently used by `udf` and `iceberg` sink. + +use std::sync::Arc; + +pub use arrow_impl::{FromArrow, ToArrow}; +use {arrow_array, arrow_buffer, arrow_cast, arrow_schema}; + +use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray}; + +#[expect(clippy::duplicate_mod)] +#[path = "./arrow_impl.rs"] +mod arrow_impl; + +/// Arrow conversion for the current version of UDF. This is in use but will be deprecated soon. +/// +/// In the current version of UDF protocol, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types. +pub struct UdfArrowConvert; + +impl ToArrow for UdfArrowConvert { + // Decimal values are stored as ASCII text representation in a large binary array. + fn decimal_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &DecimalArray, + ) -> Result { + Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) + } + + // JSON values are stored as text representation in a large string array. + fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result { + Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + } + + fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true) + } + + fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true) + } +} + +impl FromArrow for UdfArrowConvert { + fn from_large_utf8(&self) -> Result { + Ok(DataType::Jsonb) + } + + fn from_large_binary(&self) -> Result { + Ok(DataType::Decimal) + } + + fn from_large_utf8_array( + &self, + array: &arrow_array::LargeStringArray, + ) -> Result { + Ok(ArrayImpl::Jsonb(array.try_into()?)) + } + + fn from_large_binary_array( + &self, + array: &arrow_array::LargeBinaryArray, + ) -> Result { + Ok(ArrayImpl::Decimal(array.try_into()?)) + } +} + +/// Arrow conversion for the next version of UDF. This is unused for now. +/// +/// In the next version of UDF protocol, decimal and jsonb types will be mapped to Arrow extension types. +/// See . +pub struct NewUdfArrowConvert; + +impl ToArrow for NewUdfArrowConvert {} +impl FromArrow for NewUdfArrowConvert {} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::array::*; + use crate::buffer::Bitmap; + + #[test] + fn struct_array() { + // Empty array - risingwave to arrow conversion. + let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0)); + assert_eq!( + UdfArrowConvert + .struct_to_arrow( + &arrow_schema::DataType::Struct(arrow_schema::Fields::empty()), + &test_arr + ) + .unwrap() + .len(), + 0 + ); + + // Empty array - arrow to risingwave conversion. + let test_arr_2 = arrow_array::StructArray::from(vec![]); + assert_eq!( + UdfArrowConvert + .from_struct_array(&test_arr_2) + .unwrap() + .len(), + 0 + ); + + // Struct array with primitive types. arrow to risingwave conversion. + let test_arrow_struct_array = arrow_array::StructArray::try_from(vec![ + ( + "a", + Arc::new(arrow_array::BooleanArray::from(vec![ + Some(false), + Some(false), + Some(true), + None, + ])) as arrow_array::ArrayRef, + ), + ( + "b", + Arc::new(arrow_array::Int32Array::from(vec![ + Some(42), + Some(28), + Some(19), + None, + ])) as arrow_array::ArrayRef, + ), + ]) + .unwrap(); + let actual_risingwave_struct_array = UdfArrowConvert + .from_struct_array(&test_arrow_struct_array) + .unwrap() + .into_struct(); + let expected_risingwave_struct_array = StructArray::new( + StructType::new(vec![("a", DataType::Boolean), ("b", DataType::Int32)]), + vec![ + BoolArray::from_iter([Some(false), Some(false), Some(true), None]).into_ref(), + I32Array::from_iter([Some(42), Some(28), Some(19), None]).into_ref(), + ], + [true, true, true, true].into_iter().collect(), + ); + assert_eq!( + expected_risingwave_struct_array, + actual_risingwave_struct_array + ); + } + + #[test] + fn list() { + let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]); + let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true); + let arrow = UdfArrowConvert.list_to_arrow(&data_type, &array).unwrap(); + let rw_array = UdfArrowConvert + .from_list_array(arrow.as_any().downcast_ref().unwrap()) + .unwrap(); + assert_eq!(rw_array.as_list(), &array); + } +} diff --git a/src/common/src/array/arrow/mod.rs b/src/common/src/array/arrow/mod.rs index 4baea60f11b3e..67490b22315a1 100644 --- a/src/common/src/array/arrow/mod.rs +++ b/src/common/src/array/arrow/mod.rs @@ -12,12 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod arrow_default; mod arrow_deltalake; mod arrow_iceberg; +mod arrow_udf; -pub use arrow_default::{ - to_record_batch_with_schema, ToArrowArrayWithTypeConvert, ToArrowTypeConvert, -}; -pub use arrow_deltalake::to_deltalake_record_batch_with_schema; -pub use arrow_iceberg::{iceberg_to_arrow_type, to_iceberg_record_batch_with_schema}; +pub use arrow_deltalake::DeltaLakeConvert; +pub use arrow_iceberg::IcebergArrowConvert; +pub use arrow_udf::{FromArrow, ToArrow, UdfArrowConvert}; diff --git a/src/common/src/array/bytes_array.rs b/src/common/src/array/bytes_array.rs index 1257730b63c96..2019c37271919 100644 --- a/src/common/src/array/bytes_array.rs +++ b/src/common/src/array/bytes_array.rs @@ -108,12 +108,6 @@ impl Array for BytesArray { } } -impl BytesArray { - pub(super) fn data(&self) -> &[u8] { - &self.data - } -} - impl<'a> FromIterator> for BytesArray { fn from_iter>>(iter: I) -> Self { let iter = iter.into_iter(); diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index dae9a4a94bc93..748ab6777fa49 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -239,6 +239,11 @@ impl ListArray { } } + /// Return the inner array of the list array. + pub fn values(&self) -> &ArrayImpl { + &self.value + } + pub fn from_protobuf(array: &PbArray) -> ArrayResult { ensure!( array.values.is_empty(), diff --git a/src/common/src/array/mod.rs b/src/common/src/array/mod.rs index 93f6255038e9e..268c518b70c01 100644 --- a/src/common/src/array/mod.rs +++ b/src/common/src/array/mod.rs @@ -14,11 +14,7 @@ //! `Array` defines all in-memory representations of vectorized execution framework. -mod arrow; -pub use arrow::{ - iceberg_to_arrow_type, to_deltalake_record_batch_with_schema, - to_iceberg_record_batch_with_schema, to_record_batch_with_schema, -}; +pub mod arrow; mod bool_array; pub mod bytes_array; mod chrono_array; diff --git a/src/common/src/array/utf8_array.rs b/src/common/src/array/utf8_array.rs index eddcc50bb8ec2..72068d80733a2 100644 --- a/src/common/src/array/utf8_array.rs +++ b/src/common/src/array/utf8_array.rs @@ -112,10 +112,6 @@ impl Utf8Array { } builder.finish() } - - pub(super) fn data(&self) -> &[u8] { - self.bytes.data() - } } /// `Utf8ArrayBuilder` use `&str` to build an `Utf8Array`. @@ -297,12 +293,6 @@ mod tests { let array = Utf8Array::from_iter(&input); assert_eq!(array.len(), input.len()); - - assert_eq!( - array.bytes.data().len(), - input.iter().map(|s| s.unwrap_or("").len()).sum::() - ); - assert_eq!(input, array.iter().collect_vec()); } diff --git a/src/connector/src/sink/deltalake.rs b/src/connector/src/sink/deltalake.rs index 388a75c264a71..e81bb141a4c37 100644 --- a/src/connector/src/sink/deltalake.rs +++ b/src/connector/src/sink/deltalake.rs @@ -25,7 +25,8 @@ use deltalake::table::builder::s3_storage_options::{ }; use deltalake::writer::{DeltaWriter, RecordBatchWriter}; use deltalake::DeltaTable; -use risingwave_common::array::{to_deltalake_record_batch_with_schema, StreamChunk}; +use risingwave_common::array::arrow::DeltaLakeConvert; +use risingwave_common::array::StreamChunk; use risingwave_common::bail; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::Schema; @@ -388,7 +389,8 @@ impl DeltaLakeSinkWriter { } async fn write(&mut self, chunk: StreamChunk) -> Result<()> { - let a = to_deltalake_record_batch_with_schema(self.dl_schema.clone(), &chunk) + let a = DeltaLakeConvert + .to_record_batch(self.dl_schema.clone(), &chunk) .context("convert record batch error") .map_err(SinkError::DeltaLake)?; self.writer.write(a).await?; diff --git a/src/connector/src/sink/iceberg/mod.rs b/src/connector/src/sink/iceberg/mod.rs index 22972d5629076..f2e3b45d5915c 100644 --- a/src/connector/src/sink/iceberg/mod.rs +++ b/src/connector/src/sink/iceberg/mod.rs @@ -41,9 +41,8 @@ use icelake::transaction::Transaction; use icelake::types::{data_file_from_json, data_file_to_json, Any, DataFile}; use icelake::{Table, TableIdentifier}; use itertools::Itertools; -use risingwave_common::array::{ - iceberg_to_arrow_type, to_iceberg_record_batch_with_schema, Op, StreamChunk, -}; +use risingwave_common::array::arrow::{IcebergArrowConvert, ToArrow}; +use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::bail; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::Schema; @@ -475,7 +474,7 @@ impl IcebergSink { .try_into() .map_err(|err: icelake::Error| SinkError::Iceberg(anyhow!(err)))?; - try_matches_arrow_schema(&sink_schema, &iceberg_schema, false) + try_matches_arrow_schema(&sink_schema, &iceberg_schema) .map_err(|err| SinkError::Iceberg(anyhow!(err)))?; Ok(table) @@ -797,14 +796,15 @@ impl SinkWriter for IcebergWriter { let filters = chunk.visibility() & ops.iter().map(|op| *op == Op::Insert).collect::(); chunk.set_visibility(filters); - let chunk = - to_iceberg_record_batch_with_schema(self.schema.clone(), &chunk.compact()) - .map_err(|err| SinkError::Iceberg(anyhow!(err)))?; + let chunk = IcebergArrowConvert + .to_record_batch(self.schema.clone(), &chunk.compact()) + .map_err(|err| SinkError::Iceberg(anyhow!(err)))?; writer.write(chunk).await?; } IcebergWriterEnum::Upsert(writer) => { - let chunk = to_iceberg_record_batch_with_schema(self.schema.clone(), &chunk) + let chunk = IcebergArrowConvert + .to_record_batch(self.schema.clone(), &chunk) .map_err(|err| SinkError::Iceberg(anyhow!(err)))?; writer @@ -1002,11 +1002,9 @@ impl SinkCommitCoordinator for IcebergSinkCommitter { } /// Try to match our schema with iceberg schema. -/// `for_source` = true means the schema is used for source, otherwise it's used for sink. pub fn try_matches_arrow_schema( rw_schema: &Schema, arrow_schema: &ArrowSchema, - for_source: bool, ) -> anyhow::Result<()> { if rw_schema.fields.len() != arrow_schema.fields().len() { bail!( @@ -1029,17 +1027,11 @@ pub fn try_matches_arrow_schema( .ok_or_else(|| anyhow!("Field {} not found in our schema", arrow_field.name()))?; // Iceberg source should be able to read iceberg decimal type. - // Since the arrow type default conversion is used by udf, in udf, decimal is converted to - // large binary type which is not compatible with iceberg decimal type, - // so we need to convert it to decimal type manually. - let converted_arrow_data_type = if for_source - && matches!(our_field_type, risingwave_common::types::DataType::Decimal) - { - // RisingWave decimal type cannot specify precision and scale, so we use the default value. - ArrowDataType::Decimal128(38, 0) - } else { - iceberg_to_arrow_type(our_field_type).map_err(|e| anyhow!(e))? - }; + let converted_arrow_data_type = IcebergArrowConvert + .to_arrow_field("", our_field_type) + .map_err(|e| anyhow!(e))? + .data_type() + .clone(); let compatible = match (&converted_arrow_data_type, arrow_field.data_type()) { (ArrowDataType::Decimal128(_, _), ArrowDataType::Decimal128(_, _)) => true, @@ -1080,7 +1072,7 @@ mod test { ArrowField::new("c", ArrowDataType::Int32, false), ]); - try_matches_arrow_schema(&risingwave_schema, &arrow_schema, false).unwrap(); + try_matches_arrow_schema(&risingwave_schema, &arrow_schema).unwrap(); let risingwave_schema = Schema::new(vec![ Field::with_name(DataType::Int32, "d"), @@ -1094,7 +1086,7 @@ mod test { ArrowField::new("d", ArrowDataType::Int32, false), ArrowField::new("c", ArrowDataType::Int32, false), ]); - try_matches_arrow_schema(&risingwave_schema, &arrow_schema, false).unwrap(); + try_matches_arrow_schema(&risingwave_schema, &arrow_schema).unwrap(); } #[test] diff --git a/src/connector/src/source/pulsar/source/reader.rs b/src/connector/src/source/pulsar/source/reader.rs index ce808b96200cf..967874b62c335 100644 --- a/src/connector/src/source/pulsar/source/reader.rs +++ b/src/connector/src/source/pulsar/source/reader.rs @@ -27,7 +27,8 @@ use itertools::Itertools; use pulsar::consumer::InitialPosition; use pulsar::message::proto::MessageIdData; use pulsar::{Consumer, ConsumerBuilder, ConsumerOptions, Pulsar, SubType, TokioExecutor}; -use risingwave_common::array::{DataChunk, StreamChunk}; +use risingwave_common::array::arrow::{FromArrow, IcebergArrowConvert}; +use risingwave_common::array::StreamChunk; use risingwave_common::catalog::ROWID_PREFIX; use risingwave_common::{bail, ensure}; use thiserror_ext::AsReport; @@ -508,7 +509,8 @@ impl PulsarIcebergReader { offsets.push(offset); } - let data_chunk = DataChunk::try_from(&record_batch.project(&field_indices)?) + let data_chunk = IcebergArrowConvert + .from_record_batch(&record_batch.project(&field_indices)?) .context("failed to convert arrow record batch to data chunk")?; let stream_chunk = StreamChunk::from(data_chunk); diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index ed7d597cce52a..b9103b62649e5 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -13,13 +13,12 @@ // limitations under the License. use std::collections::HashMap; -use std::convert::TryFrom; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::{Arc, LazyLock, Weak}; use std::time::Duration; use anyhow::{Context, Error}; -use arrow_schema::{Field, Fields, Schema}; +use arrow_schema::{Fields, Schema, SchemaRef}; use arrow_udf_js::{CallMode as JsCallMode, Runtime as JsRuntime}; #[cfg(feature = "embedded-deno-udf")] use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; @@ -29,14 +28,14 @@ use arrow_udf_wasm::Runtime as WasmRuntime; use await_tree::InstrumentAwait; use cfg_or_panic::cfg_or_panic; use moka::sync::Cache; -use risingwave_common::array::{ArrayError, ArrayRef, DataChunk}; +use risingwave_common::array::arrow::{FromArrow, ToArrow, UdfArrowConvert}; +use risingwave_common::array::{ArrayRef, DataChunk}; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; use risingwave_expr::expr_context::FRAGMENT_ID; use risingwave_pb::expr::ExprNode; use risingwave_udf::metrics::GLOBAL_METRICS; use risingwave_udf::ArrowFlightUdfClient; -use thiserror_ext::AsReport; use super::{BoxedExpression, Build}; use crate::expr::Expression; @@ -47,8 +46,7 @@ pub struct UserDefinedFunction { children: Vec, arg_types: Vec, return_type: DataType, - #[expect(dead_code)] - arg_schema: Arc, + arg_schema: SchemaRef, imp: UdfImpl, identifier: String, span: await_tree::Span, @@ -117,7 +115,7 @@ impl Expression for UserDefinedFunction { impl UserDefinedFunction { async fn eval_inner(&self, input: &DataChunk) -> Result { // this will drop invisible rows - let arrow_input = arrow_array::RecordBatch::try_from(input)?; + let arrow_input = UdfArrowConvert.to_record_batch(self.arg_schema.clone(), input)?; // metrics let metrics = &*GLOBAL_METRICS; @@ -225,7 +223,7 @@ impl UserDefinedFunction { ); } - let output = DataChunk::try_from(&arrow_output)?; + let output = UdfArrowConvert.from_record_batch(&arrow_output)?; let output = output.uncompact(input.visibility().clone()); let Some(array) = output.columns().first() else { @@ -251,6 +249,11 @@ impl Build for UserDefinedFunction { let return_type = DataType::from(prost.get_return_type().unwrap()); let udf = prost.get_rex_node().unwrap().as_udf().unwrap(); + let arrow_return_type = UdfArrowConvert + .to_arrow_field("", &return_type)? + .data_type() + .clone(); + #[cfg(not(feature = "embedded-deno-udf"))] let runtime = "quickjs"; @@ -280,7 +283,7 @@ impl Build for UserDefinedFunction { ); rt.add_function( identifier, - arrow_schema::DataType::try_from(&return_type)?, + arrow_return_type, JsCallMode::CalledOnNullInput, &body, )?; @@ -321,7 +324,7 @@ impl Build for UserDefinedFunction { futures::executor::block_on(rt.add_function( identifier, - arrow_schema::DataType::try_from(&return_type)?, + arrow_return_type, DenoCallMode::CalledOnNullInput, &body, ))?; @@ -334,7 +337,7 @@ impl Build for UserDefinedFunction { let body = udf.get_body()?; rt.add_function( identifier, - arrow_schema::DataType::try_from(&return_type)?, + arrow_return_type, PythonCallMode::CalledOnNullInput, body, )?; @@ -352,15 +355,7 @@ impl Build for UserDefinedFunction { let arg_schema = Arc::new(Schema::new( udf.arg_types .iter() - .map::, _>(|t| { - Ok(Field::new( - "", - DataType::from(t).try_into().map_err(|e: ArrayError| { - risingwave_udf::Error::unsupported(e.to_report_string()) - })?, - true, - )) - }) + .map(|t| UdfArrowConvert.to_arrow_field("", &DataType::from(t))) .try_collect::()?, )); diff --git a/src/expr/core/src/table_function/mod.rs b/src/expr/core/src/table_function/mod.rs index 9d2d747cc4ed5..f1593cc4015c8 100644 --- a/src/expr/core/src/table_function/mod.rs +++ b/src/expr/core/src/table_function/mod.rs @@ -16,7 +16,6 @@ use either::Either; use futures_async_stream::try_stream; use futures_util::stream::BoxStream; use futures_util::StreamExt; -use itertools::Itertools; use risingwave_common::array::{Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk}; use risingwave_common::types::{DataType, DatumRef}; use risingwave_pb::expr::project_set_select_item::SelectItem; @@ -129,6 +128,7 @@ pub fn build( chunk_size: usize, children: Vec, ) -> Result { + use itertools::Itertools; let args = children.iter().map(|t| t.return_type()).collect_vec(); let desc = crate::sig::FUNCTION_REGISTRY .get(func, &args, &return_type) diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index b65ee5e77758b..4362ff27b57b9 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use anyhow::Context; use arrow_array::RecordBatch; -use arrow_schema::{Field, Fields, Schema, SchemaRef}; +use arrow_schema::{Fields, Schema, SchemaRef}; use arrow_udf_js::{CallMode as JsCallMode, Runtime as JsRuntime}; #[cfg(feature = "embedded-deno-udf")] use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; @@ -24,9 +24,9 @@ use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; use arrow_udf_python::{CallMode as PythonCallMode, Runtime as PythonRuntime}; use cfg_or_panic::cfg_or_panic; use futures_util::stream; -use risingwave_common::array::{ArrayError, DataChunk, I32Array}; +use risingwave_common::array::arrow::{FromArrow, ToArrow, UdfArrowConvert}; +use risingwave_common::array::{DataChunk, I32Array}; use risingwave_common::bail; -use thiserror_ext::AsReport; use super::*; use crate::expr::expr_udf::UdfImpl; @@ -34,7 +34,6 @@ use crate::expr::expr_udf::UdfImpl; #[derive(Debug)] pub struct UserDefinedTableFunction { children: Vec, - #[allow(dead_code)] arg_schema: SchemaRef, return_type: DataType, client: UdfImpl, @@ -109,9 +108,10 @@ impl UserDefinedTableFunction { let direct_input = DataChunk::new(columns, input.visibility().clone()); // compact the input chunk and record the row mapping - let visible_rows = direct_input.visibility().iter_ones().collect_vec(); - let compacted_input = direct_input.compact_cow(); - let arrow_input = RecordBatch::try_from(compacted_input.as_ref())?; + let visible_rows = direct_input.visibility().iter_ones().collect::>(); + // this will drop invisible rows + let arrow_input = + UdfArrowConvert.to_record_batch(self.arg_schema.clone(), &direct_input)?; // call UDTF #[for_await] @@ -119,7 +119,7 @@ impl UserDefinedTableFunction { .client .call_table_function(&self.identifier, arrow_input) { - let output = DataChunk::try_from(&res?)?; + let output = UdfArrowConvert.from_record_batch(&res?)?; self.check_output(&output)?; // we send the compacted input to UDF, so we need to map the row indices back to the @@ -182,21 +182,18 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result, _>(|t| { - Ok(Field::new( - "", - DataType::from(t).try_into().map_err(|e: ArrayError| { - risingwave_udf::Error::unsupported(e.to_report_string()) - })?, - true, - )) - }) - .try_collect::<_, Fields, _>()?, + .map(|t| UdfArrowConvert.to_arrow_field("", &DataType::from(t))) + .try_collect::()?, )); let identifier = udtf.get_identifier()?; let return_type = DataType::from(prost.get_return_type()?); + let arrow_return_type = UdfArrowConvert + .to_arrow_field("", &return_type)? + .data_type() + .clone(); + #[cfg(not(feature = "embedded-deno-udf"))] let runtime = "quickjs"; @@ -224,7 +221,7 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result Result Result Result { @@ -91,15 +95,21 @@ fn build(return_type: DataType, mut children: Vec) -> Result) -> Result) -> Result arrow_schema::Field { - arrow_schema::Field::new("", data_type, true) + fn to_field(data_type: &DataType) -> Result { + Ok(UdfArrowConvert.to_arrow_field("", data_type)?) } let args = arrow_schema::Schema::new( arg_types .iter() - .map::, _>(|t| Ok(to_field(t.try_into()?))) + .map(to_field) .try_collect::<_, Fields, _>()?, ); let returns = arrow_schema::Schema::new(match kind { - Kind::Scalar(_) => vec![to_field(return_type.clone().try_into()?)], + Kind::Scalar(_) => vec![to_field(&return_type)?], Kind::Table(_) => vec![ arrow_schema::Field::new("row_index", arrow_schema::DataType::Int32, true), - to_field(return_type.clone().try_into()?), + to_field(&return_type)?, ], _ => unreachable!(), }); diff --git a/src/frontend/src/handler/create_sink.rs b/src/frontend/src/handler/create_sink.rs index bed409de178f1..50f6914eee252 100644 --- a/src/frontend/src/handler/create_sink.rs +++ b/src/frontend/src/handler/create_sink.rs @@ -22,8 +22,9 @@ use either::Either; use itertools::Itertools; use maplit::{convert_args, hashmap}; use pgwire::pg_response::{PgResponse, StatementType}; +use risingwave_common::array::arrow::{FromArrow, IcebergArrowConvert}; use risingwave_common::catalog::{ConnectionId, DatabaseId, SchemaId, TableId, UserId}; -use risingwave_common::types::{DataType, Datum}; +use risingwave_common::types::Datum; use risingwave_common::util::value_encoding::DatumFromProtoExt; use risingwave_common::{bail, catalog}; use risingwave_connector::sink::catalog::{SinkCatalog, SinkFormatDesc, SinkType}; @@ -351,14 +352,14 @@ async fn get_partition_compute_info_for_iceberg( }) .collect::>>()?; - let DataType::Struct(partition_type) = arrow_type.into() else { + let ArrowDataType::Struct(partition_type) = arrow_type else { return Err(RwError::from(ErrorCode::SinkError( "Partition type of iceberg should be a struct type".into(), ))); }; Ok(Some(PartitionComputeInfo::Iceberg(IcebergPartitionInfo { - partition_type, + partition_type: IcebergArrowConvert.from_fields(&partition_type)?, partition_fields, }))) } diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index 0830cdb5392de..6a19cb13fad31 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -21,6 +21,7 @@ use either::Either; use itertools::Itertools; use maplit::{convert_args, hashmap}; use pgwire::pg_response::{PgResponse, StatementType}; +use risingwave_common::array::arrow::{FromArrow, IcebergArrowConvert}; use risingwave_common::bail_not_implemented; use risingwave_common::catalog::{ is_column_ids_dedup, ColumnCatalog, ColumnDesc, ColumnId, Schema, TableId, @@ -1219,11 +1220,10 @@ pub async fn extract_iceberg_columns( .iter() .enumerate() .map(|(i, field)| { - let data_type = field.data_type().clone(); let column_desc = ColumnDesc::named( field.name(), ColumnId::new((i as u32).try_into().unwrap()), - data_type.into(), + IcebergArrowConvert.from_field(field).unwrap(), ); ColumnCatalog { column_desc, @@ -1288,11 +1288,7 @@ pub async fn check_iceberg_source( .collect::>(); let new_iceberg_schema = arrow_schema::Schema::new(new_iceberg_field); - risingwave_connector::sink::iceberg::try_matches_arrow_schema( - &schema, - &new_iceberg_schema, - true, - )?; + risingwave_connector::sink::iceberg::try_matches_arrow_schema(&schema, &new_iceberg_schema)?; Ok(()) }