From 3c24444a21db23e6a4eaf33ab74f2766b610dc94 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 3 Apr 2024 15:03:27 +0900 Subject: [PATCH] fix(udf): make field name case-insensitive in Rust UDF (#16096) Signed-off-by: Runji Wang --- e2e_test/udf/rust_udf.slt | 27 ++++++++++++++++++++ src/frontend/src/handler/create_function.rs | 28 ++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/e2e_test/udf/rust_udf.slt b/e2e_test/udf/rust_udf.slt index a9ba88367822e..8fd1f82ee1f48 100644 --- a/e2e_test/udf/rust_udf.slt +++ b/e2e_test/udf/rust_udf.slt @@ -104,6 +104,33 @@ statement ok drop function key_value; +statement ok +create function key_values(varchar) returns table (Key varchar, Value varchar) language rust as $$ + #[derive(StructType)] + struct KeyValue<'a> { + // note that field names are case-insensitive + key: &'a str, + value: &'a str, + } + #[function("key_values(varchar) -> setof struct KeyValue")] + fn key_values(kv: &str) -> impl Iterator> { + kv.split(',').filter_map(|kv| { + kv.split_once('=') + .map(|(key, value)| KeyValue { key, value }) + }) + } +$$; + +query T +select * from key_values('a=1,b=2'); +---- +a 1 +b 2 + +statement ok +drop function key_values; + + statement ok create function series(n int) returns table (x int) language rust as $$ fn series(n: i32) -> impl Iterator { diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 95e2900c1db84..49a5041b985dd 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -325,8 +325,34 @@ fn find_wasm_identifier_v2( runtime: &arrow_udf_wasm::Runtime, inlined_signature: &str, ) -> Result { + // Inline types in function signature. + // + // # Example + // + // ```text + // types = { "KeyValue": "key:varchar,value:varchar" } + // input = "keyvalue(varchar, varchar) -> struct KeyValue" + // output = "keyvalue(varchar, varchar) -> struct" + // ``` + let inline_types = |s: &str| -> String { + let mut inlined = s.to_string(); + // iteratively replace `struct Xxx` with `struct<...>` until no replacement is made. + loop { + let replaced = inlined.clone(); + for (k, v) in runtime.types() { + inlined = inlined.replace(&format!("struct {k}"), &format!("struct<{v}>")); + } + if replaced == inlined { + return inlined; + } + } + }; + // Function signature in arrow-udf is case sensitive. + // However, SQL identifiers are usually case insensitive and stored in lowercase. + // So we should convert the signature to lowercase before comparison. let identifier = runtime - .find_function_by_inlined_signature(inlined_signature) + .functions() + .find(|f| inline_types(f).to_lowercase() == inlined_signature) .ok_or_else(|| { ErrorCode::InvalidParameterValue(format!( "function not found in wasm binary: \"{}\"\nHINT: available functions:\n {}",