Skip to content

Commit

Permalink
refactor(udf): use cfg_or_panic for UDF implementations
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao committed Sep 19, 2023
1 parent a789c61 commit 2637f86
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 79 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ arrow-schema = { workspace = true }
async-trait = "0.1"
auto_enums = "0.8"
await-tree = { workspace = true }
cfg-or-panic = "0.1"
chrono = { version = "0.4", default-features = false, features = ["clock", "std"] }
chrono-tz = { version = "0.8", features = ["case-insensitive"] }
ctor = "0.2"
Expand Down
31 changes: 4 additions & 27 deletions src/expr/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::sync::{Arc, LazyLock, Mutex, Weak};

use arrow_schema::{Field, Fields, Schema, SchemaRef};
use await_tree::InstrumentAwait;
use cfg_or_panic::cfg_or_panic;
use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk};
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};
Expand All @@ -39,13 +40,13 @@ pub struct UdfExpression {
span: await_tree::Span,
}

#[cfg(not(madsim))]
#[async_trait::async_trait]
impl Expression for UdfExpression {
fn return_type(&self) -> DataType {
self.return_type.clone()
}

#[cfg_or_panic(not(madsim))]
async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
let vis = input.vis().to_bitmap();
let mut columns = Vec::with_capacity(self.children.len());
Expand All @@ -56,6 +57,7 @@ impl Expression for UdfExpression {
self.eval_inner(columns, vis).await
}

#[cfg_or_panic(not(madsim))]
async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
let mut columns = Vec::with_capacity(self.children.len());
for child in &self.children {
Expand Down Expand Up @@ -114,7 +116,7 @@ impl UdfExpression {
}
}

#[cfg(not(madsim))]
#[cfg_or_panic(not(madsim))]
impl<'a> TryFrom<&'a ExprNode> for UdfExpression {
type Error = ExprError;

Expand Down Expand Up @@ -171,28 +173,3 @@ pub(crate) fn get_or_create_client(link: &str) -> Result<Arc<ArrowFlightUdfClien
Ok(client)
}
}

#[cfg(madsim)]
#[async_trait::async_trait]
impl Expression for UdfExpression {
fn return_type(&self) -> DataType {
self.return_type.clone()
}

async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
panic!("UDF is not supported in simulation yet");
}

async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
panic!("UDF is not supported in simulation yet");
}
}

#[cfg(madsim)]
impl<'a> TryFrom<&'a ExprNode> for UdfExpression {
type Error = ExprError;

fn try_from(prost: &'a ExprNode) -> Result<Self> {
panic!("UDF is not supported in simulation yet");
}
}
25 changes: 3 additions & 22 deletions src/expr/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::sync::Arc;

use arrow_array::RecordBatch;
use arrow_schema::{Field, Fields, Schema, SchemaRef};
use cfg_or_panic::cfg_or_panic;
use futures_util::stream;
use risingwave_common::array::{DataChunk, I32Array};
use risingwave_common::bail;
Expand All @@ -35,13 +36,13 @@ pub struct UserDefinedTableFunction {
chunk_size: usize,
}

#[cfg(not(madsim))]
#[async_trait::async_trait]
impl TableFunction for UserDefinedTableFunction {
fn return_type(&self) -> DataType {
self.return_type.clone()
}

#[cfg_or_panic(not(madsim))]
async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
self.eval_inner(input)
}
Expand Down Expand Up @@ -124,7 +125,7 @@ impl UserDefinedTableFunction {
}
}

#[cfg(not(madsim))]
#[cfg_or_panic(not(madsim))]
pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<BoxedTableFunction> {
let Some(udtf) = &prost.udtf else {
bail!("expect UDTF");
Expand Down Expand Up @@ -157,23 +158,3 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<Bo
}
.boxed())
}

#[cfg(madsim)]
#[async_trait::async_trait]
impl TableFunction for UserDefinedTableFunction {
fn return_type(&self) -> DataType {
panic!("UDF is not supported in simulation yet");
}

async fn eval<'a>(&'a self, input: &'a DataChunk) -> BoxStream<'a, Result<DataChunk>> {
panic!("UDF is not supported in simulation yet");
}
}

#[cfg(madsim)]
pub fn new_user_defined(
_prost: &PbTableFunction,
_chunk_size: usize,
) -> Result<BoxedTableFunction> {
panic!("UDF is not supported in simulation yet");
}
1 change: 1 addition & 0 deletions src/udf/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ arrow-array = { workspace = true }
arrow-flight = { workspace = true }
arrow-schema = { workspace = true }
arrow-select = { workspace = true }
cfg-or-panic = "0.1.1"
futures-util = "0.3.28"
static_assertions = "1"
thiserror = "1"
Expand Down
33 changes: 3 additions & 30 deletions src/udf/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use arrow_flight::error::FlightError;
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::{FlightData, FlightDescriptor};
use arrow_schema::Schema;
use cfg_or_panic::cfg_or_panic;
use futures_util::{stream, Stream, StreamExt, TryStreamExt};
use tonic::transport::Channel;

Expand All @@ -30,7 +31,8 @@ pub struct ArrowFlightUdfClient {
client: FlightServiceClient<Channel>,
}

#[cfg(not(madsim))]
// TODO: support UDF in simulation
#[cfg_or_panic(not(madsim))]
impl ArrowFlightUdfClient {
/// Connect to a UDF service.
pub async fn connect(addr: &str) -> Result<Self> {
Expand Down Expand Up @@ -129,35 +131,6 @@ impl ArrowFlightUdfClient {
}
}

// TODO: support UDF in simulation
#[cfg(madsim)]
impl ArrowFlightUdfClient {
/// Connect to a UDF service.
pub async fn connect(_addr: &str) -> Result<Self> {
panic!("UDF is not supported in simulation yet")
}

/// Check if the function is available.
pub async fn check(&self, _id: &str, _args: &Schema, _returns: &Schema) -> Result<()> {
panic!("UDF is not supported in simulation yet")
}

/// Call a function.
pub async fn call(&self, _id: &str, _input: RecordBatch) -> Result<RecordBatch> {
panic!("UDF is not supported in simulation yet")
}

/// Call a function with streaming input and output.
pub async fn call_stream(
&self,
_id: &str,
_inputs: impl Stream<Item = RecordBatch> + Send + 'static,
) -> Result<impl Stream<Item = Result<RecordBatch>> + Send + 'static> {
panic!("UDF is not supported in simulation yet");
Ok(stream::empty())
}
}

/// Check if two list of data types match, ignoring field names.
fn data_types_match(a: &[&arrow_schema::DataType], b: &[&arrow_schema::DataType]) -> bool {
if a.len() != b.len() {
Expand Down

0 comments on commit 2637f86

Please sign in to comment.