Skip to content

Commit

Permalink
fix merge issue
Browse files Browse the repository at this point in the history
  • Loading branch information
xxchan committed Oct 25, 2023
1 parent c644361 commit 56e6fc4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 47 deletions.
69 changes: 25 additions & 44 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl Expression for UdfExpression {
self.eval_inner(columns, vis).await
}

#[cfg_or_panic(not(madsim))]
async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
let mut columns = Vec::with_capacity(self.children.len());
for child in &self.children {
Expand Down Expand Up @@ -146,14 +147,21 @@ impl UdfExpression {
vis.len(),
);
}

let data_chunk =
DataChunk::try_from(&output).expect("failed to convert UDF output to DataChunk");
let output = data_chunk.uncompact(vis.clone());

let Some(array) = output.columns().get(0) else {
bail!("UDF returned no columns");
};
if !array.data_type().equals_datatype(&self.return_type) {
bail!(
"UDF returned {:?}, but expected {:?}",
array.data_type(),
self.return_type,
);
}

Ok(array.clone())
}
}
Expand All @@ -167,21 +175,6 @@ impl Build for UdfExpression {
let return_type = DataType::from(prost.get_return_type().unwrap());
let udf = prost.get_rex_node().unwrap().as_udf().unwrap();

let arg_schema = Arc::new(Schema::new(
udf.arg_types
.iter()
.map::<Result<_>, _>(|t| {
Ok(Field::new(
"",
DataType::from(t)
.try_into()
.map_err(risingwave_udf::Error::Unsupported)?,
true,
))
})
.try_collect::<Fields>()?,
));

let imp = match &udf.extra {
None | Some(PbExtra::External(PbExternalUdfExtra {})) => UdfImpl::External {
client: get_or_create_flight_client(&udf.link)?,
Expand All @@ -203,6 +196,21 @@ impl Build for UdfExpression {
}
};

let arg_schema = Arc::new(Schema::new(
udf.arg_types
.iter()
.map::<Result<_>, _>(|t| {
Ok(Field::new(
"",
DataType::from(t)
.try_into()
.map_err(risingwave_udf::Error::Unsupported)?,
true,
))
})
.try_collect::<Fields>()?,
));

Ok(Self {
children: udf.children.iter().map(build_child).try_collect()?,
arg_types: udf.arg_types.iter().map(|t| t.into()).collect(),
Expand All @@ -227,35 +235,8 @@ pub(crate) fn get_or_create_flight_client(link: &str) -> Result<Arc<ArrowFlightU
Ok(client)
} else {
// create new client
let client = Arc::new(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(ArrowFlightUdfClient::connect(link))
})?);
let client = Arc::new(ArrowFlightUdfClient::connect_lazy(link)?);
clients.insert(link.into(), Arc::downgrade(&client));
Ok(client)
}
}

#[cfg(madsim)]
#[async_trait::async_trait]
impl Expression for UdfExpression {
fn return_type(&self) -> DataType {
self.return_type.clone()
}

async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
panic!("UDF is not supported in simulation yet");
}

async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
panic!("UDF is not supported in simulation yet");
}
}

#[cfg(madsim)]
impl<'a> TryFrom<&'a ExprNode> for UdfExpression {
type Error = ExprError;

fn try_from(prost: &'a ExprNode) -> Result<Self> {
panic!("UDF is not supported in simulation yet");
}
}
5 changes: 2 additions & 3 deletions src/expr/core/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,6 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<Bo
bail!("expect UDTF");
};

// connect to UDF service
let client = crate::expr::expr_udf::get_or_create_flight_client(&udtf.link)?;

let arg_schema = Arc::new(Schema::new(
udtf.arg_types
.iter()
Expand All @@ -149,6 +146,8 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<Bo
})
.try_collect::<_, Fields, _>()?,
));
// connect to UDF service
let client = crate::expr::expr_udf::get_or_create_flight_client(&udtf.link)?;

Ok(UserDefinedTableFunction {
children: prost.args.iter().map(expr_build_from_prost).try_collect()?,
Expand Down

0 comments on commit 56e6fc4

Please sign in to comment.