diff --git a/e2e_test/error_ui/simple/main.slt b/e2e_test/error_ui/simple/main.slt index 8ef82e1f0d1c7..6bcbbde608cf8 100644 --- a/e2e_test/error_ui/simple/main.slt +++ b/e2e_test/error_ui/simple/main.slt @@ -13,8 +13,10 @@ create function int_42() returns int as int_42 using link '555.0.0.1:8815'; ---- db error: ERROR: Failed to run the query -Caused by: - Flight service error: invalid address: 555.0.0.1:8815, err: failed to parse address: http://555.0.0.1:8815: invalid IPv4 address +Caused by these errors (recent errors listed first): + 1: Expr error + 2: UDF error + 3: Flight service error: invalid address: 555.0.0.1:8815, err: failed to parse address: http://555.0.0.1:8815: invalid IPv4 address statement error diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 220fdc2742cd6..54d3006dc3033 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -421,7 +421,7 @@ impl Build for UserDefinedFunction { #[cfg(not(madsim))] _ => { let link = udf.get_link()?; - let client = crate::expr::expr_udf::get_or_create_flight_client(link)?; + let client = get_or_create_flight_client(link)?; // backward compatibility // see for details if client.protocol_version() == 1 { @@ -456,11 +456,11 @@ impl Build for UserDefinedFunction { } } -#[cfg(not(madsim))] +#[cfg_or_panic(not(madsim))] /// Get or create a client for the given UDF service. /// /// There is a global cache for clients, so that we can reuse the same client for the same service. -pub(crate) fn get_or_create_flight_client(link: &str) -> Result> { +pub fn get_or_create_flight_client(link: &str) -> Result> { static CLIENTS: LazyLock>>> = LazyLock::new(Default::default); let mut clients = CLIENTS.lock().unwrap(); @@ -489,11 +489,11 @@ async fn connect_tonic(mut addr: &str) -> Result { const REQUEST_TIMEOUT_SECS: u64 = 5; const CONNECT_TIMEOUT_SECS: u64 = 5; - if addr.starts_with("http://") { - addr = addr.strip_prefix("http://").unwrap(); + if let Some(s) = addr.strip_prefix("http://") { + addr = s; } - if addr.starts_with("https://") { - addr = addr.strip_prefix("https://").unwrap(); + if let Some(s) = addr.strip_prefix("https://") { + addr = s; } let host_addr = addr.parse::().map_err(|e| { arrow_udf_flight::Error::Service(format!( diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 6dbb3906f5618..9188ced21d111 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -51,7 +51,7 @@ use risingwave_common::types::{DataType, Datum}; pub use self::build::*; pub use self::expr_input_ref::InputRefExpression; pub use self::expr_literal::LiteralExpression; -pub use self::expr_udf::get_or_create_wasm_runtime; +pub use self::expr_udf::{get_or_create_flight_client, get_or_create_wasm_runtime}; pub use self::value::{ValueImpl, ValueRef}; pub use self::wrapper::*; pub use super::{ExprError, Result}; diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 88172293cb458..471145a12a3e4 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -12,16 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::{anyhow, Context}; +use anyhow::Context; use arrow_schema::Fields; -use arrow_udf_flight::Client as FlightClient; use bytes::Bytes; use itertools::Itertools; use pgwire::pg_response::StatementType; use risingwave_common::array::arrow::{ToArrow, UdfArrowConvert}; use risingwave_common::catalog::FunctionId; use risingwave_common::types::DataType; -use risingwave_expr::expr::get_or_create_wasm_runtime; +use risingwave_expr::expr::{get_or_create_flight_client, get_or_create_wasm_runtime}; use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; use risingwave_pb::catalog::Function; use risingwave_sqlparser::ast::{CreateFunctionBody, ObjectName, OperateFunctionArg}; @@ -167,7 +166,7 @@ pub async fn handle_create_function( // check UDF server { - let client = FlightClient::connect(&l).await.map_err(|e| anyhow!(e))?; + let client = get_or_create_flight_client(&l)?; let convert = UdfArrowConvert { legacy: client.protocol_version() == 1, };