From ee1a12b0a198db533ae201e20cb98e8cd45a6512 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 30 Apr 2024 16:54:50 +0800 Subject: [PATCH 01/27] refactor array arrow conversion Signed-off-by: Runji Wang --- src/common/src/array/arrow/arrow_default.rs | 94 +- src/common/src/array/arrow/arrow_deltalake.rs | 209 +-- src/common/src/array/arrow/arrow_iceberg.rs | 81 +- src/common/src/array/arrow/arrow_impl.rs | 1157 +++++++---------- src/common/src/array/arrow/mod.rs | 8 +- src/common/src/array/bytes_array.rs | 6 - src/common/src/array/list_array.rs | 5 + src/common/src/array/mod.rs | 6 +- src/common/src/array/utf8_array.rs | 10 - 9 files changed, 566 insertions(+), 1010 deletions(-) diff --git a/src/common/src/array/arrow/arrow_default.rs b/src/common/src/array/arrow/arrow_default.rs index 5d04527b354ba..b2867d4fdf583 100644 --- a/src/common/src/array/arrow/arrow_default.rs +++ b/src/common/src/array/arrow/arrow_default.rs @@ -18,13 +18,97 @@ //! //! The corresponding version of arrow is currently used by `udf` and `iceberg` sink. -#![allow(unused_imports)] -pub use arrow_impl::{ - to_record_batch_with_schema, ToArrowArrayConvert, ToArrowArrayWithTypeConvert, - ToArrowTypeConvert, -}; +pub use arrow_impl::{FromArrow, ToArrow}; use {arrow_array, arrow_buffer, arrow_cast, arrow_schema}; #[expect(clippy::duplicate_mod)] #[path = "./arrow_impl.rs"] mod arrow_impl; + +pub struct UdfArrowConvert; + +impl ToArrow for UdfArrowConvert {} +impl FromArrow for UdfArrowConvert {} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::array::*; + use crate::buffer::Bitmap; + + #[test] + fn struct_array() { + // Empty array - risingwave to arrow conversion. + let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0)); + assert_eq!( + UdfArrowConvert + .struct_to_arrow( + &arrow_schema::DataType::Struct(arrow_schema::Fields::empty()), + &test_arr + ) + .unwrap() + .len(), + 0 + ); + + // Empty array - arrow to risingwave conversion. + let test_arr_2 = arrow_array::StructArray::from(vec![]); + assert_eq!( + UdfArrowConvert + .from_struct_array(&test_arr_2) + .unwrap() + .len(), + 0 + ); + + // Struct array with primitive types. arrow to risingwave conversion. + let test_arrow_struct_array = arrow_array::StructArray::try_from(vec![ + ( + "a", + Arc::new(arrow_array::BooleanArray::from(vec![ + Some(false), + Some(false), + Some(true), + None, + ])) as arrow_array::ArrayRef, + ), + ( + "b", + Arc::new(arrow_array::Int32Array::from(vec![ + Some(42), + Some(28), + Some(19), + None, + ])) as arrow_array::ArrayRef, + ), + ]) + .unwrap(); + let actual_risingwave_struct_array = UdfArrowConvert + .from_struct_array(&test_arrow_struct_array) + .unwrap() + .into_struct(); + let expected_risingwave_struct_array = StructArray::new( + StructType::new(vec![("a", DataType::Boolean), ("b", DataType::Int32)]), + vec![ + BoolArray::from_iter([Some(false), Some(false), Some(true), None]).into_ref(), + I32Array::from_iter([Some(42), Some(28), Some(19), None]).into_ref(), + ], + [true, true, true, true].into_iter().collect(), + ); + assert_eq!( + expected_risingwave_struct_array, + actual_risingwave_struct_array + ); + } + + #[test] + fn list() { + let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]); + let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true); + let arrow = UdfArrowConvert.list_to_arrow(&data_type, &array).unwrap(); + let rw_array = UdfArrowConvert.from_array(&arrow).unwrap(); + assert_eq!(rw_array.as_list(), &array); + } +} diff --git a/src/common/src/array/arrow/arrow_deltalake.rs b/src/common/src/array/arrow/arrow_deltalake.rs index c55cae305b07f..50eb5b86dca1b 100644 --- a/src/common/src/array/arrow/arrow_deltalake.rs +++ b/src/common/src/array/arrow/arrow_deltalake.rs @@ -21,25 +21,29 @@ use std::ops::{Div, Mul}; use std::sync::Arc; use arrow_array::ArrayRef; -use arrow_schema::DataType; -use itertools::Itertools; use num_traits::abs; use { arrow_array_deltalake as arrow_array, arrow_buffer_deltalake as arrow_buffer, arrow_cast_deltalake as arrow_cast, arrow_schema_deltalake as arrow_schema, }; -use self::arrow_impl::ToArrowArrayWithTypeConvert; -use crate::array::arrow::arrow_deltalake::arrow_impl::FromIntoArrow; -use crate::array::{Array, ArrayError, ArrayImpl, DataChunk, Decimal, DecimalArray, ListArray}; -use crate::util::iter_util::ZipEqFast; +use self::arrow_impl::ToArrow; +use crate::array::{Array, ArrayError, DataChunk, Decimal, DecimalArray}; #[expect(clippy::duplicate_mod)] #[path = "./arrow_impl.rs"] mod arrow_impl; -struct DeltaLakeConvert; +pub struct DeltaLakeConvert; impl DeltaLakeConvert { + pub fn to_record_batch( + &self, + schema: arrow_schema::SchemaRef, + chunk: &DataChunk, + ) -> Result { + ToArrow::to_record_batch(self, schema, chunk) + } + fn decimal_to_i128(decimal: Decimal, precision: u8, max_scale: i8) -> Option { match decimal { crate::array::Decimal::Normalized(e) => { @@ -67,7 +71,7 @@ impl DeltaLakeConvert { } } -impl arrow_impl::ToArrowArrayWithTypeConvert for DeltaLakeConvert { +impl ToArrow for DeltaLakeConvert { fn decimal_to_arrow( &self, data_type: &arrow_schema::DataType, @@ -89,189 +93,6 @@ impl arrow_impl::ToArrowArrayWithTypeConvert for DeltaLakeConvert { .map_err(ArrayError::from_arrow)?; Ok(Arc::new(array) as ArrayRef) } - - #[inline] - fn list_to_arrow( - &self, - data_type: &arrow_schema::DataType, - array: &ListArray, - ) -> Result { - use arrow_array::builder::*; - fn build( - array: &ListArray, - a: &A, - builder: B, - mut append: F, - ) -> arrow_array::ListArray - where - A: Array, - B: arrow_array::builder::ArrayBuilder, - F: FnMut(&mut B, Option>), - { - let mut builder = ListBuilder::with_capacity(builder, a.len()); - for i in 0..array.len() { - for j in array.offsets[i]..array.offsets[i + 1] { - append(builder.values(), a.value_at(j as usize)); - } - builder.append(!array.is_null(i)); - } - builder.finish() - } - let inner_type = match data_type { - arrow_schema::DataType::List(inner) => inner.data_type(), - _ => return Err(ArrayError::to_arrow("Invalid list type")), - }; - let arr: arrow_array::ListArray = match &*array.value { - ArrayImpl::Int16(a) => build(array, a, Int16Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - ArrayImpl::Int32(a) => build(array, a, Int32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - ArrayImpl::Int64(a) => build(array, a, Int64Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - - ArrayImpl::Float32(a) => { - build(array, a, Float32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|f| f.0)) - }) - } - ArrayImpl::Float64(a) => { - build(array, a, Float64Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|f| f.0)) - }) - } - ArrayImpl::Utf8(a) => build( - array, - a, - StringBuilder::with_capacity(a.len(), a.data().len()), - |b, v| b.append_option(v), - ), - ArrayImpl::Int256(a) => build( - array, - a, - Decimal256Builder::with_capacity(a.len()).with_data_type( - arrow_schema::DataType::Decimal256(arrow_schema::DECIMAL256_MAX_PRECISION, 0), - ), - |b, v| b.append_option(v.map(Into::into)), - ), - ArrayImpl::Bool(a) => { - build(array, a, BooleanBuilder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }) - } - ArrayImpl::Decimal(a) => { - let (precision, max_scale) = match inner_type { - arrow_schema::DataType::Decimal128(precision, scale) => (*precision, *scale), - _ => return Err(ArrayError::to_arrow("Invalid decimal type")), - }; - build( - array, - a, - Decimal128Builder::with_capacity(a.len()) - .with_data_type(DataType::Decimal128(precision, max_scale)), - |b, v| { - let v = v.and_then(|v| { - DeltaLakeConvert::decimal_to_i128(v, precision, max_scale) - }); - b.append_option(v); - }, - ) - } - ArrayImpl::Interval(a) => build( - array, - a, - IntervalMonthDayNanoBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Date(a) => build(array, a, Date32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|d| d.into_arrow())) - }), - ArrayImpl::Timestamp(a) => build( - array, - a, - TimestampMicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Timestamptz(a) => build( - array, - a, - TimestampMicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Time(a) => build( - array, - a, - Time64MicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Jsonb(a) => build( - array, - a, - LargeStringBuilder::with_capacity(a.len(), a.len() * 16), - |b, v| b.append_option(v.map(|j| j.to_string())), - ), - ArrayImpl::Serial(_) => todo!("list of serial"), - ArrayImpl::Struct(a) => { - let values = Arc::new(arrow_array::StructArray::try_from(a)?); - arrow_array::ListArray::new( - Arc::new(arrow_schema::Field::new( - "item", - a.data_type().try_into()?, - true, - )), - arrow_buffer::OffsetBuffer::new(arrow_buffer::ScalarBuffer::from( - array - .offsets() - .iter() - .map(|o| *o as i32) - .collect::>(), - )), - values, - Some(array.null_bitmap().into()), - ) - } - ArrayImpl::List(_) => todo!("list of list"), - ArrayImpl::Bytea(a) => build( - array, - a, - BinaryBuilder::with_capacity(a.len(), a.data().len()), - |b, v| b.append_option(v), - ), - }; - Ok(Arc::new(arr)) - } -} - -/// Converts RisingWave array to Arrow array with the schema. -/// This function will try to convert the array if the type is not same with the schema. -pub fn to_deltalake_record_batch_with_schema( - schema: arrow_schema::SchemaRef, - chunk: &DataChunk, -) -> Result { - if !chunk.is_compacted() { - let c = chunk.clone(); - return to_deltalake_record_batch_with_schema(schema, &c.compact()); - } - let columns: Vec<_> = chunk - .columns() - .iter() - .zip_eq_fast(schema.fields().iter()) - .map(|(column, field)| { - let column: arrow_array::ArrayRef = - DeltaLakeConvert.to_arrow_with_type(field.data_type(), column)?; - if column.data_type() == field.data_type() { - Ok(column) - } else { - arrow_cast::cast(&column, field.data_type()).map_err(ArrayError::from_arrow) - } - }) - .try_collect::<_, _, ArrayError>()?; - - let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity())); - arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts) - .map_err(ArrayError::to_arrow) } #[cfg(test)] @@ -283,6 +104,7 @@ mod test { use arrow_schema::Field; use {arrow_array_deltalake as arrow_array, arrow_schema_deltalake as arrow_schema}; + use crate::array::arrow::arrow_deltalake::DeltaLakeConvert; use crate::array::{ArrayImpl, Decimal, DecimalArray, ListArray, ListValue}; use crate::buffer::Bitmap; @@ -309,8 +131,9 @@ mod test { false, )]); - let record_batch = - super::to_deltalake_record_batch_with_schema(Arc::new(schema), &chunk).unwrap(); + let record_batch = DeltaLakeConvert + .to_record_batch(Arc::new(schema), &chunk) + .unwrap(); let expect_array = Arc::new( arrow_array::Decimal128Array::from(vec![ None, diff --git a/src/common/src/array/arrow/arrow_iceberg.rs b/src/common/src/array/arrow/arrow_iceberg.rs index 2dd7900da5da1..72a49cb349370 100644 --- a/src/common/src/array/arrow/arrow_iceberg.rs +++ b/src/common/src/array/arrow/arrow_iceberg.rs @@ -15,25 +15,22 @@ use std::ops::{Div, Mul}; use std::sync::Arc; -use arrow_array::{ArrayRef, StructArray}; -use arrow_schema::DataType; -use itertools::Itertools; +use arrow_array::ArrayRef; use num_traits::abs; -use super::{ToArrowArrayWithTypeConvert, ToArrowTypeConvert}; -use crate::array::{Array, ArrayError, DataChunk, DecimalArray}; -use crate::util::iter_util::ZipEqFast; +use super::{FromArrow, ToArrow}; +use crate::array::{Array, ArrayError, DecimalArray}; -struct IcebergArrowConvert; +pub struct IcebergArrowConvert; -impl ToArrowTypeConvert for IcebergArrowConvert { +impl ToArrow for IcebergArrowConvert { #[inline] - fn decimal_type_to_arrow(&self) -> arrow_schema::DataType { - arrow_schema::DataType::Decimal128(arrow_schema::DECIMAL128_MAX_PRECISION, 0) + fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + let data_type = + arrow_schema::DataType::Decimal128(arrow_schema::DECIMAL128_MAX_PRECISION, 0); + arrow_schema::Field::new(name, data_type, true) } -} -impl ToArrowArrayWithTypeConvert for IcebergArrowConvert { fn decimal_to_arrow( &self, data_type: &arrow_schema::DataType, @@ -85,63 +82,7 @@ impl ToArrowArrayWithTypeConvert for IcebergArrowConvert { } } -/// Converts RisingWave array to Arrow array with the schema. -/// The behavior is specified for iceberg: -/// For different struct type, try to use fields in schema to cast. -pub fn to_iceberg_record_batch_with_schema( - schema: arrow_schema::SchemaRef, - chunk: &DataChunk, -) -> Result { - if !chunk.is_compacted() { - let c = chunk.clone(); - return to_iceberg_record_batch_with_schema(schema, &c.compact()); - } - let columns: Vec<_> = chunk - .columns() - .iter() - .zip_eq_fast(schema.fields().iter()) - .map(|(column, field)| { - let column: arrow_array::ArrayRef = - IcebergArrowConvert {}.to_arrow_with_type(field.data_type(), column)?; - if column.data_type() == field.data_type() { - Ok(column) - } else if let DataType::Struct(actual) = column.data_type() - && let DataType::Struct(expect) = field.data_type() - { - // Special case for iceberg - if actual.len() != expect.len() { - return Err(ArrayError::to_arrow(format!( - "Struct field count mismatch, expect {}, actual {}", - expect.len(), - actual.len() - ))); - } - let column = column - .as_any() - .downcast_ref::() - .unwrap() - .clone(); - let (_, struct_columns, nullable) = column.into_parts(); - Ok(Arc::new( - StructArray::try_new(expect.clone(), struct_columns, nullable) - .map_err(ArrayError::from_arrow)?, - ) as ArrayRef) - } else { - arrow_cast::cast(&column, field.data_type()).map_err(ArrayError::from_arrow) - } - }) - .try_collect::<_, _, ArrayError>()?; - - let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity())); - arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts) - .map_err(ArrayError::to_arrow) -} - -pub fn iceberg_to_arrow_type( - data_type: &crate::array::DataType, -) -> Result { - IcebergArrowConvert {}.to_arrow_type(data_type) -} +impl FromArrow for IcebergArrowConvert {} #[cfg(test)] mod test { @@ -150,7 +91,7 @@ mod test { use arrow_array::ArrayRef; use crate::array::arrow::arrow_iceberg::IcebergArrowConvert; - use crate::array::arrow::ToArrowArrayWithTypeConvert; + use crate::array::arrow::ToArrow; use crate::array::{Decimal, DecimalArray}; #[test] diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index e426caa306d55..b29f62eaf2b10 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -39,6 +39,7 @@ use std::fmt::Write; use std::sync::Arc; +use arrow_buffer::OffsetBuffer; use chrono::{NaiveDateTime, NaiveTime}; use itertools::Itertools; @@ -50,317 +51,76 @@ use crate::buffer::Bitmap; use crate::types::*; use crate::util::iter_util::ZipEqFast; -/// Converts RisingWave array to Arrow array with the schema. -/// This function will try to convert the array if the type is not same with the schema. -#[allow(dead_code)] -pub fn to_record_batch_with_schema( - schema: arrow_schema::SchemaRef, - chunk: &DataChunk, -) -> Result { - if !chunk.is_compacted() { - let c = chunk.clone(); - return to_record_batch_with_schema(schema, &c.compact()); - } - let columns: Vec<_> = chunk - .columns() - .iter() - .zip_eq_fast(schema.fields().iter()) - .map(|(column, field)| { - let column: arrow_array::ArrayRef = column.as_ref().try_into()?; - if column.data_type() == field.data_type() { - Ok(column) - } else { - arrow_cast::cast(&column, field.data_type()).map_err(ArrayError::from_arrow) - } - }) - .try_collect::<_, _, ArrayError>()?; - - let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity())); - arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts) - .map_err(ArrayError::to_arrow) -} - -// Implement bi-directional `From` between `DataChunk` and `arrow_array::RecordBatch`. -impl TryFrom<&DataChunk> for arrow_array::RecordBatch { - type Error = ArrayError; - - fn try_from(chunk: &DataChunk) -> Result { +/// Defines how to convert RisingWave arrays to Arrow arrays. +pub trait ToArrow { + /// Converts RisingWave `DataChunk` to Arrow `RecordBatch` with specified schema. + /// + /// This function will try to convert the array if the type is not same with the schema. + fn to_record_batch( + &self, + schema: arrow_schema::SchemaRef, + chunk: &DataChunk, + ) -> Result { + // compact the chunk if it's not compacted if !chunk.is_compacted() { let c = chunk.clone(); - return Self::try_from(&c.compact()); + return self.to_record_batch(schema, &c.compact()); } + + // convert each column to arrow array let columns: Vec<_> = chunk .columns() .iter() - .map(|column| column.as_ref().try_into()) - .try_collect::<_, _, Self::Error>()?; - - let fields: Vec<_> = columns - .iter() - .map(|array: &Arc| { - let nullable = array.null_count() > 0; - let data_type = array.data_type().clone(); - arrow_schema::Field::new("", data_type, nullable) - }) - .collect(); + .zip_eq_fast(schema.fields().iter()) + .map(|(column, field)| self.to_array(field.data_type(), column)) + .try_collect()?; - let schema = Arc::new(arrow_schema::Schema::new(fields)); + // create record batch let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity())); arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts) .map_err(ArrayError::to_arrow) } -} - -impl TryFrom<&arrow_array::RecordBatch> for DataChunk { - type Error = ArrayError; - - fn try_from(batch: &arrow_array::RecordBatch) -> Result { - let mut columns = Vec::with_capacity(batch.num_columns()); - for array in batch.columns() { - let column = Arc::new(array.try_into()?); - columns.push(column); - } - Ok(DataChunk::new(columns, batch.num_rows())) - } -} -/// Provides the default conversion logic for RisingWave array to Arrow array with type info. -pub trait ToArrowArrayWithTypeConvert { - fn to_arrow_with_type( + /// Converts RisingWave array to Arrow array. + fn to_array( &self, data_type: &arrow_schema::DataType, array: &ArrayImpl, ) -> Result { - match array { - ArrayImpl::Int16(array) => self.int16_to_arrow(data_type, array), - ArrayImpl::Int32(array) => self.int32_to_arrow(data_type, array), - ArrayImpl::Int64(array) => self.int64_to_arrow(data_type, array), - ArrayImpl::Float32(array) => self.float32_to_arrow(data_type, array), - ArrayImpl::Float64(array) => self.float64_to_arrow(data_type, array), - ArrayImpl::Utf8(array) => self.utf8_to_arrow(data_type, array), - ArrayImpl::Bool(array) => self.bool_to_arrow(data_type, array), - ArrayImpl::Decimal(array) => self.decimal_to_arrow(data_type, array), - ArrayImpl::Int256(array) => self.int256_to_arrow(data_type, array), - ArrayImpl::Date(array) => self.date_to_arrow(data_type, array), - ArrayImpl::Timestamp(array) => self.timestamp_to_arrow(data_type, array), - ArrayImpl::Timestamptz(array) => self.timestamptz_to_arrow(data_type, array), - ArrayImpl::Time(array) => self.time_to_arrow(data_type, array), - ArrayImpl::Interval(array) => self.interval_to_arrow(data_type, array), - ArrayImpl::Struct(array) => self.struct_to_arrow(data_type, array), - ArrayImpl::List(array) => self.list_to_arrow(data_type, array), - ArrayImpl::Bytea(array) => self.bytea_to_arrow(data_type, array), - ArrayImpl::Jsonb(array) => self.jsonb_to_arrow(data_type, array), - ArrayImpl::Serial(array) => self.serial_to_arrow(data_type, array), - } - } - - #[inline] - fn int16_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &I16Array, - ) -> Result { - Ok(Arc::new(arrow_array::Int16Array::from(array))) - } - - #[inline] - fn int32_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &I32Array, - ) -> Result { - Ok(Arc::new(arrow_array::Int32Array::from(array))) - } - - #[inline] - fn int64_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &I64Array, - ) -> Result { - Ok(Arc::new(arrow_array::Int64Array::from(array))) - } - - #[inline] - fn float32_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &F32Array, - ) -> Result { - Ok(Arc::new(arrow_array::Float32Array::from(array))) - } - - #[inline] - fn float64_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &F64Array, - ) -> Result { - Ok(Arc::new(arrow_array::Float64Array::from(array))) - } - - #[inline] - fn utf8_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &Utf8Array, - ) -> Result { - Ok(Arc::new(arrow_array::StringArray::from(array))) - } - - #[inline] - fn bool_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &BoolArray, - ) -> Result { - Ok(Arc::new(arrow_array::BooleanArray::from(array))) - } - - // Decimal values are stored as ASCII text representation in a large binary array. - #[inline] - fn decimal_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &DecimalArray, - ) -> Result { - Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) - } - - #[inline] - fn int256_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &Int256Array, - ) -> Result { - Ok(Arc::new(arrow_array::Decimal256Array::from(array))) - } - - #[inline] - fn date_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &DateArray, - ) -> Result { - Ok(Arc::new(arrow_array::Date32Array::from(array))) - } - - #[inline] - fn timestamp_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &TimestampArray, - ) -> Result { - Ok(Arc::new(arrow_array::TimestampMicrosecondArray::from( - array, - ))) - } - - #[inline] - fn timestamptz_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &TimestamptzArray, - ) -> Result { - Ok(Arc::new( - arrow_array::TimestampMicrosecondArray::from(array).with_timezone_utc(), - )) - } - - #[inline] - fn time_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &TimeArray, - ) -> Result { - Ok(Arc::new(arrow_array::Time64MicrosecondArray::from(array))) - } - - #[inline] - fn interval_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &IntervalArray, - ) -> Result { - Ok(Arc::new(arrow_array::IntervalMonthDayNanoArray::from( - array, - ))) - } - - #[inline] - fn struct_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &StructArray, - ) -> Result { - Ok(Arc::new(arrow_array::StructArray::try_from(array)?)) - } - - #[inline] - fn list_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &ListArray, - ) -> Result { - Ok(Arc::new(arrow_array::ListArray::try_from(array)?)) - } - - #[inline] - fn bytea_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &BytesArray, - ) -> Result { - Ok(Arc::new(arrow_array::BinaryArray::from(array))) - } - - // JSON values are stored as text representation in a large string array. - #[inline] - fn jsonb_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - array: &JsonbArray, - ) -> Result { - Ok(Arc::new(arrow_array::LargeStringArray::from(array))) - } - - #[inline] - fn serial_to_arrow( - &self, - _data_type: &arrow_schema::DataType, - _array: &SerialArray, - ) -> Result { - todo!("serial type is not supported to convert to arrow") - } -} - -/// Provides the default conversion logic for RisingWave array to Arrow array with type info. -pub trait ToArrowArrayConvert { - fn to_arrow(&self, array: &ArrayImpl) -> Result { - match array { + let arrow_array = match array { + ArrayImpl::Bool(array) => self.bool_to_arrow(array), ArrayImpl::Int16(array) => self.int16_to_arrow(array), ArrayImpl::Int32(array) => self.int32_to_arrow(array), ArrayImpl::Int64(array) => self.int64_to_arrow(array), + ArrayImpl::Int256(array) => self.int256_to_arrow(array), ArrayImpl::Float32(array) => self.float32_to_arrow(array), ArrayImpl::Float64(array) => self.float64_to_arrow(array), - ArrayImpl::Utf8(array) => self.utf8_to_arrow(array), - ArrayImpl::Bool(array) => self.bool_to_arrow(array), - ArrayImpl::Decimal(array) => self.decimal_to_arrow(array), - ArrayImpl::Int256(array) => self.int256_to_arrow(array), ArrayImpl::Date(array) => self.date_to_arrow(array), + ArrayImpl::Time(array) => self.time_to_arrow(array), ArrayImpl::Timestamp(array) => self.timestamp_to_arrow(array), ArrayImpl::Timestamptz(array) => self.timestamptz_to_arrow(array), - ArrayImpl::Time(array) => self.time_to_arrow(array), ArrayImpl::Interval(array) => self.interval_to_arrow(array), - ArrayImpl::Struct(array) => self.struct_to_arrow(array), - ArrayImpl::List(array) => self.list_to_arrow(array), + ArrayImpl::Utf8(array) => self.utf8_to_arrow(array), ArrayImpl::Bytea(array) => self.bytea_to_arrow(array), + ArrayImpl::Decimal(array) => self.decimal_to_arrow(data_type, array), ArrayImpl::Jsonb(array) => self.jsonb_to_arrow(array), ArrayImpl::Serial(array) => self.serial_to_arrow(array), + ArrayImpl::List(array) => self.list_to_arrow(data_type, array), + ArrayImpl::Struct(array) => self.struct_to_arrow(data_type, array), + }?; + if arrow_array.data_type() != data_type { + arrow_cast::cast(&arrow_array, data_type).map_err(ArrayError::to_arrow) + } else { + Ok(arrow_array) } } + #[inline] + fn bool_to_arrow(&self, array: &BoolArray) -> Result { + Ok(Arc::new(arrow_array::BooleanArray::from(array))) + } + #[inline] fn int16_to_arrow(&self, array: &I16Array) -> Result { Ok(Arc::new(arrow_array::Int16Array::from(array))) @@ -391,17 +151,6 @@ pub trait ToArrowArrayConvert { Ok(Arc::new(arrow_array::StringArray::from(array))) } - #[inline] - fn bool_to_arrow(&self, array: &BoolArray) -> Result { - Ok(Arc::new(arrow_array::BooleanArray::from(array))) - } - - // Decimal values are stored as ASCII text representation in a large binary array. - #[inline] - fn decimal_to_arrow(&self, array: &DecimalArray) -> Result { - Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) - } - #[inline] fn int256_to_arrow(&self, array: &Int256Array) -> Result { Ok(Arc::new(arrow_array::Decimal256Array::from(array))) @@ -448,18 +197,18 @@ pub trait ToArrowArrayConvert { } #[inline] - fn struct_to_arrow(&self, array: &StructArray) -> Result { - Ok(Arc::new(arrow_array::StructArray::try_from(array)?)) - } - - #[inline] - fn list_to_arrow(&self, array: &ListArray) -> Result { - Ok(Arc::new(arrow_array::ListArray::try_from(array)?)) + fn bytea_to_arrow(&self, array: &BytesArray) -> Result { + Ok(Arc::new(arrow_array::BinaryArray::from(array))) } + // Decimal values are stored as ASCII text representation in a large binary array. #[inline] - fn bytea_to_arrow(&self, array: &BytesArray) -> Result { - Ok(Arc::new(arrow_array::BinaryArray::from(array))) + fn decimal_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &DecimalArray, + ) -> Result { + Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) } // JSON values are stored as text representation in a large string array. @@ -469,35 +218,82 @@ pub trait ToArrowArrayConvert { } #[inline] - fn serial_to_arrow(&self, _array: &SerialArray) -> Result { - todo!("serial type is not supported to convert to arrow") + fn serial_to_arrow(&self, array: &SerialArray) -> Result { + Ok(Arc::new(arrow_array::Int64Array::from(array))) + } + + #[inline] + fn list_to_arrow( + &self, + data_type: &arrow_schema::DataType, + array: &ListArray, + ) -> Result { + let arrow_schema::DataType::List(field) = data_type else { + return Err(ArrayError::to_arrow("Invalid list type")); + }; + let values = self.to_array(field.data_type(), array.values())?; + let offsets = OffsetBuffer::new(array.offsets().iter().map(|&o| o as i32).collect()); + let nulls = array.null_bitmap().into(); + Ok(Arc::new(arrow_array::ListArray::new( + field.clone(), + offsets, + values, + Some(nulls), + ))) + } + + #[inline] + fn struct_to_arrow( + &self, + data_type: &arrow_schema::DataType, + array: &StructArray, + ) -> Result { + let arrow_schema::DataType::Struct(fields) = data_type else { + return Err(ArrayError::to_arrow("Invalid struct type")); + }; + Ok(Arc::new(arrow_array::StructArray::new( + fields.clone(), + array + .fields() + .zip_eq_fast(fields) + .map(|(arr, field)| self.to_array(field.data_type(), arr)) + .try_collect::<_, _, ArrayError>()?, + Some(array.null_bitmap().into()), + ))) } -} -pub trait ToArrowTypeConvert { - fn to_arrow_type(&self, value: &DataType) -> Result { - match value { + /// Convert RisingWave data type to Arrow data type. + /// + /// This function returns a `Field` instead of `DataType` because some may be converted to + /// extension types which require additional metadata in the field. + fn to_arrow_field( + &self, + name: &str, + value: &DataType, + ) -> Result { + let data_type = match value { // using the inline function - DataType::Boolean => Ok(self.bool_type_to_arrow()), - DataType::Int16 => Ok(self.int16_type_to_arrow()), - DataType::Int32 => Ok(self.int32_type_to_arrow()), - DataType::Int64 => Ok(self.int64_type_to_arrow()), - DataType::Int256 => Ok(self.int256_type_to_arrow()), - DataType::Float32 => Ok(self.float32_type_to_arrow()), - DataType::Float64 => Ok(self.float64_type_to_arrow()), - DataType::Date => Ok(self.date_type_to_arrow()), - DataType::Timestamp => Ok(self.timestamp_type_to_arrow()), - DataType::Timestamptz => Ok(self.timestamptz_type_to_arrow()), - DataType::Time => Ok(self.time_type_to_arrow()), - DataType::Interval => Ok(self.interval_type_to_arrow()), - DataType::Varchar => Ok(self.varchar_type_to_arrow()), - DataType::Jsonb => Ok(self.jsonb_type_to_arrow()), - DataType::Bytea => Ok(self.bytea_type_to_arrow()), - DataType::Decimal => Ok(self.decimal_type_to_arrow()), - DataType::Serial => Ok(self.serial_type_to_arrow()), - DataType::Struct(fields) => self.struct_type_to_arrow(fields), - DataType::List(datatype) => self.list_type_to_arrow(datatype), - } + DataType::Boolean => self.bool_type_to_arrow(), + DataType::Int16 => self.int16_type_to_arrow(), + DataType::Int32 => self.int32_type_to_arrow(), + DataType::Int64 => self.int64_type_to_arrow(), + DataType::Int256 => self.int256_type_to_arrow(), + DataType::Float32 => self.float32_type_to_arrow(), + DataType::Float64 => self.float64_type_to_arrow(), + DataType::Date => self.date_type_to_arrow(), + DataType::Time => self.time_type_to_arrow(), + DataType::Timestamp => self.timestamp_type_to_arrow(), + DataType::Timestamptz => self.timestamptz_type_to_arrow(), + DataType::Interval => self.interval_type_to_arrow(), + DataType::Varchar => self.varchar_type_to_arrow(), + DataType::Bytea => self.bytea_type_to_arrow(), + DataType::Serial => self.serial_type_to_arrow(), + DataType::Decimal => return Ok(self.decimal_type_to_arrow(name)), + DataType::Jsonb => return Ok(self.jsonb_type_to_arrow(name)), + DataType::Struct(fields) => self.struct_type_to_arrow(fields)?, + DataType::List(datatype) => self.list_type_to_arrow(datatype)?, + }; + Ok(arrow_schema::Field::new(name, data_type, true)) } #[inline] @@ -505,6 +301,11 @@ pub trait ToArrowTypeConvert { arrow_schema::DataType::Boolean } + #[inline] + fn int16_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Int16 + } + #[inline] fn int32_type_to_arrow(&self) -> arrow_schema::DataType { arrow_schema::DataType::Int32 @@ -515,12 +316,6 @@ pub trait ToArrowTypeConvert { arrow_schema::DataType::Int64 } - // generate function for each type for me using inline - #[inline] - fn int16_type_to_arrow(&self) -> arrow_schema::DataType { - arrow_schema::DataType::Int16 - } - #[inline] fn int256_type_to_arrow(&self) -> arrow_schema::DataType { arrow_schema::DataType::Decimal256(arrow_schema::DECIMAL256_MAX_PRECISION, 0) @@ -541,6 +336,11 @@ pub trait ToArrowTypeConvert { arrow_schema::DataType::Date32 } + #[inline] + fn time_type_to_arrow(&self) -> arrow_schema::DataType { + arrow_schema::DataType::Time64(arrow_schema::TimeUnit::Microsecond) + } + #[inline] fn timestamp_type_to_arrow(&self) -> arrow_schema::DataType { arrow_schema::DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None) @@ -554,11 +354,6 @@ pub trait ToArrowTypeConvert { ) } - #[inline] - fn time_type_to_arrow(&self) -> arrow_schema::DataType { - arrow_schema::DataType::Time64(arrow_schema::TimeUnit::Microsecond) - } - #[inline] fn interval_type_to_arrow(&self) -> arrow_schema::DataType { arrow_schema::DataType::Interval(arrow_schema::IntervalUnit::MonthDayNano) @@ -570,8 +365,8 @@ pub trait ToArrowTypeConvert { } #[inline] - fn jsonb_type_to_arrow(&self) -> arrow_schema::DataType { - arrow_schema::DataType::LargeUtf8 + fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true) } #[inline] @@ -580,22 +375,22 @@ pub trait ToArrowTypeConvert { } #[inline] - fn decimal_type_to_arrow(&self) -> arrow_schema::DataType { - arrow_schema::DataType::LargeBinary + fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true) } #[inline] fn serial_type_to_arrow(&self) -> arrow_schema::DataType { - todo!("serial type is not supported to convert to arrow") + arrow_schema::DataType::Int64 } #[inline] fn list_type_to_arrow( &self, - datatype: &DataType, + elem_type: &DataType, ) -> Result { Ok(arrow_schema::DataType::List(Arc::new( - arrow_schema::Field::new("item", datatype.try_into()?, true), + self.to_arrow_field("item", elem_type)?, ))) } @@ -607,160 +402,226 @@ pub trait ToArrowTypeConvert { Ok(arrow_schema::DataType::Struct( fields .iter() - .map(|(name, ty)| Ok(arrow_schema::Field::new(name, ty.try_into()?, true))) + .map(|(name, ty)| self.to_arrow_field(name, ty)) .try_collect::<_, _, ArrayError>()?, )) } } -struct DefaultArrowConvert; -impl ToArrowArrayConvert for DefaultArrowConvert {} - -/// Implement bi-directional `From` between `ArrayImpl` and `arrow_array::ArrayRef`. -macro_rules! converts_generic { - ($({ $ArrowType:ty, $ArrowPattern:pat, $ArrayImplPattern:path }),*) => { - // RisingWave array -> Arrow array - impl TryFrom<&ArrayImpl> for arrow_array::ArrayRef { - type Error = ArrayError; - fn try_from(array: &ArrayImpl) -> Result { - DefaultArrowConvert{}.to_arrow(array) - } - } - // Arrow array -> RisingWave array - impl TryFrom<&arrow_array::ArrayRef> for ArrayImpl { - type Error = ArrayError; - fn try_from(array: &arrow_array::ArrayRef) -> Result { - use arrow_schema::DataType::*; - use arrow_schema::IntervalUnit::*; - use arrow_schema::TimeUnit::*; - match array.data_type() { - $($ArrowPattern => Ok($ArrayImplPattern( - array - .as_any() - .downcast_ref::<$ArrowType>() - .unwrap() - .try_into()?, - )),)* - Timestamp(Microsecond, Some(_)) => Ok(ArrayImpl::Timestamptz( - array - .as_any() - .downcast_ref::() - .unwrap() - .try_into()?, - )), - // This arrow decimal type is used by iceberg source to read iceberg decimal into RW decimal. - Decimal128(_, _) => Ok(ArrayImpl::Decimal( - array - .as_any() - .downcast_ref::() - .unwrap() - .try_into()?, - )), - t => Err(ArrayError::from_arrow(format!("unsupported data type: {t:?}"))), - } - } +/// Defines how to convert Arrow arrays to RisingWave arrays. +pub trait FromArrow { + /// Converts Arrow `RecordBatch` to RisingWave `DataChunk`. + fn from_record_batch(&self, batch: &arrow_array::RecordBatch) -> Result { + let mut columns = Vec::with_capacity(batch.num_columns()); + for array in batch.columns() { + let column = Arc::new(self.from_array(array)?); + columns.push(column); } - }; -} -converts_generic! { - { arrow_array::Int16Array, Int16, ArrayImpl::Int16 }, - { arrow_array::Int32Array, Int32, ArrayImpl::Int32 }, - { arrow_array::Int64Array, Int64, ArrayImpl::Int64 }, - { arrow_array::Float32Array, Float32, ArrayImpl::Float32 }, - { arrow_array::Float64Array, Float64, ArrayImpl::Float64 }, - { arrow_array::StringArray, Utf8, ArrayImpl::Utf8 }, - { arrow_array::BooleanArray, Boolean, ArrayImpl::Bool }, - // Arrow doesn't have a data type to represent unconstrained numeric (`DECIMAL` in RisingWave and - // Postgres). So we pick a special type `LargeBinary` for it. - // Values stored in the array are the string representation of the decimal. e.g. b"1.234", b"+inf" - { arrow_array::LargeBinaryArray, LargeBinary, ArrayImpl::Decimal }, - { arrow_array::Decimal256Array, Decimal256(_, _), ArrayImpl::Int256 }, - { arrow_array::Date32Array, Date32, ArrayImpl::Date }, - { arrow_array::TimestampMicrosecondArray, Timestamp(Microsecond, None), ArrayImpl::Timestamp }, - { arrow_array::Time64MicrosecondArray, Time64(Microsecond), ArrayImpl::Time }, - { arrow_array::IntervalMonthDayNanoArray, Interval(MonthDayNano), ArrayImpl::Interval }, - { arrow_array::StructArray, Struct(_), ArrayImpl::Struct }, - { arrow_array::ListArray, List(_), ArrayImpl::List }, - { arrow_array::BinaryArray, Binary, ArrayImpl::Bytea }, - { arrow_array::LargeStringArray, LargeUtf8, ArrayImpl::Jsonb } // we use LargeUtf8 to represent Jsonb in arrow -} + Ok(DataChunk::new(columns, batch.num_rows())) + } + + /// Converts Arrow `Fields` to RisingWave `StructType`. + fn from_fields(&self, fields: &arrow_schema::Fields) -> Result { + Ok(StructType::new( + fields + .iter() + .map(|f| Ok((f.name().clone(), self.from_field(f)?))) + .try_collect::<_, _, ArrayError>()?, + )) + } -// Arrow Datatype -> Risingwave Datatype -impl From<&arrow_schema::DataType> for DataType { - fn from(value: &arrow_schema::DataType) -> Self { + /// Converts Arrow `Field` to RisingWave `DataType`. + fn from_field(&self, field: &arrow_schema::Field) -> Result { use arrow_schema::DataType::*; use arrow_schema::IntervalUnit::*; use arrow_schema::TimeUnit::*; - match value { - Boolean => Self::Boolean, - Int16 => Self::Int16, - Int32 => Self::Int32, - Int64 => Self::Int64, - Float32 => Self::Float32, - Float64 => Self::Float64, - LargeBinary => Self::Decimal, - Decimal256(_, _) => Self::Int256, - Date32 => Self::Date, - Time64(Microsecond) => Self::Time, - Timestamp(Microsecond, None) => Self::Timestamp, - Timestamp(Microsecond, Some(_)) => Self::Timestamptz, - Interval(MonthDayNano) => Self::Interval, - Binary => Self::Bytea, - Utf8 => Self::Varchar, - LargeUtf8 => Self::Jsonb, - Struct(fields) => Self::Struct(fields.into()), - List(field) => Self::List(Box::new(field.data_type().into())), - Decimal128(_, _) => Self::Decimal, - _ => todo!("Unsupported arrow data type: {value:?}"), + + // extension type + if let Some(type_name) = field.metadata().get("ARROW:extension:name") { + return self.from_extension_type(type_name, field.data_type()); } + + Ok(match field.data_type() { + Boolean => DataType::Boolean, + Int16 => DataType::Int16, + Int32 => DataType::Int32, + Int64 => DataType::Int64, + Float32 => DataType::Float32, + Float64 => DataType::Float64, + Decimal128(_, _) => DataType::Decimal, + Decimal256(_, _) => DataType::Int256, + Date32 => DataType::Date, + Time64(Microsecond) => DataType::Time, + Timestamp(Microsecond, None) => DataType::Timestamp, + Timestamp(Microsecond, Some(_)) => DataType::Timestamptz, + Interval(MonthDayNano) => DataType::Interval, + Utf8 => DataType::Varchar, + Binary => DataType::Bytea, + LargeUtf8 => self.from_large_utf8()?, + LargeBinary => self.from_large_binary()?, + List(field) => DataType::List(Box::new(self.from_field(field)?)), + Struct(fields) => DataType::Struct(self.from_fields(fields)?), + t => { + return Err(ArrayError::from_arrow(format!( + "unsupported arrow data type: {t:?}" + ))) + } + }) } -} -impl From<&arrow_schema::Fields> for StructType { - fn from(fields: &arrow_schema::Fields) -> Self { - Self::new( - fields - .iter() - .map(|f| (f.name().clone(), f.data_type().into())) - .collect(), - ) + /// Converts Arrow LargeUtf8 type to RisingWave data type. + fn from_large_utf8(&self) -> Result { + Ok(DataType::Jsonb) } -} -impl TryFrom<&StructType> for arrow_schema::Fields { - type Error = ArrayError; + /// Converts Arrow LargeBinary type to RisingWave data type. + fn from_large_binary(&self) -> Result { + Ok(DataType::Decimal) + } - fn try_from(struct_type: &StructType) -> Result { - struct_type - .iter() - .map(|(name, ty)| Ok(arrow_schema::Field::new(name, ty.try_into()?, true))) - .try_collect() + /// Converts Arrow extension type to RisingWave `DataType`. + fn from_extension_type( + &self, + type_name: &str, + _physical_type: &arrow_schema::DataType, + ) -> Result { + Err(ArrayError::from_arrow(format!( + "unsupported extension type: {type_name:?}" + ))) } -} -impl From for DataType { - fn from(value: arrow_schema::DataType) -> Self { - (&value).into() + /// Converts Arrow `Array` to RisingWave `ArrayImpl`. + fn from_array(&self, array: &arrow_array::ArrayRef) -> Result { + use arrow_schema::DataType::*; + use arrow_schema::TimeUnit::*; + match array.data_type() { + Boolean => self.from_bool_array(array.as_any().downcast_ref().unwrap()), + Int16 => self.from_int16_array(array.as_any().downcast_ref().unwrap()), + Int32 => self.from_int32_array(array.as_any().downcast_ref().unwrap()), + Int64 => self.from_int64_array(array.as_any().downcast_ref().unwrap()), + Decimal256(_, _) => self.from_int256_array(array.as_any().downcast_ref().unwrap()), + Float32 => self.from_float32_array(array.as_any().downcast_ref().unwrap()), + Float64 => self.from_float64_array(array.as_any().downcast_ref().unwrap()), + Date32 => self.from_date32_array(array.as_any().downcast_ref().unwrap()), + Time64(Microsecond) => self.from_time64us_array(array.as_any().downcast_ref().unwrap()), + Timestamp(Microsecond, _) => { + self.from_timestampus_array(array.as_any().downcast_ref().unwrap()) + } + Utf8 => self.from_utf8_array(array.as_any().downcast_ref().unwrap()), + Binary => self.from_binary_array(array.as_any().downcast_ref().unwrap()), + LargeUtf8 => self.from_large_utf8_array(array.as_any().downcast_ref().unwrap()), + LargeBinary => self.from_large_binary_array(array.as_any().downcast_ref().unwrap()), + List(_) => self.from_list_array(array.as_any().downcast_ref().unwrap()), + Struct(_) => self.from_struct_array(array.as_any().downcast_ref().unwrap()), + t => Err(ArrayError::from_arrow(format!( + "unsupported arrow data type: {t:?}", + ))), + } } -} -struct DefaultArrowTypeConvert; + fn from_bool_array(&self, array: &arrow_array::BooleanArray) -> Result { + Ok(ArrayImpl::Bool(array.into())) + } -impl ToArrowTypeConvert for DefaultArrowTypeConvert {} + fn from_int16_array(&self, array: &arrow_array::Int16Array) -> Result { + Ok(ArrayImpl::Int16(array.into())) + } -impl TryFrom<&DataType> for arrow_schema::DataType { - type Error = ArrayError; + fn from_int32_array(&self, array: &arrow_array::Int32Array) -> Result { + Ok(ArrayImpl::Int32(array.into())) + } - fn try_from(value: &DataType) -> Result { - DefaultArrowTypeConvert {}.to_arrow_type(value) + fn from_int64_array(&self, array: &arrow_array::Int64Array) -> Result { + Ok(ArrayImpl::Int64(array.into())) } -} -impl TryFrom for arrow_schema::DataType { - type Error = ArrayError; + fn from_int256_array( + &self, + array: &arrow_array::Decimal256Array, + ) -> Result { + Ok(ArrayImpl::Int256(array.into())) + } + + fn from_float32_array( + &self, + array: &arrow_array::Float32Array, + ) -> Result { + Ok(ArrayImpl::Float32(array.into())) + } - fn try_from(value: DataType) -> Result { - (&value).try_into() + fn from_float64_array( + &self, + array: &arrow_array::Float64Array, + ) -> Result { + Ok(ArrayImpl::Float64(array.into())) + } + + fn from_date32_array(&self, array: &arrow_array::Date32Array) -> Result { + Ok(ArrayImpl::Date(array.into())) + } + + fn from_time64us_array( + &self, + array: &arrow_array::Time64MicrosecondArray, + ) -> Result { + Ok(ArrayImpl::Time(array.into())) + } + + fn from_timestampus_array( + &self, + array: &arrow_array::TimestampMicrosecondArray, + ) -> Result { + Ok(ArrayImpl::Timestamp(array.into())) + } + + fn from_utf8_array(&self, array: &arrow_array::StringArray) -> Result { + Ok(ArrayImpl::Utf8(array.into())) + } + + fn from_binary_array(&self, array: &arrow_array::BinaryArray) -> Result { + Ok(ArrayImpl::Bytea(array.into())) + } + + fn from_large_utf8_array( + &self, + array: &arrow_array::LargeStringArray, + ) -> Result { + Ok(ArrayImpl::Jsonb(array.try_into()?)) + } + + fn from_large_binary_array( + &self, + array: &arrow_array::LargeBinaryArray, + ) -> Result { + Ok(ArrayImpl::Decimal(array.try_into()?)) + } + + fn from_list_array(&self, array: &arrow_array::ListArray) -> Result { + use arrow_array::Array; + Ok(ArrayImpl::List(ListArray { + value: Box::new(self.from_array(array.values())?), + bitmap: match array.nulls() { + Some(nulls) => nulls.iter().collect(), + None => Bitmap::ones(array.len()), + }, + offsets: array.offsets().iter().map(|o| *o as u32).collect(), + })) + } + + fn from_struct_array(&self, array: &arrow_array::StructArray) -> Result { + use arrow_array::Array; + let arrow_schema::DataType::Struct(fields) = array.data_type() else { + panic!("nested field types cannot be determined."); + }; + Ok(ArrayImpl::Struct(StructArray::new( + self.from_fields(fields)?, + array + .columns() + .iter() + .map(|a| self.from_array(a).map(Arc::new)) + .try_collect()?, + (0..array.len()).map(|i| array.is_valid(i)).collect(), + ))) } } @@ -836,6 +697,7 @@ converts!(TimeArray, arrow_array::Time64MicrosecondArray, @map); converts!(TimestampArray, arrow_array::TimestampMicrosecondArray, @map); converts!(TimestamptzArray, arrow_array::TimestampMicrosecondArray, @map); converts!(IntervalArray, arrow_array::IntervalMonthDayNanoArray, @map); +converts!(SerialArray, arrow_array::Int64Array, @map); /// Converts RisingWave value from and into Arrow value. pub trait FromIntoArrow { @@ -845,6 +707,18 @@ pub trait FromIntoArrow { fn into_arrow(self) -> Self::ArrowType; } +impl FromIntoArrow for Serial { + type ArrowType = i64; + + fn from_arrow(value: Self::ArrowType) -> Self { + value.into() + } + + fn into_arrow(self) -> Self::ArrowType { + self.into() + } +} + impl FromIntoArrow for F32 { type ArrowType = f32; @@ -973,6 +847,17 @@ impl From<&DecimalArray> for arrow_array::LargeBinaryArray { } } +impl From<&DecimalArray> for arrow_array::StringArray { + fn from(array: &DecimalArray) -> Self { + let mut builder = + arrow_array::builder::StringBuilder::with_capacity(array.len(), array.len() * 8); + for value in array.iter() { + builder.append_option(value.map(|d| d.to_string())); + } + builder.finish() + } +} + // This arrow decimal type is used by iceberg source to read iceberg decimal into RW decimal. impl TryFrom<&arrow_array::Decimal128Array> for DecimalArray { type Error = ArrayError; @@ -1020,6 +905,57 @@ impl TryFrom<&arrow_array::LargeBinaryArray> for DecimalArray { } } +impl TryFrom<&arrow_array::StringArray> for DecimalArray { + type Error = ArrayError; + + fn try_from(array: &arrow_array::StringArray) -> Result { + array + .iter() + .map(|o| { + o.map(|s| { + s.parse() + .map_err(|_| ArrayError::from_arrow(format!("invalid decimal: {s:?}"))) + }) + .transpose() + }) + .try_collect() + } +} + +impl From<&JsonbArray> for arrow_array::StringArray { + fn from(array: &JsonbArray) -> Self { + let mut builder = + arrow_array::builder::StringBuilder::with_capacity(array.len(), array.len() * 16); + for value in array.iter() { + match value { + Some(jsonb) => { + write!(&mut builder, "{}", jsonb).unwrap(); + builder.append_value(""); + } + None => builder.append_null(), + } + } + builder.finish() + } +} + +impl TryFrom<&arrow_array::StringArray> for JsonbArray { + type Error = ArrayError; + + fn try_from(array: &arrow_array::StringArray) -> Result { + array + .iter() + .map(|o| { + o.map(|s| { + s.parse() + .map_err(|_| ArrayError::from_arrow(format!("invalid json: {s}"))) + }) + .transpose() + }) + .try_collect() + } +} + impl From<&JsonbArray> for arrow_array::LargeStringArray { fn from(array: &JsonbArray) -> Self { let mut builder = @@ -1088,195 +1024,8 @@ impl From<&arrow_array::Decimal256Array> for Int256Array { } } -impl TryFrom<&ListArray> for arrow_array::ListArray { - type Error = ArrayError; - - fn try_from(array: &ListArray) -> Result { - use arrow_array::builder::*; - fn build( - array: &ListArray, - a: &A, - builder: B, - mut append: F, - ) -> arrow_array::ListArray - where - A: Array, - B: arrow_array::builder::ArrayBuilder, - F: FnMut(&mut B, Option>), - { - let mut builder = ListBuilder::with_capacity(builder, a.len()); - for i in 0..array.len() { - for j in array.offsets[i]..array.offsets[i + 1] { - append(builder.values(), a.value_at(j as usize)); - } - builder.append(!array.is_null(i)); - } - builder.finish() - } - Ok(match &*array.value { - ArrayImpl::Int16(a) => build(array, a, Int16Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - ArrayImpl::Int32(a) => build(array, a, Int32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - ArrayImpl::Int64(a) => build(array, a, Int64Builder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }), - - ArrayImpl::Float32(a) => { - build(array, a, Float32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|f| f.0)) - }) - } - ArrayImpl::Float64(a) => { - build(array, a, Float64Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|f| f.0)) - }) - } - ArrayImpl::Utf8(a) => build( - array, - a, - StringBuilder::with_capacity(a.len(), a.data().len()), - |b, v| b.append_option(v), - ), - ArrayImpl::Int256(a) => build( - array, - a, - Decimal256Builder::with_capacity(a.len()).with_data_type( - arrow_schema::DataType::Decimal256(arrow_schema::DECIMAL256_MAX_PRECISION, 0), - ), - |b, v| b.append_option(v.map(Into::into)), - ), - ArrayImpl::Bool(a) => { - build(array, a, BooleanBuilder::with_capacity(a.len()), |b, v| { - b.append_option(v) - }) - } - ArrayImpl::Decimal(a) => build( - array, - a, - LargeBinaryBuilder::with_capacity(a.len(), a.len() * 8), - |b, v| b.append_option(v.map(|d| d.to_string())), - ), - ArrayImpl::Interval(a) => build( - array, - a, - IntervalMonthDayNanoBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Date(a) => build(array, a, Date32Builder::with_capacity(a.len()), |b, v| { - b.append_option(v.map(|d| d.into_arrow())) - }), - ArrayImpl::Timestamp(a) => build( - array, - a, - TimestampMicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Timestamptz(a) => build( - array, - a, - TimestampMicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Time(a) => build( - array, - a, - Time64MicrosecondBuilder::with_capacity(a.len()), - |b, v| b.append_option(v.map(|d| d.into_arrow())), - ), - ArrayImpl::Jsonb(a) => build( - array, - a, - LargeStringBuilder::with_capacity(a.len(), a.len() * 16), - |b, v| b.append_option(v.map(|j| j.to_string())), - ), - ArrayImpl::Serial(_) => todo!("list of serial"), - ArrayImpl::Struct(a) => { - let values = Arc::new(arrow_array::StructArray::try_from(a)?); - arrow_array::ListArray::new( - Arc::new(arrow_schema::Field::new( - "item", - a.data_type().try_into()?, - true, - )), - arrow_buffer::OffsetBuffer::new(arrow_buffer::ScalarBuffer::from( - array - .offsets() - .iter() - .map(|o| *o as i32) - .collect::>(), - )), - values, - Some(array.null_bitmap().into()), - ) - } - ArrayImpl::List(_) => todo!("list of list"), - ArrayImpl::Bytea(a) => build( - array, - a, - BinaryBuilder::with_capacity(a.len(), a.data().len()), - |b, v| b.append_option(v), - ), - }) - } -} - -impl TryFrom<&arrow_array::ListArray> for ListArray { - type Error = ArrayError; - - fn try_from(array: &arrow_array::ListArray) -> Result { - use arrow_array::Array; - Ok(ListArray { - value: Box::new(ArrayImpl::try_from(array.values())?), - bitmap: match array.nulls() { - Some(nulls) => nulls.iter().collect(), - None => Bitmap::ones(array.len()), - }, - offsets: array.offsets().iter().map(|o| *o as u32).collect(), - }) - } -} - -impl TryFrom<&StructArray> for arrow_array::StructArray { - type Error = ArrayError; - - fn try_from(array: &StructArray) -> Result { - Ok(arrow_array::StructArray::new( - array.data_type().as_struct().try_into()?, - array - .fields() - .map(|arr| arr.as_ref().try_into()) - .try_collect::<_, _, ArrayError>()?, - Some(array.null_bitmap().into()), - )) - } -} - -impl TryFrom<&arrow_array::StructArray> for StructArray { - type Error = ArrayError; - - fn try_from(array: &arrow_array::StructArray) -> Result { - use arrow_array::Array; - let arrow_schema::DataType::Struct(fields) = array.data_type() else { - panic!("nested field types cannot be determined."); - }; - Ok(StructArray::new( - fields.into(), - array - .columns() - .iter() - .map(|a| ArrayImpl::try_from(a).map(Arc::new)) - .try_collect()?, - (0..array.len()).map(|i| !array.is_null(i)).collect(), - )) - } -} - #[cfg(test)] mod tests { - use super::arrow_array::Array as _; use super::*; #[test] @@ -1293,6 +1042,20 @@ mod tests { assert_eq!(I16Array::from(&arrow), array); } + #[test] + fn i32() { + let array = I32Array::from_iter([None, Some(-7), Some(25)]); + let arrow = arrow_array::Int32Array::from(&array); + assert_eq!(I32Array::from(&arrow), array); + } + + #[test] + fn i64() { + let array = I64Array::from_iter([None, Some(-7), Some(25)]); + let arrow = arrow_array::Int64Array::from(&array); + assert_eq!(I64Array::from(&arrow), array); + } + #[test] fn f32() { let array = F32Array::from_iter([None, Some(-7.0), Some(25.0)]); @@ -1300,6 +1063,13 @@ mod tests { assert_eq!(F32Array::from(&arrow), array); } + #[test] + fn f64() { + let array = F64Array::from_iter([None, Some(-7.0), Some(25.0)]); + let arrow = arrow_array::Float64Array::from(&array); + assert_eq!(F64Array::from(&arrow), array); + } + #[test] fn date() { let array = DateArray::from_iter([ @@ -1352,6 +1122,13 @@ mod tests { assert_eq!(Utf8Array::from(&arrow), array); } + #[test] + fn binary() { + let array = BytesArray::from_iter([None, Some("array".as_bytes())]); + let arrow = arrow_array::BinaryArray::from(&array); + assert_eq!(BytesArray::from(&arrow), array); + } + #[test] fn decimal() { let array = DecimalArray::from_iter([ @@ -1364,6 +1141,9 @@ mod tests { ]); let arrow = arrow_array::LargeBinaryArray::from(&array); assert_eq!(DecimalArray::try_from(&arrow).unwrap(), array); + + let arrow = arrow_array::StringArray::from(&array); + assert_eq!(DecimalArray::try_from(&arrow).unwrap(), array); } #[test] @@ -1378,6 +1158,9 @@ mod tests { ]); let arrow = arrow_array::LargeStringArray::from(&array); assert_eq!(JsonbArray::try_from(&arrow).unwrap(), array); + + let arrow = arrow_array::StringArray::from(&array); + assert_eq!(JsonbArray::try_from(&arrow).unwrap(), array); } #[test] @@ -1403,62 +1186,4 @@ mod tests { let arrow = arrow_array::Decimal256Array::from(&array); assert_eq!(Int256Array::from(&arrow), array); } - - #[test] - fn struct_array() { - // Empty array - risingwave to arrow conversion. - let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0)); - assert_eq!( - arrow_array::StructArray::try_from(&test_arr).unwrap().len(), - 0 - ); - - // Empty array - arrow to risingwave conversion. - let test_arr_2 = arrow_array::StructArray::from(vec![]); - assert_eq!(StructArray::try_from(&test_arr_2).unwrap().len(), 0); - - // Struct array with primitive types. arrow to risingwave conversion. - let test_arrow_struct_array = arrow_array::StructArray::try_from(vec![ - ( - "a", - Arc::new(arrow_array::BooleanArray::from(vec![ - Some(false), - Some(false), - Some(true), - None, - ])) as arrow_array::ArrayRef, - ), - ( - "b", - Arc::new(arrow_array::Int32Array::from(vec![ - Some(42), - Some(28), - Some(19), - None, - ])) as arrow_array::ArrayRef, - ), - ]) - .unwrap(); - let actual_risingwave_struct_array = - StructArray::try_from(&test_arrow_struct_array).unwrap(); - let expected_risingwave_struct_array = StructArray::new( - StructType::new(vec![("a", DataType::Boolean), ("b", DataType::Int32)]), - vec![ - BoolArray::from_iter([Some(false), Some(false), Some(true), None]).into_ref(), - I32Array::from_iter([Some(42), Some(28), Some(19), None]).into_ref(), - ], - [true, true, true, true].into_iter().collect(), - ); - assert_eq!( - expected_risingwave_struct_array, - actual_risingwave_struct_array - ); - } - - #[test] - fn list() { - let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]); - let arrow = arrow_array::ListArray::try_from(&array).unwrap(); - assert_eq!(ListArray::try_from(&arrow).unwrap(), array); - } } diff --git a/src/common/src/array/arrow/mod.rs b/src/common/src/array/arrow/mod.rs index 4baea60f11b3e..cb726721c867b 100644 --- a/src/common/src/array/arrow/mod.rs +++ b/src/common/src/array/arrow/mod.rs @@ -16,8 +16,6 @@ mod arrow_default; mod arrow_deltalake; mod arrow_iceberg; -pub use arrow_default::{ - to_record_batch_with_schema, ToArrowArrayWithTypeConvert, ToArrowTypeConvert, -}; -pub use arrow_deltalake::to_deltalake_record_batch_with_schema; -pub use arrow_iceberg::{iceberg_to_arrow_type, to_iceberg_record_batch_with_schema}; +pub use arrow_default::{FromArrow, ToArrow, UdfArrowConvert}; +pub use arrow_deltalake::DeltaLakeConvert; +pub use arrow_iceberg::IcebergArrowConvert; diff --git a/src/common/src/array/bytes_array.rs b/src/common/src/array/bytes_array.rs index 1257730b63c96..2019c37271919 100644 --- a/src/common/src/array/bytes_array.rs +++ b/src/common/src/array/bytes_array.rs @@ -108,12 +108,6 @@ impl Array for BytesArray { } } -impl BytesArray { - pub(super) fn data(&self) -> &[u8] { - &self.data - } -} - impl<'a> FromIterator> for BytesArray { fn from_iter>>(iter: I) -> Self { let iter = iter.into_iter(); diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index dae9a4a94bc93..748ab6777fa49 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -239,6 +239,11 @@ impl ListArray { } } + /// Return the inner array of the list array. + pub fn values(&self) -> &ArrayImpl { + &self.value + } + pub fn from_protobuf(array: &PbArray) -> ArrayResult { ensure!( array.values.is_empty(), diff --git a/src/common/src/array/mod.rs b/src/common/src/array/mod.rs index 93f6255038e9e..268c518b70c01 100644 --- a/src/common/src/array/mod.rs +++ b/src/common/src/array/mod.rs @@ -14,11 +14,7 @@ //! `Array` defines all in-memory representations of vectorized execution framework. -mod arrow; -pub use arrow::{ - iceberg_to_arrow_type, to_deltalake_record_batch_with_schema, - to_iceberg_record_batch_with_schema, to_record_batch_with_schema, -}; +pub mod arrow; mod bool_array; pub mod bytes_array; mod chrono_array; diff --git a/src/common/src/array/utf8_array.rs b/src/common/src/array/utf8_array.rs index eddcc50bb8ec2..72068d80733a2 100644 --- a/src/common/src/array/utf8_array.rs +++ b/src/common/src/array/utf8_array.rs @@ -112,10 +112,6 @@ impl Utf8Array { } builder.finish() } - - pub(super) fn data(&self) -> &[u8] { - self.bytes.data() - } } /// `Utf8ArrayBuilder` use `&str` to build an `Utf8Array`. @@ -297,12 +293,6 @@ mod tests { let array = Utf8Array::from_iter(&input); assert_eq!(array.len(), input.len()); - - assert_eq!( - array.bytes.data().len(), - input.iter().map(|s| s.unwrap_or("").len()).sum::() - ); - assert_eq!(input, array.iter().collect_vec()); } From c37dfc765419376193ee27aecb9bd23c30d2ee90 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 30 Apr 2024 23:30:32 +0800 Subject: [PATCH 02/27] fix usages Signed-off-by: Runji Wang --- src/batch/src/executor/iceberg_scan.rs | 9 +--- src/common/src/array/arrow/arrow_impl.rs | 5 ++- src/connector/src/sink/deltalake.rs | 6 ++- src/connector/src/sink/iceberg/mod.rs | 38 +++++++---------- .../src/source/pulsar/source/reader.rs | 6 ++- src/expr/core/src/expr/expr_udf.rs | 35 +++++++--------- src/expr/core/src/table_function/mod.rs | 2 +- .../core/src/table_function/user_defined.rs | 38 ++++++++--------- src/expr/impl/src/scalar/external/iceberg.rs | 42 ++++++++++++------- src/frontend/src/handler/create_function.rs | 11 ++--- src/frontend/src/handler/create_sink.rs | 7 ++-- src/frontend/src/handler/create_source.rs | 10 ++--- 12 files changed, 101 insertions(+), 108 deletions(-) diff --git a/src/batch/src/executor/iceberg_scan.rs b/src/batch/src/executor/iceberg_scan.rs index e20b1aaedb8fe..31abcc73a2aac 100644 --- a/src/batch/src/executor/iceberg_scan.rs +++ b/src/batch/src/executor/iceberg_scan.rs @@ -13,13 +13,13 @@ // limitations under the License. use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::Arc; use anyhow::anyhow; use arrow_array::RecordBatch; use futures_async_stream::try_stream; use futures_util::stream::StreamExt; use icelake::io::{FileScan, TableScan}; +use risingwave_common::array::arrow::{FromArrow, IcebergArrowConvert}; use risingwave_common::catalog::Schema; use risingwave_connector::sink::iceberg::IcebergConfig; @@ -150,12 +150,7 @@ impl IcebergScanExecutor { } fn record_batch_to_chunk(record_batch: RecordBatch) -> Result { - let mut columns = Vec::with_capacity(record_batch.num_columns()); - for array in record_batch.columns() { - let column = Arc::new(array.try_into()?); - columns.push(column); - } - Ok(DataChunk::new(columns, record_batch.num_rows())) + Ok(IcebergArrowConvert.from_record_batch(&record_batch)?) } } diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index b29f62eaf2b10..9a48f4583b6c2 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -409,6 +409,7 @@ pub trait ToArrow { } /// Defines how to convert Arrow arrays to RisingWave arrays. +#[allow(clippy::wrong_self_convention)] pub trait FromArrow { /// Converts Arrow `RecordBatch` to RisingWave `DataChunk`. fn from_record_batch(&self, batch: &arrow_array::RecordBatch) -> Result { @@ -469,12 +470,12 @@ pub trait FromArrow { }) } - /// Converts Arrow LargeUtf8 type to RisingWave data type. + /// Converts Arrow `LargeUtf8` type to RisingWave data type. fn from_large_utf8(&self) -> Result { Ok(DataType::Jsonb) } - /// Converts Arrow LargeBinary type to RisingWave data type. + /// Converts Arrow `LargeBinary` type to RisingWave data type. fn from_large_binary(&self) -> Result { Ok(DataType::Decimal) } diff --git a/src/connector/src/sink/deltalake.rs b/src/connector/src/sink/deltalake.rs index 388a75c264a71..e81bb141a4c37 100644 --- a/src/connector/src/sink/deltalake.rs +++ b/src/connector/src/sink/deltalake.rs @@ -25,7 +25,8 @@ use deltalake::table::builder::s3_storage_options::{ }; use deltalake::writer::{DeltaWriter, RecordBatchWriter}; use deltalake::DeltaTable; -use risingwave_common::array::{to_deltalake_record_batch_with_schema, StreamChunk}; +use risingwave_common::array::arrow::DeltaLakeConvert; +use risingwave_common::array::StreamChunk; use risingwave_common::bail; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::Schema; @@ -388,7 +389,8 @@ impl DeltaLakeSinkWriter { } async fn write(&mut self, chunk: StreamChunk) -> Result<()> { - let a = to_deltalake_record_batch_with_schema(self.dl_schema.clone(), &chunk) + let a = DeltaLakeConvert + .to_record_batch(self.dl_schema.clone(), &chunk) .context("convert record batch error") .map_err(SinkError::DeltaLake)?; self.writer.write(a).await?; diff --git a/src/connector/src/sink/iceberg/mod.rs b/src/connector/src/sink/iceberg/mod.rs index 22972d5629076..f2e3b45d5915c 100644 --- a/src/connector/src/sink/iceberg/mod.rs +++ b/src/connector/src/sink/iceberg/mod.rs @@ -41,9 +41,8 @@ use icelake::transaction::Transaction; use icelake::types::{data_file_from_json, data_file_to_json, Any, DataFile}; use icelake::{Table, TableIdentifier}; use itertools::Itertools; -use risingwave_common::array::{ - iceberg_to_arrow_type, to_iceberg_record_batch_with_schema, Op, StreamChunk, -}; +use risingwave_common::array::arrow::{IcebergArrowConvert, ToArrow}; +use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::bail; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::Schema; @@ -475,7 +474,7 @@ impl IcebergSink { .try_into() .map_err(|err: icelake::Error| SinkError::Iceberg(anyhow!(err)))?; - try_matches_arrow_schema(&sink_schema, &iceberg_schema, false) + try_matches_arrow_schema(&sink_schema, &iceberg_schema) .map_err(|err| SinkError::Iceberg(anyhow!(err)))?; Ok(table) @@ -797,14 +796,15 @@ impl SinkWriter for IcebergWriter { let filters = chunk.visibility() & ops.iter().map(|op| *op == Op::Insert).collect::(); chunk.set_visibility(filters); - let chunk = - to_iceberg_record_batch_with_schema(self.schema.clone(), &chunk.compact()) - .map_err(|err| SinkError::Iceberg(anyhow!(err)))?; + let chunk = IcebergArrowConvert + .to_record_batch(self.schema.clone(), &chunk.compact()) + .map_err(|err| SinkError::Iceberg(anyhow!(err)))?; writer.write(chunk).await?; } IcebergWriterEnum::Upsert(writer) => { - let chunk = to_iceberg_record_batch_with_schema(self.schema.clone(), &chunk) + let chunk = IcebergArrowConvert + .to_record_batch(self.schema.clone(), &chunk) .map_err(|err| SinkError::Iceberg(anyhow!(err)))?; writer @@ -1002,11 +1002,9 @@ impl SinkCommitCoordinator for IcebergSinkCommitter { } /// Try to match our schema with iceberg schema. -/// `for_source` = true means the schema is used for source, otherwise it's used for sink. pub fn try_matches_arrow_schema( rw_schema: &Schema, arrow_schema: &ArrowSchema, - for_source: bool, ) -> anyhow::Result<()> { if rw_schema.fields.len() != arrow_schema.fields().len() { bail!( @@ -1029,17 +1027,11 @@ pub fn try_matches_arrow_schema( .ok_or_else(|| anyhow!("Field {} not found in our schema", arrow_field.name()))?; // Iceberg source should be able to read iceberg decimal type. - // Since the arrow type default conversion is used by udf, in udf, decimal is converted to - // large binary type which is not compatible with iceberg decimal type, - // so we need to convert it to decimal type manually. - let converted_arrow_data_type = if for_source - && matches!(our_field_type, risingwave_common::types::DataType::Decimal) - { - // RisingWave decimal type cannot specify precision and scale, so we use the default value. - ArrowDataType::Decimal128(38, 0) - } else { - iceberg_to_arrow_type(our_field_type).map_err(|e| anyhow!(e))? - }; + let converted_arrow_data_type = IcebergArrowConvert + .to_arrow_field("", our_field_type) + .map_err(|e| anyhow!(e))? + .data_type() + .clone(); let compatible = match (&converted_arrow_data_type, arrow_field.data_type()) { (ArrowDataType::Decimal128(_, _), ArrowDataType::Decimal128(_, _)) => true, @@ -1080,7 +1072,7 @@ mod test { ArrowField::new("c", ArrowDataType::Int32, false), ]); - try_matches_arrow_schema(&risingwave_schema, &arrow_schema, false).unwrap(); + try_matches_arrow_schema(&risingwave_schema, &arrow_schema).unwrap(); let risingwave_schema = Schema::new(vec![ Field::with_name(DataType::Int32, "d"), @@ -1094,7 +1086,7 @@ mod test { ArrowField::new("d", ArrowDataType::Int32, false), ArrowField::new("c", ArrowDataType::Int32, false), ]); - try_matches_arrow_schema(&risingwave_schema, &arrow_schema, false).unwrap(); + try_matches_arrow_schema(&risingwave_schema, &arrow_schema).unwrap(); } #[test] diff --git a/src/connector/src/source/pulsar/source/reader.rs b/src/connector/src/source/pulsar/source/reader.rs index ce808b96200cf..967874b62c335 100644 --- a/src/connector/src/source/pulsar/source/reader.rs +++ b/src/connector/src/source/pulsar/source/reader.rs @@ -27,7 +27,8 @@ use itertools::Itertools; use pulsar::consumer::InitialPosition; use pulsar::message::proto::MessageIdData; use pulsar::{Consumer, ConsumerBuilder, ConsumerOptions, Pulsar, SubType, TokioExecutor}; -use risingwave_common::array::{DataChunk, StreamChunk}; +use risingwave_common::array::arrow::{FromArrow, IcebergArrowConvert}; +use risingwave_common::array::StreamChunk; use risingwave_common::catalog::ROWID_PREFIX; use risingwave_common::{bail, ensure}; use thiserror_ext::AsReport; @@ -508,7 +509,8 @@ impl PulsarIcebergReader { offsets.push(offset); } - let data_chunk = DataChunk::try_from(&record_batch.project(&field_indices)?) + let data_chunk = IcebergArrowConvert + .from_record_batch(&record_batch.project(&field_indices)?) .context("failed to convert arrow record batch to data chunk")?; let stream_chunk = StreamChunk::from(data_chunk); diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index ed7d597cce52a..b9103b62649e5 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -13,13 +13,12 @@ // limitations under the License. use std::collections::HashMap; -use std::convert::TryFrom; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::{Arc, LazyLock, Weak}; use std::time::Duration; use anyhow::{Context, Error}; -use arrow_schema::{Field, Fields, Schema}; +use arrow_schema::{Fields, Schema, SchemaRef}; use arrow_udf_js::{CallMode as JsCallMode, Runtime as JsRuntime}; #[cfg(feature = "embedded-deno-udf")] use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; @@ -29,14 +28,14 @@ use arrow_udf_wasm::Runtime as WasmRuntime; use await_tree::InstrumentAwait; use cfg_or_panic::cfg_or_panic; use moka::sync::Cache; -use risingwave_common::array::{ArrayError, ArrayRef, DataChunk}; +use risingwave_common::array::arrow::{FromArrow, ToArrow, UdfArrowConvert}; +use risingwave_common::array::{ArrayRef, DataChunk}; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; use risingwave_expr::expr_context::FRAGMENT_ID; use risingwave_pb::expr::ExprNode; use risingwave_udf::metrics::GLOBAL_METRICS; use risingwave_udf::ArrowFlightUdfClient; -use thiserror_ext::AsReport; use super::{BoxedExpression, Build}; use crate::expr::Expression; @@ -47,8 +46,7 @@ pub struct UserDefinedFunction { children: Vec, arg_types: Vec, return_type: DataType, - #[expect(dead_code)] - arg_schema: Arc, + arg_schema: SchemaRef, imp: UdfImpl, identifier: String, span: await_tree::Span, @@ -117,7 +115,7 @@ impl Expression for UserDefinedFunction { impl UserDefinedFunction { async fn eval_inner(&self, input: &DataChunk) -> Result { // this will drop invisible rows - let arrow_input = arrow_array::RecordBatch::try_from(input)?; + let arrow_input = UdfArrowConvert.to_record_batch(self.arg_schema.clone(), input)?; // metrics let metrics = &*GLOBAL_METRICS; @@ -225,7 +223,7 @@ impl UserDefinedFunction { ); } - let output = DataChunk::try_from(&arrow_output)?; + let output = UdfArrowConvert.from_record_batch(&arrow_output)?; let output = output.uncompact(input.visibility().clone()); let Some(array) = output.columns().first() else { @@ -251,6 +249,11 @@ impl Build for UserDefinedFunction { let return_type = DataType::from(prost.get_return_type().unwrap()); let udf = prost.get_rex_node().unwrap().as_udf().unwrap(); + let arrow_return_type = UdfArrowConvert + .to_arrow_field("", &return_type)? + .data_type() + .clone(); + #[cfg(not(feature = "embedded-deno-udf"))] let runtime = "quickjs"; @@ -280,7 +283,7 @@ impl Build for UserDefinedFunction { ); rt.add_function( identifier, - arrow_schema::DataType::try_from(&return_type)?, + arrow_return_type, JsCallMode::CalledOnNullInput, &body, )?; @@ -321,7 +324,7 @@ impl Build for UserDefinedFunction { futures::executor::block_on(rt.add_function( identifier, - arrow_schema::DataType::try_from(&return_type)?, + arrow_return_type, DenoCallMode::CalledOnNullInput, &body, ))?; @@ -334,7 +337,7 @@ impl Build for UserDefinedFunction { let body = udf.get_body()?; rt.add_function( identifier, - arrow_schema::DataType::try_from(&return_type)?, + arrow_return_type, PythonCallMode::CalledOnNullInput, body, )?; @@ -352,15 +355,7 @@ impl Build for UserDefinedFunction { let arg_schema = Arc::new(Schema::new( udf.arg_types .iter() - .map::, _>(|t| { - Ok(Field::new( - "", - DataType::from(t).try_into().map_err(|e: ArrayError| { - risingwave_udf::Error::unsupported(e.to_report_string()) - })?, - true, - )) - }) + .map(|t| UdfArrowConvert.to_arrow_field("", &DataType::from(t))) .try_collect::()?, )); diff --git a/src/expr/core/src/table_function/mod.rs b/src/expr/core/src/table_function/mod.rs index 9d2d747cc4ed5..f1593cc4015c8 100644 --- a/src/expr/core/src/table_function/mod.rs +++ b/src/expr/core/src/table_function/mod.rs @@ -16,7 +16,6 @@ use either::Either; use futures_async_stream::try_stream; use futures_util::stream::BoxStream; use futures_util::StreamExt; -use itertools::Itertools; use risingwave_common::array::{Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk}; use risingwave_common::types::{DataType, DatumRef}; use risingwave_pb::expr::project_set_select_item::SelectItem; @@ -129,6 +128,7 @@ pub fn build( chunk_size: usize, children: Vec, ) -> Result { + use itertools::Itertools; let args = children.iter().map(|t| t.return_type()).collect_vec(); let desc = crate::sig::FUNCTION_REGISTRY .get(func, &args, &return_type) diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index b65ee5e77758b..e5334ecb47932 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use anyhow::Context; use arrow_array::RecordBatch; -use arrow_schema::{Field, Fields, Schema, SchemaRef}; +use arrow_schema::{Fields, Schema, SchemaRef}; use arrow_udf_js::{CallMode as JsCallMode, Runtime as JsRuntime}; #[cfg(feature = "embedded-deno-udf")] use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; @@ -24,9 +24,9 @@ use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; use arrow_udf_python::{CallMode as PythonCallMode, Runtime as PythonRuntime}; use cfg_or_panic::cfg_or_panic; use futures_util::stream; -use risingwave_common::array::{ArrayError, DataChunk, I32Array}; +use risingwave_common::array::arrow::{FromArrow, ToArrow, UdfArrowConvert}; +use risingwave_common::array::{DataChunk, I32Array}; use risingwave_common::bail; -use thiserror_ext::AsReport; use super::*; use crate::expr::expr_udf::UdfImpl; @@ -34,7 +34,6 @@ use crate::expr::expr_udf::UdfImpl; #[derive(Debug)] pub struct UserDefinedTableFunction { children: Vec, - #[allow(dead_code)] arg_schema: SchemaRef, return_type: DataType, client: UdfImpl, @@ -109,9 +108,9 @@ impl UserDefinedTableFunction { let direct_input = DataChunk::new(columns, input.visibility().clone()); // compact the input chunk and record the row mapping - let visible_rows = direct_input.visibility().iter_ones().collect_vec(); - let compacted_input = direct_input.compact_cow(); - let arrow_input = RecordBatch::try_from(compacted_input.as_ref())?; + let visible_rows = direct_input.visibility().iter_ones().collect::>(); + // this will drop invisible rows + let arrow_input = UdfArrowConvert.to_record_batch(self.arg_schema.clone(), input)?; // call UDTF #[for_await] @@ -119,7 +118,7 @@ impl UserDefinedTableFunction { .client .call_table_function(&self.identifier, arrow_input) { - let output = DataChunk::try_from(&res?)?; + let output = UdfArrowConvert.from_record_batch(&res?)?; self.check_output(&output)?; // we send the compacted input to UDF, so we need to map the row indices back to the @@ -182,21 +181,18 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result, _>(|t| { - Ok(Field::new( - "", - DataType::from(t).try_into().map_err(|e: ArrayError| { - risingwave_udf::Error::unsupported(e.to_report_string()) - })?, - true, - )) - }) - .try_collect::<_, Fields, _>()?, + .map(|t| UdfArrowConvert.to_arrow_field("", &DataType::from(t))) + .try_collect::()?, )); let identifier = udtf.get_identifier()?; let return_type = DataType::from(prost.get_return_type()?); + let arrow_return_type = UdfArrowConvert + .to_arrow_field("", &return_type)? + .data_type() + .clone(); + #[cfg(not(feature = "embedded-deno-udf"))] let runtime = "quickjs"; @@ -224,7 +220,7 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result Result Result Result { @@ -91,15 +92,20 @@ fn build(return_type: DataType, mut children: Vec) -> Result) -> Result) -> Result arrow_schema::Field { - arrow_schema::Field::new("", data_type, true) + fn to_field(data_type: &DataType) -> Result { + Ok(UdfArrowConvert.to_arrow_field("", data_type)?) } let args = arrow_schema::Schema::new( arg_types .iter() - .map::, _>(|t| Ok(to_field(t.try_into()?))) + .map(|t| to_field(t)) .try_collect::<_, Fields, _>()?, ); let returns = arrow_schema::Schema::new(match kind { - Kind::Scalar(_) => vec![to_field(return_type.clone().try_into()?)], + Kind::Scalar(_) => vec![to_field(&return_type)?], Kind::Table(_) => vec![ arrow_schema::Field::new("row_index", arrow_schema::DataType::Int32, true), - to_field(return_type.clone().try_into()?), + to_field(&return_type)?, ], _ => unreachable!(), }); diff --git a/src/frontend/src/handler/create_sink.rs b/src/frontend/src/handler/create_sink.rs index bed409de178f1..50f6914eee252 100644 --- a/src/frontend/src/handler/create_sink.rs +++ b/src/frontend/src/handler/create_sink.rs @@ -22,8 +22,9 @@ use either::Either; use itertools::Itertools; use maplit::{convert_args, hashmap}; use pgwire::pg_response::{PgResponse, StatementType}; +use risingwave_common::array::arrow::{FromArrow, IcebergArrowConvert}; use risingwave_common::catalog::{ConnectionId, DatabaseId, SchemaId, TableId, UserId}; -use risingwave_common::types::{DataType, Datum}; +use risingwave_common::types::Datum; use risingwave_common::util::value_encoding::DatumFromProtoExt; use risingwave_common::{bail, catalog}; use risingwave_connector::sink::catalog::{SinkCatalog, SinkFormatDesc, SinkType}; @@ -351,14 +352,14 @@ async fn get_partition_compute_info_for_iceberg( }) .collect::>>()?; - let DataType::Struct(partition_type) = arrow_type.into() else { + let ArrowDataType::Struct(partition_type) = arrow_type else { return Err(RwError::from(ErrorCode::SinkError( "Partition type of iceberg should be a struct type".into(), ))); }; Ok(Some(PartitionComputeInfo::Iceberg(IcebergPartitionInfo { - partition_type, + partition_type: IcebergArrowConvert.from_fields(&partition_type)?, partition_fields, }))) } diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index 0830cdb5392de..6a19cb13fad31 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -21,6 +21,7 @@ use either::Either; use itertools::Itertools; use maplit::{convert_args, hashmap}; use pgwire::pg_response::{PgResponse, StatementType}; +use risingwave_common::array::arrow::{FromArrow, IcebergArrowConvert}; use risingwave_common::bail_not_implemented; use risingwave_common::catalog::{ is_column_ids_dedup, ColumnCatalog, ColumnDesc, ColumnId, Schema, TableId, @@ -1219,11 +1220,10 @@ pub async fn extract_iceberg_columns( .iter() .enumerate() .map(|(i, field)| { - let data_type = field.data_type().clone(); let column_desc = ColumnDesc::named( field.name(), ColumnId::new((i as u32).try_into().unwrap()), - data_type.into(), + IcebergArrowConvert.from_field(field).unwrap(), ); ColumnCatalog { column_desc, @@ -1288,11 +1288,7 @@ pub async fn check_iceberg_source( .collect::>(); let new_iceberg_schema = arrow_schema::Schema::new(new_iceberg_field); - risingwave_connector::sink::iceberg::try_matches_arrow_schema( - &schema, - &new_iceberg_schema, - true, - )?; + risingwave_connector::sink::iceberg::try_matches_arrow_schema(&schema, &new_iceberg_schema)?; Ok(()) } From 7ef80fed81c45f1f31c012f8c7706265dc9b0c9f Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 6 May 2024 22:45:20 +0800 Subject: [PATCH 03/27] fix unit test Signed-off-by: Runji Wang --- src/common/src/array/arrow/arrow_impl.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index 9a48f4583b6c2..23267e86ea4a0 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -233,12 +233,12 @@ pub trait ToArrow { }; let values = self.to_array(field.data_type(), array.values())?; let offsets = OffsetBuffer::new(array.offsets().iter().map(|&o| o as i32).collect()); - let nulls = array.null_bitmap().into(); + let nulls = (!array.null_bitmap().all()).then(|| array.null_bitmap().into()); Ok(Arc::new(arrow_array::ListArray::new( field.clone(), offsets, values, - Some(nulls), + nulls, ))) } From c218388d1bc80b2816c02b717ae52e5a43192647 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 6 May 2024 22:46:01 +0800 Subject: [PATCH 04/27] fix clippy Signed-off-by: Runji Wang --- src/frontend/src/handler/create_function.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 6c3f9e8b3a1ec..a23632fbd62c9 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -177,7 +177,7 @@ pub async fn handle_create_function( let args = arrow_schema::Schema::new( arg_types .iter() - .map(|t| to_field(t)) + .map(to_field) .try_collect::<_, Fields, _>()?, ); let returns = arrow_schema::Schema::new(match kind { From f44db241144e051fe4d288d3d2493df0840ea6c2 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 6 May 2024 22:50:30 +0800 Subject: [PATCH 05/27] add missing data type Signed-off-by: Runji Wang --- src/common/src/array/arrow/arrow_impl.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index 23267e86ea4a0..5c1f8ac45fba3 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -494,6 +494,7 @@ pub trait FromArrow { /// Converts Arrow `Array` to RisingWave `ArrayImpl`. fn from_array(&self, array: &arrow_array::ArrayRef) -> Result { use arrow_schema::DataType::*; + use arrow_schema::IntervalUnit::*; use arrow_schema::TimeUnit::*; match array.data_type() { Boolean => self.from_bool_array(array.as_any().downcast_ref().unwrap()), @@ -508,6 +509,9 @@ pub trait FromArrow { Timestamp(Microsecond, _) => { self.from_timestampus_array(array.as_any().downcast_ref().unwrap()) } + Interval(MonthDayNano) => { + self.from_interval_array(array.as_any().downcast_ref().unwrap()) + } Utf8 => self.from_utf8_array(array.as_any().downcast_ref().unwrap()), Binary => self.from_binary_array(array.as_any().downcast_ref().unwrap()), LargeUtf8 => self.from_large_utf8_array(array.as_any().downcast_ref().unwrap()), @@ -575,6 +579,13 @@ pub trait FromArrow { Ok(ArrayImpl::Timestamp(array.into())) } + fn from_interval_array( + &self, + array: &arrow_array::IntervalMonthDayNanoArray, + ) -> Result { + Ok(ArrayImpl::Interval(array.into())) + } + fn from_utf8_array(&self, array: &arrow_array::StringArray) -> Result { Ok(ArrayImpl::Utf8(array.into())) } From 5341ee404489b7f02c365b01603767a592b81f60 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 6 May 2024 23:10:06 +0800 Subject: [PATCH 06/27] fix unit test Signed-off-by: Runji Wang --- src/common/src/array/arrow/arrow_deltalake.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/src/array/arrow/arrow_deltalake.rs b/src/common/src/array/arrow/arrow_deltalake.rs index 50eb5b86dca1b..c9f4052e2036d 100644 --- a/src/common/src/array/arrow/arrow_deltalake.rs +++ b/src/common/src/array/arrow/arrow_deltalake.rs @@ -126,7 +126,7 @@ mod test { arrow_schema::DataType::List(Arc::new(Field::new( "test", arrow_schema::DataType::Decimal128(10, 0), - false, + true, ))), false, )]); From 3c6e87ac5fd9bcfe6332470cd6772d5d9d4dd0f0 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 7 May 2024 12:06:56 +0800 Subject: [PATCH 07/27] fix udtf Signed-off-by: Runji Wang --- src/expr/core/src/table_function/user_defined.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index e5334ecb47932..4362ff27b57b9 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -110,7 +110,8 @@ impl UserDefinedTableFunction { // compact the input chunk and record the row mapping let visible_rows = direct_input.visibility().iter_ones().collect::>(); // this will drop invisible rows - let arrow_input = UdfArrowConvert.to_record_batch(self.arg_schema.clone(), input)?; + let arrow_input = + UdfArrowConvert.to_record_batch(self.arg_schema.clone(), &direct_input)?; // call UDTF #[for_await] From 9ff46254a681383d1b44c3f6535b81fb7a000b17 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 29 Apr 2024 15:02:59 +0800 Subject: [PATCH 08/27] remove java udf sdk Signed-off-by: Runji Wang --- java/pom.xml | 4 +- java/udf/CHANGELOG.md | 39 -- java/udf/README.md | 274 ---------- java/udf/pom.xml | 58 -- .../risingwave/functions/DataTypeHint.java | 23 - .../risingwave/functions/PeriodDuration.java | 29 - .../risingwave/functions/ScalarFunction.java | 53 -- .../functions/ScalarFunctionBatch.java | 61 --- .../risingwave/functions/TableFunction.java | 60 --- .../functions/TableFunctionBatch.java | 87 --- .../com/risingwave/functions/TypeUtils.java | 505 ------------------ .../com/risingwave/functions/UdfProducer.java | 108 ---- .../com/risingwave/functions/UdfServer.java | 81 --- .../functions/UserDefinedFunction.java | 23 - .../functions/UserDefinedFunctionBatch.java | 87 --- .../risingwave/functions/TestUdfServer.java | 286 ---------- .../com/risingwave/functions/UdfClient.java | 51 -- 17 files changed, 1 insertion(+), 1828 deletions(-) delete mode 100644 java/udf/CHANGELOG.md delete mode 100644 java/udf/README.md delete mode 100644 java/udf/pom.xml delete mode 100644 java/udf/src/main/java/com/risingwave/functions/DataTypeHint.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/PeriodDuration.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/ScalarFunction.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/TableFunction.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/TypeUtils.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/UdfProducer.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/UdfServer.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/UserDefinedFunction.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java delete mode 100644 java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java delete mode 100644 java/udf/src/test/java/com/risingwave/functions/UdfClient.java diff --git a/java/pom.xml b/java/pom.xml index 922c62ead69e5..f1ee457ef3b84 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -37,8 +37,6 @@ proto - udf - udf-example java-binding common-utils java-binding-integration-test @@ -572,4 +570,4 @@ https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ - + \ No newline at end of file diff --git a/java/udf/CHANGELOG.md b/java/udf/CHANGELOG.md deleted file mode 100644 index fb1f055783225..0000000000000 --- a/java/udf/CHANGELOG.md +++ /dev/null @@ -1,39 +0,0 @@ -# Changelog - -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [Unreleased] - -## [0.1.3] - 2023-12-06 - -### Fixed - -- Fix decimal type output. - -## [0.1.2] - 2023-12-04 - -### Fixed - -- Fix index-out-of-bound error when string or string list is large. -- Fix memory leak. - -## [0.1.1] - 2023-12-03 - -### Added - -- Support struct in struct and struct[] in struct. - -### Changed - -- Bump Arrow version to 14. - -### Fixed - -- Fix unconstrained decimal type. - -## [0.1.0] - 2023-09-01 - -- Initial release. \ No newline at end of file diff --git a/java/udf/README.md b/java/udf/README.md deleted file mode 100644 index 200b897b8b890..0000000000000 --- a/java/udf/README.md +++ /dev/null @@ -1,274 +0,0 @@ -# RisingWave Java UDF SDK - -This library provides a Java SDK for creating user-defined functions (UDF) in RisingWave. - -## Introduction - -RisingWave supports user-defined functions implemented as external functions. -With the RisingWave Java UDF SDK, users can define custom UDFs using Java and start a Java process as a UDF server. -RisingWave can then remotely access the UDF server to execute the defined functions. - -## Installation - -To install the RisingWave Java UDF SDK: - -```sh -git clone https://github.com/risingwavelabs/risingwave.git -cd risingwave/java/udf -mvn install -``` - -Or you can add the following dependency to your `pom.xml` file: - -```xml - - - com.risingwave - risingwave-udf - 0.1.0 - - -``` - - -## Creating a New Project - -> NOTE: You can also start from the [udf-example](../udf-example) project without creating the project from scratch. - -To create a new project using the RisingWave Java UDF SDK, follow these steps: - -```sh -mvn archetype:generate -DgroupId=com.example -DartifactId=udf-example -DarchetypeArtifactId=maven-archetype-quickstart -DarchetypeVersion=1.4 -DinteractiveMode=false -``` - -Configure your `pom.xml` file as follows: - -```xml - - - 4.0.0 - com.example - udf-example - 1.0-SNAPSHOT - - - - com.risingwave - risingwave-udf - 0.1.0 - - - -``` - -The `--add-opens` flag must be added when running unit tests through Maven: - -```xml - - - - org.apache.maven.plugins - maven-surefire-plugin - 3.0.0 - - --add-opens=java.base/java.nio=ALL-UNNAMED - - - - -``` - -## Scalar Functions - -A user-defined scalar function maps zero, one, or multiple scalar values to a new scalar value. - -In order to define a scalar function, one has to create a new class that implements the `ScalarFunction` -interface in `com.risingwave.functions` and implement exactly one evaluation method named `eval(...)`. -This method must be declared public and non-static. - -Any [data type](#data-types) listed in the data types section can be used as a parameter or return type of an evaluation method. - -Here's an example of a scalar function that calculates the greatest common divisor (GCD) of two integers: - -```java -import com.risingwave.functions.ScalarFunction; - -public class Gcd implements ScalarFunction { - public int eval(int a, int b) { - while (b != 0) { - int temp = b; - b = a % b; - a = temp; - } - return a; - } -} -``` - -> **NOTE:** Differences with Flink -> 1. The `ScalarFunction` is an interface instead of an abstract class. -> 2. Multiple overloaded `eval` methods are not supported. -> 3. Variable arguments such as `eval(Integer...)` are not supported. - -## Table Functions - -A user-defined table function maps zero, one, or multiple scalar values to one or multiple -rows (structured types). - -In order to define a table function, one has to create a new class that implements the `TableFunction` -interface in `com.risingwave.functions` and implement exactly one evaluation method named `eval(...)`. -This method must be declared public and non-static. - -The return type must be an `Iterator` of any [data type](#data-types) listed in the data types section. -Similar to scalar functions, input and output data types are automatically extracted using reflection. -This includes the generic argument T of the return value for determining an output data type. - -Here's an example of a table function that generates a series of integers: - -```java -import com.risingwave.functions.TableFunction; - -public class Series implements TableFunction { - public Iterator eval(int n) { - return java.util.stream.IntStream.range(0, n).iterator(); - } -} -``` - -> **NOTE:** Differences with Flink -> 1. The `TableFunction` is an interface instead of an abstract class. It has no generic arguments. -> 2. Instead of calling `collect` to emit a row, the `eval` method returns an `Iterator` of the output rows. -> 3. Multiple overloaded `eval` methods are not supported. -> 4. Variable arguments such as `eval(Integer...)` are not supported. -> 5. In SQL, table functions can be used in the `FROM` clause directly. `JOIN LATERAL TABLE` is not supported. - -## UDF Server - -To create a UDF server and register functions: - -```java -import com.risingwave.functions.UdfServer; - -public class App { - public static void main(String[] args) { - try (var server = new UdfServer("0.0.0.0", 8815)) { - // register functions - server.addFunction("gcd", new Gcd()); - server.addFunction("series", new Series()); - // start the server - server.start(); - server.awaitTermination(); - } catch (Exception e) { - e.printStackTrace(); - } - } -} -``` - -To run the UDF server, execute the following command: - -```sh -_JAVA_OPTIONS="--add-opens=java.base/java.nio=ALL-UNNAMED" mvn exec:java -Dexec.mainClass="com.example.App" -``` - -## Creating Functions in RisingWave - -```sql -create function gcd(int, int) returns int -as gcd using link 'http://localhost:8815'; - -create function series(int) returns table (x int) -as series using link 'http://localhost:8815'; -``` - -For more detailed information and examples, please refer to the official RisingWave [documentation](https://www.risingwave.dev/docs/current/user-defined-functions/#4-declare-your-functions-in-risingwave). - -## Using Functions in RisingWave - -Once the user-defined functions are created in RisingWave, you can use them in SQL queries just like any built-in functions. Here are a few examples: - -```sql -select gcd(25, 15); - -select * from series(10); -``` - -## Data Types - -The RisingWave Java UDF SDK supports the following data types: - -| SQL Type | Java Type | Notes | -| ---------------- | --------------------------------------- | ------------------ | -| BOOLEAN | boolean, Boolean | | -| SMALLINT | short, Short | | -| INT | int, Integer | | -| BIGINT | long, Long | | -| REAL | float, Float | | -| DOUBLE PRECISION | double, Double | | -| DECIMAL | BigDecimal | | -| DATE | java.time.LocalDate | | -| TIME | java.time.LocalTime | | -| TIMESTAMP | java.time.LocalDateTime | | -| INTERVAL | com.risingwave.functions.PeriodDuration | | -| VARCHAR | String | | -| BYTEA | byte[] | | -| JSONB | String | Use `@DataTypeHint("JSONB") String` as the type. See [example](#jsonb). | -| T[] | T'[] | `T` can be any of the above SQL types. `T'` should be the corresponding Java type.| -| STRUCT<> | user-defined class | Define a data class as the type. See [example](#struct-type). | -| ...others | | Not supported yet. | - -### JSONB - -```java -import com.google.gson.Gson; - -// Returns the i-th element of a JSON array. -public class JsonbAccess implements ScalarFunction { - static Gson gson = new Gson(); - - public @DataTypeHint("JSONB") String eval(@DataTypeHint("JSONB") String json, int index) { - if (json == null) - return null; - var array = gson.fromJson(json, Object[].class); - if (index >= array.length || index < 0) - return null; - var obj = array[index]; - return gson.toJson(obj); - } -} -``` - -```sql -create function jsonb_access(jsonb, int) returns jsonb -as jsonb_access using link 'http://localhost:8815'; -``` - -### Struct Type - -```java -// Split a socket address into host and port. -public static class IpPort implements ScalarFunction { - public static class SocketAddr { - public String host; - public short port; - } - - public SocketAddr eval(String addr) { - var socketAddr = new SocketAddr(); - var parts = addr.split(":"); - socketAddr.host = parts[0]; - socketAddr.port = Short.parseShort(parts[1]); - return socketAddr; - } -} -``` - -```sql -create function ip_port(varchar) returns struct -as ip_port using link 'http://localhost:8815'; -``` - -## Full Example - -You can checkout [udf-example](../udf-example) and use it as a template to create your own UDFs. diff --git a/java/udf/pom.xml b/java/udf/pom.xml deleted file mode 100644 index f747603ca8429..0000000000000 --- a/java/udf/pom.xml +++ /dev/null @@ -1,58 +0,0 @@ - - 4.0.0 - - com.risingwave - risingwave-udf - jar - 0.1.3-SNAPSHOT - - - risingwave-java-root - com.risingwave - 0.1.0-SNAPSHOT - ../pom.xml - - - RisingWave Java UDF SDK - https://docs.risingwave.com/docs/current/udf-java - - - - org.junit.jupiter - junit-jupiter-engine - 5.9.1 - test - - - org.apache.arrow - arrow-vector - 14.0.0 - - - org.apache.arrow - flight-core - 14.0.0 - - - org.slf4j - slf4j-api - 2.0.7 - - - org.slf4j - slf4j-simple - 2.0.7 - - - - - - kr.motd.maven - os-maven-plugin - 1.7.0 - - - - \ No newline at end of file diff --git a/java/udf/src/main/java/com/risingwave/functions/DataTypeHint.java b/java/udf/src/main/java/com/risingwave/functions/DataTypeHint.java deleted file mode 100644 index 7baf0fe4c6115..0000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/DataTypeHint.java +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.lang.annotation.*; - -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.METHOD, ElementType.FIELD, ElementType.PARAMETER}) -public @interface DataTypeHint { - String value(); -} diff --git a/java/udf/src/main/java/com/risingwave/functions/PeriodDuration.java b/java/udf/src/main/java/com/risingwave/functions/PeriodDuration.java deleted file mode 100644 index 6d704100f6f35..0000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/PeriodDuration.java +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.time.Duration; -import java.time.Period; - -/** Combination of Period and Duration. */ -public class PeriodDuration extends org.apache.arrow.vector.PeriodDuration { - public PeriodDuration(Period period, Duration duration) { - super(period, duration); - } - - PeriodDuration(org.apache.arrow.vector.PeriodDuration base) { - super(base.getPeriod(), base.getDuration()); - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/ScalarFunction.java b/java/udf/src/main/java/com/risingwave/functions/ScalarFunction.java deleted file mode 100644 index 5f3fcaf287330..0000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/ScalarFunction.java +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -/** - * Base interface for a user-defined scalar function. A user-defined scalar function maps zero, one, - * or multiple scalar values to a new scalar value. - * - *

The behavior of a {@link ScalarFunction} can be defined by implementing a custom evaluation - * method. An evaluation method must be declared publicly, not static, and named eval. - * Multiple overloaded methods named eval are not supported yet. - * - *

By default, input and output data types are automatically extracted using reflection. - * - *

The following examples show how to specify a scalar function: - * - *

{@code
- * // a function that accepts two INT arguments and computes a sum
- * class SumFunction implements ScalarFunction {
- *     public Integer eval(Integer a, Integer b) {
- *         return a + b;
- *     }
- * }
- *
- * // a function that returns a struct type
- * class StructFunction implements ScalarFunction {
- *     public static class KeyValue {
- *         public String key;
- *         public int value;
- *     }
- *
- *     public KeyValue eval(int a) {
- *         KeyValue kv = new KeyValue();
- *         kv.key = a.toString();
- *         kv.value = a;
- *         return kv;
- *     }
- * }
- * }
- */ -public interface ScalarFunction extends UserDefinedFunction {} diff --git a/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java deleted file mode 100644 index 5d837d3b370f9..0000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.lang.invoke.MethodHandle; -import java.util.Collections; -import java.util.Iterator; -import java.util.function.Function; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; - -/** Batch-processing wrapper over a user-defined scalar function. */ -class ScalarFunctionBatch extends UserDefinedFunctionBatch { - ScalarFunction function; - MethodHandle methodHandle; - Function[] processInputs; - - ScalarFunctionBatch(ScalarFunction function) { - this.function = function; - var method = Reflection.getEvalMethod(function); - this.methodHandle = Reflection.getMethodHandle(method); - this.inputSchema = TypeUtils.methodToInputSchema(method); - this.outputSchema = TypeUtils.methodToOutputSchema(method); - this.processInputs = TypeUtils.methodToProcessInputs(method); - } - - @Override - Iterator evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) { - var row = new Object[batch.getSchema().getFields().size() + 1]; - row[0] = this.function; - var outputValues = new Object[batch.getRowCount()]; - for (int i = 0; i < batch.getRowCount(); i++) { - for (int j = 0; j < row.length - 1; j++) { - var val = batch.getVector(j).getObject(i); - row[j + 1] = this.processInputs[j].apply(val); - } - try { - outputValues[i] = this.methodHandle.invokeWithArguments(row); - } catch (Throwable e) { - throw new RuntimeException(e); - } - } - var outputVector = - TypeUtils.createVector( - this.outputSchema.getFields().get(0), allocator, outputValues); - var outputBatch = VectorSchemaRoot.of(outputVector); - return Collections.singleton(outputBatch).iterator(); - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/TableFunction.java b/java/udf/src/main/java/com/risingwave/functions/TableFunction.java deleted file mode 100644 index ec5b9d214553f..0000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/TableFunction.java +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -/** - * Base interface for a user-defined table function. A user-defined table function maps zero, one, - * or multiple scalar values to zero, one, or multiple rows (or structured types). If an output - * record consists of only one field, the structured record can be omitted, and a scalar value can - * be emitted that will be implicitly wrapped into a row by the runtime. - * - *

The behavior of a {@link TableFunction} can be defined by implementing a custom evaluation - * method. An evaluation method must be declared publicly, not static, and named eval. - * The return type must be an Iterator. Multiple overloaded methods named eval are not - * supported yet. - * - *

By default, input and output data types are automatically extracted using reflection. - * - *

The following examples show how to specify a table function: - * - *

{@code
- * // a function that accepts an INT arguments and emits the range from 0 to the
- * // given number.
- * class Series implements TableFunction {
- *     public Iterator eval(int n) {
- *         return java.util.stream.IntStream.range(0, n).iterator();
- *     }
- * }
- *
- * // a function that accepts an String arguments and emits the words of the
- * // given string.
- * class Split implements TableFunction {
- *     public static class Row {
- *         public String word;
- *         public int length;
- *     }
- *
- *     public Iterator eval(String str) {
- *         return Stream.of(str.split(" ")).map(s -> {
- *             Row row = new Row();
- *             row.word = s;
- *             row.length = s.length();
- *             return row;
- *         }).iterator();
- *     }
- * }
- * }
- */ -public interface TableFunction extends UserDefinedFunction {} diff --git a/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java deleted file mode 100644 index a0e0608e60210..0000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.lang.invoke.MethodHandle; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.function.Function; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; - -/** Batch-processing wrapper over a user-defined table function. */ -class TableFunctionBatch extends UserDefinedFunctionBatch { - TableFunction function; - MethodHandle methodHandle; - Function[] processInputs; - int chunkSize = 1024; - - TableFunctionBatch(TableFunction function) { - this.function = function; - var method = Reflection.getEvalMethod(function); - this.methodHandle = Reflection.getMethodHandle(method); - this.inputSchema = TypeUtils.methodToInputSchema(method); - this.outputSchema = TypeUtils.tableFunctionToOutputSchema(method); - this.processInputs = TypeUtils.methodToProcessInputs(method); - } - - @Override - Iterator evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) { - var outputs = new ArrayList(); - var row = new Object[batch.getSchema().getFields().size() + 1]; - row[0] = this.function; - var indexes = new ArrayList(); - var values = new ArrayList(); - Runnable buildChunk = - () -> { - var fields = this.outputSchema.getFields(); - var indexVector = - TypeUtils.createVector(fields.get(0), allocator, indexes.toArray()); - var valueVector = - TypeUtils.createVector(fields.get(1), allocator, values.toArray()); - indexes.clear(); - values.clear(); - var outputBatch = VectorSchemaRoot.of(indexVector, valueVector); - outputs.add(outputBatch); - }; - for (int i = 0; i < batch.getRowCount(); i++) { - // prepare input row - for (int j = 0; j < row.length - 1; j++) { - var val = batch.getVector(j).getObject(i); - row[j + 1] = this.processInputs[j].apply(val); - } - // call function - Iterator iterator; - try { - iterator = (Iterator) this.methodHandle.invokeWithArguments(row); - } catch (Throwable e) { - throw new RuntimeException(e); - } - // push values - while (iterator.hasNext()) { - indexes.add(i); - values.add(iterator.next()); - // check if we need to flush - if (indexes.size() >= this.chunkSize) { - buildChunk.run(); - } - } - } - if (indexes.size() > 0) { - buildChunk.run(); - } - return outputs.iterator(); - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java deleted file mode 100644 index 06c2f79858c40..0000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java +++ /dev/null @@ -1,505 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.lang.invoke.MethodHandles; -import java.lang.reflect.Array; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.lang.reflect.ParameterizedType; -import java.math.BigDecimal; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.LocalTime; -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.*; -import org.apache.arrow.vector.complex.ListVector; -import org.apache.arrow.vector.complex.StructVector; -import org.apache.arrow.vector.types.*; -import org.apache.arrow.vector.types.pojo.*; - -class TypeUtils { - /** Convert a string to an Arrow type. */ - static Field stringToField(String typeStr, String name) { - typeStr = typeStr.toUpperCase(); - if (typeStr.equals("BOOLEAN") || typeStr.equals("BOOL")) { - return Field.nullable(name, new ArrowType.Bool()); - } else if (typeStr.equals("SMALLINT") || typeStr.equals("INT2")) { - return Field.nullable(name, new ArrowType.Int(16, true)); - } else if (typeStr.equals("INT") || typeStr.equals("INTEGER") || typeStr.equals("INT4")) { - return Field.nullable(name, new ArrowType.Int(32, true)); - } else if (typeStr.equals("BIGINT") || typeStr.equals("INT8")) { - return Field.nullable(name, new ArrowType.Int(64, true)); - } else if (typeStr.equals("FLOAT4") || typeStr.equals("REAL")) { - return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)); - } else if (typeStr.equals("FLOAT8") || typeStr.equals("DOUBLE PRECISION")) { - return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); - } else if (typeStr.equals("DECIMAL") || typeStr.equals("NUMERIC")) { - return Field.nullable(name, new ArrowType.LargeBinary()); - } else if (typeStr.equals("DATE")) { - return Field.nullable(name, new ArrowType.Date(DateUnit.DAY)); - } else if (typeStr.equals("TIME") || typeStr.equals("TIME WITHOUT TIME ZONE")) { - return Field.nullable(name, new ArrowType.Time(TimeUnit.MICROSECOND, 64)); - } else if (typeStr.equals("TIMESTAMP") || typeStr.equals("TIMESTAMP WITHOUT TIME ZONE")) { - return Field.nullable(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)); - } else if (typeStr.startsWith("INTERVAL")) { - return Field.nullable(name, new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO)); - } else if (typeStr.equals("VARCHAR")) { - return Field.nullable(name, new ArrowType.Utf8()); - } else if (typeStr.equals("JSONB")) { - return Field.nullable(name, new ArrowType.LargeUtf8()); - } else if (typeStr.equals("BYTEA")) { - return Field.nullable(name, new ArrowType.Binary()); - } else if (typeStr.endsWith("[]")) { - Field innerField = stringToField(typeStr.substring(0, typeStr.length() - 2), ""); - return new Field( - name, FieldType.nullable(new ArrowType.List()), Arrays.asList(innerField)); - } else if (typeStr.startsWith("STRUCT")) { - // extract "STRUCT" - var typeList = typeStr.substring(7, typeStr.length() - 1); - var fields = - Arrays.stream(typeList.split(",")) - .map(s -> stringToField(s.trim(), "")) - .collect(Collectors.toList()); - return new Field(name, FieldType.nullable(new ArrowType.Struct()), fields); - } else { - throw new IllegalArgumentException("Unsupported type: " + typeStr); - } - } - - /** - * Convert a Java class to an Arrow type. - * - * @param param The Java class. - * @param hint An optional DataTypeHint annotation. - * @param name The name of the field. - * @return The Arrow type. - */ - static Field classToField(Class param, DataTypeHint hint, String name) { - if (hint != null) { - return stringToField(hint.value(), name); - } else if (param == Boolean.class || param == boolean.class) { - return Field.nullable(name, new ArrowType.Bool()); - } else if (param == Short.class || param == short.class) { - return Field.nullable(name, new ArrowType.Int(16, true)); - } else if (param == Integer.class || param == int.class) { - return Field.nullable(name, new ArrowType.Int(32, true)); - } else if (param == Long.class || param == long.class) { - return Field.nullable(name, new ArrowType.Int(64, true)); - } else if (param == Float.class || param == float.class) { - return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)); - } else if (param == Double.class || param == double.class) { - return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); - } else if (param == BigDecimal.class) { - return Field.nullable(name, new ArrowType.LargeBinary()); - } else if (param == LocalDate.class) { - return Field.nullable(name, new ArrowType.Date(DateUnit.DAY)); - } else if (param == LocalTime.class) { - return Field.nullable(name, new ArrowType.Time(TimeUnit.MICROSECOND, 64)); - } else if (param == LocalDateTime.class) { - return Field.nullable(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)); - } else if (param == PeriodDuration.class) { - return Field.nullable(name, new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO)); - } else if (param == String.class) { - return Field.nullable(name, new ArrowType.Utf8()); - } else if (param == byte[].class) { - return Field.nullable(name, new ArrowType.Binary()); - } else if (param.isArray()) { - var innerField = classToField(param.getComponentType(), null, ""); - return new Field( - name, FieldType.nullable(new ArrowType.List()), Arrays.asList(innerField)); - } else { - // struct type - var fields = new ArrayList(); - for (var field : param.getDeclaredFields()) { - var subhint = field.getAnnotation(DataTypeHint.class); - fields.add(classToField(field.getType(), subhint, field.getName())); - } - return new Field(name, FieldType.nullable(new ArrowType.Struct()), fields); - // TODO: more types - // throw new IllegalArgumentException("Unsupported type: " + param); - } - } - - /** Get the input schema from a Java method. */ - static Schema methodToInputSchema(Method method) { - var fields = new ArrayList(); - for (var param : method.getParameters()) { - var hint = param.getAnnotation(DataTypeHint.class); - fields.add(classToField(param.getType(), hint, param.getName())); - } - return new Schema(fields); - } - - /** Get the output schema of a scalar function from a Java method. */ - static Schema methodToOutputSchema(Method method) { - var type = method.getReturnType(); - var hint = method.getAnnotation(DataTypeHint.class); - return new Schema(Arrays.asList(classToField(type, hint, ""))); - } - - /** Get the output schema of a table function from a Java class. */ - static Schema tableFunctionToOutputSchema(Method method) { - var hint = method.getAnnotation(DataTypeHint.class); - var type = method.getReturnType(); - if (!Iterator.class.isAssignableFrom(type)) { - throw new IllegalArgumentException("Table function must return Iterator"); - } - var typeArguments = - ((ParameterizedType) method.getGenericReturnType()).getActualTypeArguments(); - type = (Class) typeArguments[0]; - var rowIndex = Field.nullable("row_index", new ArrowType.Int(32, true)); - return new Schema(Arrays.asList(rowIndex, classToField(type, hint, ""))); - } - - /** Return functions to process input values from a Java method. */ - static Function[] methodToProcessInputs(Method method) { - var schema = methodToInputSchema(method); - var params = method.getParameters(); - @SuppressWarnings("unchecked") - Function[] funcs = new Function[schema.getFields().size()]; - for (int i = 0; i < schema.getFields().size(); i++) { - funcs[i] = processFunc(schema.getFields().get(i), params[i].getType()); - } - return funcs; - } - - /** Create an Arrow vector from an array of values. */ - static FieldVector createVector(Field field, BufferAllocator allocator, Object[] values) { - var vector = field.createVector(allocator); - fillVector(vector, values); - return vector; - } - - /** Fill an Arrow vector with an array of values. */ - static void fillVector(FieldVector fieldVector, Object[] values) { - if (fieldVector instanceof BitVector) { - var vector = (BitVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (boolean) values[i] ? 1 : 0); - } - } - } else if (fieldVector instanceof SmallIntVector) { - var vector = (SmallIntVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (short) values[i]); - } - } - } else if (fieldVector instanceof IntVector) { - var vector = (IntVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (int) values[i]); - } - } - } else if (fieldVector instanceof BigIntVector) { - var vector = (BigIntVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (long) values[i]); - } - } - } else if (fieldVector instanceof Float4Vector) { - var vector = (Float4Vector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (float) values[i]); - } - } - } else if (fieldVector instanceof Float8Vector) { - var vector = (Float8Vector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (double) values[i]); - } - } - } else if (fieldVector instanceof LargeVarBinaryVector) { - var vector = (LargeVarBinaryVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - // use `toPlainString` to avoid scientific notation - vector.set(i, ((BigDecimal) values[i]).toPlainString().getBytes()); - } - } - } else if (fieldVector instanceof DateDayVector) { - var vector = (DateDayVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (int) ((LocalDate) values[i]).toEpochDay()); - } - } - } else if (fieldVector instanceof TimeMicroVector) { - var vector = (TimeMicroVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, ((LocalTime) values[i]).toNanoOfDay() / 1000); - } - } - } else if (fieldVector instanceof TimeStampMicroVector) { - var vector = (TimeStampMicroVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, timestampToMicros((LocalDateTime) values[i])); - } - } - } else if (fieldVector instanceof IntervalMonthDayNanoVector) { - var vector = (IntervalMonthDayNanoVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - var pd = (PeriodDuration) values[i]; - var months = (int) pd.getPeriod().toTotalMonths(); - var days = pd.getPeriod().getDays(); - var nanos = pd.getDuration().toNanos(); - vector.set(i, months, days, nanos); - } - } - } else if (fieldVector instanceof VarCharVector) { - var vector = (VarCharVector) fieldVector; - int totalBytes = 0; - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - totalBytes += ((String) values[i]).length(); - } - } - vector.allocateNew(totalBytes, values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, ((String) values[i]).getBytes()); - } - } - } else if (fieldVector instanceof LargeVarCharVector) { - var vector = (LargeVarCharVector) fieldVector; - int totalBytes = 0; - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - totalBytes += ((String) values[i]).length(); - } - } - vector.allocateNew(totalBytes, values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, ((String) values[i]).getBytes()); - } - } - } else if (fieldVector instanceof VarBinaryVector) { - var vector = (VarBinaryVector) fieldVector; - int totalBytes = 0; - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - totalBytes += ((byte[]) values[i]).length; - } - } - vector.allocateNew(totalBytes, values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (byte[]) values[i]); - } - } - } else if (fieldVector instanceof ListVector) { - var vector = (ListVector) fieldVector; - vector.allocateNew(); - // flatten the `values` - var flattenLength = 0; - for (int i = 0; i < values.length; i++) { - if (values[i] == null) { - continue; - } - var len = Array.getLength(values[i]); - vector.startNewValue(i); - vector.endValue(i, len); - flattenLength += len; - } - var flattenValues = new Object[flattenLength]; - var ii = 0; - for (var list : values) { - if (list == null) { - continue; - } - var length = Array.getLength(list); - for (int i = 0; i < length; i++) { - flattenValues[ii++] = Array.get(list, i); - } - } - // fill the inner vector - fillVector(vector.getDataVector(), flattenValues); - } else if (fieldVector instanceof StructVector) { - var vector = (StructVector) fieldVector; - vector.allocateNew(); - var lookup = MethodHandles.lookup(); - // get class of the first non-null value - Class valueClass = null; - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - valueClass = values[i].getClass(); - break; - } - } - for (var field : vector.getField().getChildren()) { - // extract field from values - var subvalues = new Object[values.length]; - if (valueClass != null) { - try { - var javaField = valueClass.getDeclaredField(field.getName()); - var varHandle = lookup.unreflectVarHandle(javaField); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - subvalues[i] = varHandle.get(values[i]); - } - } - } catch (NoSuchFieldException | IllegalAccessException e) { - throw new RuntimeException(e); - } - } - var subvector = vector.getChild(field.getName()); - fillVector(subvector, subvalues); - } - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.setIndexDefined(i); - } - } - } else { - throw new IllegalArgumentException("Unsupported type: " + fieldVector.getClass()); - } - fieldVector.setValueCount(values.length); - } - - static long timestampToMicros(LocalDateTime timestamp) { - var date = timestamp.toLocalDate().toEpochDay(); - var time = timestamp.toLocalTime().toNanoOfDay(); - return date * 24 * 3600 * 1000 * 1000 + time / 1000; - } - - /** Return a function that converts the object get from input array to the correct type. */ - static Function processFunc(Field field, Class targetClass) { - var inner = processFunc0(field, targetClass); - return obj -> obj == null ? null : inner.apply(obj); - } - - static Function processFunc0(Field field, Class targetClass) { - if (field.getType() instanceof ArrowType.Utf8 && targetClass == String.class) { - // object is org.apache.arrow.vector.util.Text - return obj -> obj.toString(); - } else if (field.getType() instanceof ArrowType.LargeUtf8 && targetClass == String.class) { - // object is org.apache.arrow.vector.util.Text - return obj -> obj.toString(); - } else if (field.getType() instanceof ArrowType.LargeBinary - && targetClass == BigDecimal.class) { - // object is byte[] - return obj -> new BigDecimal(new String((byte[]) obj)); - } else if (field.getType() instanceof ArrowType.Date && targetClass == LocalDate.class) { - // object is Integer - return obj -> LocalDate.ofEpochDay((int) obj); - } else if (field.getType() instanceof ArrowType.Time && targetClass == LocalTime.class) { - // object is Long - return obj -> LocalTime.ofNanoOfDay((long) obj * 1000); - } else if (field.getType() instanceof ArrowType.Interval - && targetClass == PeriodDuration.class) { - // object is arrow PeriodDuration - return obj -> new PeriodDuration((org.apache.arrow.vector.PeriodDuration) obj); - } else if (field.getType() instanceof ArrowType.List) { - // object is List - var subfield = field.getChildren().get(0); - var subfunc = processFunc(subfield, targetClass.getComponentType()); - if (subfield.getType() instanceof ArrowType.Bool) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Boolean[]::new); - } else if (subfield.getType().equals(new ArrowType.Int(16, true))) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Short[]::new); - } else if (subfield.getType().equals(new ArrowType.Int(32, true))) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Integer[]::new); - } else if (subfield.getType().equals(new ArrowType.Int(64, true))) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Long[]::new); - } else if (subfield.getType() - .equals(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Float[]::new); - } else if (subfield.getType() - .equals(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Double[]::new); - } else if (subfield.getType() instanceof ArrowType.LargeBinary) { - return obj -> ((List) obj).stream().map(subfunc).toArray(BigDecimal[]::new); - } else if (subfield.getType() instanceof ArrowType.Date) { - return obj -> ((List) obj).stream().map(subfunc).toArray(LocalDate[]::new); - } else if (subfield.getType() instanceof ArrowType.Time) { - return obj -> ((List) obj).stream().map(subfunc).toArray(LocalTime[]::new); - } else if (subfield.getType() instanceof ArrowType.Timestamp) { - return obj -> ((List) obj).stream().map(subfunc).toArray(LocalDateTime[]::new); - } else if (subfield.getType() instanceof ArrowType.Interval) { - return obj -> ((List) obj).stream().map(subfunc).toArray(PeriodDuration[]::new); - } else if (subfield.getType() instanceof ArrowType.Utf8) { - return obj -> ((List) obj).stream().map(subfunc).toArray(String[]::new); - } else if (subfield.getType() instanceof ArrowType.LargeUtf8) { - return obj -> ((List) obj).stream().map(subfunc).toArray(String[]::new); - } else if (subfield.getType() instanceof ArrowType.Binary) { - return obj -> ((List) obj).stream().map(subfunc).toArray(byte[][]::new); - } else if (subfield.getType() instanceof ArrowType.Struct) { - return obj -> { - var list = (List) obj; - Object array = Array.newInstance(targetClass.getComponentType(), list.size()); - for (int i = 0; i < list.size(); i++) { - Array.set(array, i, subfunc.apply(list.get(i))); - } - return array; - }; - } - throw new IllegalArgumentException("Unsupported type: " + subfield.getType()); - } else if (field.getType() instanceof ArrowType.Struct) { - // object is org.apache.arrow.vector.util.JsonStringHashMap - var subfields = field.getChildren(); - @SuppressWarnings("unchecked") - Function[] subfunc = new Function[subfields.size()]; - for (int i = 0; i < subfields.size(); i++) { - subfunc[i] = processFunc(subfields.get(i), targetClass.getFields()[i].getType()); - } - return obj -> { - var map = (AbstractMap) obj; - try { - var row = targetClass.getDeclaredConstructor().newInstance(); - for (int i = 0; i < subfields.size(); i++) { - var field0 = targetClass.getFields()[i]; - var val = subfunc[i].apply(map.get(field0.getName())); - field0.set(row, val); - } - return row; - } catch (InstantiationException - | IllegalAccessException - | InvocationTargetException - | NoSuchMethodException e) { - throw new RuntimeException(e); - } - }; - } - return Function.identity(); - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java b/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java deleted file mode 100644 index 692d898acaf8a..0000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import org.apache.arrow.flight.*; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorLoader; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.Schema; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -class UdfProducer extends NoOpFlightProducer { - - private BufferAllocator allocator; - private HashMap functions = new HashMap<>(); - private static final Logger logger = LoggerFactory.getLogger(UdfServer.class); - - UdfProducer(BufferAllocator allocator) { - this.allocator = allocator; - } - - void addFunction(String name, UserDefinedFunction function) throws IllegalArgumentException { - UserDefinedFunctionBatch udf; - if (function instanceof ScalarFunction) { - udf = new ScalarFunctionBatch((ScalarFunction) function); - } else if (function instanceof TableFunction) { - udf = new TableFunctionBatch((TableFunction) function); - } else { - throw new IllegalArgumentException( - "Unknown function type: " + function.getClass().getName()); - } - if (functions.containsKey(name)) { - throw new IllegalArgumentException("Function already exists: " + name); - } - functions.put(name, udf); - } - - @Override - public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { - try { - var functionName = descriptor.getPath().get(0); - var udf = functions.get(functionName); - if (udf == null) { - throw new IllegalArgumentException("Unknown function: " + functionName); - } - var fields = new ArrayList(); - fields.addAll(udf.getInputSchema().getFields()); - fields.addAll(udf.getOutputSchema().getFields()); - var fullSchema = new Schema(fields); - var inputLen = udf.getInputSchema().getFields().size(); - - return new FlightInfo(fullSchema, descriptor, Collections.emptyList(), 0, inputLen); - } catch (Exception e) { - logger.error("Error occurred during getFlightInfo", e); - throw e; - } - } - - @Override - public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { - try (var allocator = this.allocator.newChildAllocator("exchange", 0, Long.MAX_VALUE)) { - var functionName = reader.getDescriptor().getPath().get(0); - logger.debug("call function: " + functionName); - - var udf = this.functions.get(functionName); - try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), allocator)) { - var loader = new VectorLoader(root); - writer.start(root); - while (reader.next()) { - try (var input = reader.getRoot()) { - var outputBatches = udf.evalBatch(input, allocator); - while (outputBatches.hasNext()) { - try (var outputRoot = outputBatches.next()) { - var unloader = new VectorUnloader(outputRoot); - try (var outputBatch = unloader.getRecordBatch()) { - loader.load(outputBatch); - } - } - writer.putNext(); - } - } - } - writer.completed(); - } - } catch (Exception e) { - logger.error("Error occurred during UDF execution", e); - writer.error(e); - } - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/UdfServer.java b/java/udf/src/main/java/com/risingwave/functions/UdfServer.java deleted file mode 100644 index 66f2a8d3bb0dd..0000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/UdfServer.java +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.io.IOException; -import org.apache.arrow.flight.*; -import org.apache.arrow.memory.RootAllocator; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** A server that exposes user-defined functions over Apache Arrow Flight. */ -public class UdfServer implements AutoCloseable { - - private FlightServer server; - private UdfProducer producer; - private static final Logger logger = LoggerFactory.getLogger(UdfServer.class); - - public UdfServer(String host, int port) { - var location = Location.forGrpcInsecure(host, port); - var allocator = new RootAllocator(); - this.producer = new UdfProducer(allocator); - this.server = FlightServer.builder(allocator, location, this.producer).build(); - } - - /** - * Add a user-defined function to the server. - * - * @param name the name of the function - * @param udf the function to add - * @throws IllegalArgumentException if a function with the same name already exists - */ - public void addFunction(String name, UserDefinedFunction udf) throws IllegalArgumentException { - logger.info("added function: " + name); - this.producer.addFunction(name, udf); - } - - /** - * Start the server. - * - * @throws IOException if the server fails to start - */ - public void start() throws IOException { - this.server.start(); - logger.info("listening on " + this.server.getLocation().toSocketAddress()); - } - - /** - * Get the port the server is listening on. - * - * @return the port number - */ - public int getPort() { - return this.server.getPort(); - } - - /** - * Wait for the server to terminate. - * - * @throws InterruptedException if the thread is interrupted while waiting - */ - public void awaitTermination() throws InterruptedException { - this.server.awaitTermination(); - } - - /** Close the server. */ - public void close() throws InterruptedException { - this.server.close(); - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunction.java b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunction.java deleted file mode 100644 index 3db6f1714cd83..0000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunction.java +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -/** - * Base interface for all user-defined functions. - * - * @see ScalarFunction - * @see TableFunction - */ -public interface UserDefinedFunction {} diff --git a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java deleted file mode 100644 index e2c513a7954ad..0000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.util.ArrayList; -import java.util.Iterator; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.types.pojo.Schema; - -/** Base class for a batch-processing user-defined function. */ -abstract class UserDefinedFunctionBatch { - protected Schema inputSchema; - protected Schema outputSchema; - - /** Get the input schema of the function. */ - Schema getInputSchema() { - return inputSchema; - } - - /** Get the output schema of the function. */ - Schema getOutputSchema() { - return outputSchema; - } - - /** - * Evaluate the function by processing a batch of input data. - * - * @param batch the input data batch to process - * @param allocator the allocator to use for allocating output data - * @return an iterator over the output data batches - */ - abstract Iterator evalBatch( - VectorSchemaRoot batch, BufferAllocator allocator); -} - -/** Utility class for reflection. */ -class Reflection { - /** Get the method named eval. */ - static Method getEvalMethod(UserDefinedFunction obj) { - var methods = new ArrayList(); - for (Method method : obj.getClass().getDeclaredMethods()) { - if (method.getName().equals("eval")) { - methods.add(method); - } - } - if (methods.size() != 1) { - throw new IllegalArgumentException( - "Exactly one eval method must be defined for class " - + obj.getClass().getName()); - } - var method = methods.get(0); - if (Modifier.isStatic(method.getModifiers())) { - throw new IllegalArgumentException( - "The eval method should not be static for class " + obj.getClass().getName()); - } - return method; - } - - /** Get the method handle of the given method. */ - static MethodHandle getMethodHandle(Method method) { - var lookup = MethodHandles.lookup(); - try { - return lookup.unreflect(method); - } catch (IllegalAccessException e) { - throw new IllegalArgumentException( - "The eval method must be public for class " - + method.getDeclaringClass().getName()); - } - } -} diff --git a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java deleted file mode 100644 index 5722efa1dd702..0000000000000 --- a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.io.IOException; -import java.math.BigDecimal; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.LocalTime; -import java.util.Iterator; -import java.util.stream.IntStream; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.*; -import org.apache.arrow.vector.complex.StructVector; -import org.apache.arrow.vector.types.Types.MinorType; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -/** Unit test for UDF server. */ -public class TestUdfServer { - private static UdfClient client; - private static UdfServer server; - private static BufferAllocator allocator = new RootAllocator(); - - @BeforeAll - public static void setup() throws IOException { - server = new UdfServer("localhost", 0); - server.addFunction("gcd", new Gcd()); - server.addFunction("return_all", new ReturnAll()); - server.addFunction("series", new Series()); - server.start(); - - client = new UdfClient("localhost", server.getPort()); - } - - @AfterAll - public static void teardown() throws InterruptedException { - client.close(); - server.close(); - } - - public static class Gcd implements ScalarFunction { - public int eval(int a, int b) { - while (b != 0) { - int temp = b; - b = a % b; - a = temp; - } - return a; - } - } - - @Test - public void gcd() throws Exception { - var c0 = new IntVector("", allocator); - c0.allocateNew(1); - c0.set(0, 15); - c0.setValueCount(1); - - var c1 = new IntVector("", allocator); - c1.allocateNew(1); - c1.set(0, 12); - c1.setValueCount(1); - - var input = VectorSchemaRoot.of(c0, c1); - - try (var stream = client.call("gcd", input)) { - var output = stream.getRoot(); - assertTrue(stream.next()); - assertEquals("3", output.contentToTSVString().trim()); - } - } - - public static class ReturnAll implements ScalarFunction { - public static class Row { - public Boolean bool; - public Short i16; - public Integer i32; - public Long i64; - public Float f32; - public Double f64; - public BigDecimal decimal; - public LocalDate date; - public LocalTime time; - public LocalDateTime timestamp; - public PeriodDuration interval; - public String str; - public byte[] bytes; - public @DataTypeHint("JSONB") String jsonb; - public Struct struct; - } - - public static class Struct { - public Integer f1; - public Integer f2; - - public String toString() { - return String.format("(%d, %d)", f1, f2); - } - } - - public Row eval( - Boolean bool, - Short i16, - Integer i32, - Long i64, - Float f32, - Double f64, - BigDecimal decimal, - LocalDate date, - LocalTime time, - LocalDateTime timestamp, - PeriodDuration interval, - String str, - byte[] bytes, - @DataTypeHint("JSONB") String jsonb, - Struct struct) { - var row = new Row(); - row.bool = bool; - row.i16 = i16; - row.i32 = i32; - row.i64 = i64; - row.f32 = f32; - row.f64 = f64; - row.decimal = decimal; - row.date = date; - row.time = time; - row.timestamp = timestamp; - row.interval = interval; - row.str = str; - row.bytes = bytes; - row.jsonb = jsonb; - row.struct = struct; - return row; - } - } - - @Test - public void all_types() throws Exception { - var c0 = new BitVector("", allocator); - c0.allocateNew(2); - c0.set(0, 1); - c0.setValueCount(2); - - var c1 = new SmallIntVector("", allocator); - c1.allocateNew(2); - c1.set(0, 1); - c1.setValueCount(2); - - var c2 = new IntVector("", allocator); - c2.allocateNew(2); - c2.set(0, 1); - c2.setValueCount(2); - - var c3 = new BigIntVector("", allocator); - c3.allocateNew(2); - c3.set(0, 1); - c3.setValueCount(2); - - var c4 = new Float4Vector("", allocator); - c4.allocateNew(2); - c4.set(0, 1); - c4.setValueCount(2); - - var c5 = new Float8Vector("", allocator); - c5.allocateNew(2); - c5.set(0, 1); - c5.setValueCount(2); - - var c6 = new LargeVarBinaryVector("", allocator); - c6.allocateNew(2); - c6.set(0, "1.234".getBytes()); - c6.setValueCount(2); - - var c7 = new DateDayVector("", allocator); - c7.allocateNew(2); - c7.set(0, (int) LocalDate.of(2023, 1, 1).toEpochDay()); - c7.setValueCount(2); - - var c8 = new TimeMicroVector("", allocator); - c8.allocateNew(2); - c8.set(0, LocalTime.of(1, 2, 3).toNanoOfDay() / 1000); - c8.setValueCount(2); - - var c9 = new TimeStampMicroVector("", allocator); - c9.allocateNew(2); - var ts = LocalDateTime.of(2023, 1, 1, 1, 2, 3); - c9.set( - 0, - ts.toLocalDate().toEpochDay() * 24 * 3600 * 1000000 - + ts.toLocalTime().toNanoOfDay() / 1000); - c9.setValueCount(2); - - var c10 = - new IntervalMonthDayNanoVector( - "", - FieldType.nullable(MinorType.INTERVALMONTHDAYNANO.getType()), - allocator); - c10.allocateNew(2); - c10.set(0, 1000, 2000, 3000); - c10.setValueCount(2); - - var c11 = new VarCharVector("", allocator); - c11.allocateNew(2); - c11.set(0, "string".getBytes()); - c11.setValueCount(2); - - var c12 = new VarBinaryVector("", allocator); - c12.allocateNew(2); - c12.set(0, "bytes".getBytes()); - c12.setValueCount(2); - - var c13 = new LargeVarCharVector("", allocator); - c13.allocateNew(2); - c13.set(0, "{ key: 1 }".getBytes()); - c13.setValueCount(2); - - var c14 = - new StructVector( - "", allocator, FieldType.nullable(ArrowType.Struct.INSTANCE), null); - c14.allocateNew(); - var f1 = c14.addOrGet("f1", FieldType.nullable(MinorType.INT.getType()), IntVector.class); - var f2 = c14.addOrGet("f2", FieldType.nullable(MinorType.INT.getType()), IntVector.class); - f1.allocateNew(2); - f2.allocateNew(2); - f1.set(0, 1); - f2.set(0, 2); - c14.setIndexDefined(0); - c14.setValueCount(2); - - var input = - VectorSchemaRoot.of( - c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14); - - try (var stream = client.call("return_all", input)) { - var output = stream.getRoot(); - assertTrue(stream.next()); - assertEquals( - "{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":\"MS4yMzQ=\",\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\",\"struct\":{\"f1\":1,\"f2\":2}}\n{}", - output.contentToTSVString().trim()); - } - } - - public static class Series implements TableFunction { - public Iterator eval(int n) { - return IntStream.range(0, n).iterator(); - } - } - - @Test - public void series() throws Exception { - var c0 = new IntVector("", allocator); - c0.allocateNew(3); - c0.set(0, 0); - c0.set(1, 1); - c0.set(2, 2); - c0.setValueCount(3); - - var input = VectorSchemaRoot.of(c0); - - try (var stream = client.call("series", input)) { - var output = stream.getRoot(); - assertTrue(stream.next()); - assertEquals("row_index\t\n1\t0\n2\t0\n2\t1\n", output.contentToTSVString()); - } - } -} diff --git a/java/udf/src/test/java/com/risingwave/functions/UdfClient.java b/java/udf/src/test/java/com/risingwave/functions/UdfClient.java deleted file mode 100644 index 12728bf64fbec..0000000000000 --- a/java/udf/src/test/java/com/risingwave/functions/UdfClient.java +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import org.apache.arrow.flight.*; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; - -public class UdfClient implements AutoCloseable { - - private FlightClient client; - - public UdfClient(String host, int port) { - var allocator = new RootAllocator(); - var location = Location.forGrpcInsecure(host, port); - this.client = FlightClient.builder(allocator, location).build(); - } - - public void close() throws InterruptedException { - this.client.close(); - } - - public FlightInfo getFlightInfo(String functionName) { - var descriptor = FlightDescriptor.command(functionName.getBytes()); - return client.getInfo(descriptor); - } - - public FlightStream call(String functionName, VectorSchemaRoot root) { - var descriptor = FlightDescriptor.path(functionName); - var readerWriter = client.doExchange(descriptor); - var writer = readerWriter.getWriter(); - var reader = readerWriter.getReader(); - - writer.start(root); - writer.putNext(); - writer.completed(); - return reader; - } -} From 19b036cd93c27fab9a3f82f0c7183adfc2ac7a91 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 29 Apr 2024 15:29:10 +0800 Subject: [PATCH 09/27] move java udf example to e2e_test Signed-off-by: Runji Wang --- {java/udf-example => e2e_test/udf/java}/README.md | 0 {java/udf-example => e2e_test/udf/java}/pom.xml | 12 ++---------- .../java}/src/main/java/com/example/UdfExample.java | 0 3 files changed, 2 insertions(+), 10 deletions(-) rename {java/udf-example => e2e_test/udf/java}/README.md (100%) rename {java/udf-example => e2e_test/udf/java}/pom.xml (86%) rename {java/udf-example => e2e_test/udf/java}/src/main/java/com/example/UdfExample.java (100%) diff --git a/java/udf-example/README.md b/e2e_test/udf/java/README.md similarity index 100% rename from java/udf-example/README.md rename to e2e_test/udf/java/README.md diff --git a/java/udf-example/pom.xml b/e2e_test/udf/java/pom.xml similarity index 86% rename from java/udf-example/pom.xml rename to e2e_test/udf/java/pom.xml index 8bf51cd108128..9c2351f8ce1f9 100644 --- a/java/udf-example/pom.xml +++ b/e2e_test/udf/java/pom.xml @@ -5,17 +5,9 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 - - - com.risingwave - risingwave-java-root - 0.1.0-SNAPSHOT - ../pom.xml - - com.risingwave risingwave-udf-example - 0.1.1-SNAPSHOT + 0.1.0-SNAPSHOT udf-example https://docs.risingwave.com/docs/current/udf-java @@ -31,7 +23,7 @@ com.risingwave risingwave-udf - 0.1.3-SNAPSHOT + 0.2.0-SNAPSHOT com.google.code.gson diff --git a/java/udf-example/src/main/java/com/example/UdfExample.java b/e2e_test/udf/java/src/main/java/com/example/UdfExample.java similarity index 100% rename from java/udf-example/src/main/java/com/example/UdfExample.java rename to e2e_test/udf/java/src/main/java/com/example/UdfExample.java From 2526b40adb918faf74520826092e0b532f69ab4a Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 29 Apr 2024 15:34:09 +0800 Subject: [PATCH 10/27] remove python udf sdk Signed-off-by: Runji Wang --- src/expr/udf/python/.gitignore | 2 - src/expr/udf/python/CHANGELOG.md | 37 -- src/expr/udf/python/README.md | 112 ---- src/expr/udf/python/publish.md | 19 - src/expr/udf/python/pyproject.toml | 20 - src/expr/udf/python/risingwave/__init__.py | 13 - src/expr/udf/python/risingwave/test_udf.py | 240 -------- src/expr/udf/python/risingwave/udf.py | 552 ------------------ .../udf/python/risingwave/udf/health_check.py | 40 -- 9 files changed, 1035 deletions(-) delete mode 100644 src/expr/udf/python/.gitignore delete mode 100644 src/expr/udf/python/CHANGELOG.md delete mode 100644 src/expr/udf/python/README.md delete mode 100644 src/expr/udf/python/publish.md delete mode 100644 src/expr/udf/python/pyproject.toml delete mode 100644 src/expr/udf/python/risingwave/__init__.py delete mode 100644 src/expr/udf/python/risingwave/test_udf.py delete mode 100644 src/expr/udf/python/risingwave/udf.py delete mode 100644 src/expr/udf/python/risingwave/udf/health_check.py diff --git a/src/expr/udf/python/.gitignore b/src/expr/udf/python/.gitignore deleted file mode 100644 index 75b18b1dc1919..0000000000000 --- a/src/expr/udf/python/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -/dist -/risingwave.egg-info diff --git a/src/expr/udf/python/CHANGELOG.md b/src/expr/udf/python/CHANGELOG.md deleted file mode 100644 index a20411e69d83e..0000000000000 --- a/src/expr/udf/python/CHANGELOG.md +++ /dev/null @@ -1,37 +0,0 @@ -# Changelog - -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [Unreleased] - -## [0.1.1] - 2023-12-06 - -### Fixed - -- Fix decimal type output. - -## [0.1.0] - 2023-12-01 - -### Fixed - -- Fix unconstrained decimal type. - -## [0.0.12] - 2023-11-28 - -### Changed - -- Change the default struct field name to `f{i}`. - -### Fixed - -- Fix parsing nested struct type. - - -## [0.0.11] - 2023-11-06 - -### Fixed - -- Hook SIGTERM to stop the UDF server gracefully. diff --git a/src/expr/udf/python/README.md b/src/expr/udf/python/README.md deleted file mode 100644 index d1655be05350b..0000000000000 --- a/src/expr/udf/python/README.md +++ /dev/null @@ -1,112 +0,0 @@ -# RisingWave Python UDF SDK - -This library provides a Python SDK for creating user-defined functions (UDF) in [RisingWave](https://www.risingwave.com/). - -For a detailed guide on how to use Python UDF in RisingWave, please refer to [this doc](https://docs.risingwave.com/docs/current/udf-python/). - -## Introduction - -RisingWave supports user-defined functions implemented as external functions. -With the RisingWave Python UDF SDK, users can define custom UDFs using Python and start a Python process as a UDF server. -RisingWave can then remotely access the UDF server to execute the defined functions. - -## Installation - -```sh -pip install risingwave -``` - -## Usage - -Define functions in a Python file: - -```python -# udf.py -from risingwave.udf import udf, udtf, UdfServer -import struct -import socket - -# Define a scalar function -@udf(input_types=['INT', 'INT'], result_type='INT') -def gcd(x, y): - while y != 0: - (x, y) = (y, x % y) - return x - -# Define a scalar function that returns multiple values (within a struct) -@udf(input_types=['BYTEA'], result_type='STRUCT') -def extract_tcp_info(tcp_packet: bytes): - src_addr, dst_addr = struct.unpack('!4s4s', tcp_packet[12:20]) - src_port, dst_port = struct.unpack('!HH', tcp_packet[20:24]) - src_addr = socket.inet_ntoa(src_addr) - dst_addr = socket.inet_ntoa(dst_addr) - return src_addr, dst_addr, src_port, dst_port - -# Define a table function -@udtf(input_types='INT', result_types='INT') -def series(n): - for i in range(n): - yield i - -# Start a UDF server -if __name__ == '__main__': - server = UdfServer(location="0.0.0.0:8815") - server.add_function(gcd) - server.add_function(series) - server.serve() -``` - -Start the UDF server: - -```sh -python3 udf.py -``` - -To create functions in RisingWave, use the following syntax: - -```sql -create function ( [, ...] ) - [ returns | returns table ( [, ...] ) ] - as using link ''; -``` - -- The `as` parameter specifies the function name defined in the UDF server. -- The `link` parameter specifies the address of the UDF server. - -For example: - -```sql -create function gcd(int, int) returns int -as gcd using link 'http://localhost:8815'; - -create function series(int) returns table (x int) -as series using link 'http://localhost:8815'; - -select gcd(25, 15); - -select * from series(10); -``` - -## Data Types - -The RisingWave Python UDF SDK supports the following data types: - -| SQL Type | Python Type | Notes | -| ---------------- | ----------------------------- | ------------------ | -| BOOLEAN | bool | | -| SMALLINT | int | | -| INT | int | | -| BIGINT | int | | -| REAL | float | | -| DOUBLE PRECISION | float | | -| DECIMAL | decimal.Decimal | | -| DATE | datetime.date | | -| TIME | datetime.time | | -| TIMESTAMP | datetime.datetime | | -| INTERVAL | MonthDayNano / (int, int, int) | Fields can be obtained by `months()`, `days()` and `nanoseconds()` from `MonthDayNano` | -| VARCHAR | str | | -| BYTEA | bytes | | -| JSONB | any | | -| T[] | list[T] | | -| STRUCT<> | tuple | | -| ...others | | Not supported yet. | diff --git a/src/expr/udf/python/publish.md b/src/expr/udf/python/publish.md deleted file mode 100644 index 0bc22d713906f..0000000000000 --- a/src/expr/udf/python/publish.md +++ /dev/null @@ -1,19 +0,0 @@ -# How to publish this library - -Install the build tool: - -```sh -pip3 install build -``` - -Build the library: - -```sh -python3 -m build -``` - -Upload the library to PyPI: - -```sh -twine upload dist/* -``` diff --git a/src/expr/udf/python/pyproject.toml b/src/expr/udf/python/pyproject.toml deleted file mode 100644 index b535355168363..0000000000000 --- a/src/expr/udf/python/pyproject.toml +++ /dev/null @@ -1,20 +0,0 @@ -[build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "risingwave" -version = "0.1.1" -authors = [{ name = "RisingWave Labs" }] -description = "RisingWave Python API" -readme = "README.md" -license = { text = "Apache Software License" } -classifiers = [ - "Programming Language :: Python", - "License :: OSI Approved :: Apache Software License", -] -requires-python = ">=3.8" -dependencies = ["pyarrow"] - -[project.optional-dependencies] -test = ["pytest"] diff --git a/src/expr/udf/python/risingwave/__init__.py b/src/expr/udf/python/risingwave/__init__.py deleted file mode 100644 index 3d60f2f96d025..0000000000000 --- a/src/expr/udf/python/risingwave/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 RisingWave Labs -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/expr/udf/python/risingwave/test_udf.py b/src/expr/udf/python/risingwave/test_udf.py deleted file mode 100644 index e3c2029d3d1f9..0000000000000 --- a/src/expr/udf/python/risingwave/test_udf.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright 2024 RisingWave Labs -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from multiprocessing import Process -import pytest -from risingwave.udf import udf, UdfServer, _to_data_type -import pyarrow as pa -import pyarrow.flight as flight -import time -import datetime -from typing import Any - - -def flight_server(): - server = UdfServer(location="localhost:8815") - server.add_function(add) - server.add_function(wait) - server.add_function(wait_concurrent) - server.add_function(return_all) - return server - - -def flight_client(): - client = flight.FlightClient(("localhost", 8815)) - return client - - -# Define a scalar function -@udf(input_types=["INT", "INT"], result_type="INT") -def add(x, y): - return x + y - - -@udf(input_types=["INT"], result_type="INT") -def wait(x): - time.sleep(0.01) - return 0 - - -@udf(input_types=["INT"], result_type="INT", io_threads=32) -def wait_concurrent(x): - time.sleep(0.01) - return 0 - - -@udf( - input_types=[ - "BOOLEAN", - "SMALLINT", - "INT", - "BIGINT", - "FLOAT4", - "FLOAT8", - "DECIMAL", - "DATE", - "TIME", - "TIMESTAMP", - "INTERVAL", - "VARCHAR", - "BYTEA", - "JSONB", - ], - result_type="""struct< - BOOLEAN, - SMALLINT, - INT, - BIGINT, - FLOAT4, - FLOAT8, - DECIMAL, - DATE, - TIME, - TIMESTAMP, - INTERVAL, - VARCHAR, - BYTEA, - JSONB - >""", -) -def return_all( - bool, - i16, - i32, - i64, - f32, - f64, - decimal, - date, - time, - timestamp, - interval, - varchar, - bytea, - jsonb, -): - return ( - bool, - i16, - i32, - i64, - f32, - f64, - decimal, - date, - time, - timestamp, - interval, - varchar, - bytea, - jsonb, - ) - - -def test_simple(): - LEN = 64 - data = pa.Table.from_arrays( - [pa.array(range(0, LEN)), pa.array(range(0, LEN))], names=["x", "y"] - ) - - batches = data.to_batches(max_chunksize=512) - - with flight_client() as client, flight_server() as server: - flight_info = flight.FlightDescriptor.for_path(b"add") - writer, reader = client.do_exchange(descriptor=flight_info) - with writer: - writer.begin(schema=data.schema) - for batch in batches: - writer.write_batch(batch) - writer.done_writing() - - chunk = reader.read_chunk() - assert len(chunk.data) == LEN - assert chunk.data.column("output").equals( - pa.array(range(0, LEN * 2, 2), type=pa.int32()) - ) - - -def test_io_concurrency(): - LEN = 64 - data = pa.Table.from_arrays([pa.array(range(0, LEN))], names=["x"]) - batches = data.to_batches(max_chunksize=512) - - with flight_client() as client, flight_server() as server: - # Single-threaded function takes a long time - flight_info = flight.FlightDescriptor.for_path(b"wait") - writer, reader = client.do_exchange(descriptor=flight_info) - with writer: - writer.begin(schema=data.schema) - for batch in batches: - writer.write_batch(batch) - writer.done_writing() - start_time = time.time() - - total_len = 0 - for chunk in reader: - total_len += len(chunk.data) - - assert total_len == LEN - - elapsed_time = time.time() - start_time # ~0.64s - assert elapsed_time > 0.5 - - # Multi-threaded I/O bound function will take a much shorter time - flight_info = flight.FlightDescriptor.for_path(b"wait_concurrent") - writer, reader = client.do_exchange(descriptor=flight_info) - with writer: - writer.begin(schema=data.schema) - for batch in batches: - writer.write_batch(batch) - writer.done_writing() - start_time = time.time() - - total_len = 0 - for chunk in reader: - total_len += len(chunk.data) - - assert total_len == LEN - - elapsed_time = time.time() - start_time - assert elapsed_time < 0.25 - - -def test_all_types(): - arrays = [ - pa.array([True], type=pa.bool_()), - pa.array([1], type=pa.int16()), - pa.array([1], type=pa.int32()), - pa.array([1], type=pa.int64()), - pa.array([1], type=pa.float32()), - pa.array([1], type=pa.float64()), - pa.array(["12345678901234567890.1234567890"], type=pa.large_binary()), - pa.array([datetime.date(2023, 6, 1)], type=pa.date32()), - pa.array([datetime.time(1, 2, 3, 456789)], type=pa.time64("us")), - pa.array( - [datetime.datetime(2023, 6, 1, 1, 2, 3, 456789)], - type=pa.timestamp("us"), - ), - pa.array([(1, 2, 3)], type=pa.month_day_nano_interval()), - pa.array(["string"], type=pa.string()), - pa.array(["bytes"], type=pa.binary()), - pa.array(['{ "key": 1 }'], type=pa.large_string()), - ] - batch = pa.RecordBatch.from_arrays(arrays, names=["" for _ in arrays]) - - with flight_client() as client, flight_server() as server: - flight_info = flight.FlightDescriptor.for_path(b"return_all") - writer, reader = client.do_exchange(descriptor=flight_info) - with writer: - writer.begin(schema=batch.schema) - writer.write_batch(batch) - writer.done_writing() - - chunk = reader.read_chunk() - assert [v.as_py() for _, v in chunk.data.column(0)[0].items()] == [ - True, - 1, - 1, - 1, - 1.0, - 1.0, - b"12345678901234567890.1234567890", - datetime.date(2023, 6, 1), - datetime.time(1, 2, 3, 456789), - datetime.datetime(2023, 6, 1, 1, 2, 3, 456789), - (1, 2, 3), - "string", - b"bytes", - '{"key": 1}', - ] diff --git a/src/expr/udf/python/risingwave/udf.py b/src/expr/udf/python/risingwave/udf.py deleted file mode 100644 index aad53e25e0c98..0000000000000 --- a/src/expr/udf/python/risingwave/udf.py +++ /dev/null @@ -1,552 +0,0 @@ -# Copyright 2024 RisingWave Labs -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import * -import pyarrow as pa -import pyarrow.flight -import pyarrow.parquet -import inspect -import traceback -import json -from concurrent.futures import ThreadPoolExecutor -import concurrent -from decimal import Decimal -import signal - - -class UserDefinedFunction: - """ - Base interface for user-defined function. - """ - - _name: str - _input_schema: pa.Schema - _result_schema: pa.Schema - _io_threads: Optional[int] - _executor: Optional[ThreadPoolExecutor] - - def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: - """ - Apply the function on a batch of inputs. - """ - return iter([]) - - -class ScalarFunction(UserDefinedFunction): - """ - Base interface for user-defined scalar function. A user-defined scalar functions maps zero, one, - or multiple scalar values to a new scalar value. - """ - - def __init__(self, *args, **kwargs): - self._io_threads = kwargs.pop("io_threads") - self._executor = ( - ThreadPoolExecutor(max_workers=self._io_threads) - if self._io_threads is not None - else None - ) - super().__init__(*args, **kwargs) - - def eval(self, *args) -> Any: - """ - Method which defines the logic of the scalar function. - """ - pass - - def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: - # parse value from json string for jsonb columns - inputs = [[v.as_py() for v in array] for array in batch] - inputs = [ - _process_func(pa.list_(type), False)(array) - for array, type in zip(inputs, self._input_schema.types) - ] - if self._executor is not None: - # evaluate the function for each row - tasks = [ - self._executor.submit(self._func, *[col[i] for col in inputs]) - for i in range(batch.num_rows) - ] - column = [ - future.result() for future in concurrent.futures.as_completed(tasks) - ] - else: - # evaluate the function for each row - column = [ - self.eval(*[col[i] for col in inputs]) for i in range(batch.num_rows) - ] - - column = _process_func(pa.list_(self._result_schema.types[0]), True)(column) - - array = pa.array(column, type=self._result_schema.types[0]) - yield pa.RecordBatch.from_arrays([array], schema=self._result_schema) - - -def _process_func(type: pa.DataType, output: bool) -> Callable: - """Return a function to process input or output value.""" - if pa.types.is_list(type): - func = _process_func(type.value_type, output) - return lambda array: [(func(v) if v is not None else None) for v in array] - - if pa.types.is_struct(type): - funcs = [_process_func(field.type, output) for field in type] - if output: - return lambda tup: tuple( - (func(v) if v is not None else None) for v, func in zip(tup, funcs) - ) - else: - # the input value of struct type is a dict - # we convert it into tuple here - return lambda map: tuple( - (func(v) if v is not None else None) - for v, func in zip(map.values(), funcs) - ) - - if type.equals(JSONB): - if output: - return lambda v: json.dumps(v) - else: - return lambda v: json.loads(v) - - if type.equals(UNCONSTRAINED_DECIMAL): - if output: - - def decimal_to_str(v): - if not isinstance(v, Decimal): - raise ValueError(f"Expected Decimal, got {v}") - # use `f` format to avoid scientific notation, e.g. `1e10` - return format(v, "f").encode("utf-8") - - return decimal_to_str - else: - return lambda v: Decimal(v.decode("utf-8")) - - return lambda v: v - - -class TableFunction(UserDefinedFunction): - """ - Base interface for user-defined table function. A user-defined table functions maps zero, one, - or multiple scalar values to a new table value. - """ - - BATCH_SIZE = 1024 - - def eval(self, *args) -> Iterator: - """ - Method which defines the logic of the table function. - """ - yield - - def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: - class RecordBatchBuilder: - """A utility class for constructing Arrow RecordBatch by row.""" - - schema: pa.Schema - columns: List[List] - - def __init__(self, schema: pa.Schema): - self.schema = schema - self.columns = [[] for _ in self.schema.types] - - def len(self) -> int: - """Returns the number of rows in the RecordBatch being built.""" - return len(self.columns[0]) - - def append(self, index: int, value: Any): - """Appends a new row to the RecordBatch being built.""" - self.columns[0].append(index) - self.columns[1].append(value) - - def build(self) -> pa.RecordBatch: - """Builds the RecordBatch from the accumulated data and clears the state.""" - # Convert the columns to arrow arrays - arrays = [ - pa.array(col, type) - for col, type in zip(self.columns, self.schema.types) - ] - # Reset columns - self.columns = [[] for _ in self.schema.types] - return pa.RecordBatch.from_arrays(arrays, schema=self.schema) - - builder = RecordBatchBuilder(self._result_schema) - - # Iterate through rows in the input RecordBatch - for row_index in range(batch.num_rows): - row = tuple(column[row_index].as_py() for column in batch) - for result in self.eval(*row): - builder.append(row_index, result) - if builder.len() == self.BATCH_SIZE: - yield builder.build() - if builder.len() != 0: - yield builder.build() - - -class UserDefinedScalarFunctionWrapper(ScalarFunction): - """ - Base Wrapper for Python user-defined scalar function. - """ - - _func: Callable - - def __init__(self, func, input_types, result_type, name=None, io_threads=None): - self._func = func - self._input_schema = pa.schema( - zip( - inspect.getfullargspec(func)[0], - [_to_data_type(t) for t in _to_list(input_types)], - ) - ) - self._result_schema = pa.schema([("output", _to_data_type(result_type))]) - self._name = name or ( - func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ - ) - super().__init__(io_threads=io_threads) - - def __call__(self, *args): - return self._func(*args) - - def eval(self, *args): - return self._func(*args) - - -class UserDefinedTableFunctionWrapper(TableFunction): - """ - Base Wrapper for Python user-defined table function. - """ - - _func: Callable - - def __init__(self, func, input_types, result_types, name=None): - self._func = func - self._name = name or ( - func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ - ) - self._input_schema = pa.schema( - zip( - inspect.getfullargspec(func)[0], - [_to_data_type(t) for t in _to_list(input_types)], - ) - ) - self._result_schema = pa.schema( - [ - ("row_index", pa.int32()), - ( - self._name, - pa.struct([("", _to_data_type(t)) for t in result_types]) - if isinstance(result_types, list) - else _to_data_type(result_types), - ), - ] - ) - - def __call__(self, *args): - return self._func(*args) - - def eval(self, *args): - return self._func(*args) - - -def _to_list(x): - if isinstance(x, list): - return x - else: - return [x] - - -def udf( - input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], - result_type: Union[str, pa.DataType], - name: Optional[str] = None, - io_threads: Optional[int] = None, -) -> Callable: - """ - Annotation for creating a user-defined scalar function. - - Parameters: - - input_types: A list of strings or Arrow data types that specifies the input data types. - - result_type: A string or an Arrow data type that specifies the return value type. - - name: An optional string specifying the function name. If not provided, the original name will be used. - - io_threads: Number of I/O threads used per data chunk for I/O bound functions. - - Example: - ``` - @udf(input_types=['INT', 'INT'], result_type='INT') - def gcd(x, y): - while y != 0: - (x, y) = (y, x % y) - return x - ``` - - I/O bound Example: - ``` - @udf(input_types=['INT'], result_type='INT', io_threads=64) - def external_api(x): - response = requests.get(my_endpoint + '?param=' + x) - return response["data"] - ``` - """ - - if io_threads is not None and io_threads > 1: - return lambda f: UserDefinedScalarFunctionWrapper( - f, input_types, result_type, name, io_threads=io_threads - ) - else: - return lambda f: UserDefinedScalarFunctionWrapper( - f, input_types, result_type, name - ) - - -def udtf( - input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], - result_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], - name: Optional[str] = None, -) -> Callable: - """ - Annotation for creating a user-defined table function. - - Parameters: - - input_types: A list of strings or Arrow data types that specifies the input data types. - - result_types A list of strings or Arrow data types that specifies the return value types. - - name: An optional string specifying the function name. If not provided, the original name will be used. - - Example: - ``` - @udtf(input_types='INT', result_types='INT') - def series(n): - for i in range(n): - yield i - ``` - """ - - return lambda f: UserDefinedTableFunctionWrapper(f, input_types, result_types, name) - - -class UdfServer(pa.flight.FlightServerBase): - """ - A server that provides user-defined functions to clients. - - Example: - ``` - server = UdfServer(location="0.0.0.0:8815") - server.add_function(my_udf) - server.serve() - ``` - """ - - # UDF server based on Apache Arrow Flight protocol. - # Reference: https://arrow.apache.org/cookbook/py/flight.html#simple-parquet-storage-service-with-arrow-flight - - _location: str - _functions: Dict[str, UserDefinedFunction] - - def __init__(self, location="0.0.0.0:8815", **kwargs): - super(UdfServer, self).__init__("grpc://" + location, **kwargs) - self._location = location - self._functions = {} - - def get_flight_info(self, context, descriptor): - """Return the result schema of a function.""" - udf = self._functions[descriptor.path[0].decode("utf-8")] - # return the concatenation of input and output schema - full_schema = pa.schema(list(udf._input_schema) + list(udf._result_schema)) - # we use `total_records` to indicate the number of input arguments - return pa.flight.FlightInfo( - schema=full_schema, - descriptor=descriptor, - endpoints=[], - total_records=len(udf._input_schema), - total_bytes=0, - ) - - def add_function(self, udf: UserDefinedFunction): - """Add a function to the server.""" - name = udf._name - if name in self._functions: - raise ValueError("Function already exists: " + name) - - input_types = ",".join( - [_data_type_to_string(t) for t in udf._input_schema.types] - ) - if isinstance(udf, TableFunction): - output_type = udf._result_schema.types[-1] - if isinstance(output_type, pa.StructType): - output_type = ",".join( - f"field_{i} {_data_type_to_string(field.type)}" - for i, field in enumerate(output_type) - ) - output_type = f"TABLE({output_type})" - else: - output_type = _data_type_to_string(output_type) - output_type = f"TABLE(output {output_type})" - else: - output_type = _data_type_to_string(udf._result_schema.types[-1]) - - sql = f"CREATE FUNCTION {name}({input_types}) RETURNS {output_type} AS '{name}' USING LINK 'http://{self._location}';" - print(f"added function: {name}, corresponding SQL:\n{sql}\n") - self._functions[name] = udf - - def do_exchange(self, context, descriptor, reader, writer): - """Call a function from the client.""" - udf = self._functions[descriptor.path[0].decode("utf-8")] - writer.begin(udf._result_schema) - try: - for batch in reader: - # print(pa.Table.from_batches([batch.data])) - for output_batch in udf.eval_batch(batch.data): - writer.write_batch(output_batch) - except Exception as e: - print(traceback.print_exc()) - raise e - - def serve(self): - """ - Block until the server shuts down. - - This method only returns if shutdown() is called or a signal (SIGINT, SIGTERM) received. - """ - print( - "Note: You can use arbitrary function names and struct field names in CREATE FUNCTION statements." - f"\n\nlistening on {self._location}" - ) - signal.signal(signal.SIGTERM, lambda s, f: self.shutdown()) - super(UdfServer, self).serve() - - -def _to_data_type(t: Union[str, pa.DataType]) -> pa.DataType: - """ - Convert a SQL data type string or `pyarrow.DataType` to `pyarrow.DataType`. - """ - if isinstance(t, str): - return _string_to_data_type(t) - else: - return t - - -# we use `large_binary` to represent unconstrained decimal type -UNCONSTRAINED_DECIMAL = pa.large_binary() -JSONB = pa.large_string() - - -def _string_to_data_type(type_str: str): - """ - Convert a SQL data type string to `pyarrow.DataType`. - """ - type_str = type_str.upper() - if type_str.endswith("[]"): - return pa.list_(_string_to_data_type(type_str[:-2])) - elif type_str in ("BOOLEAN", "BOOL"): - return pa.bool_() - elif type_str in ("SMALLINT", "INT2"): - return pa.int16() - elif type_str in ("INT", "INTEGER", "INT4"): - return pa.int32() - elif type_str in ("BIGINT", "INT8"): - return pa.int64() - elif type_str in ("FLOAT4", "REAL"): - return pa.float32() - elif type_str in ("FLOAT8", "DOUBLE PRECISION"): - return pa.float64() - elif type_str.startswith("DECIMAL") or type_str.startswith("NUMERIC"): - if type_str == "DECIMAL" or type_str == "NUMERIC": - return UNCONSTRAINED_DECIMAL - rest = type_str[8:-1] # remove "DECIMAL(" and ")" - if "," in rest: - precision, scale = rest.split(",") - return pa.decimal128(int(precision), int(scale)) - else: - return pa.decimal128(int(rest), 0) - elif type_str in ("DATE"): - return pa.date32() - elif type_str in ("TIME", "TIME WITHOUT TIME ZONE"): - return pa.time64("us") - elif type_str in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE"): - return pa.timestamp("us") - elif type_str.startswith("INTERVAL"): - return pa.month_day_nano_interval() - elif type_str in ("VARCHAR"): - return pa.string() - elif type_str in ("JSONB"): - return JSONB - elif type_str in ("BYTEA"): - return pa.binary() - elif type_str.startswith("STRUCT"): - # extract 'STRUCT, ...>' - type_list = type_str[7:-1] # strip "STRUCT<>" - fields = [] - elements = [] - start = 0 - depth = 0 - for i, c in enumerate(type_list): - if c == "<": - depth += 1 - elif c == ">": - depth -= 1 - elif c == "," and depth == 0: - type_str = type_list[start:i].strip() - fields.append(pa.field("", _string_to_data_type(type_str))) - start = i + 1 - type_str = type_list[start:].strip() - fields.append(pa.field("", _string_to_data_type(type_str))) - return pa.struct(fields) - - raise ValueError(f"Unsupported type: {type_str}") - - -def _data_type_to_string(t: pa.DataType) -> str: - """ - Convert a `pyarrow.DataType` to a SQL data type string. - """ - if isinstance(t, pa.ListType): - return _data_type_to_string(t.value_type) + "[]" - elif t.equals(pa.bool_()): - return "BOOLEAN" - elif t.equals(pa.int16()): - return "SMALLINT" - elif t.equals(pa.int32()): - return "INT" - elif t.equals(pa.int64()): - return "BIGINT" - elif t.equals(pa.float32()): - return "FLOAT4" - elif t.equals(pa.float64()): - return "FLOAT8" - elif t.equals(UNCONSTRAINED_DECIMAL): - return "DECIMAL" - elif pa.types.is_decimal(t): - return f"DECIMAL({t.precision},{t.scale})" - elif t.equals(pa.date32()): - return "DATE" - elif t.equals(pa.time64("us")): - return "TIME" - elif t.equals(pa.timestamp("us")): - return "TIMESTAMP" - elif t.equals(pa.month_day_nano_interval()): - return "INTERVAL" - elif t.equals(pa.string()): - return "VARCHAR" - elif t.equals(JSONB): - return "JSONB" - elif t.equals(pa.binary()): - return "BYTEA" - elif isinstance(t, pa.StructType): - return ( - "STRUCT<" - + ",".join( - f"f{i+1} {_data_type_to_string(field.type)}" - for i, field in enumerate(t) - ) - + ">" - ) - else: - raise ValueError(f"Unsupported type: {t}") diff --git a/src/expr/udf/python/risingwave/udf/health_check.py b/src/expr/udf/python/risingwave/udf/health_check.py deleted file mode 100644 index ad2d38681a6cc..0000000000000 --- a/src/expr/udf/python/risingwave/udf/health_check.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2024 RisingWave Labs -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pyarrow.flight import FlightClient -import sys - - -def check_udf_service_available(addr: str) -> bool: - """Check if the UDF service is available at the given address.""" - try: - client = FlightClient(f"grpc://{addr}") - client.wait_for_available() - return True - except Exception as e: - print(f"Error connecting to RisingWave UDF service: {str(e)}") - return False - - -if __name__ == "__main__": - if len(sys.argv) != 2: - print("usage: python3 health_check.py ") - sys.exit(1) - - server_address = sys.argv[1] - if check_udf_service_available(server_address): - print("OK") - else: - print("unavailable") - exit(-1) From 8ea1b1b7e7ca417ebbe837ab7f963a57af2460a9 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 29 Apr 2024 17:08:47 +0800 Subject: [PATCH 11/27] remove risingwave_udf crate Signed-off-by: Runji Wang --- Cargo.toml | 1 - src/expr/udf/Cargo.toml | 35 ----- src/expr/udf/README-js.md | 83 ---------- src/expr/udf/README.md | 118 --------------- src/expr/udf/examples/client.rs | 76 ---------- src/expr/udf/src/error.rs | 67 -------- src/expr/udf/src/external.rs | 260 -------------------------------- src/expr/udf/src/lib.rs | 24 --- src/expr/udf/src/metrics.rs | 111 -------------- src/frontend/Cargo.toml | 1 - 10 files changed, 776 deletions(-) delete mode 100644 src/expr/udf/Cargo.toml delete mode 100644 src/expr/udf/README-js.md delete mode 100644 src/expr/udf/README.md delete mode 100644 src/expr/udf/examples/client.rs delete mode 100644 src/expr/udf/src/error.rs delete mode 100644 src/expr/udf/src/external.rs delete mode 100644 src/expr/udf/src/lib.rs delete mode 100644 src/expr/udf/src/metrics.rs diff --git a/Cargo.toml b/Cargo.toml index ce1a66c94bdaa..9970cf4e07681 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ members = [ "src/expr/core", "src/expr/impl", "src/expr/macro", - "src/expr/udf", "src/frontend", "src/frontend/macro", "src/frontend/planner_test", diff --git a/src/expr/udf/Cargo.toml b/src/expr/udf/Cargo.toml deleted file mode 100644 index b17ad7acadfc1..0000000000000 --- a/src/expr/udf/Cargo.toml +++ /dev/null @@ -1,35 +0,0 @@ -[package] -name = "risingwave_udf" -version = "0.1.0" -edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[package.metadata.cargo-machete] -ignored = ["workspace-hack"] - -[package.metadata.cargo-udeps.ignore] -normal = ["workspace-hack"] - -[dependencies] -arrow-array = { workspace = true } -arrow-flight = { workspace = true } -arrow-schema = { workspace = true } -arrow-select = { workspace = true } -cfg-or-panic = "0.2" -futures = "0.3" -futures-util = "0.3.28" -ginepro = "0.7.0" -prometheus = "0.13" -risingwave_common = { workspace = true } -static_assertions = "1" -thiserror = "1" -thiserror-ext = { workspace = true } -tokio = { version = "0.2", package = "madsim-tokio", features = [ - "rt", - "macros", -] } -tonic = { workspace = true } -tracing = "0.1" - -[lints] -workspace = true diff --git a/src/expr/udf/README-js.md b/src/expr/udf/README-js.md deleted file mode 100644 index 902bce4ef52ee..0000000000000 --- a/src/expr/udf/README-js.md +++ /dev/null @@ -1,83 +0,0 @@ -# Use UDFs in JavaScript - -This article provides a step-by-step guide for defining JavaScript functions in RisingWave. - -JavaScript code is inlined in `CREATE FUNCTION` statement and then run on the embedded QuickJS virtual machine in RisingWave. It does not support access to external networks and is limited to computational tasks only. -Compared to other languages, JavaScript UDFs offer the easiest way to define UDFs in RisingWave. - -## Define your functions - -You can use the `CREATE FUNCTION` statement to create JavaScript UDFs. The syntax is as follows: - -```sql -CREATE FUNCTION function_name ( arg_name arg_type [, ...] ) - [ RETURNS return_type | RETURNS TABLE ( column_name column_type [, ...] ) ] - LANGUAGE javascript - AS [ $$ function_body $$ | 'function_body' ]; -``` - -The argument names you define can be used in the function body. For example: - -```sql -CREATE FUNCTION gcd(a int, b int) RETURNS int LANGUAGE javascript AS $$ - if(a == null || b == null) { - return null; - } - while (b != 0) { - let t = b; - b = a % b; - a = t; - } - return a; -$$; -``` - -The correspondence between SQL types and JavaScript types can be found in the [appendix table](#appendix-type-mapping). You need to ensure that the type of the return value is either `null` or consistent with the type in the `RETURNS` clause. - -If the function you define returns a table, you need to use the `yield` statement to return the data of each row. For example: - -```sql -CREATE FUNCTION series(n int) RETURNS TABLE (x int) LANGUAGE javascript AS $$ - for(let i = 0; i < n; i++) { - yield i; - } -$$; -``` - -## Use your functions - -Once the UDFs are created in RisingWave, you can use them in SQL queries just like any built-in functions. For example: - -```sql -SELECT gcd(25, 15); -SELECT * from series(5); -``` - -## Appendix: Type Mapping - -The following table shows the type mapping between SQL and JavaScript: - -| SQL Type | JS Type | Note | -| --------------------- | ------------- | --------------------- | -| boolean | boolean | | -| smallint | number | | -| int | number | | -| bigint | number | | -| real | number | | -| double precision | number | | -| decimal | BigDecimal | | -| date | | not supported yet | -| time | | not supported yet | -| timestamp | | not supported yet | -| timestamptz | | not supported yet | -| interval | | not supported yet | -| varchar | string | | -| bytea | Uint8Array | | -| jsonb | null, boolean, number, string, array or object | `JSON.parse(string)` | -| smallint[] | Int16Array | | -| int[] | Int32Array | | -| bigint[] | BigInt64Array | | -| real[] | Float32Array | | -| double precision[] | Float64Array | | -| others[] | array | | -| struct<..> | object | | diff --git a/src/expr/udf/README.md b/src/expr/udf/README.md deleted file mode 100644 index d9428cc547249..0000000000000 --- a/src/expr/udf/README.md +++ /dev/null @@ -1,118 +0,0 @@ -# Use UDFs in Rust - -This article provides a step-by-step guide for defining Rust functions in RisingWave. - -Rust functions are compiled into WebAssembly modules and then run on the embedded WebAssembly virtual machine in RisingWave. Compared to Python and Java, Rust UDFs offer **higher performance** (near native) and are **managed by the RisingWave kernel**, eliminating the need for additional maintenance. However, since they run embedded in the kernel, for security reasons, Rust UDFs currently **do not support access to external networks and are limited to computational tasks only**, with restricted CPU and memory resources. Therefore, we recommend using Rust UDFs for **computationally intensive tasks**, such as packet parsing and format conversion. - -## Prerequisites - -- Ensure that you have [Rust toolchain](https://rustup.rs) (stable channel) installed on your computer. -- Ensure that the Rust standard library for `wasm32-wasi` target is installed: - ```shell - rustup target add wasm32-wasi - ``` - -## 1. Create a project - -Create a Rust project named `udf`: - -```shell -cargo new --lib udf -cd udf -``` - -Add the following lines to `Cargo.toml`: - -```toml -[lib] -crate-type = ["cdylib"] - -[dependencies] -arrow-udf = "0.1" -``` - -## 2. Define your functions - -In `src/lib.rs`, define your functions using the `function` macro: - -```rust -use arrow_udf::function; - -// define a scalar function -#[function("gcd(int, int) -> int")] -fn gcd(mut x: i32, mut y: i32) -> i32 { - while y != 0 { - (x, y) = (y, x % y); - } - x -} - -// define a table function -#[function("series(int) -> setof int")] -fn series(n: i32) -> impl Iterator { - 0..n -} -``` - -You can find more usages in the [documentation](https://docs.rs/arrow_udf/0.1.0/arrow_udf/attr.function.html) and more examples in the [tests](https://github.com/risingwavelabs/arrow-udf/blob/main/arrow-udf/tests/tests.rs). - -Currently we only support a limited set of data types. `timestamptz` and complex array types are not supported yet. - -## 3. Build the project - -Build your functions into a WebAssembly module: - -```shell -cargo build --release --target wasm32-wasi -``` - -You can find the generated WASM module at `target/wasm32-wasi/release/udf.wasm`. - -Optional: It is recommended to strip the binary to reduce its size: - -```shell -# Install wasm-tools -cargo install wasm-tools - -# Strip the binary -wasm-tools strip ./target/wasm32-wasi/release/udf.wasm > udf.wasm -``` - -## 4. Declare your functions in RisingWave - -In RisingWave, use the `CREATE FUNCTION` command to declare the functions you defined. - -There are two ways to load the WASM module: - -1. The WASM binary can be embedded in the SQL statement using the base64 encoding. -You can use the following shell script to encode the binary and generate the SQL statement: - ```shell - encoded=$(base64 -i udf.wasm) - sql="CREATE FUNCTION gcd(int, int) RETURNS int LANGUAGE wasm USING BASE64 '$encoded';" - echo "$sql" > create_function.sql - ``` - When created successfully, the WASM binary will be automatically uploaded to the object store. - -2. The WASM binary can be loaded from the object store. - ```sql - CREATE FUNCTION gcd(int, int) RETURNS int - LANGUAGE wasm USING LINK 's3://bucket/path/to/udf.wasm'; - - CREATE FUNCTION series(int) RETURNS TABLE (x int) - LANGUAGE wasm USING LINK 's3://bucket/path/to/udf.wasm'; - ``` - - Or if you run RisingWave locally, you can use the local file system: - ```sql - CREATE FUNCTION gcd(int, int) RETURNS int - LANGUAGE wasm USING LINK 'fs://path/to/udf.wasm'; - ``` - -## 5. Use your functions in RisingWave - -Once the UDFs are created in RisingWave, you can use them in SQL queries just like any built-in functions. For example: - -```sql -SELECT gcd(25, 15); -SELECT series(5); -``` diff --git a/src/expr/udf/examples/client.rs b/src/expr/udf/examples/client.rs deleted file mode 100644 index 92f93ae13614e..0000000000000 --- a/src/expr/udf/examples/client.rs +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use arrow_array::{Int32Array, RecordBatch}; -use arrow_schema::{DataType, Field, Schema}; -use risingwave_udf::ArrowFlightUdfClient; - -#[tokio::main] -async fn main() { - let addr = "http://localhost:8815"; - let client = ArrowFlightUdfClient::connect(addr).await.unwrap(); - - // build `RecordBatch` to send (equivalent to our `DataChunk`) - let array1 = Arc::new(Int32Array::from_iter(vec![1, 6, 10])); - let array2 = Arc::new(Int32Array::from_iter(vec![3, 4, 15])); - let array3 = Arc::new(Int32Array::from_iter(vec![6, 8, 3])); - let input2_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ]); - let input3_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let output_schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]); - - // check function - client - .check("gcd", &input2_schema, &output_schema) - .await - .unwrap(); - client - .check("gcd3", &input3_schema, &output_schema) - .await - .unwrap(); - - let input2 = RecordBatch::try_new( - Arc::new(input2_schema), - vec![array1.clone(), array2.clone()], - ) - .unwrap(); - - let output = client - .call("gcd", input2) - .await - .expect("failed to call function"); - - println!("{:?}", output); - - let input3 = RecordBatch::try_new( - Arc::new(input3_schema), - vec![array1.clone(), array2.clone(), array3.clone()], - ) - .unwrap(); - - let output = client - .call("gcd3", input3) - .await - .expect("failed to call function"); - - println!("{:?}", output); -} diff --git a/src/expr/udf/src/error.rs b/src/expr/udf/src/error.rs deleted file mode 100644 index fc6733052b137..0000000000000 --- a/src/expr/udf/src/error.rs +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use arrow_flight::error::FlightError; -use thiserror::Error; -use thiserror_ext::{Box, Construct}; - -/// A specialized `Result` type for UDF operations. -pub type Result = std::result::Result; - -/// The error type for UDF operations. -#[derive(Error, Debug, Box, Construct)] -#[thiserror_ext(newtype(name = Error))] -pub enum ErrorInner { - #[error("failed to send requests to UDF service: {0}")] - Tonic(#[from] tonic::Status), - - #[error("failed to call UDF: {0}")] - Flight(#[from] FlightError), - - #[error("type mismatch: {0}")] - TypeMismatch(String), - - #[error("arrow error: {0}")] - Arrow(#[from] arrow_schema::ArrowError), - - #[error("UDF unsupported: {0}")] - // TODO(error-handling): should prefer use error types than strings. - Unsupported(String), - - #[error("UDF service returned no data")] - NoReturned, - - #[error("Flight service error: {0}")] - ServiceError(String), -} - -impl Error { - /// Returns true if the error is caused by a connection error. - pub fn is_connection_error(&self) -> bool { - match self.inner() { - // Connection refused - ErrorInner::Tonic(status) if status.code() == tonic::Code::Unavailable => true, - _ => false, - } - } - - pub fn is_tonic_error(&self) -> bool { - matches!( - self.inner(), - ErrorInner::Tonic(_) | ErrorInner::Flight(FlightError::Tonic(_)) - ) - } -} - -static_assertions::const_assert_eq!(std::mem::size_of::(), 8); diff --git a/src/expr/udf/src/external.rs b/src/expr/udf/src/external.rs deleted file mode 100644 index 7560638b03985..0000000000000 --- a/src/expr/udf/src/external.rs +++ /dev/null @@ -1,260 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::str::FromStr; -use std::time::Duration; - -use arrow_array::RecordBatch; -use arrow_flight::decode::FlightRecordBatchStream; -use arrow_flight::encode::FlightDataEncoderBuilder; -use arrow_flight::error::FlightError; -use arrow_flight::flight_service_client::FlightServiceClient; -use arrow_flight::{FlightData, FlightDescriptor}; -use arrow_schema::Schema; -use cfg_or_panic::cfg_or_panic; -use futures_util::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; -use ginepro::{LoadBalancedChannel, ResolutionStrategy}; -use risingwave_common::util::addr::HostAddr; -use thiserror_ext::AsReport; -use tokio::time::Duration as TokioDuration; -use tonic::transport::Channel; - -use crate::metrics::GLOBAL_METRICS; -use crate::{Error, Result}; - -// Interval between two successive probes of the UDF DNS. -const DNS_PROBE_INTERVAL_SECS: u64 = 5; -// Timeout duration for performing an eager DNS resolution. -const EAGER_DNS_RESOLVE_TIMEOUT_SECS: u64 = 5; -const REQUEST_TIMEOUT_SECS: u64 = 5; -const CONNECT_TIMEOUT_SECS: u64 = 5; - -/// Client for external function service based on Arrow Flight. -#[derive(Debug)] -pub struct ArrowFlightUdfClient { - client: FlightServiceClient, - addr: String, -} - -// TODO: support UDF in simulation -#[cfg_or_panic(not(madsim))] -impl ArrowFlightUdfClient { - /// Connect to a UDF service. - pub async fn connect(addr: &str) -> Result { - Self::connect_inner( - addr, - ResolutionStrategy::Eager { - timeout: TokioDuration::from_secs(EAGER_DNS_RESOLVE_TIMEOUT_SECS), - }, - ) - .await - } - - /// Connect to a UDF service lazily (i.e. only when the first request is sent). - pub fn connect_lazy(addr: &str) -> Result { - Self::connect_inner(addr, ResolutionStrategy::Lazy) - .now_or_never() - .unwrap() - } - - async fn connect_inner( - mut addr: &str, - resolution_strategy: ResolutionStrategy, - ) -> Result { - if addr.starts_with("http://") { - addr = addr.strip_prefix("http://").unwrap(); - } - if addr.starts_with("https://") { - addr = addr.strip_prefix("https://").unwrap(); - } - let host_addr = HostAddr::from_str(addr).map_err(|e| { - Error::service_error(format!("invalid address: {}, err: {}", addr, e.as_report())) - })?; - let channel = LoadBalancedChannel::builder((host_addr.host.clone(), host_addr.port)) - .dns_probe_interval(std::time::Duration::from_secs(DNS_PROBE_INTERVAL_SECS)) - .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS)) - .connect_timeout(Duration::from_secs(CONNECT_TIMEOUT_SECS)) - .resolution_strategy(resolution_strategy) - .channel() - .await - .map_err(|e| { - Error::service_error(format!( - "failed to create LoadBalancedChannel, address: {}, err: {}", - host_addr, - e.as_report() - )) - })?; - let client = FlightServiceClient::new(channel.into()); - Ok(Self { - client, - addr: addr.into(), - }) - } - - /// Check if the function is available and the schema is match. - pub async fn check(&self, id: &str, args: &Schema, returns: &Schema) -> Result<()> { - let descriptor = FlightDescriptor::new_path(vec![id.into()]); - - let response = self.client.clone().get_flight_info(descriptor).await?; - - // check schema - let info = response.into_inner(); - let input_num = info.total_records as usize; - let full_schema = Schema::try_from(info).map_err(|e| { - FlightError::DecodeError(format!("Error decoding schema: {}", e.as_report())) - })?; - if input_num > full_schema.fields.len() { - return Err(Error::service_error(format!( - "function {:?} schema info not consistency: input_num: {}, total_fields: {}", - id, - input_num, - full_schema.fields.len() - ))); - } - - let (input_fields, return_fields) = full_schema.fields.split_at(input_num); - let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect(); - let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect(); - let expect_input_types: Vec<_> = args.fields.iter().map(|f| f.data_type()).collect(); - let expect_result_types: Vec<_> = returns.fields.iter().map(|f| f.data_type()).collect(); - if !data_types_match(&expect_input_types, &actual_input_types) { - return Err(Error::type_mismatch(format!( - "function: {:?}, expect arguments: {:?}, actual: {:?}", - id, expect_input_types, actual_input_types - ))); - } - if !data_types_match(&expect_result_types, &actual_result_types) { - return Err(Error::type_mismatch(format!( - "function: {:?}, expect return: {:?}, actual: {:?}", - id, expect_result_types, actual_result_types - ))); - } - Ok(()) - } - - /// Call a function. - pub async fn call(&self, id: &str, input: RecordBatch) -> Result { - self.call_internal(id, input).await - } - - async fn call_internal(&self, id: &str, input: RecordBatch) -> Result { - let mut output_stream = self - .call_stream_internal(id, stream::once(async { input })) - .await?; - let mut batches = vec![]; - while let Some(batch) = output_stream.next().await { - batches.push(batch?); - } - Ok(arrow_select::concat::concat_batches( - output_stream.schema().ok_or_else(Error::no_returned)?, - batches.iter(), - )?) - } - - /// Call a function, retry up to 5 times / 3s if connection is broken. - pub async fn call_with_retry(&self, id: &str, input: RecordBatch) -> Result { - let mut backoff = Duration::from_millis(100); - for i in 0..5 { - match self.call(id, input.clone()).await { - Err(err) if err.is_connection_error() && i != 4 => { - tracing::error!(error = %err.as_report(), "UDF connection error. retry..."); - } - ret => return ret, - } - tokio::time::sleep(backoff).await; - backoff *= 2; - } - unreachable!() - } - - /// Always retry on connection error - pub async fn call_with_always_retry_on_network_error( - &self, - id: &str, - input: RecordBatch, - fragment_id: &str, - ) -> Result { - let mut backoff = Duration::from_millis(100); - let metrics = &*GLOBAL_METRICS; - let labels: &[&str; 4] = &[&self.addr, "external", id, fragment_id]; - loop { - match self.call(id, input.clone()).await { - Err(err) if err.is_tonic_error() => { - tracing::error!(error = %err.as_report(), "UDF tonic error. retry..."); - } - ret => { - if ret.is_err() { - tracing::error!(error = %ret.as_ref().unwrap_err().as_report(), "UDF error. exiting..."); - } - return ret; - } - } - metrics.udf_retry_count.with_label_values(labels).inc(); - tokio::time::sleep(backoff).await; - backoff *= 2; - } - } - - /// Call a function with streaming input and output. - #[panic_return = "Result>"] - pub async fn call_stream( - &self, - id: &str, - inputs: impl Stream + Send + 'static, - ) -> Result> + Send + 'static> { - Ok(self - .call_stream_internal(id, inputs) - .await? - .map_err(|e| e.into())) - } - - async fn call_stream_internal( - &self, - id: &str, - inputs: impl Stream + Send + 'static, - ) -> Result { - let descriptor = FlightDescriptor::new_path(vec![id.into()]); - let flight_data_stream = - FlightDataEncoderBuilder::new() - .build(inputs.map(Ok)) - .map(move |res| FlightData { - // TODO: fill descriptor only for the first message - flight_descriptor: Some(descriptor.clone()), - ..res.unwrap() - }); - - // call `do_exchange` on Flight server - let response = self.client.clone().do_exchange(flight_data_stream).await?; - - // decode response - let stream = response.into_inner(); - Ok(FlightRecordBatchStream::new_from_flight_data( - // convert tonic::Status to FlightError - stream.map_err(|e| e.into()), - )) - } - - pub fn get_addr(&self) -> &str { - &self.addr - } -} - -/// Check if two list of data types match, ignoring field names. -fn data_types_match(a: &[&arrow_schema::DataType], b: &[&arrow_schema::DataType]) -> bool { - if a.len() != b.len() { - return false; - } - #[allow(clippy::disallowed_methods)] - a.iter().zip(b.iter()).all(|(a, b)| a.equals_datatype(b)) -} diff --git a/src/expr/udf/src/lib.rs b/src/expr/udf/src/lib.rs deleted file mode 100644 index ddd8cf1bdeab9..0000000000000 --- a/src/expr/udf/src/lib.rs +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#![feature(error_generic_member_access)] -#![feature(lazy_cell)] - -mod error; -mod external; -pub mod metrics; - -pub use error::{Error, Result}; -pub use external::ArrowFlightUdfClient; -pub use metrics::GLOBAL_METRICS; diff --git a/src/expr/udf/src/metrics.rs b/src/expr/udf/src/metrics.rs deleted file mode 100644 index 50ef1b068307d..0000000000000 --- a/src/expr/udf/src/metrics.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::LazyLock; - -use prometheus::{ - exponential_buckets, register_histogram_vec_with_registry, - register_int_counter_vec_with_registry, HistogramVec, IntCounterVec, Registry, -}; -use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY; - -/// Monitor metrics for UDF. -#[derive(Debug, Clone)] -pub struct Metrics { - /// Number of successful UDF calls. - pub udf_success_count: IntCounterVec, - /// Number of failed UDF calls. - pub udf_failure_count: IntCounterVec, - /// Total number of retried UDF calls. - pub udf_retry_count: IntCounterVec, - /// Input chunk rows of UDF calls. - pub udf_input_chunk_rows: HistogramVec, - /// The latency of UDF calls in seconds. - pub udf_latency: HistogramVec, - /// Total number of input rows of UDF calls. - pub udf_input_rows: IntCounterVec, - /// Total number of input bytes of UDF calls. - pub udf_input_bytes: IntCounterVec, -} - -/// Global UDF metrics. -pub static GLOBAL_METRICS: LazyLock = - LazyLock::new(|| Metrics::new(&GLOBAL_METRICS_REGISTRY)); - -impl Metrics { - fn new(registry: &Registry) -> Self { - let labels = &["link", "language", "name", "fragment_id"]; - let udf_success_count = register_int_counter_vec_with_registry!( - "udf_success_count", - "Total number of successful UDF calls", - labels, - registry - ) - .unwrap(); - let udf_failure_count = register_int_counter_vec_with_registry!( - "udf_failure_count", - "Total number of failed UDF calls", - labels, - registry - ) - .unwrap(); - let udf_retry_count = register_int_counter_vec_with_registry!( - "udf_retry_count", - "Total number of retried UDF calls", - labels, - registry - ) - .unwrap(); - let udf_input_chunk_rows = register_histogram_vec_with_registry!( - "udf_input_chunk_rows", - "Input chunk rows of UDF calls", - labels, - exponential_buckets(1.0, 2.0, 10).unwrap(), // 1 to 1024 - registry - ) - .unwrap(); - let udf_latency = register_histogram_vec_with_registry!( - "udf_latency", - "The latency(s) of UDF calls", - labels, - exponential_buckets(0.000001, 2.0, 30).unwrap(), // 1us to 1000s - registry - ) - .unwrap(); - let udf_input_rows = register_int_counter_vec_with_registry!( - "udf_input_rows", - "Total number of input rows of UDF calls", - labels, - registry - ) - .unwrap(); - let udf_input_bytes = register_int_counter_vec_with_registry!( - "udf_input_bytes", - "Total number of input bytes of UDF calls", - labels, - registry - ) - .unwrap(); - - Metrics { - udf_success_count, - udf_failure_count, - udf_retry_count, - udf_input_chunk_rows, - udf_latency, - udf_input_rows, - udf_input_bytes, - } - } -} diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index 3cba7afe82660..def7b4743033f 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -72,7 +72,6 @@ risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } risingwave_sqlparser = { workspace = true } risingwave_storage = { workspace = true } -risingwave_udf = { workspace = true } risingwave_variables = { workspace = true } rw_futures_util = { workspace = true } serde = { version = "1", features = ["derive"] } From a94ee09e235f3e76e43f302d2983ab3f76c27e74 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 29 Apr 2024 17:26:36 +0800 Subject: [PATCH 12/27] migrate to arrow-udf-flight Signed-off-by: Runji Wang --- Cargo.lock | 315 +++++++++--------- Cargo.toml | 9 +- e2e_test/udf/external_udf.slt | 2 +- .../src/main/java/com/example/UdfExample.java | 2 +- src/expr/core/Cargo.toml | 6 +- src/expr/core/src/error.rs | 8 +- src/expr/core/src/expr/expr_udf.rs | 221 ++++++++++-- .../core/src/table_function/user_defined.rs | 19 +- src/frontend/Cargo.toml | 1 + src/frontend/src/handler/create_function.rs | 38 ++- 10 files changed, 391 insertions(+), 230 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cfd22c41c3d7a..3610b36d80637 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -716,11 +716,26 @@ dependencies = [ "regex-syntax 0.8.2", ] +[[package]] +name = "arrow-udf-flight" +version = "0.1.0" +dependencies = [ + "arrow-array 50.0.0", + "arrow-flight", + "arrow-schema 50.0.0", + "arrow-select 50.0.0", + "futures-util", + "thiserror", + "tokio", + "tonic 0.10.2", + "tracing", +] + [[package]] name = "arrow-udf-js" -version = "0.1.2" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "252b6355ad1e57eb6454b705c51652de55aa22eb018cdb95be0dbf62ee3ec78f" +checksum = "0519711e77180c5fe9891b81d912d937864894c77932b5df52169966f4a948bb" dependencies = [ "anyhow", "arrow-array 50.0.0", @@ -732,7 +747,7 @@ dependencies = [ [[package]] name = "arrow-udf-js-deno" version = "0.0.1" -source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=23fe0dd#23fe0dd41616f4646f9139e22a335518e6cc9a47" +source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=76c995d#76c995d31f66785c39c7d70196c8ba0f1a61ad60" dependencies = [ "anyhow", "arrow-array 50.0.0", @@ -754,7 +769,7 @@ dependencies = [ [[package]] name = "arrow-udf-js-deno-runtime" version = "0.0.1" -source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=23fe0dd#23fe0dd41616f4646f9139e22a335518e6cc9a47" +source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=76c995d#76c995d31f66785c39c7d70196c8ba0f1a61ad60" dependencies = [ "anyhow", "deno_ast", @@ -782,7 +797,8 @@ dependencies = [ [[package]] name = "arrow-udf-python" version = "0.1.0" -source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=6c32f71#6c32f710b5948147f8214797fc334a4a3cadef0d" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41eaaa010b9cf07bedda6f1dafa050496e96fff7ae4b9602fb77c25c24c64cb7" dependencies = [ "anyhow", "arrow-array 50.0.0", @@ -796,9 +812,9 @@ dependencies = [ [[package]] name = "arrow-udf-wasm" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59a51355b8ca4de8ae028e5efb45c248dad4568cde6707f23b89f9b86a907f36" +checksum = "eb829e25925161d93617d4b053bae03fe51e708f2cce088d85df856011d4f369" dependencies = [ "anyhow", "arrow-array 50.0.0", @@ -1640,7 +1656,7 @@ dependencies = [ "cfg-if", "libc", "miniz_oxide", - "object", + "object 0.32.1", "rustc-demangle", ] @@ -2719,18 +2735,18 @@ dependencies = [ [[package]] name = "cranelift-bforest" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b3775cc6cc00c90d29eebea55feedb2b0168e23f5415bab7859c4004d7323d1" +checksum = "79b27922a6879b5b5361d0a084cb0b1941bf109a98540addcb932da13b68bed4" dependencies = [ "cranelift-entity", ] [[package]] name = "cranelift-codegen" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "637f3184ba5bfa48d425bad1d2e4faf5fcf619f5e0ca107edc6dc02f589d4d74" +checksum = "304c455b28bf56372729acb356afbb55d622f2b0f2f7837aa5e57c138acaac4d" dependencies = [ "bumpalo", "cranelift-bforest", @@ -2749,33 +2765,33 @@ dependencies = [ [[package]] name = "cranelift-codegen-meta" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4b35b8240462341d94d31aab807cad704683988708261aecae3d57db48b7212" +checksum = "1653c56b99591d07f67c5ca7f9f25888948af3f4b97186bff838d687d666f613" dependencies = [ "cranelift-codegen-shared", ] [[package]] name = "cranelift-codegen-shared" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3cd1555aa9df1d6d8375732de41b4cb0d787006948d55b6d004d521e9efeb0" +checksum = "f5b6a9cf6b6eb820ee3f973a0db313c05dc12d370f37b4fe9630286e1672573f" [[package]] name = "cranelift-control" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14b31a562a10e98ab148fa146801e20665c5f9eda4fce9b2c5a3836575887d74" +checksum = "d9d06e6bf30075fb6bed9e034ec046475093392eea1aff90eb5c44c4a033d19a" dependencies = [ "arbitrary", ] [[package]] name = "cranelift-entity" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af1e0467700a3f4fccf5feddbaebdf8b0eb82535b06a9600c4bc5df40872e75d" +checksum = "29be04f931b73cdb9694874a295027471817f26f26d2f0ebe5454153176b6e3a" dependencies = [ "serde", "serde_derive", @@ -2783,9 +2799,9 @@ dependencies = [ [[package]] name = "cranelift-frontend" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6cb918ee2c23939262efd1b99d76a21212ac7bd35129582133e21a22a6ff0467" +checksum = "a07fd7393041d7faa2f37426f5dc7fc04003b70988810e8c063beefeff1cd8f9" dependencies = [ "cranelift-codegen", "log", @@ -2795,15 +2811,15 @@ dependencies = [ [[package]] name = "cranelift-isle" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "966e4cfb23cf6d7f1d285d53a912baaffc5f06bcd9c9b0a2d8c66a184fae534b" +checksum = "f341d7938caa6dff8149dac05bb2b53fc680323826b83b4cf175ab9f5139a3c9" [[package]] name = "cranelift-native" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bea803aadfc4aabdfae7c3870f1b1f6dd4332f4091859e9758ef5fca6bf8cc87" +checksum = "82af6066e6448d26eeabb7aa26a43f7ff79f8217b06bade4ee6ef230aecc8880" dependencies = [ "cranelift-codegen", "libc", @@ -2812,9 +2828,9 @@ dependencies = [ [[package]] name = "cranelift-wasm" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11d18a3572cd897555bba3621e568029417d8f5cc26aeede2d7cb0bad6afd916" +checksum = "2766fab7284a914a7f17f90ebe865c86453225fb8637ac31f123f5028fee69cd" dependencies = [ "cranelift-codegen", "cranelift-entity", @@ -5495,9 +5511,9 @@ dependencies = [ [[package]] name = "ginepro" -version = "0.7.0" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eedbff62a689be48f58f32571dbf3d60c4a73b39740141dfe7ac942536ea27f7" +checksum = "3b00ef897d4082727a53ea1111cd19bfa4ccdc476a5eb9f49087047113a43891" dependencies = [ "anyhow", "async-trait", @@ -6954,20 +6970,11 @@ dependencies = [ "pkg-config", ] -[[package]] -name = "mach" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa" -dependencies = [ - "libc", -] - [[package]] name = "mach2" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0d1830bcd151a6fc4aea1369af235b36c1528fe976b8ff678683c9995eade8" +checksum = "19b955cdeb2a02b9117f121ce63aa52d08ade45de53e48fe6a38b39c10f6f709" dependencies = [ "libc", ] @@ -7773,6 +7780,15 @@ name = "object" version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +dependencies = [ + "memchr", +] + +[[package]] +name = "object" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8dd6c0cdf9429bce006e1362bfce61fa1bfd8c898a643ed8d2b471934701d3d" dependencies = [ "crc32fast", "hashbrown 0.14.3", @@ -9323,9 +9339,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" +checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" dependencies = [ "cfg-if", "indoc", @@ -9341,9 +9357,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" +checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" dependencies = [ "once_cell", "target-lexicon", @@ -9351,9 +9367,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" +checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" dependencies = [ "libc", "pyo3-build-config", @@ -9361,9 +9377,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" +checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -9373,9 +9389,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" +checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" dependencies = [ "heck 0.4.1", "proc-macro2", @@ -10670,7 +10686,9 @@ version = "1.9.0-alpha" dependencies = [ "anyhow", "arrow-array 50.0.0", + "arrow-flight", "arrow-schema 50.0.0", + "arrow-udf-flight", "arrow-udf-js", "arrow-udf-js-deno", "arrow-udf-python", @@ -10690,6 +10708,7 @@ dependencies = [ "futures", "futures-async-stream", "futures-util", + "ginepro", "itertools 0.12.1", "linkme", "madsim-tokio", @@ -10699,15 +10718,16 @@ dependencies = [ "openssl", "parse-display", "paste", + "prometheus", "risingwave_common", "risingwave_common_estimate_size", "risingwave_expr_macro", "risingwave_pb", - "risingwave_udf", "smallvec", "static_assertions", "thiserror", "thiserror-ext", + "tonic 0.10.2", "tracing", "workspace-hack", "zstd 0.13.0", @@ -10773,6 +10793,7 @@ dependencies = [ "anyhow", "arc-swap", "arrow-schema 50.0.0", + "arrow-udf-flight", "arrow-udf-wasm", "assert_matches", "async-recursion", @@ -10832,7 +10853,6 @@ dependencies = [ "risingwave_rpc_client", "risingwave_sqlparser", "risingwave_storage", - "risingwave_udf", "risingwave_variables", "rw_futures_util", "serde", @@ -11596,28 +11616,6 @@ dependencies = [ "workspace-hack", ] -[[package]] -name = "risingwave_udf" -version = "0.1.0" -dependencies = [ - "arrow-array 50.0.0", - "arrow-flight", - "arrow-schema 50.0.0", - "arrow-select 50.0.0", - "cfg-or-panic", - "futures", - "futures-util", - "ginepro", - "madsim-tokio", - "madsim-tonic", - "prometheus", - "risingwave_common", - "static_assertions", - "thiserror", - "thiserror-ext", - "tracing", -] - [[package]] name = "risingwave_variables" version = "1.9.0-alpha" @@ -15316,9 +15314,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasi-common" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b53dfacdeacca15ee2a48a4aa0ec6a6d0da737676e465770c0585f79c04e638" +checksum = "63255d85e10627b07325d7cf4e5fe5a40fa4ff183569a0a67931be26d50ede07" dependencies = [ "anyhow", "bitflags 2.5.0", @@ -15414,9 +15412,18 @@ checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" [[package]] name = "wasm-encoder" -version = "0.201.0" +version = "0.202.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9c7d2731df60006819b013f64ccc2019691deccf6e11a1804bc850cd6748f1a" +checksum = "bfd106365a7f5f7aa3c1916a98cbb3ad477f5ff96ddb130285a91c6e7429e67a" +dependencies = [ + "leb128", +] + +[[package]] +name = "wasm-encoder" +version = "0.206.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d759312e1137f199096d80a70be685899cd7d3d09c572836bb2e9b69b4dc3b1e" dependencies = [ "leb128", ] @@ -15449,9 +15456,9 @@ dependencies = [ [[package]] name = "wasmparser" -version = "0.201.0" +version = "0.202.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84e5df6dba6c0d7fafc63a450f1738451ed7a0b52295d83e868218fa286bf708" +checksum = "d6998515d3cf3f8b980ef7c11b29a9b1017d4cf86b99ae93b546992df9931413" dependencies = [ "bitflags 2.5.0", "indexmap 2.0.0", @@ -15460,9 +15467,9 @@ dependencies = [ [[package]] name = "wasmprinter" -version = "0.201.0" +version = "0.202.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a67e66da702706ba08729a78e3c0079085f6bfcb1a62e4799e97bbf728c2c265" +checksum = "ab1cc9508685eef9502e787f4d4123745f5651a1e29aec047645d3cac1e2da7a" dependencies = [ "anyhow", "wasmparser", @@ -15470,9 +15477,9 @@ dependencies = [ [[package]] name = "wasmtime" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "516be5b58a8f75d39b01378516dcb0ff7b9bc39c7f1f10eec5b338d4916cf988" +checksum = "5a5990663c28d81015ddbb02a068ac1bf396a4ea296eba7125b2dfc7c00cb52e" dependencies = [ "addr2line", "anyhow", @@ -15487,7 +15494,7 @@ dependencies = [ "ittapi", "libc", "log", - "object", + "object 0.33.0", "once_cell", "paste", "rayon", @@ -15497,7 +15504,7 @@ dependencies = [ "serde_derive", "serde_json", "target-lexicon", - "wasm-encoder", + "wasm-encoder 0.202.0", "wasmparser", "wasmtime-cache", "wasmtime-component-macro", @@ -15516,18 +15523,18 @@ dependencies = [ [[package]] name = "wasmtime-asm-macros" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d22d88a92d69385f18143c946884bf6aaa9ec206ce54c85a2d320c1362b009" +checksum = "625ee94c72004f3ea0228989c9506596e469517d7d0ed66f7300d1067bdf1ca9" dependencies = [ "cfg-if", ] [[package]] name = "wasmtime-cache" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "068728a840223b56c964507550da671372e7e5c2f3a7856012b57482e3e979a7" +checksum = "98534bf28de232299e83eab33984a7a6c40c69534d6bd0ea216150b63d41a83a" dependencies = [ "anyhow", "base64 0.21.7", @@ -15545,9 +15552,9 @@ dependencies = [ [[package]] name = "wasmtime-component-macro" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "631244bac89c57ebe7283209d86fe175ad5929328e75f61bf9141895cafbf52d" +checksum = "64f84414a25ee3a624c8b77550f3fe7b5d8145bd3405ca58886ee6900abb6dc2" dependencies = [ "anyhow", "proc-macro2", @@ -15560,15 +15567,15 @@ dependencies = [ [[package]] name = "wasmtime-component-util" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82ad496ba0558f7602da5e9d4c201f35f7aefcca70f973ec916f3f0d0787ef74" +checksum = "78580bdb4e04c7da3bf98088559ca1d29382668536e4d5c7f2f966d79c390307" [[package]] name = "wasmtime-cranelift" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "961ab5ee4b17e627001b18069ee89ef906edbbd3f84955515f6aad5ab6d82299" +checksum = "b60df0ee08c6a536c765f69e9e8205273435b66d02dd401e938769a2622a6c1a" dependencies = [ "anyhow", "cfg-if", @@ -15580,36 +15587,19 @@ dependencies = [ "cranelift-wasm", "gimli", "log", - "object", + "object 0.33.0", "target-lexicon", "thiserror", "wasmparser", - "wasmtime-cranelift-shared", "wasmtime-environ", "wasmtime-versioned-export-macros", ] -[[package]] -name = "wasmtime-cranelift-shared" -version = "19.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc4db94596be14cd1f85844ce85470bf68acf235143098b9d9bf72b49e47b917" -dependencies = [ - "anyhow", - "cranelift-codegen", - "cranelift-control", - "cranelift-native", - "gimli", - "object", - "target-lexicon", - "wasmtime-environ", -] - [[package]] name = "wasmtime-environ" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420b13858ef27dfd116f1fdb0513e9593a307a632ade2ea58334b639a3d8d24e" +checksum = "64ffc1613db69ee47c96738861534f9a405e422a5aa00224fbf5d410b03fb445" dependencies = [ "anyhow", "bincode 1.3.3", @@ -15618,13 +15608,13 @@ dependencies = [ "gimli", "indexmap 2.0.0", "log", - "object", + "object 0.33.0", "rustc-demangle", "serde", "serde_derive", "target-lexicon", "thiserror", - "wasm-encoder", + "wasm-encoder 0.202.0", "wasmparser", "wasmprinter", "wasmtime-component-util", @@ -15633,9 +15623,9 @@ dependencies = [ [[package]] name = "wasmtime-fiber" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d37ff0e11a023019e34fe839c74a1c00880b989f4446176b6cc6da3b58e3ef2" +checksum = "f043514a23792761c5765f8ba61a4aa7d67f260c0c37494caabceb41d8ae81de" dependencies = [ "anyhow", "cc", @@ -15648,11 +15638,11 @@ dependencies = [ [[package]] name = "wasmtime-jit-debug" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b849f19ad1d4a8133ff05b82c438144f17fb49b08e5f7995f8c1e25cf35f390" +checksum = "9c0ca2ad8f5d2b37f507ef1c935687a690e84e9f325f5a2af9639440b43c1f0e" dependencies = [ - "object", + "object 0.33.0", "once_cell", "rustix 0.38.31", "wasmtime-versioned-export-macros", @@ -15660,9 +15650,9 @@ dependencies = [ [[package]] name = "wasmtime-jit-icache-coherence" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59c48eb4223d6556ffbf3decb146d0da124f1fd043f41c98b705252cb6a5c186" +checksum = "7a9f93a3289057b26dc75eb84d6e60d7694f7d169c7c09597495de6e016a13ff" dependencies = [ "cfg-if", "libc", @@ -15671,9 +15661,9 @@ dependencies = [ [[package]] name = "wasmtime-runtime" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fefac2cb5f5a6f365234a3584bf40bd2e45e7f6cd90a689d9b2afbb9881978f" +checksum = "c6332a2b0af4224c3ea57c857ad39acd2780ccc2b0c99ba1baa01864d90d7c94" dependencies = [ "anyhow", "cc", @@ -15682,34 +15672,34 @@ dependencies = [ "indexmap 2.0.0", "libc", "log", - "mach", + "mach2", "memfd", "memoffset", "paste", "psm", "rustix 0.38.31", "sptr", - "wasm-encoder", + "wasm-encoder 0.202.0", "wasmtime-asm-macros", "wasmtime-environ", "wasmtime-fiber", "wasmtime-jit-debug", + "wasmtime-slab", "wasmtime-versioned-export-macros", - "wasmtime-wmemcheck", "windows-sys 0.52.0", ] [[package]] name = "wasmtime-slab" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52d7b97b92df126fdbe994a53d2215828ec5ed5087535e6d4703b1fbd299f0e3" +checksum = "8b3655075824a374c536a2b2cc9283bb765fcdf3d58b58587862c48571ad81ef" [[package]] name = "wasmtime-types" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "509c88abb830819b259c49e2d4e4f22b555db066ba08ded0b76b071a2aa53ddf" +checksum = "b98cf64a242b0b9257604181ca28b28a5fcaa4c9ea1d396f76d1d2d1c5b40eef" dependencies = [ "cranelift-entity", "serde", @@ -15720,9 +15710,9 @@ dependencies = [ [[package]] name = "wasmtime-versioned-export-macros" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1d81c092a61ca1667013e2eb08fed7c6c53e496dbbaa32d5685dc5152b0a772" +checksum = "8561d9e2920db2a175213d557d71c2ac7695831ab472bbfafb9060cd1034684f" dependencies = [ "proc-macro2", "quote", @@ -15731,26 +15721,26 @@ dependencies = [ [[package]] name = "wasmtime-winch" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0958907880e37a2d3974f5b3574c23bf70aaf1fc6c1f716625bb50dac776f1a" +checksum = "a06b573d14ac846a0fb8c541d8fca6a64acf9a1d176176982472274ab1d2fa5d" dependencies = [ "anyhow", "cranelift-codegen", "gimli", - "object", + "object 0.33.0", "target-lexicon", "wasmparser", - "wasmtime-cranelift-shared", + "wasmtime-cranelift", "wasmtime-environ", "winch-codegen", ] [[package]] name = "wasmtime-wit-bindgen" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a593ddefd2f80617df6bea084b2e422d8969e924bc209642a794d57518f59587" +checksum = "595bc7bb3b0ff4aa00fab718c323ea552c3034d77abc821a35112552f2ea487a" dependencies = [ "anyhow", "heck 0.4.1", @@ -15758,12 +15748,6 @@ dependencies = [ "wit-parser", ] -[[package]] -name = "wasmtime-wmemcheck" -version = "19.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b77212b6874bbc86d220bb1d28632d0c11c6afe996c3e1ddcf746b1a6b4919b9" - [[package]] name = "wast" version = "35.0.2" @@ -15775,24 +15759,24 @@ dependencies = [ [[package]] name = "wast" -version = "201.0.0" +version = "206.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ef6e1ef34d7da3e2b374fd2b1a9c0227aff6cad596e1b24df9b58d0f6222faa" +checksum = "68586953ee4960b1f5d84ebf26df3b628b17e6173bc088e0acfbce431469795a" dependencies = [ "bumpalo", "leb128", "memchr", "unicode-width", - "wasm-encoder", + "wasm-encoder 0.206.0", ] [[package]] name = "wat" -version = "1.201.0" +version = "1.206.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453d5b37a45b98dee4f4cb68015fc73634d7883bbef1c65e6e9c78d454cf3f32" +checksum = "da4c6f2606276c6e991aebf441b2fc92c517807393f039992a3e0ad873efe4ad" dependencies = [ - "wast 201.0.0", + "wast 206.0.0", ] [[package]] @@ -15880,9 +15864,9 @@ checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" [[package]] name = "wiggle" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f093d8afdb09efaf2ed1037468bd4614308a762d215b6cafd60a7712993a8ffa" +checksum = "1b6552dda951239e219c329e5a768393664e8d120c5e0818487ac2633f173b1f" dependencies = [ "anyhow", "async-trait", @@ -15895,9 +15879,9 @@ dependencies = [ [[package]] name = "wiggle-generate" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47c7bccd5172ce8d853242f723e42c84b8c131b24fb07a1570f9045d99258616" +checksum = "da64cb31e0bfe8b1d2d13956ef9fd5c77545756a1a6ef0e6cfd44e8f1f207aed" dependencies = [ "anyhow", "heck 0.4.1", @@ -15910,9 +15894,9 @@ dependencies = [ [[package]] name = "wiggle-macro" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a69d087dee85991096fc0c6eaf4dcf4e17cd16a0594c33b8ab9e2d345234ef75" +checksum = "900b2416ef2ff2903ded6cf55d4a941fed601bf56a8c4874856d7a77c1891994" dependencies = [ "proc-macro2", "quote", @@ -15953,9 +15937,9 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "winch-codegen" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e72a6a7034793b874b85e428fd6d7b3ccccb98c326e33af3aa40cdf50d0c33da" +checksum = "fb23450977f9d4a23c02439cf6899340b2d68887b19465c5682740d9cc37d52e" dependencies = [ "anyhow", "cranelift-codegen", @@ -15964,6 +15948,7 @@ dependencies = [ "smallvec", "target-lexicon", "wasmparser", + "wasmtime-cranelift", "wasmtime-environ", ] @@ -16252,9 +16237,9 @@ dependencies = [ [[package]] name = "wit-parser" -version = "0.201.0" +version = "0.202.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "196d3ecfc4b759a8573bf86a9b3f8996b304b3732e4c7de81655f875f6efdca6" +checksum = "744237b488352f4f27bca05a10acb79474415951c450e52ebd0da784c1df2bcc" dependencies = [ "anyhow", "id-arena", diff --git a/Cargo.toml b/Cargo.toml index 9970cf4e07681..9474456921f2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -138,10 +138,11 @@ arrow-flight = "50" arrow-select = "50" arrow-ord = "50" arrow-row = "50" -arrow-udf-js = "0.1" -arrow-udf-js-deno = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "23fe0dd" } -arrow-udf-wasm = { version = "0.2.1", features = ["build"] } -arrow-udf-python = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "6c32f71" } +arrow-udf-js = "0.2" +arrow-udf-js-deno = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "76c995d" } +arrow-udf-wasm = { version = "0.2.2", features = ["build"] } +arrow-udf-python = "0.1" +arrow-udf-flight = { path = "../arrow-udf/arrow-udf-flight" } arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" } arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" } arrow-cast-deltalake = { package = "arrow-cast", version = "48.0.1" } diff --git a/e2e_test/udf/external_udf.slt b/e2e_test/udf/external_udf.slt index 096a605709d67..7a38506f81563 100644 --- a/e2e_test/udf/external_udf.slt +++ b/e2e_test/udf/external_udf.slt @@ -1,7 +1,7 @@ # Before running this test: # python3 e2e_test/udf/test.py # or: -# cd java/udf-example && mvn package && java -jar target/risingwave-udf-example.jar +# cd e2e_test/udf/java && mvn package && java -jar target/risingwave-udf-example.jar # Create a function. statement ok diff --git a/e2e_test/udf/java/src/main/java/com/example/UdfExample.java b/e2e_test/udf/java/src/main/java/com/example/UdfExample.java index 883dc5035514c..1702e244bf1ff 100644 --- a/e2e_test/udf/java/src/main/java/com/example/UdfExample.java +++ b/e2e_test/udf/java/src/main/java/com/example/UdfExample.java @@ -33,7 +33,7 @@ public class UdfExample { public static void main(String[] args) throws IOException { - try (var server = new UdfServer("0.0.0.0", 8815)) { + try (var server = new UdfServer("localhost", 8815)) { server.addFunction("int_42", new Int42()); server.addFunction("float_to_decimal", new FloatToDecimal()); server.addFunction("sleep", new Sleep()); diff --git a/src/expr/core/Cargo.toml b/src/expr/core/Cargo.toml index 3f5ca590026db..c811d81b34658 100644 --- a/src/expr/core/Cargo.toml +++ b/src/expr/core/Cargo.toml @@ -22,7 +22,9 @@ embedded-python-udf = ["arrow-udf-python"] [dependencies] anyhow = "1" arrow-array = { workspace = true } +arrow-flight = "50" arrow-schema = { workspace = true } +arrow-udf-flight = { workspace = true } arrow-udf-js = { workspace = true } arrow-udf-js-deno = { workspace = true, optional = true } arrow-udf-python = { workspace = true, optional = true } @@ -44,6 +46,7 @@ enum-as-inner = "0.6" futures = "0.3" futures-async-stream = { workspace = true } futures-util = "0.3" +ginepro = "0.7" itertools = { workspace = true } linkme = { version = "0.3", features = ["used_linker"] } md5 = "0.7" @@ -52,11 +55,11 @@ num-traits = "0.2" openssl = { version = "0.10", features = ["vendored"] } parse-display = "0.9" paste = "1" +prometheus = "0.13" risingwave_common = { workspace = true } risingwave_common_estimate_size = { workspace = true } risingwave_expr_macro = { path = "../macro" } risingwave_pb = { workspace = true } -risingwave_udf = { workspace = true } smallvec = "1" static_assertions = "1" thiserror = "1" @@ -65,6 +68,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [ "rt-multi-thread", "macros", ] } +tonic = "0.10" tracing = "0.1" zstd = { version = "0.13", default-features = false } diff --git a/src/expr/core/src/error.rs b/src/expr/core/src/error.rs index 6688824093d2d..efc3c20526f13 100644 --- a/src/expr/core/src/error.rs +++ b/src/expr/core/src/error.rs @@ -99,7 +99,7 @@ pub enum ExprError { Udf( #[from] #[backtrace] - risingwave_udf::Error, + Box, ), #[error("not a constant")] @@ -152,6 +152,12 @@ impl From for ExprError { } } +impl From for ExprError { + fn from(err: arrow_udf_flight::Error) -> Self { + Self::Udf(Box::new(err)) + } +} + /// A collection of multiple errors. #[derive(Error, Debug)] pub struct MultiExprError(Box<[ExprError]>); diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index ed7d597cce52a..4968c51d0dc90 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -19,7 +19,9 @@ use std::sync::{Arc, LazyLock, Weak}; use std::time::Duration; use anyhow::{Context, Error}; +use arrow_array::RecordBatch; use arrow_schema::{Field, Fields, Schema}; +use arrow_udf_flight::Client as FlightClient; use arrow_udf_js::{CallMode as JsCallMode, Runtime as JsRuntime}; #[cfg(feature = "embedded-deno-udf")] use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; @@ -29,13 +31,16 @@ use arrow_udf_wasm::Runtime as WasmRuntime; use await_tree::InstrumentAwait; use cfg_or_panic::cfg_or_panic; use moka::sync::Cache; -use risingwave_common::array::{ArrayError, ArrayRef, DataChunk}; +use prometheus::{ + exponential_buckets, register_histogram_vec_with_registry, + register_int_counter_vec_with_registry, HistogramVec, IntCounter, IntCounterVec, Registry, +}; +use risingwave_common::array::{ArrayRef, DataChunk}; +use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; use risingwave_expr::expr_context::FRAGMENT_ID; use risingwave_pb::expr::ExprNode; -use risingwave_udf::metrics::GLOBAL_METRICS; -use risingwave_udf::ArrowFlightUdfClient; use thiserror_ext::AsReport; use super::{BoxedExpression, Build}; @@ -51,6 +56,7 @@ pub struct UserDefinedFunction { arg_schema: Arc, imp: UdfImpl, identifier: String, + link: Option, span: await_tree::Span, /// Number of remaining successful calls until retry is enabled. /// This parameter is designed to prevent continuous retry on every call, which would increase delay. @@ -70,7 +76,7 @@ const INITIAL_RETRY_COUNT: u8 = 16; #[derive(Debug)] pub enum UdfImpl { - External(Arc), + External(Arc), Wasm(Arc), JavaScript(JsRuntime), #[cfg(feature = "embedded-python-udf")] @@ -125,10 +131,6 @@ impl UserDefinedFunction { let fragment_id = FRAGMENT_ID::try_with(ToOwned::to_owned) .unwrap_or(0) .to_string(); - let addr = match &self.imp { - UdfImpl::External(client) => client.get_addr(), - _ => "", - }; let language = match &self.imp { UdfImpl::Wasm(_) => "wasm", UdfImpl::JavaScript(_) => "javascript(quickjs)", @@ -138,7 +140,12 @@ impl UserDefinedFunction { UdfImpl::Deno(_) => "javascript(deno)", UdfImpl::External(_) => "external", }; - let labels: &[&str; 4] = &[addr, language, &self.identifier, fragment_id.as_str()]; + let labels: &[&str; 4] = &[ + self.link.as_deref().unwrap_or(""), + language, + &self.identifier, + fragment_id.as_str(), + ]; metrics .udf_input_chunk_rows .with_label_values(labels) @@ -166,28 +173,27 @@ impl UserDefinedFunction { UdfImpl::External(client) => { let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); let result = if self.always_retry_on_network_error { - client - .call_with_always_retry_on_network_error( - &self.identifier, - arrow_input, - &fragment_id, - ) - .instrument_await(self.span.clone()) - .await + call_with_always_retry_on_network_error( + &client, + &self.identifier, + &arrow_input, + &metrics.udf_retry_count.with_label_values(labels), + ) + .instrument_await(self.span.clone()) + .await } else { let result = if disable_retry_count != 0 { client - .call(&self.identifier, arrow_input) + .call(&self.identifier, &arrow_input) .instrument_await(self.span.clone()) .await } else { - client - .call_with_retry(&self.identifier, arrow_input) + call_with_retry(&client, &self.identifier, &arrow_input) .instrument_await(self.span.clone()) .await }; let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); - let connection_error = matches!(&result, Err(e) if e.is_connection_error()); + let connection_error = matches!(&result, Err(e) if is_connection_error(e)); if connection_error && disable_retry_count != INITIAL_RETRY_COUNT { // reset count on connection error self.disable_retry_count @@ -243,6 +249,52 @@ impl UserDefinedFunction { } } +/// Call a function, retry up to 5 times / 3s if connection is broken. +async fn call_with_retry( + client: &FlightClient, + id: &str, + input: &RecordBatch, +) -> Result { + let mut backoff = Duration::from_millis(100); + for i in 0..5 { + match client.call(id, input).await { + Err(err) if is_connection_error(&err) && i != 4 => { + tracing::error!(error = %err.as_report(), "UDF connection error. retry..."); + } + ret => return ret, + } + tokio::time::sleep(backoff).await; + backoff *= 2; + } + unreachable!() +} + +/// Always retry on connection error +async fn call_with_always_retry_on_network_error( + client: &FlightClient, + id: &str, + input: &RecordBatch, + retry_count: &IntCounter, +) -> Result { + let mut backoff = Duration::from_millis(100); + loop { + match client.call(id, input).await { + Err(err) if is_tonic_error(&err) => { + tracing::error!(error = %err.as_report(), "UDF tonic error. retry..."); + } + ret => { + if ret.is_err() { + tracing::error!(error = %ret.as_ref().unwrap_err().as_report(), "UDF error. exiting..."); + } + return ret; + } + } + retry_count.inc(); + tokio::time::sleep(backoff).await; + backoff *= 2; + } +} + impl Build for UserDefinedFunction { fn build( prost: &ExprNode, @@ -352,15 +404,7 @@ impl Build for UserDefinedFunction { let arg_schema = Arc::new(Schema::new( udf.arg_types .iter() - .map::, _>(|t| { - Ok(Field::new( - "", - DataType::from(t).try_into().map_err(|e: ArrayError| { - risingwave_udf::Error::unsupported(e.to_report_string()) - })?, - true, - )) - }) + .map::, _>(|t| Ok(Field::new("", DataType::from(t).try_into()?, true))) .try_collect::()?, )); @@ -371,6 +415,7 @@ impl Build for UserDefinedFunction { arg_schema, imp, identifier: identifier.clone(), + link: udf.link.clone(), span: format!("udf_call({})", identifier).into(), disable_retry_count: AtomicU8::new(0), always_retry_on_network_error: udf.always_retry_on_network_error, @@ -382,8 +427,8 @@ impl Build for UserDefinedFunction { /// Get or create a client for the given UDF service. /// /// There is a global cache for clients, so that we can reuse the same client for the same service. -pub(crate) fn get_or_create_flight_client(link: &str) -> Result> { - static CLIENTS: LazyLock>>> = +pub(crate) fn get_or_create_flight_client(link: &str) -> Result> { + static CLIENTS: LazyLock>>> = LazyLock::new(Default::default); let mut clients = CLIENTS.lock().unwrap(); if let Some(client) = clients.get(link).and_then(|c| c.upgrade()) { @@ -391,8 +436,10 @@ pub(crate) fn get_or_create_flight_client(link: &str) -> Result Result> { RUNTIMES.insert(md5, runtime.clone()); Ok(runtime) } + +/// Returns true if the arrow flight error is caused by a connection error. +fn is_connection_error(err: &arrow_udf_flight::Error) -> bool { + match err { + // Connection refused + arrow_udf_flight::Error::Tonic(status) if status.code() == tonic::Code::Unavailable => true, + _ => false, + } +} + +fn is_tonic_error(err: &arrow_udf_flight::Error) -> bool { + match err { + arrow_udf_flight::Error::Tonic(_) + | arrow_udf_flight::Error::Flight(arrow_flight::error::FlightError::Tonic(_)) => true, + _ => false, + } +} + +/// Monitor metrics for UDF. +#[derive(Debug, Clone)] +struct Metrics { + /// Number of successful UDF calls. + udf_success_count: IntCounterVec, + /// Number of failed UDF calls. + udf_failure_count: IntCounterVec, + /// Total number of retried UDF calls. + udf_retry_count: IntCounterVec, + /// Input chunk rows of UDF calls. + udf_input_chunk_rows: HistogramVec, + /// The latency of UDF calls in seconds. + udf_latency: HistogramVec, + /// Total number of input rows of UDF calls. + udf_input_rows: IntCounterVec, + /// Total number of input bytes of UDF calls. + udf_input_bytes: IntCounterVec, +} + +/// Global UDF metrics. +static GLOBAL_METRICS: LazyLock = LazyLock::new(|| Metrics::new(&GLOBAL_METRICS_REGISTRY)); + +impl Metrics { + fn new(registry: &Registry) -> Self { + let labels = &["link", "language", "name", "fragment_id"]; + let udf_success_count = register_int_counter_vec_with_registry!( + "udf_success_count", + "Total number of successful UDF calls", + labels, + registry + ) + .unwrap(); + let udf_failure_count = register_int_counter_vec_with_registry!( + "udf_failure_count", + "Total number of failed UDF calls", + labels, + registry + ) + .unwrap(); + let udf_retry_count = register_int_counter_vec_with_registry!( + "udf_retry_count", + "Total number of retried UDF calls", + labels, + registry + ) + .unwrap(); + let udf_input_chunk_rows = register_histogram_vec_with_registry!( + "udf_input_chunk_rows", + "Input chunk rows of UDF calls", + labels, + exponential_buckets(1.0, 2.0, 10).unwrap(), // 1 to 1024 + registry + ) + .unwrap(); + let udf_latency = register_histogram_vec_with_registry!( + "udf_latency", + "The latency(s) of UDF calls", + labels, + exponential_buckets(0.000001, 2.0, 30).unwrap(), // 1us to 1000s + registry + ) + .unwrap(); + let udf_input_rows = register_int_counter_vec_with_registry!( + "udf_input_rows", + "Total number of input rows of UDF calls", + labels, + registry + ) + .unwrap(); + let udf_input_bytes = register_int_counter_vec_with_registry!( + "udf_input_bytes", + "Total number of input bytes of UDF calls", + labels, + registry + ) + .unwrap(); + + Metrics { + udf_success_count, + udf_failure_count, + udf_retry_count, + udf_input_chunk_rows, + udf_latency, + udf_input_rows, + udf_input_bytes, + } + } +} diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index b65ee5e77758b..a015529b3bbac 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -23,10 +23,8 @@ use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; #[cfg(feature = "embedded-python-udf")] use arrow_udf_python::{CallMode as PythonCallMode, Runtime as PythonRuntime}; use cfg_or_panic::cfg_or_panic; -use futures_util::stream; -use risingwave_common::array::{ArrayError, DataChunk, I32Array}; +use risingwave_common::array::{DataChunk, I32Array}; use risingwave_common::bail; -use thiserror_ext::AsReport; use super::*; use crate::expr::expr_udf::UdfImpl; @@ -62,10 +60,7 @@ impl UdfImpl { match self { UdfImpl::External(client) => { #[for_await] - for res in client - .call_stream(identifier, stream::once(async { input })) - .await? - { + for res in client.call_table_function(identifier, &input).await? { yield res?; } } @@ -182,15 +177,7 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result, _>(|t| { - Ok(Field::new( - "", - DataType::from(t).try_into().map_err(|e: ArrayError| { - risingwave_udf::Error::unsupported(e.to_report_string()) - })?, - true, - )) - }) + .map::, _>(|t| Ok(Field::new("", DataType::from(t).try_into()?, true))) .try_collect::<_, Fields, _>()?, )); diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index def7b4743033f..6a09101dc3e59 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -18,6 +18,7 @@ normal = ["workspace-hack"] anyhow = "1" arc-swap = "1" arrow-schema = { workspace = true } +arrow-udf-flight = { workspace = true } arrow-udf-wasm = { workspace = true } async-recursion = "1.1.0" async-trait = "0.1" diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 428a5612e1770..b06b0b2fbe854 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -14,6 +14,7 @@ use anyhow::{anyhow, Context}; use arrow_schema::Fields; +use arrow_udf_flight::Client as FlightClient; use bytes::Bytes; use itertools::Itertools; use pgwire::pg_response::StatementType; @@ -23,7 +24,6 @@ use risingwave_expr::expr::get_or_create_wasm_runtime; use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; use risingwave_pb::catalog::Function; use risingwave_sqlparser::ast::{CreateFunctionBody, ObjectName, OperateFunctionArg}; -use risingwave_udf::ArrowFlightUdfClient; use super::*; use crate::catalog::CatalogError; @@ -166,9 +166,7 @@ pub async fn handle_create_function( // check UDF server { - let client = ArrowFlightUdfClient::connect(&l) - .await - .map_err(|e| anyhow!(e))?; + let client = FlightClient::connect(&l).await.map_err(|e| anyhow!(e))?; /// A helper function to create a unnamed field from data type. fn to_field(data_type: arrow_schema::DataType) -> arrow_schema::Field { arrow_schema::Field::new("", data_type, true) @@ -182,15 +180,29 @@ pub async fn handle_create_function( let returns = arrow_schema::Schema::new(match kind { Kind::Scalar(_) => vec![to_field(return_type.clone().try_into()?)], Kind::Table(_) => vec![ - arrow_schema::Field::new("row_index", arrow_schema::DataType::Int32, true), + arrow_schema::Field::new("row", arrow_schema::DataType::Int32, true), to_field(return_type.clone().try_into()?), ], _ => unreachable!(), }); - client - .check(&identifier, &args, &returns) + let function = client + .get(&identifier) .await .context("failed to check UDF signature")?; + if !data_types_match(&function.args, &args) { + return Err(ErrorCode::InvalidParameterValue(format!( + "argument type mismatch, expect: {:?}, actual: {:?}", + args, function.args, + )) + .into()); + } + if !data_types_match(&function.returns, &returns) { + return Err(ErrorCode::InvalidParameterValue(format!( + "return type mismatch, expect: {:?}, actual: {:?}", + returns, function.returns, + )) + .into()); + } } link = Some(l); } @@ -481,3 +493,15 @@ fn datatype_name(ty: &DataType) -> String { ), } } + +/// Check if two list of data types match, ignoring field names. +fn data_types_match(a: &arrow_schema::Schema, b: &arrow_schema::Schema) -> bool { + if a.fields().len() != b.fields().len() { + return false; + } + #[allow(clippy::disallowed_methods)] + a.fields() + .iter() + .zip(b.fields()) + .all(|(a, b)| a.data_type().equals_datatype(b.data_type())) +} From b1f3cf66686cb3f14cd57554e6c3288057856384 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 7 May 2024 15:56:07 +0800 Subject: [PATCH 13/27] introduce NewUdfArrowConvert Signed-off-by: Runji Wang --- src/common/src/array/arrow/arrow_impl.rs | 91 +++++++++++++++---- .../arrow/{arrow_default.rs => arrow_udf.rs} | 69 +++++++++++++- src/common/src/array/arrow/mod.rs | 4 +- src/expr/impl/src/scalar/external/iceberg.rs | 7 +- 4 files changed, 146 insertions(+), 25 deletions(-) rename src/common/src/array/arrow/{arrow_default.rs => arrow_udf.rs} (61%) diff --git a/src/common/src/array/arrow/arrow_impl.rs b/src/common/src/array/arrow/arrow_impl.rs index 5c1f8ac45fba3..514d3b299769c 100644 --- a/src/common/src/array/arrow/arrow_impl.rs +++ b/src/common/src/array/arrow/arrow_impl.rs @@ -201,20 +201,20 @@ pub trait ToArrow { Ok(Arc::new(arrow_array::BinaryArray::from(array))) } - // Decimal values are stored as ASCII text representation in a large binary array. + // Decimal values are stored as ASCII text representation in a string array. #[inline] fn decimal_to_arrow( &self, _data_type: &arrow_schema::DataType, array: &DecimalArray, ) -> Result { - Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) + Ok(Arc::new(arrow_array::StringArray::from(array))) } - // JSON values are stored as text representation in a large string array. + // JSON values are stored as text representation in a string array. #[inline] fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result { - Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + Ok(Arc::new(arrow_array::StringArray::from(array))) } #[inline] @@ -366,7 +366,8 @@ pub trait ToArrow { #[inline] fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field { - arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true) + arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true) + .with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into()) } #[inline] @@ -376,7 +377,8 @@ pub trait ToArrow { #[inline] fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field { - arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true) + arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true) + .with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into()) } #[inline] @@ -414,8 +416,8 @@ pub trait FromArrow { /// Converts Arrow `RecordBatch` to RisingWave `DataChunk`. fn from_record_batch(&self, batch: &arrow_array::RecordBatch) -> Result { let mut columns = Vec::with_capacity(batch.num_columns()); - for array in batch.columns() { - let column = Arc::new(self.from_array(array)?); + for (array, field) in batch.columns().iter().zip_eq_fast(batch.schema().fields()) { + let column = Arc::new(self.from_array(field, array)?); columns.push(column); } Ok(DataChunk::new(columns, batch.num_rows())) @@ -472,30 +474,44 @@ pub trait FromArrow { /// Converts Arrow `LargeUtf8` type to RisingWave data type. fn from_large_utf8(&self) -> Result { - Ok(DataType::Jsonb) + Ok(DataType::Varchar) } /// Converts Arrow `LargeBinary` type to RisingWave data type. fn from_large_binary(&self) -> Result { - Ok(DataType::Decimal) + Ok(DataType::Bytea) } /// Converts Arrow extension type to RisingWave `DataType`. fn from_extension_type( &self, type_name: &str, - _physical_type: &arrow_schema::DataType, + physical_type: &arrow_schema::DataType, ) -> Result { - Err(ArrayError::from_arrow(format!( - "unsupported extension type: {type_name:?}" - ))) + match (type_name, physical_type) { + ("arrowudf.decimal", arrow_schema::DataType::Utf8) => Ok(DataType::Decimal), + ("arrowudf.json", arrow_schema::DataType::Utf8) => Ok(DataType::Jsonb), + _ => Err(ArrayError::from_arrow(format!( + "unsupported extension type: {type_name:?}" + ))), + } } /// Converts Arrow `Array` to RisingWave `ArrayImpl`. - fn from_array(&self, array: &arrow_array::ArrayRef) -> Result { + fn from_array( + &self, + field: &arrow_schema::Field, + array: &arrow_array::ArrayRef, + ) -> Result { use arrow_schema::DataType::*; use arrow_schema::IntervalUnit::*; use arrow_schema::TimeUnit::*; + + // extension type + if let Some(type_name) = field.metadata().get("ARROW:extension:name") { + return self.from_extension_array(type_name, array); + } + match array.data_type() { Boolean => self.from_bool_array(array.as_any().downcast_ref().unwrap()), Int16 => self.from_int16_array(array.as_any().downcast_ref().unwrap()), @@ -524,6 +540,37 @@ pub trait FromArrow { } } + /// Converts Arrow extension array to RisingWave `ArrayImpl`. + fn from_extension_array( + &self, + type_name: &str, + array: &arrow_array::ArrayRef, + ) -> Result { + match type_name { + "arrowudf.decimal" => { + let array: &arrow_array::StringArray = + array.as_any().downcast_ref().ok_or_else(|| { + ArrayError::from_arrow( + "expected string array for `arrowudf.decimal`".to_string(), + ) + })?; + Ok(ArrayImpl::Decimal(array.try_into()?)) + } + "arrowudf.json" => { + let array: &arrow_array::StringArray = + array.as_any().downcast_ref().ok_or_else(|| { + ArrayError::from_arrow( + "expected string array for `arrowudf.json`".to_string(), + ) + })?; + Ok(ArrayImpl::Jsonb(array.try_into()?)) + } + _ => Err(ArrayError::from_arrow(format!( + "unsupported extension type: {type_name:?}" + ))), + } + } + fn from_bool_array(&self, array: &arrow_array::BooleanArray) -> Result { Ok(ArrayImpl::Bool(array.into())) } @@ -598,20 +645,23 @@ pub trait FromArrow { &self, array: &arrow_array::LargeStringArray, ) -> Result { - Ok(ArrayImpl::Jsonb(array.try_into()?)) + Ok(ArrayImpl::Utf8(array.into())) } fn from_large_binary_array( &self, array: &arrow_array::LargeBinaryArray, ) -> Result { - Ok(ArrayImpl::Decimal(array.try_into()?)) + Ok(ArrayImpl::Bytea(array.into())) } fn from_list_array(&self, array: &arrow_array::ListArray) -> Result { use arrow_array::Array; + let arrow_schema::DataType::List(field) = array.data_type() else { + panic!("nested field types cannot be determined."); + }; Ok(ArrayImpl::List(ListArray { - value: Box::new(self.from_array(array.values())?), + value: Box::new(self.from_array(field, array.values())?), bitmap: match array.nulls() { Some(nulls) => nulls.iter().collect(), None => Bitmap::ones(array.len()), @@ -630,7 +680,8 @@ pub trait FromArrow { array .columns() .iter() - .map(|a| self.from_array(a).map(Arc::new)) + .zip_eq_fast(fields) + .map(|(array, field)| self.from_array(field, array).map(Arc::new)) .try_collect()?, (0..array.len()).map(|i| array.is_valid(i)).collect(), ))) @@ -703,7 +754,9 @@ converts!(I64Array, arrow_array::Int64Array); converts!(F32Array, arrow_array::Float32Array, @map); converts!(F64Array, arrow_array::Float64Array, @map); converts!(BytesArray, arrow_array::BinaryArray); +converts!(BytesArray, arrow_array::LargeBinaryArray); converts!(Utf8Array, arrow_array::StringArray); +converts!(Utf8Array, arrow_array::LargeStringArray); converts!(DateArray, arrow_array::Date32Array, @map); converts!(TimeArray, arrow_array::Time64MicrosecondArray, @map); converts!(TimestampArray, arrow_array::TimestampMicrosecondArray, @map); diff --git a/src/common/src/array/arrow/arrow_default.rs b/src/common/src/array/arrow/arrow_udf.rs similarity index 61% rename from src/common/src/array/arrow/arrow_default.rs rename to src/common/src/array/arrow/arrow_udf.rs index b2867d4fdf583..e2f9e39ad385a 100644 --- a/src/common/src/array/arrow/arrow_default.rs +++ b/src/common/src/array/arrow/arrow_udf.rs @@ -18,17 +18,78 @@ //! //! The corresponding version of arrow is currently used by `udf` and `iceberg` sink. +use std::sync::Arc; + pub use arrow_impl::{FromArrow, ToArrow}; use {arrow_array, arrow_buffer, arrow_cast, arrow_schema}; +use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray}; + #[expect(clippy::duplicate_mod)] #[path = "./arrow_impl.rs"] mod arrow_impl; +/// Arrow conversion for the current version of UDF. This is in use but will be deprecated soon. +/// +/// In the current version of UDF protocol, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types. pub struct UdfArrowConvert; -impl ToArrow for UdfArrowConvert {} -impl FromArrow for UdfArrowConvert {} +impl ToArrow for UdfArrowConvert { + // Decimal values are stored as ASCII text representation in a large binary array. + fn decimal_to_arrow( + &self, + _data_type: &arrow_schema::DataType, + array: &DecimalArray, + ) -> Result { + Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) + } + + // JSON values are stored as text representation in a large string array. + fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result { + Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + } + + fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true) + } + + fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true) + } +} + +impl FromArrow for UdfArrowConvert { + fn from_large_utf8(&self) -> Result { + Ok(DataType::Jsonb) + } + + fn from_large_binary(&self) -> Result { + Ok(DataType::Decimal) + } + + fn from_large_utf8_array( + &self, + array: &arrow_array::LargeStringArray, + ) -> Result { + Ok(ArrayImpl::Jsonb(array.try_into()?)) + } + + fn from_large_binary_array( + &self, + array: &arrow_array::LargeBinaryArray, + ) -> Result { + Ok(ArrayImpl::Decimal(array.try_into()?)) + } +} + +/// Arrow conversion for the next version of UDF. This is unused for now. +/// +/// In the next version of UDF protocol, decimal and jsonb types will be mapped to Arrow extension types. +/// See . +pub struct NewUdfArrowConvert; + +impl ToArrow for NewUdfArrowConvert {} +impl FromArrow for NewUdfArrowConvert {} #[cfg(test)] mod tests { @@ -108,7 +169,9 @@ mod tests { let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]); let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true); let arrow = UdfArrowConvert.list_to_arrow(&data_type, &array).unwrap(); - let rw_array = UdfArrowConvert.from_array(&arrow).unwrap(); + let rw_array = UdfArrowConvert + .from_list_array(arrow.as_any().downcast_ref().unwrap()) + .unwrap(); assert_eq!(rw_array.as_list(), &array); } } diff --git a/src/common/src/array/arrow/mod.rs b/src/common/src/array/arrow/mod.rs index cb726721c867b..67490b22315a1 100644 --- a/src/common/src/array/arrow/mod.rs +++ b/src/common/src/array/arrow/mod.rs @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod arrow_default; mod arrow_deltalake; mod arrow_iceberg; +mod arrow_udf; -pub use arrow_default::{FromArrow, ToArrow, UdfArrowConvert}; pub use arrow_deltalake::DeltaLakeConvert; pub use arrow_iceberg::IcebergArrowConvert; +pub use arrow_udf::{FromArrow, ToArrow, UdfArrowConvert}; diff --git a/src/expr/impl/src/scalar/external/iceberg.rs b/src/expr/impl/src/scalar/external/iceberg.rs index 2194d8b1355be..ea39ea7ef989d 100644 --- a/src/expr/impl/src/scalar/external/iceberg.rs +++ b/src/expr/impl/src/scalar/external/iceberg.rs @@ -35,6 +35,7 @@ pub struct IcebergTransform { child: BoxedExpression, transform: BoxedTransformFunction, input_arrow_type: arrow_schema::DataType, + output_arrow_field: arrow_schema::Field, return_type: DataType, } @@ -61,7 +62,9 @@ impl risingwave_expr::expr::Expression for IcebergTransform { // Transform let res_array = self.transform.transform(arrow_array).unwrap(); // Convert back to array ref and return it - Ok(Arc::new(IcebergArrowConvert.from_array(&res_array)?)) + Ok(Arc::new( + IcebergArrowConvert.from_array(&self.output_arrow_field, &res_array)?, + )) } async fn eval_row(&self, _row: &OwnedRow) -> Result { @@ -96,6 +99,7 @@ fn build(return_type: DataType, mut children: Vec) -> Result) -> Result Date: Tue, 7 May 2024 16:51:22 +0800 Subject: [PATCH 14/27] fix arrow conversion for the new UDF protocol Signed-off-by: Runji Wang --- e2e_test/udf/wasm/Cargo.toml | 2 +- src/common/src/array/arrow/arrow_udf.rs | 89 +++++++++++++------ src/expr/core/src/expr/expr_udf.rs | 56 ++++++++---- .../core/src/table_function/user_defined.rs | 68 +++++++++----- src/frontend/src/handler/create_function.rs | 35 +++++--- 5 files changed, 166 insertions(+), 84 deletions(-) diff --git a/e2e_test/udf/wasm/Cargo.toml b/e2e_test/udf/wasm/Cargo.toml index 250bd8132ca53..54c7da45b1af8 100644 --- a/e2e_test/udf/wasm/Cargo.toml +++ b/e2e_test/udf/wasm/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] -arrow-udf = "0.2" +arrow-udf = "0.3" genawaiter = "0.99" rust_decimal = "1" serde_json = "1" diff --git a/src/common/src/array/arrow/arrow_udf.rs b/src/common/src/array/arrow/arrow_udf.rs index e2f9e39ad385a..5a44ef1439619 100644 --- a/src/common/src/array/arrow/arrow_udf.rs +++ b/src/common/src/array/arrow/arrow_udf.rs @@ -29,68 +29,99 @@ use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray}; #[path = "./arrow_impl.rs"] mod arrow_impl; -/// Arrow conversion for the current version of UDF. This is in use but will be deprecated soon. -/// -/// In the current version of UDF protocol, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types. -pub struct UdfArrowConvert; +/// Arrow conversion for UDF. +#[derive(Default, Debug)] +pub struct UdfArrowConvert { + /// Whether the UDF talks in legacy mode. + /// + /// If true, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types. + /// Otherwise, they are mapped to Arrow extension types. + /// See . + pub legacy: bool, +} impl ToArrow for UdfArrowConvert { - // Decimal values are stored as ASCII text representation in a large binary array. fn decimal_to_arrow( &self, _data_type: &arrow_schema::DataType, array: &DecimalArray, ) -> Result { - Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) + if self.legacy { + // Decimal values are stored as ASCII text representation in a large binary array. + Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) + } else { + Ok(Arc::new(arrow_array::StringArray::from(array))) + } } - // JSON values are stored as text representation in a large string array. fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result { - Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + if self.legacy { + // JSON values are stored as text representation in a large string array. + Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + } else { + Ok(Arc::new(arrow_array::StringArray::from(array))) + } } fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field { - arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true) + if self.legacy { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true) + } else { + arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true) + .with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into()) + } } fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field { - arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true) + if self.legacy { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true) + } else { + arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true) + .with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into()) + } } } impl FromArrow for UdfArrowConvert { fn from_large_utf8(&self) -> Result { - Ok(DataType::Jsonb) + if self.legacy { + Ok(DataType::Jsonb) + } else { + Ok(DataType::Varchar) + } } fn from_large_binary(&self) -> Result { - Ok(DataType::Decimal) + if self.legacy { + Ok(DataType::Decimal) + } else { + Ok(DataType::Bytea) + } } fn from_large_utf8_array( &self, array: &arrow_array::LargeStringArray, ) -> Result { - Ok(ArrayImpl::Jsonb(array.try_into()?)) + if self.legacy { + Ok(ArrayImpl::Jsonb(array.try_into()?)) + } else { + Ok(ArrayImpl::Utf8(array.into())) + } } fn from_large_binary_array( &self, array: &arrow_array::LargeBinaryArray, ) -> Result { - Ok(ArrayImpl::Decimal(array.try_into()?)) + if self.legacy { + Ok(ArrayImpl::Decimal(array.try_into()?)) + } else { + Ok(ArrayImpl::Bytea(array.into())) + } } } -/// Arrow conversion for the next version of UDF. This is unused for now. -/// -/// In the next version of UDF protocol, decimal and jsonb types will be mapped to Arrow extension types. -/// See . -pub struct NewUdfArrowConvert; - -impl ToArrow for NewUdfArrowConvert {} -impl FromArrow for NewUdfArrowConvert {} - #[cfg(test)] mod tests { use std::sync::Arc; @@ -104,7 +135,7 @@ mod tests { // Empty array - risingwave to arrow conversion. let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0)); assert_eq!( - UdfArrowConvert + UdfArrowConvert::default() .struct_to_arrow( &arrow_schema::DataType::Struct(arrow_schema::Fields::empty()), &test_arr @@ -117,7 +148,7 @@ mod tests { // Empty array - arrow to risingwave conversion. let test_arr_2 = arrow_array::StructArray::from(vec![]); assert_eq!( - UdfArrowConvert + UdfArrowConvert::default() .from_struct_array(&test_arr_2) .unwrap() .len(), @@ -146,7 +177,7 @@ mod tests { ), ]) .unwrap(); - let actual_risingwave_struct_array = UdfArrowConvert + let actual_risingwave_struct_array = UdfArrowConvert::default() .from_struct_array(&test_arrow_struct_array) .unwrap() .into_struct(); @@ -168,8 +199,10 @@ mod tests { fn list() { let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]); let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true); - let arrow = UdfArrowConvert.list_to_arrow(&data_type, &array).unwrap(); - let rw_array = UdfArrowConvert + let arrow = UdfArrowConvert::default() + .list_to_arrow(&data_type, &array) + .unwrap(); + let rw_array = UdfArrowConvert::default() .from_list_array(arrow.as_any().downcast_ref().unwrap()) .unwrap(); assert_eq!(rw_array.as_list(), &array); diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 42dd82196ef19..b9b1e290fae08 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -19,8 +19,8 @@ use std::time::Duration; use anyhow::{Context, Error}; use arrow_array::RecordBatch; -use arrow_udf_flight::Client as FlightClient; use arrow_schema::{Fields, Schema, SchemaRef}; +use arrow_udf_flight::Client as FlightClient; use arrow_udf_js::{CallMode as JsCallMode, Runtime as JsRuntime}; #[cfg(feature = "embedded-deno-udf")] use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; @@ -56,6 +56,7 @@ pub struct UserDefinedFunction { imp: UdfImpl, identifier: String, link: Option, + arrow_convert: UdfArrowConvert, span: await_tree::Span, /// Number of remaining successful calls until retry is enabled. /// This parameter is designed to prevent continuous retry on every call, which would increase delay. @@ -122,7 +123,9 @@ impl Expression for UserDefinedFunction { impl UserDefinedFunction { async fn eval_inner(&self, input: &DataChunk) -> Result { // this will drop invisible rows - let arrow_input = UdfArrowConvert.to_record_batch(self.arg_schema.clone(), input)?; + let arrow_input = self + .arrow_convert + .to_record_batch(self.arg_schema.clone(), input)?; // metrics let metrics = &*GLOBAL_METRICS; @@ -230,7 +233,7 @@ impl UserDefinedFunction { ); } - let output = UdfArrowConvert.from_record_batch(&arrow_output)?; + let output = self.arrow_convert.from_record_batch(&arrow_output)?; let output = output.uncompact(input.visibility().clone()); let Some(array) = output.columns().first() else { @@ -301,11 +304,7 @@ impl Build for UserDefinedFunction { ) -> Result { let return_type = DataType::from(prost.get_return_type().unwrap()); let udf = prost.get_rex_node().unwrap().as_udf().unwrap(); - - let arrow_return_type = UdfArrowConvert - .to_arrow_field("", &return_type)? - .data_type() - .clone(); + let mut arrow_convert = UdfArrowConvert::default(); #[cfg(not(feature = "embedded-deno-udf"))] let runtime = "quickjs"; @@ -324,6 +323,10 @@ impl Build for UserDefinedFunction { let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice()) .context("failed to decompress wasm binary")?; let runtime = get_or_create_wasm_runtime(&wasm_binary)?; + // backward compatibility + if runtime.abi_version().0 <= 2 { + arrow_convert = UdfArrowConvert { legacy: true }; + } UdfImpl::Wasm(runtime) } "javascript" if runtime != "deno" => { @@ -336,7 +339,10 @@ impl Build for UserDefinedFunction { ); rt.add_function( identifier, - arrow_return_type, + arrow_convert + .to_arrow_field("", &return_type)? + .data_type() + .clone(), JsCallMode::CalledOnNullInput, &body, )?; @@ -375,12 +381,17 @@ impl Build for UserDefinedFunction { ) }; - futures::executor::block_on(rt.add_function( - identifier, - arrow_return_type, - DenoCallMode::CalledOnNullInput, - &body, - ))?; + futures::executor::block_on( + rt.add_function( + identifier, + arrow_convert + .to_arrow_field("", &return_type)? + .data_type() + .clone(), + DenoCallMode::CalledOnNullInput, + &body, + ), + )?; UdfImpl::Deno(rt) } @@ -390,7 +401,10 @@ impl Build for UserDefinedFunction { let body = udf.get_body()?; rt.add_function( identifier, - arrow_return_type, + arrow_convert + .to_arrow_field("", &return_type)? + .data_type() + .clone(), PythonCallMode::CalledOnNullInput, body, )?; @@ -399,7 +413,12 @@ impl Build for UserDefinedFunction { #[cfg(not(madsim))] _ => { let link = udf.get_link()?; - UdfImpl::External(get_or_create_flight_client(link)?) + let client = crate::expr::expr_udf::get_or_create_flight_client(link)?; + // backward compatibility + if client.protocol_version() == 1 { + arrow_convert = UdfArrowConvert { legacy: true }; + } + UdfImpl::External(client) } #[cfg(madsim)] l => panic!("UDF language {l:?} is not supported on madsim"), @@ -408,7 +427,7 @@ impl Build for UserDefinedFunction { let arg_schema = Arc::new(Schema::new( udf.arg_types .iter() - .map(|t| UdfArrowConvert.to_arrow_field("", &DataType::from(t))) + .map(|t| arrow_convert.to_arrow_field("", &DataType::from(t))) .try_collect::()?, )); @@ -420,6 +439,7 @@ impl Build for UserDefinedFunction { imp, identifier: identifier.clone(), link: udf.link.clone(), + arrow_convert, span: format!("udf_call({})", identifier).into(), disable_retry_count: AtomicU8::new(0), always_retry_on_network_error: udf.always_retry_on_network_error, diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index 04c1f345d5408..c1dc09bafe861 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -37,6 +37,7 @@ pub struct UserDefinedTableFunction { return_type: DataType, client: UdfImpl, identifier: String, + arrow_convert: UdfArrowConvert, #[allow(dead_code)] chunk_size: usize, } @@ -106,8 +107,9 @@ impl UserDefinedTableFunction { // compact the input chunk and record the row mapping let visible_rows = direct_input.visibility().iter_ones().collect::>(); // this will drop invisible rows - let arrow_input = - UdfArrowConvert.to_record_batch(self.arg_schema.clone(), &direct_input)?; + let arrow_input = self + .arrow_convert + .to_record_batch(self.arg_schema.clone(), &direct_input)?; // call UDTF #[for_await] @@ -115,7 +117,7 @@ impl UserDefinedTableFunction { .client .call_table_function(&self.identifier, arrow_input) { - let output = UdfArrowConvert.from_record_batch(&res?)?; + let output = self.arrow_convert.from_record_batch(&res?)?; self.check_output(&output)?; // we send the compacted input to UDF, so we need to map the row indices back to the @@ -175,21 +177,9 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result()?, - )); - let identifier = udtf.get_identifier()?; let return_type = DataType::from(prost.get_return_type()?); - let arrow_return_type = UdfArrowConvert - .to_arrow_field("", &return_type)? - .data_type() - .clone(); - #[cfg(not(feature = "embedded-deno-udf"))] let runtime = "quickjs"; @@ -199,12 +189,18 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result "quickjs", }; + let mut arrow_convert = UdfArrowConvert::default(); + let client = match udtf.language.as_str() { "wasm" | "rust" => { let compressed_wasm_binary = udtf.get_compressed_binary()?; let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice()) .context("failed to decompress wasm binary")?; let runtime = crate::expr::expr_udf::get_or_create_wasm_runtime(&wasm_binary)?; + // backward compatibility + if runtime.abi_version().0 <= 2 { + arrow_convert = UdfArrowConvert { legacy: true }; + } UdfImpl::Wasm(runtime) } "javascript" if runtime != "deno" => { @@ -217,7 +213,10 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result Result Result Result { let link = udtf.get_link()?; - UdfImpl::External(crate::expr::expr_udf::get_or_create_flight_client(link)?) + let client = crate::expr::expr_udf::get_or_create_flight_client(link)?; + // backward compatibility + if client.protocol_version() == 1 { + arrow_convert = UdfArrowConvert { legacy: true }; + } + UdfImpl::External(client) } }; + let arg_schema = Arc::new(Schema::new( + udtf.arg_types + .iter() + .map(|t| arrow_convert.to_arrow_field("", &DataType::from(t))) + .try_collect::()?, + )); + Ok(UserDefinedTableFunction { children: prost.args.iter().map(expr_build_from_prost).try_collect()?, return_type, arg_schema, client, identifier: identifier.clone(), + arrow_convert, chunk_size, } .boxed()) diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index bac8200c77641..88172293cb458 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -168,10 +168,11 @@ pub async fn handle_create_function( // check UDF server { let client = FlightClient::connect(&l).await.map_err(|e| anyhow!(e))?; - /// A helper function to create a unnamed field from data type. - fn to_field(data_type: &DataType) -> Result { - Ok(UdfArrowConvert.to_arrow_field("", data_type)?) - } + let convert = UdfArrowConvert { + legacy: client.protocol_version() == 1, + }; + // A helper function to create a unnamed field from data type. + let to_field = |data_type| convert.to_arrow_field("", data_type); let args = arrow_schema::Schema::new( arg_types .iter() @@ -288,6 +289,7 @@ pub async fn handle_create_function( let wasm_binary = tokio::task::spawn_blocking(move || { let mut opts = arrow_udf_wasm::build::BuildOpts::default(); + opts.arrow_udf_version = Some("0.3".to_string()); opts.script = script; // use a fixed tempdir to reuse the build cache opts.tempdir = Some(std::env::temp_dir().join("risingwave-rust-udf")); @@ -321,6 +323,13 @@ pub async fn handle_create_function( } }; let runtime = get_or_create_wasm_runtime(&wasm_binary)?; + if runtime.abi_version().0 <= 2 { + return Err(ErrorCode::InvalidParameterValue( + "legacy arrow-udf is no longer supported. please update arrow-udf to 0.3+" + .to_string(), + ) + .into()); + } let identifier_v1 = wasm_identifier_v1( &function_name, &arg_types, @@ -469,13 +478,13 @@ fn wasm_identifier_v1( fn datatype_name(ty: &DataType) -> String { match ty { DataType::Boolean => "boolean".to_string(), - DataType::Int16 => "int2".to_string(), - DataType::Int32 => "int4".to_string(), - DataType::Int64 => "int8".to_string(), - DataType::Float32 => "float4".to_string(), - DataType::Float64 => "float8".to_string(), - DataType::Date => "date".to_string(), - DataType::Time => "time".to_string(), + DataType::Int16 => "int16".to_string(), + DataType::Int32 => "int32".to_string(), + DataType::Int64 => "int64".to_string(), + DataType::Float32 => "float32".to_string(), + DataType::Float64 => "float64".to_string(), + DataType::Date => "date32".to_string(), + DataType::Time => "time64".to_string(), DataType::Timestamp => "timestamp".to_string(), DataType::Timestamptz => "timestamptz".to_string(), DataType::Interval => "interval".to_string(), @@ -483,8 +492,8 @@ fn datatype_name(ty: &DataType) -> String { DataType::Jsonb => "json".to_string(), DataType::Serial => "serial".to_string(), DataType::Int256 => "int256".to_string(), - DataType::Bytea => "bytea".to_string(), - DataType::Varchar => "varchar".to_string(), + DataType::Bytea => "binary".to_string(), + DataType::Varchar => "string".to_string(), DataType::List(inner) => format!("{}[]", datatype_name(inner)), DataType::Struct(s) => format!( "struct<{}>", From 7c8bce9bd5e483a47f383a8d26634251797a3b4d Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 7 May 2024 18:38:35 +0800 Subject: [PATCH 15/27] update python udf test Signed-off-by: Runji Wang --- e2e_test/udf/requirements.txt | 3 +- e2e_test/udf/test.py | 170 +++++++++++++++++++++++----------- 2 files changed, 120 insertions(+), 53 deletions(-) diff --git a/e2e_test/udf/requirements.txt b/e2e_test/udf/requirements.txt index 8642e2b1ec254..36688db1ed1ee 100644 --- a/e2e_test/udf/requirements.txt +++ b/e2e_test/udf/requirements.txt @@ -1,2 +1,3 @@ flask -waitress \ No newline at end of file +waitress +arrow_udf==0.2.1 \ No newline at end of file diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py index 6195476a80004..4443a81a6e74d 100644 --- a/e2e_test/udf/test.py +++ b/e2e_test/udf/test.py @@ -19,9 +19,7 @@ from typing import Iterator, List, Optional, Tuple, Any from decimal import Decimal -sys.path.append("src/expr/udf/python") # noqa - -from risingwave.udf import udf, udtf, UdfServer +from arrow_udf import udf, udtf, UdfServer @udf(input_types=[], result_type="INT") @@ -47,13 +45,21 @@ def gcd3(x: int, y: int, z: int) -> int: return gcd(gcd(x, y), z) -@udf(input_types=["BYTEA"], result_type="STRUCT") +@udf( + input_types=["BYTEA"], + result_type="STRUCT", +) def extract_tcp_info(tcp_packet: bytes): src_addr, dst_addr = struct.unpack("!4s4s", tcp_packet[12:20]) src_port, dst_port = struct.unpack("!HH", tcp_packet[20:24]) src_addr = socket.inet_ntoa(src_addr) dst_addr = socket.inet_ntoa(dst_addr) - return src_addr, dst_addr, src_port, dst_port + return { + "src_addr": src_addr, + "dst_addr": dst_addr, + "src_port": src_port, + "dst_port": dst_port, + } @udtf(input_types="INT", result_types="INT") @@ -84,7 +90,7 @@ def hex_to_dec(hex: Optional[str]) -> Optional[Decimal]: return dec -@udf(input_types=["FLOAT8"], result_type="DECIMAL") +@udf(input_types=["FLOAT64"], result_type="DECIMAL") def float_to_decimal(f: float) -> Decimal: return Decimal(f) @@ -120,21 +126,49 @@ def jsonb_array_identity(list: List[Any]) -> List[Any]: return list -@udf(input_types="STRUCT", result_type="STRUCT") +@udf( + input_types="STRUCT", + result_type="STRUCT", +) def jsonb_array_struct_identity(v: Tuple[List[Any], int]) -> Tuple[List[Any], int]: return v -ALL_TYPES = "BOOLEAN,SMALLINT,INT,BIGINT,FLOAT4,FLOAT8,DECIMAL,DATE,TIME,TIMESTAMP,INTERVAL,VARCHAR,BYTEA,JSONB".split( - "," -) + [ - "STRUCT" -] - - @udf( - input_types=ALL_TYPES, - result_type=f"struct<{','.join(ALL_TYPES)}>", + input_types=[ + "boolean", + "int16", + "int32", + "int64", + "float32", + "float64", + "decimal", + "date32", + "time64", + "timestamp", + "interval", + "string", + "binary", + "json", + "struct", + ], + result_type="""struct< + boolean: boolean, + int16: int16, + int32: int32, + int64: int64, + float32: float32, + float64: float64, + decimal: decimal, + date32: date32, + time64: time64, + timestamp: timestamp, + interval: interval, + string: string, + binary: binary, + json: json, + struct: struct, + >""", ) def return_all( bool, @@ -153,28 +187,60 @@ def return_all( jsonb, struct, ): - return ( - bool, - i16, - i32, - i64, - f32, - f64, - decimal, - date, - time, - timestamp, - interval, - varchar, - bytea, - jsonb, - struct, - ) + return { + "boolean": bool, + "int16": i16, + "int32": i32, + "int64": i64, + "float32": f32, + "float64": f64, + "decimal": decimal, + "date32": date, + "time64": time, + "timestamp": timestamp, + "interval": interval, + "string": varchar, + "binary": bytea, + "json": jsonb, + "struct": struct, + } @udf( - input_types=[t + "[]" for t in ALL_TYPES], - result_type=f"struct<{','.join(t + '[]' for t in ALL_TYPES)}>", + input_types=[ + "boolean[]", + "int16[]", + "int32[]", + "int64[]", + "float32[]", + "float64[]", + "decimal[]", + "date32[]", + "time64[]", + "timestamp[]", + "interval[]", + "string[]", + "binary[]", + "json[]", + "struct[]", + ], + result_type="""struct< + boolean: boolean[], + int16: int16[], + int32: int32[], + int64: int64[], + float32: float32[], + float64: float64[], + decimal: decimal[], + date32: date32[], + time64: time64[], + timestamp: timestamp[], + interval: interval[], + string: string[], + binary: binary[], + json: json[], + struct: struct[], + >""", ) def return_all_arrays( bool, @@ -193,23 +259,23 @@ def return_all_arrays( jsonb, struct, ): - return ( - bool, - i16, - i32, - i64, - f32, - f64, - decimal, - date, - time, - timestamp, - interval, - varchar, - bytea, - jsonb, - struct, - ) + return { + "boolean": bool, + "int16": i16, + "int32": i32, + "int64": i64, + "float32": f32, + "float64": f64, + "decimal": decimal, + "date32": date, + "time64": time, + "timestamp": timestamp, + "interval": interval, + "string": varchar, + "binary": bytea, + "json": jsonb, + "struct": struct, + } if __name__ == "__main__": From b66a4f6be37e13e829f3056c60815c37f24d1d43 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 7 May 2024 18:45:19 +0800 Subject: [PATCH 16/27] use arrow-udf-flight v0.1 Signed-off-by: Runji Wang --- Cargo.lock | 2 ++ Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 3610b36d80637..6721ee10bbd05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -719,6 +719,8 @@ dependencies = [ [[package]] name = "arrow-udf-flight" version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4adb3a066bd22fb520bc3d040d9d59ee54f320c21faeb6df815ea20445c80c54" dependencies = [ "arrow-array 50.0.0", "arrow-flight", diff --git a/Cargo.toml b/Cargo.toml index 9474456921f2d..5c86ea9a5e703 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -142,7 +142,7 @@ arrow-udf-js = "0.2" arrow-udf-js-deno = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "76c995d" } arrow-udf-wasm = { version = "0.2.2", features = ["build"] } arrow-udf-python = "0.1" -arrow-udf-flight = { path = "../arrow-udf/arrow-udf-flight" } +arrow-udf-flight = "0.1" arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" } arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" } arrow-cast-deltalake = { package = "arrow-cast", version = "48.0.1" } From 61667fcb1a45e7753ecc7cf5a536935e069913be Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 7 May 2024 18:56:01 +0800 Subject: [PATCH 17/27] revert dns resolution Signed-off-by: Runji Wang --- src/expr/core/src/expr/expr_udf.rs | 48 +++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index b9b1e290fae08..6faf087a90fc9 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -29,6 +29,7 @@ use arrow_udf_python::{CallMode as PythonCallMode, Runtime as PythonRuntime}; use arrow_udf_wasm::Runtime as WasmRuntime; use await_tree::InstrumentAwait; use cfg_or_panic::cfg_or_panic; +use ginepro::{LoadBalancedChannel, ResolutionStrategy}; use moka::sync::Cache; use prometheus::{ exponential_buckets, register_histogram_vec_with_registry, @@ -39,6 +40,7 @@ use risingwave_common::array::{ArrayRef, DataChunk}; use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; +use risingwave_common::util::addr::HostAddr; use risingwave_expr::expr_context::FRAGMENT_ID; use risingwave_pb::expr::ExprNode; use thiserror_ext::AsReport; @@ -461,13 +463,57 @@ pub(crate) fn get_or_create_flight_client(link: &str) -> Result + }) })?); clients.insert(link.to_owned(), Arc::downgrade(&client)); Ok(client) } } +/// Connect to a UDF service and return a tonic `Channel`. +async fn connect_tonic(mut addr: &str) -> Result { + // Interval between two successive probes of the UDF DNS. + const DNS_PROBE_INTERVAL_SECS: u64 = 5; + // Timeout duration for performing an eager DNS resolution. + const EAGER_DNS_RESOLVE_TIMEOUT_SECS: u64 = 5; + const REQUEST_TIMEOUT_SECS: u64 = 5; + const CONNECT_TIMEOUT_SECS: u64 = 5; + + if addr.starts_with("http://") { + addr = addr.strip_prefix("http://").unwrap(); + } + if addr.starts_with("https://") { + addr = addr.strip_prefix("https://").unwrap(); + } + let host_addr = addr.parse::().map_err(|e| { + arrow_udf_flight::Error::Service(format!( + "invalid address: {}, err: {}", + addr, + e.as_report() + )) + })?; + let channel = LoadBalancedChannel::builder((host_addr.host.clone(), host_addr.port)) + .dns_probe_interval(std::time::Duration::from_secs(DNS_PROBE_INTERVAL_SECS)) + .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS)) + .connect_timeout(Duration::from_secs(CONNECT_TIMEOUT_SECS)) + .resolution_strategy(ResolutionStrategy::Eager { + timeout: tokio::time::Duration::from_secs(EAGER_DNS_RESOLVE_TIMEOUT_SECS), + }) + .channel() + .await + .map_err(|e| { + arrow_udf_flight::Error::Service(format!( + "failed to create LoadBalancedChannel, address: {}, err: {}", + host_addr, + e.as_report() + )) + })?; + Ok(channel.into()) +} + /// Get or create a wasm runtime. /// /// Runtimes returned by this function are cached inside for at least 60 seconds. From d94df069195e705301a01a5268958ed381afae6b Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 7 May 2024 18:58:18 +0800 Subject: [PATCH 18/27] fix clippy Signed-off-by: Runji Wang --- src/expr/core/src/expr/expr_udf.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 6faf087a90fc9..7d9bc8ece3954 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -178,7 +178,7 @@ impl UserDefinedFunction { let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); let result = if self.always_retry_on_network_error { call_with_always_retry_on_network_error( - &client, + client, &self.identifier, &arrow_input, &metrics.udf_retry_count.with_label_values(labels), @@ -192,7 +192,7 @@ impl UserDefinedFunction { .instrument_await(self.span.clone()) .await } else { - call_with_retry(&client, &self.identifier, &arrow_input) + call_with_retry(client, &self.identifier, &arrow_input) .instrument_await(self.span.clone()) .await }; @@ -546,11 +546,11 @@ fn is_connection_error(err: &arrow_udf_flight::Error) -> bool { } fn is_tonic_error(err: &arrow_udf_flight::Error) -> bool { - match err { + matches!( + err, arrow_udf_flight::Error::Tonic(_) - | arrow_udf_flight::Error::Flight(arrow_flight::error::FlightError::Tonic(_)) => true, - _ => false, - } + | arrow_udf_flight::Error::Flight(arrow_flight::error::FlightError::Tonic(_)) + ) } /// Monitor metrics for UDF. From c7bbe2f4f659b44f32405a824c29e6c0b6b41dc4 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 8 May 2024 15:00:55 +0800 Subject: [PATCH 19/27] remove python udf sdk unit test from ci Signed-off-by: Runji Wang --- ci/scripts/run-unit-test.sh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ci/scripts/run-unit-test.sh b/ci/scripts/run-unit-test.sh index d9a723a34fa19..394cdb1a78261 100755 --- a/ci/scripts/run-unit-test.sh +++ b/ci/scripts/run-unit-test.sh @@ -5,11 +5,6 @@ set -euo pipefail REPO_ROOT=${PWD} -echo "+++ Run python UDF SDK unit tests" -cd "${REPO_ROOT}"/src/expr/udf/python -python3 -m pytest -cd "${REPO_ROOT}" - echo "+++ Run unit tests" # use tee to disable progress bar NEXTEST_PROFILE=ci cargo nextest run --features failpoints,sync_point --workspace --exclude risingwave_simulation From 5a786adc68ab6e6e9aa118415229ed858f82afb4 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 8 May 2024 15:07:12 +0800 Subject: [PATCH 20/27] fix udf test in ci Signed-off-by: Runji Wang --- ci/scripts/build-other.sh | 10 +++++++--- ci/scripts/run-e2e-test.sh | 6 +++--- java/dev.md | 6 ------ 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/ci/scripts/build-other.sh b/ci/scripts/build-other.sh index 2311e5164fe74..65c50462f97a0 100755 --- a/ci/scripts/build-other.sh +++ b/ci/scripts/build-other.sh @@ -16,9 +16,13 @@ cd java mvn -B package -Dmaven.test.skip=true mvn -B install -Dmaven.test.skip=true --pl java-binding-integration-test --am mvn dependency:copy-dependencies --no-transfer-progress --pl java-binding-integration-test -mvn -B test --pl udf cd .. +echo "--- Build Java UDF" +cd e2e_test/udf/java +mvn -B package +cd ../../.. + echo "--- Build rust binary for java binding integration test" cargo build -p risingwave_java_binding --bin data-chunk-payload-generator --bin data-chunk-payload-convert-generator @@ -30,9 +34,9 @@ tar --zstd -cf java-binding-integration-test.tar.zst bin java/java-binding-integ echo "--- Upload Java artifacts" cp java/connector-node/assembly/target/risingwave-connector-1.0.0.tar.gz ./risingwave-connector.tar.gz -cp java/udf-example/target/risingwave-udf-example.jar ./risingwave-udf-example.jar +cp e2e_test/udf/java/target/risingwave-udf-example.jar ./udf.jar cp e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm udf.wasm buildkite-agent artifact upload ./risingwave-connector.tar.gz -buildkite-agent artifact upload ./risingwave-udf-example.jar buildkite-agent artifact upload ./java-binding-integration-test.tar.zst +buildkite-agent artifact upload ./udf.jar buildkite-agent artifact upload ./udf.wasm diff --git a/ci/scripts/run-e2e-test.sh b/ci/scripts/run-e2e-test.sh index 5ce0b55f27e9e..31a2f7c4baaa3 100755 --- a/ci/scripts/run-e2e-test.sh +++ b/ci/scripts/run-e2e-test.sh @@ -70,7 +70,7 @@ download-and-decompress-artifact e2e_test_generated ./ download-and-decompress-artifact risingwave_e2e_extended_mode_test-"$profile" target/debug/ mkdir -p e2e_test/udf/wasm/target/wasm32-wasi/release/ buildkite-agent artifact download udf.wasm e2e_test/udf/wasm/target/wasm32-wasi/release/ -buildkite-agent artifact download risingwave-udf-example.jar ./ +buildkite-agent artifact download udf.jar ./ mv target/debug/risingwave_e2e_extended_mode_test-"$profile" target/debug/risingwave_e2e_extended_mode_test chmod +x ./target/debug/risingwave_e2e_extended_mode_test @@ -117,13 +117,13 @@ sqllogictest -p 4566 -d dev './e2e_test/udf/always_retry_python.slt' # sqllogictest -p 4566 -d dev './e2e_test/udf/retry_python.slt' echo "--- e2e, $mode, external java udf" -java -jar risingwave-udf-example.jar & +java -jar udf.jar & sleep 1 sqllogictest -p 4566 -d dev './e2e_test/udf/external_udf.slt' pkill java echo "--- e2e, $mode, embedded udf" -python3 -m pip install --break-system-packages flask waitress +python3 -m pip install --break-system-packages flask waitress arrow-udf==0.2.1 sqllogictest -p 4566 -d dev './e2e_test/udf/wasm_udf.slt' sqllogictest -p 4566 -d dev './e2e_test/udf/rust_udf.slt' sqllogictest -p 4566 -d dev './e2e_test/udf/js_udf.slt' diff --git a/java/dev.md b/java/dev.md index ac20c30fe69fa..148fde173baad 100644 --- a/java/dev.md +++ b/java/dev.md @@ -56,9 +56,3 @@ Config with the following. It may work. "java.format.settings.profile": "Android" } ``` - -## Deploy UDF Library to Maven - -```sh -mvn clean deploy --pl udf --am -``` \ No newline at end of file From f73fb82ad668ffa2051c236351a89094ca313cab Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 8 May 2024 16:59:48 +0800 Subject: [PATCH 21/27] use java sdk from maven repository Signed-off-by: Runji Wang --- e2e_test/udf/java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/e2e_test/udf/java/pom.xml b/e2e_test/udf/java/pom.xml index 9c2351f8ce1f9..7ecd7c54dca17 100644 --- a/e2e_test/udf/java/pom.xml +++ b/e2e_test/udf/java/pom.xml @@ -23,7 +23,7 @@ com.risingwave risingwave-udf - 0.2.0-SNAPSHOT + 0.2.0 com.google.code.gson From f54390675468d8218881ca9eb0c24b0fafb8747c Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 9 May 2024 10:49:46 +0800 Subject: [PATCH 22/27] fix e2e test Signed-off-by: Runji Wang --- ci/scripts/run-e2e-test.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/scripts/run-e2e-test.sh b/ci/scripts/run-e2e-test.sh index 31a2f7c4baaa3..044193b712724 100755 --- a/ci/scripts/run-e2e-test.sh +++ b/ci/scripts/run-e2e-test.sh @@ -105,6 +105,7 @@ echo "--- e2e, $mode, Apache Superset" sqllogictest -p 4566 -d dev './e2e_test/superset/*.slt' --junit "batch-${profile}" echo "--- e2e, $mode, external python udf" +python3 -m pip install --break-system-packages arrow-udf==0.2.1 python3 e2e_test/udf/test.py & sleep 1 sqllogictest -p 4566 -d dev './e2e_test/udf/external_udf.slt' @@ -123,7 +124,7 @@ sqllogictest -p 4566 -d dev './e2e_test/udf/external_udf.slt' pkill java echo "--- e2e, $mode, embedded udf" -python3 -m pip install --break-system-packages flask waitress arrow-udf==0.2.1 +python3 -m pip install --break-system-packages flask waitress sqllogictest -p 4566 -d dev './e2e_test/udf/wasm_udf.slt' sqllogictest -p 4566 -d dev './e2e_test/udf/rust_udf.slt' sqllogictest -p 4566 -d dev './e2e_test/udf/js_udf.slt' From e38e2b831e2706d17af958a31e0fc3a433de4b85 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 9 May 2024 14:48:40 +0800 Subject: [PATCH 23/27] fix decimal output Signed-off-by: Runji Wang --- src/expr/core/src/expr/expr_udf.rs | 27 ++++++------------- .../core/src/table_function/user_defined.rs | 27 ++++++------------- 2 files changed, 16 insertions(+), 38 deletions(-) diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 7d9bc8ece3954..db63046addee6 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -341,10 +341,7 @@ impl Build for UserDefinedFunction { ); rt.add_function( identifier, - arrow_convert - .to_arrow_field("", &return_type)? - .data_type() - .clone(), + arrow_convert.to_arrow_field("", &return_type)?, JsCallMode::CalledOnNullInput, &body, )?; @@ -383,17 +380,12 @@ impl Build for UserDefinedFunction { ) }; - futures::executor::block_on( - rt.add_function( - identifier, - arrow_convert - .to_arrow_field("", &return_type)? - .data_type() - .clone(), - DenoCallMode::CalledOnNullInput, - &body, - ), - )?; + futures::executor::block_on(rt.add_function( + identifier, + arrow_convert.to_arrow_field("", &return_type)?, + DenoCallMode::CalledOnNullInput, + &body, + ))?; UdfImpl::Deno(rt) } @@ -403,10 +395,7 @@ impl Build for UserDefinedFunction { let body = udf.get_body()?; rt.add_function( identifier, - arrow_convert - .to_arrow_field("", &return_type)? - .data_type() - .clone(), + arrow_convert.to_arrow_field("", &return_type)?, PythonCallMode::CalledOnNullInput, body, )?; diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index c1dc09bafe861..6054dafb62b1b 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -213,10 +213,7 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result Result Result Date: Thu, 9 May 2024 14:57:41 +0800 Subject: [PATCH 24/27] add link for backward compatibility Signed-off-by: Runji Wang --- src/expr/core/src/expr/expr_udf.rs | 2 ++ src/expr/core/src/table_function/user_defined.rs | 1 + 2 files changed, 3 insertions(+) diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index db63046addee6..a53204212a474 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -326,6 +326,7 @@ impl Build for UserDefinedFunction { .context("failed to decompress wasm binary")?; let runtime = get_or_create_wasm_runtime(&wasm_binary)?; // backward compatibility + // see for details if runtime.abi_version().0 <= 2 { arrow_convert = UdfArrowConvert { legacy: true }; } @@ -406,6 +407,7 @@ impl Build for UserDefinedFunction { let link = udf.get_link()?; let client = crate::expr::expr_udf::get_or_create_flight_client(link)?; // backward compatibility + // see for details if client.protocol_version() == 1 { arrow_convert = UdfArrowConvert { legacy: true }; } diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index 6054dafb62b1b..bf8354df0ea4b 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -274,6 +274,7 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result for details if client.protocol_version() == 1 { arrow_convert = UdfArrowConvert { legacy: true }; } From bdc67707e4123b9f735dc16ad7da3cf7f66d10c1 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Fri, 10 May 2024 12:20:47 +0800 Subject: [PATCH 25/27] fix: handle error column Signed-off-by: Runji Wang --- src/common/src/array/data_chunk.rs | 1 + src/expr/core/src/error.rs | 10 ++++++++++ src/expr/core/src/expr/expr_udf.rs | 20 ++++++++++++++++++-- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/common/src/array/data_chunk.rs b/src/common/src/array/data_chunk.rs index cdb012a3185cb..6d5b2247979ca 100644 --- a/src/common/src/array/data_chunk.rs +++ b/src/common/src/array/data_chunk.rs @@ -249,6 +249,7 @@ impl DataChunk { Self::new(columns, Bitmap::ones(cardinality)) } + /// Scatter a compacted chunk to a new chunk with the given visibility. pub fn uncompact(self, vis: Bitmap) -> Self { let mut uncompact_builders: Vec<_> = self .columns diff --git a/src/expr/core/src/error.rs b/src/expr/core/src/error.rs index efc3c20526f13..08562b3a973b7 100644 --- a/src/expr/core/src/error.rs +++ b/src/expr/core/src/error.rs @@ -119,6 +119,10 @@ pub enum ExprError { #[error("error in cryptography: {0}")] Cryptography(Box), + + /// Function error message returned by UDF. + #[error("{0}")] + Custom(String), } #[derive(Debug)] @@ -184,6 +188,12 @@ impl From> for MultiExprError { } } +impl FromIterator for MultiExprError { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + impl IntoIterator for MultiExprError { type IntoIter = std::vec::IntoIter; type Item = ExprError; diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index a53204212a474..220fdc2742cd6 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -36,7 +36,7 @@ use prometheus::{ register_int_counter_vec_with_registry, HistogramVec, IntCounter, IntCounterVec, Registry, }; use risingwave_common::array::arrow::{FromArrow, ToArrow, UdfArrowConvert}; -use risingwave_common::array::{ArrayRef, DataChunk}; +use risingwave_common::array::{Array, ArrayRef, DataChunk}; use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; @@ -47,7 +47,7 @@ use thiserror_ext::AsReport; use super::{BoxedExpression, Build}; use crate::expr::Expression; -use crate::{bail, Result}; +use crate::{bail, ExprError, Result}; #[derive(Debug)] pub struct UserDefinedFunction { @@ -249,6 +249,22 @@ impl UserDefinedFunction { ); } + // handle optional error column + if let Some(errors) = output.columns().get(1) { + if errors.data_type() != DataType::Varchar { + bail!( + "UDF returned errors column with invalid type: {:?}", + errors.data_type() + ); + } + let errors = errors + .as_utf8() + .iter() + .filter_map(|msg| msg.map(|s| ExprError::Custom(s.into()))) + .collect(); + return Err(crate::ExprError::Multiple(array.clone(), errors)); + } + Ok(array.clone()) } } From 2970441deee375e6220c64a0e6c2daf269152891 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Fri, 10 May 2024 14:26:48 +0800 Subject: [PATCH 26/27] update arrow-udf-js-deno to fix decimal output Signed-off-by: Runji Wang --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a017d176f8fe2..60707a229c53a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -749,7 +749,7 @@ dependencies = [ [[package]] name = "arrow-udf-js-deno" version = "0.0.1" -source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=76c995d#76c995d31f66785c39c7d70196c8ba0f1a61ad60" +source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=fa36365#fa3636559de986aa592da6e8b3fbfac7bdd4bb78" dependencies = [ "anyhow", "arrow-array 50.0.0", @@ -771,7 +771,7 @@ dependencies = [ [[package]] name = "arrow-udf-js-deno-runtime" version = "0.0.1" -source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=76c995d#76c995d31f66785c39c7d70196c8ba0f1a61ad60" +source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=fa36365#fa3636559de986aa592da6e8b3fbfac7bdd4bb78" dependencies = [ "anyhow", "deno_ast", diff --git a/Cargo.toml b/Cargo.toml index 5c86ea9a5e703..19675c3d08d9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -139,7 +139,7 @@ arrow-select = "50" arrow-ord = "50" arrow-row = "50" arrow-udf-js = "0.2" -arrow-udf-js-deno = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "76c995d" } +arrow-udf-js-deno = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "fa36365" } arrow-udf-wasm = { version = "0.2.2", features = ["build"] } arrow-udf-python = "0.1" arrow-udf-flight = "0.1" From a35ac1cc7f598f3bc53d774a359334dcb677adf8 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Fri, 10 May 2024 16:27:32 +0800 Subject: [PATCH 27/27] fix error ui slt Signed-off-by: Runji Wang --- e2e_test/error_ui/simple/main.slt | 6 ++++-- src/expr/core/src/expr/expr_udf.rs | 14 +++++++------- src/expr/core/src/expr/mod.rs | 2 +- src/frontend/src/handler/create_function.rs | 7 +++---- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/e2e_test/error_ui/simple/main.slt b/e2e_test/error_ui/simple/main.slt index 8ef82e1f0d1c7..6bcbbde608cf8 100644 --- a/e2e_test/error_ui/simple/main.slt +++ b/e2e_test/error_ui/simple/main.slt @@ -13,8 +13,10 @@ create function int_42() returns int as int_42 using link '555.0.0.1:8815'; ---- db error: ERROR: Failed to run the query -Caused by: - Flight service error: invalid address: 555.0.0.1:8815, err: failed to parse address: http://555.0.0.1:8815: invalid IPv4 address +Caused by these errors (recent errors listed first): + 1: Expr error + 2: UDF error + 3: Flight service error: invalid address: 555.0.0.1:8815, err: failed to parse address: http://555.0.0.1:8815: invalid IPv4 address statement error diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 220fdc2742cd6..54d3006dc3033 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -421,7 +421,7 @@ impl Build for UserDefinedFunction { #[cfg(not(madsim))] _ => { let link = udf.get_link()?; - let client = crate::expr::expr_udf::get_or_create_flight_client(link)?; + let client = get_or_create_flight_client(link)?; // backward compatibility // see for details if client.protocol_version() == 1 { @@ -456,11 +456,11 @@ impl Build for UserDefinedFunction { } } -#[cfg(not(madsim))] +#[cfg_or_panic(not(madsim))] /// Get or create a client for the given UDF service. /// /// There is a global cache for clients, so that we can reuse the same client for the same service. -pub(crate) fn get_or_create_flight_client(link: &str) -> Result> { +pub fn get_or_create_flight_client(link: &str) -> Result> { static CLIENTS: LazyLock>>> = LazyLock::new(Default::default); let mut clients = CLIENTS.lock().unwrap(); @@ -489,11 +489,11 @@ async fn connect_tonic(mut addr: &str) -> Result { const REQUEST_TIMEOUT_SECS: u64 = 5; const CONNECT_TIMEOUT_SECS: u64 = 5; - if addr.starts_with("http://") { - addr = addr.strip_prefix("http://").unwrap(); + if let Some(s) = addr.strip_prefix("http://") { + addr = s; } - if addr.starts_with("https://") { - addr = addr.strip_prefix("https://").unwrap(); + if let Some(s) = addr.strip_prefix("https://") { + addr = s; } let host_addr = addr.parse::().map_err(|e| { arrow_udf_flight::Error::Service(format!( diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 6dbb3906f5618..9188ced21d111 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -51,7 +51,7 @@ use risingwave_common::types::{DataType, Datum}; pub use self::build::*; pub use self::expr_input_ref::InputRefExpression; pub use self::expr_literal::LiteralExpression; -pub use self::expr_udf::get_or_create_wasm_runtime; +pub use self::expr_udf::{get_or_create_flight_client, get_or_create_wasm_runtime}; pub use self::value::{ValueImpl, ValueRef}; pub use self::wrapper::*; pub use super::{ExprError, Result}; diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 88172293cb458..471145a12a3e4 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -12,16 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::{anyhow, Context}; +use anyhow::Context; use arrow_schema::Fields; -use arrow_udf_flight::Client as FlightClient; use bytes::Bytes; use itertools::Itertools; use pgwire::pg_response::StatementType; use risingwave_common::array::arrow::{ToArrow, UdfArrowConvert}; use risingwave_common::catalog::FunctionId; use risingwave_common::types::DataType; -use risingwave_expr::expr::get_or_create_wasm_runtime; +use risingwave_expr::expr::{get_or_create_flight_client, get_or_create_wasm_runtime}; use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; use risingwave_pb::catalog::Function; use risingwave_sqlparser::ast::{CreateFunctionBody, ObjectName, OperateFunctionArg}; @@ -167,7 +166,7 @@ pub async fn handle_create_function( // check UDF server { - let client = FlightClient::connect(&l).await.map_err(|e| anyhow!(e))?; + let client = get_or_create_flight_client(&l)?; let convert = UdfArrowConvert { legacy: client.protocol_version() == 1, };