Skip to content

Commit

Permalink
introduce NewUdfArrowConvert
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 committed May 7, 2024
1 parent 3c6e87a commit b1f3cf6
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 25 deletions.
91 changes: 72 additions & 19 deletions src/common/src/array/arrow/arrow_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,20 +201,20 @@ pub trait ToArrow {
Ok(Arc::new(arrow_array::BinaryArray::from(array)))
}

// Decimal values are stored as ASCII text representation in a large binary array.
// Decimal values are stored as ASCII text representation in a string array.
#[inline]
fn decimal_to_arrow(
&self,
_data_type: &arrow_schema::DataType,
array: &DecimalArray,
) -> Result<arrow_array::ArrayRef, ArrayError> {
Ok(Arc::new(arrow_array::LargeBinaryArray::from(array)))
Ok(Arc::new(arrow_array::StringArray::from(array)))
}

// JSON values are stored as text representation in a large string array.
// JSON values are stored as text representation in a string array.
#[inline]
fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result<arrow_array::ArrayRef, ArrayError> {
Ok(Arc::new(arrow_array::LargeStringArray::from(array)))
Ok(Arc::new(arrow_array::StringArray::from(array)))
}

#[inline]
Expand Down Expand Up @@ -366,7 +366,8 @@ pub trait ToArrow {

#[inline]
fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true)
arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
.with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into())
}

#[inline]
Expand All @@ -376,7 +377,8 @@ pub trait ToArrow {

#[inline]
fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true)
arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true)
.with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into())
}

#[inline]
Expand Down Expand Up @@ -414,8 +416,8 @@ pub trait FromArrow {
/// Converts Arrow `RecordBatch` to RisingWave `DataChunk`.
fn from_record_batch(&self, batch: &arrow_array::RecordBatch) -> Result<DataChunk, ArrayError> {
let mut columns = Vec::with_capacity(batch.num_columns());
for array in batch.columns() {
let column = Arc::new(self.from_array(array)?);
for (array, field) in batch.columns().iter().zip_eq_fast(batch.schema().fields()) {
let column = Arc::new(self.from_array(field, array)?);
columns.push(column);
}
Ok(DataChunk::new(columns, batch.num_rows()))
Expand Down Expand Up @@ -472,30 +474,44 @@ pub trait FromArrow {

/// Converts Arrow `LargeUtf8` type to RisingWave data type.
fn from_large_utf8(&self) -> Result<DataType, ArrayError> {
Ok(DataType::Jsonb)
Ok(DataType::Varchar)
}

/// Converts Arrow `LargeBinary` type to RisingWave data type.
fn from_large_binary(&self) -> Result<DataType, ArrayError> {
Ok(DataType::Decimal)
Ok(DataType::Bytea)
}

/// Converts Arrow extension type to RisingWave `DataType`.
fn from_extension_type(
&self,
type_name: &str,
_physical_type: &arrow_schema::DataType,
physical_type: &arrow_schema::DataType,
) -> Result<DataType, ArrayError> {
Err(ArrayError::from_arrow(format!(
"unsupported extension type: {type_name:?}"
)))
match (type_name, physical_type) {
("arrowudf.decimal", arrow_schema::DataType::Utf8) => Ok(DataType::Decimal),
("arrowudf.json", arrow_schema::DataType::Utf8) => Ok(DataType::Jsonb),
_ => Err(ArrayError::from_arrow(format!(
"unsupported extension type: {type_name:?}"
))),
}
}

/// Converts Arrow `Array` to RisingWave `ArrayImpl`.
fn from_array(&self, array: &arrow_array::ArrayRef) -> Result<ArrayImpl, ArrayError> {
fn from_array(
&self,
field: &arrow_schema::Field,
array: &arrow_array::ArrayRef,
) -> Result<ArrayImpl, ArrayError> {
use arrow_schema::DataType::*;
use arrow_schema::IntervalUnit::*;
use arrow_schema::TimeUnit::*;

// extension type
if let Some(type_name) = field.metadata().get("ARROW:extension:name") {
return self.from_extension_array(type_name, array);
}

match array.data_type() {
Boolean => self.from_bool_array(array.as_any().downcast_ref().unwrap()),
Int16 => self.from_int16_array(array.as_any().downcast_ref().unwrap()),
Expand Down Expand Up @@ -524,6 +540,37 @@ pub trait FromArrow {
}
}

/// Converts Arrow extension array to RisingWave `ArrayImpl`.
fn from_extension_array(
&self,
type_name: &str,
array: &arrow_array::ArrayRef,
) -> Result<ArrayImpl, ArrayError> {
match type_name {
"arrowudf.decimal" => {
let array: &arrow_array::StringArray =
array.as_any().downcast_ref().ok_or_else(|| {
ArrayError::from_arrow(
"expected string array for `arrowudf.decimal`".to_string(),
)
})?;
Ok(ArrayImpl::Decimal(array.try_into()?))
}
"arrowudf.json" => {
let array: &arrow_array::StringArray =
array.as_any().downcast_ref().ok_or_else(|| {
ArrayError::from_arrow(
"expected string array for `arrowudf.json`".to_string(),
)
})?;
Ok(ArrayImpl::Jsonb(array.try_into()?))
}
_ => Err(ArrayError::from_arrow(format!(
"unsupported extension type: {type_name:?}"
))),
}
}

fn from_bool_array(&self, array: &arrow_array::BooleanArray) -> Result<ArrayImpl, ArrayError> {
Ok(ArrayImpl::Bool(array.into()))
}
Expand Down Expand Up @@ -598,20 +645,23 @@ pub trait FromArrow {
&self,
array: &arrow_array::LargeStringArray,
) -> Result<ArrayImpl, ArrayError> {
Ok(ArrayImpl::Jsonb(array.try_into()?))
Ok(ArrayImpl::Utf8(array.into()))
}

fn from_large_binary_array(
&self,
array: &arrow_array::LargeBinaryArray,
) -> Result<ArrayImpl, ArrayError> {
Ok(ArrayImpl::Decimal(array.try_into()?))
Ok(ArrayImpl::Bytea(array.into()))
}

fn from_list_array(&self, array: &arrow_array::ListArray) -> Result<ArrayImpl, ArrayError> {
use arrow_array::Array;
let arrow_schema::DataType::List(field) = array.data_type() else {
panic!("nested field types cannot be determined.");
};
Ok(ArrayImpl::List(ListArray {
value: Box::new(self.from_array(array.values())?),
value: Box::new(self.from_array(field, array.values())?),
bitmap: match array.nulls() {
Some(nulls) => nulls.iter().collect(),
None => Bitmap::ones(array.len()),
Expand All @@ -630,7 +680,8 @@ pub trait FromArrow {
array
.columns()
.iter()
.map(|a| self.from_array(a).map(Arc::new))
.zip_eq_fast(fields)
.map(|(array, field)| self.from_array(field, array).map(Arc::new))
.try_collect()?,
(0..array.len()).map(|i| array.is_valid(i)).collect(),
)))
Expand Down Expand Up @@ -703,7 +754,9 @@ converts!(I64Array, arrow_array::Int64Array);
converts!(F32Array, arrow_array::Float32Array, @map);
converts!(F64Array, arrow_array::Float64Array, @map);
converts!(BytesArray, arrow_array::BinaryArray);
converts!(BytesArray, arrow_array::LargeBinaryArray);
converts!(Utf8Array, arrow_array::StringArray);
converts!(Utf8Array, arrow_array::LargeStringArray);
converts!(DateArray, arrow_array::Date32Array, @map);
converts!(TimeArray, arrow_array::Time64MicrosecondArray, @map);
converts!(TimestampArray, arrow_array::TimestampMicrosecondArray, @map);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,78 @@
//!
//! The corresponding version of arrow is currently used by `udf` and `iceberg` sink.
use std::sync::Arc;

pub use arrow_impl::{FromArrow, ToArrow};
use {arrow_array, arrow_buffer, arrow_cast, arrow_schema};

use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray};

#[expect(clippy::duplicate_mod)]
#[path = "./arrow_impl.rs"]
mod arrow_impl;

/// Arrow conversion for the current version of UDF. This is in use but will be deprecated soon.
///
/// In the current version of UDF protocol, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types.
pub struct UdfArrowConvert;

impl ToArrow for UdfArrowConvert {}
impl FromArrow for UdfArrowConvert {}
impl ToArrow for UdfArrowConvert {
// Decimal values are stored as ASCII text representation in a large binary array.
fn decimal_to_arrow(
&self,
_data_type: &arrow_schema::DataType,
array: &DecimalArray,
) -> Result<arrow_array::ArrayRef, ArrayError> {
Ok(Arc::new(arrow_array::LargeBinaryArray::from(array)))
}

// JSON values are stored as text representation in a large string array.
fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result<arrow_array::ArrayRef, ArrayError> {
Ok(Arc::new(arrow_array::LargeStringArray::from(array)))
}

fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true)
}

fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field {
arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true)
}
}

impl FromArrow for UdfArrowConvert {
fn from_large_utf8(&self) -> Result<DataType, ArrayError> {
Ok(DataType::Jsonb)
}

fn from_large_binary(&self) -> Result<DataType, ArrayError> {
Ok(DataType::Decimal)
}

fn from_large_utf8_array(
&self,
array: &arrow_array::LargeStringArray,
) -> Result<ArrayImpl, ArrayError> {
Ok(ArrayImpl::Jsonb(array.try_into()?))
}

fn from_large_binary_array(
&self,
array: &arrow_array::LargeBinaryArray,
) -> Result<ArrayImpl, ArrayError> {
Ok(ArrayImpl::Decimal(array.try_into()?))
}
}

/// Arrow conversion for the next version of UDF. This is unused for now.
///
/// In the next version of UDF protocol, decimal and jsonb types will be mapped to Arrow extension types.
/// See <https://github.com/risingwavelabs/arrow-udf/tree/main#extension-types>.
pub struct NewUdfArrowConvert;

impl ToArrow for NewUdfArrowConvert {}
impl FromArrow for NewUdfArrowConvert {}

#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -108,7 +169,9 @@ mod tests {
let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]);
let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true);
let arrow = UdfArrowConvert.list_to_arrow(&data_type, &array).unwrap();
let rw_array = UdfArrowConvert.from_array(&arrow).unwrap();
let rw_array = UdfArrowConvert
.from_list_array(arrow.as_any().downcast_ref().unwrap())
.unwrap();
assert_eq!(rw_array.as_list(), &array);
}
}
4 changes: 2 additions & 2 deletions src/common/src/array/arrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod arrow_default;
mod arrow_deltalake;
mod arrow_iceberg;
mod arrow_udf;

pub use arrow_default::{FromArrow, ToArrow, UdfArrowConvert};
pub use arrow_deltalake::DeltaLakeConvert;
pub use arrow_iceberg::IcebergArrowConvert;
pub use arrow_udf::{FromArrow, ToArrow, UdfArrowConvert};
7 changes: 6 additions & 1 deletion src/expr/impl/src/scalar/external/iceberg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub struct IcebergTransform {
child: BoxedExpression,
transform: BoxedTransformFunction,
input_arrow_type: arrow_schema::DataType,
output_arrow_field: arrow_schema::Field,
return_type: DataType,
}

Expand All @@ -61,7 +62,9 @@ impl risingwave_expr::expr::Expression for IcebergTransform {
// Transform
let res_array = self.transform.transform(arrow_array).unwrap();
// Convert back to array ref and return it
Ok(Arc::new(IcebergArrowConvert.from_array(&res_array)?))
Ok(Arc::new(
IcebergArrowConvert.from_array(&self.output_arrow_field, &res_array)?,
))
}

async fn eval_row(&self, _row: &OwnedRow) -> Result<Datum> {
Expand Down Expand Up @@ -96,6 +99,7 @@ fn build(return_type: DataType, mut children: Vec<BoxedExpression>) -> Result<Bo
.to_arrow_field("", &children[1].return_type())?
.data_type()
.clone();
let output_arrow_field = IcebergArrowConvert.to_arrow_field("", &return_type)?;
let input_type = IcelakeDataType::try_from(input_arrow_type.clone()).map_err(|err| {
ExprError::InvalidParam {
name: "input type in iceberg_transform",
Expand Down Expand Up @@ -146,6 +150,7 @@ fn build(return_type: DataType, mut children: Vec<BoxedExpression>) -> Result<Bo
transform: create_transform_function(&transform_type)
.map_err(|err| ExprError::Internal(err.into()))?,
input_arrow_type,
output_arrow_field,
return_type,
}))
}
Expand Down

0 comments on commit b1f3cf6

Please sign in to comment.