From 228b9e8549de31ada6bd12b5a7ee4d166513f271 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Sat, 11 May 2024 11:48:39 +0800 Subject: [PATCH] refactor(udf): switch to the latest arrow-udf versions (#16619) Signed-off-by: Runji Wang --- Cargo.lock | 317 +++++----- Cargo.toml | 10 +- ci/scripts/build-other.sh | 10 +- ci/scripts/run-e2e-test.sh | 5 +- ci/scripts/run-unit-test.sh | 5 - e2e_test/error_ui/simple/main.slt | 6 +- e2e_test/udf/external_udf.slt | 2 +- .../udf/java}/README.md | 0 .../udf-example => e2e_test/udf/java}/pom.xml | 12 +- .../src/main/java/com/example/UdfExample.java | 2 +- e2e_test/udf/requirements.txt | 3 +- e2e_test/udf/test.py | 170 ++++-- e2e_test/udf/wasm/Cargo.toml | 2 +- java/dev.md | 6 - java/pom.xml | 4 +- java/udf/CHANGELOG.md | 39 -- java/udf/README.md | 274 --------- java/udf/pom.xml | 58 -- .../risingwave/functions/DataTypeHint.java | 23 - .../risingwave/functions/PeriodDuration.java | 29 - .../risingwave/functions/ScalarFunction.java | 53 -- .../functions/ScalarFunctionBatch.java | 61 -- .../risingwave/functions/TableFunction.java | 60 -- .../functions/TableFunctionBatch.java | 87 --- .../com/risingwave/functions/TypeUtils.java | 505 ---------------- .../com/risingwave/functions/UdfProducer.java | 108 ---- .../com/risingwave/functions/UdfServer.java | 81 --- .../functions/UserDefinedFunction.java | 23 - .../functions/UserDefinedFunctionBatch.java | 87 --- .../risingwave/functions/TestUdfServer.java | 286 --------- .../com/risingwave/functions/UdfClient.java | 51 -- src/common/src/array/arrow/arrow_udf.rs | 89 ++- src/common/src/array/data_chunk.rs | 1 + src/expr/core/Cargo.toml | 6 +- src/expr/core/src/error.rs | 18 +- src/expr/core/src/expr/expr_udf.rs | 313 ++++++++-- src/expr/core/src/expr/mod.rs | 2 +- .../core/src/table_function/user_defined.rs | 54 +- src/expr/udf/Cargo.toml | 35 -- src/expr/udf/README-js.md | 83 --- src/expr/udf/README.md | 118 ---- src/expr/udf/examples/client.rs | 76 --- src/expr/udf/python/.gitignore | 2 - src/expr/udf/python/CHANGELOG.md | 37 -- src/expr/udf/python/README.md | 112 ---- src/expr/udf/python/publish.md | 19 - src/expr/udf/python/pyproject.toml | 20 - src/expr/udf/python/risingwave/__init__.py | 13 - src/expr/udf/python/risingwave/test_udf.py | 240 -------- src/expr/udf/python/risingwave/udf.py | 552 ------------------ .../udf/python/risingwave/udf/health_check.py | 40 -- src/expr/udf/src/error.rs | 67 --- src/expr/udf/src/external.rs | 260 --------- src/expr/udf/src/lib.rs | 24 - src/expr/udf/src/metrics.rs | 111 ---- src/frontend/Cargo.toml | 2 +- src/frontend/src/handler/create_function.rs | 76 ++- 57 files changed, 741 insertions(+), 4008 deletions(-) rename {java/udf-example => e2e_test/udf/java}/README.md (100%) rename {java/udf-example => e2e_test/udf/java}/pom.xml (86%) rename {java/udf-example => e2e_test/udf/java}/src/main/java/com/example/UdfExample.java (99%) delete mode 100644 java/udf/CHANGELOG.md delete mode 100644 java/udf/README.md delete mode 100644 java/udf/pom.xml delete mode 100644 java/udf/src/main/java/com/risingwave/functions/DataTypeHint.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/PeriodDuration.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/ScalarFunction.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/TableFunction.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/TypeUtils.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/UdfProducer.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/UdfServer.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/UserDefinedFunction.java delete mode 100644 java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java delete mode 100644 java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java delete mode 100644 java/udf/src/test/java/com/risingwave/functions/UdfClient.java delete mode 100644 src/expr/udf/Cargo.toml delete mode 100644 src/expr/udf/README-js.md delete mode 100644 src/expr/udf/README.md delete mode 100644 src/expr/udf/examples/client.rs delete mode 100644 src/expr/udf/python/.gitignore delete mode 100644 src/expr/udf/python/CHANGELOG.md delete mode 100644 src/expr/udf/python/README.md delete mode 100644 src/expr/udf/python/publish.md delete mode 100644 src/expr/udf/python/pyproject.toml delete mode 100644 src/expr/udf/python/risingwave/__init__.py delete mode 100644 src/expr/udf/python/risingwave/test_udf.py delete mode 100644 src/expr/udf/python/risingwave/udf.py delete mode 100644 src/expr/udf/python/risingwave/udf/health_check.py delete mode 100644 src/expr/udf/src/error.rs delete mode 100644 src/expr/udf/src/external.rs delete mode 100644 src/expr/udf/src/lib.rs delete mode 100644 src/expr/udf/src/metrics.rs diff --git a/Cargo.lock b/Cargo.lock index fd0a73af29b2..f90c1d41690f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -716,11 +716,28 @@ dependencies = [ "regex-syntax 0.8.2", ] +[[package]] +name = "arrow-udf-flight" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4adb3a066bd22fb520bc3d040d9d59ee54f320c21faeb6df815ea20445c80c54" +dependencies = [ + "arrow-array 50.0.0", + "arrow-flight", + "arrow-schema 50.0.0", + "arrow-select 50.0.0", + "futures-util", + "thiserror", + "tokio", + "tonic 0.10.2", + "tracing", +] + [[package]] name = "arrow-udf-js" -version = "0.1.2" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "252b6355ad1e57eb6454b705c51652de55aa22eb018cdb95be0dbf62ee3ec78f" +checksum = "0519711e77180c5fe9891b81d912d937864894c77932b5df52169966f4a948bb" dependencies = [ "anyhow", "arrow-array 50.0.0", @@ -732,7 +749,7 @@ dependencies = [ [[package]] name = "arrow-udf-js-deno" version = "0.0.1" -source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=23fe0dd#23fe0dd41616f4646f9139e22a335518e6cc9a47" +source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=fa36365#fa3636559de986aa592da6e8b3fbfac7bdd4bb78" dependencies = [ "anyhow", "arrow-array 50.0.0", @@ -754,7 +771,7 @@ dependencies = [ [[package]] name = "arrow-udf-js-deno-runtime" version = "0.0.1" -source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=23fe0dd#23fe0dd41616f4646f9139e22a335518e6cc9a47" +source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=fa36365#fa3636559de986aa592da6e8b3fbfac7bdd4bb78" dependencies = [ "anyhow", "deno_ast", @@ -782,7 +799,8 @@ dependencies = [ [[package]] name = "arrow-udf-python" version = "0.1.0" -source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=6c32f71#6c32f710b5948147f8214797fc334a4a3cadef0d" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41eaaa010b9cf07bedda6f1dafa050496e96fff7ae4b9602fb77c25c24c64cb7" dependencies = [ "anyhow", "arrow-array 50.0.0", @@ -796,9 +814,9 @@ dependencies = [ [[package]] name = "arrow-udf-wasm" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59a51355b8ca4de8ae028e5efb45c248dad4568cde6707f23b89f9b86a907f36" +checksum = "eb829e25925161d93617d4b053bae03fe51e708f2cce088d85df856011d4f369" dependencies = [ "anyhow", "arrow-array 50.0.0", @@ -1658,7 +1676,7 @@ dependencies = [ "cfg-if", "libc", "miniz_oxide", - "object", + "object 0.32.1", "rustc-demangle", ] @@ -2728,18 +2746,18 @@ dependencies = [ [[package]] name = "cranelift-bforest" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b3775cc6cc00c90d29eebea55feedb2b0168e23f5415bab7859c4004d7323d1" +checksum = "79b27922a6879b5b5361d0a084cb0b1941bf109a98540addcb932da13b68bed4" dependencies = [ "cranelift-entity", ] [[package]] name = "cranelift-codegen" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "637f3184ba5bfa48d425bad1d2e4faf5fcf619f5e0ca107edc6dc02f589d4d74" +checksum = "304c455b28bf56372729acb356afbb55d622f2b0f2f7837aa5e57c138acaac4d" dependencies = [ "bumpalo", "cranelift-bforest", @@ -2758,33 +2776,33 @@ dependencies = [ [[package]] name = "cranelift-codegen-meta" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4b35b8240462341d94d31aab807cad704683988708261aecae3d57db48b7212" +checksum = "1653c56b99591d07f67c5ca7f9f25888948af3f4b97186bff838d687d666f613" dependencies = [ "cranelift-codegen-shared", ] [[package]] name = "cranelift-codegen-shared" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3cd1555aa9df1d6d8375732de41b4cb0d787006948d55b6d004d521e9efeb0" +checksum = "f5b6a9cf6b6eb820ee3f973a0db313c05dc12d370f37b4fe9630286e1672573f" [[package]] name = "cranelift-control" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14b31a562a10e98ab148fa146801e20665c5f9eda4fce9b2c5a3836575887d74" +checksum = "d9d06e6bf30075fb6bed9e034ec046475093392eea1aff90eb5c44c4a033d19a" dependencies = [ "arbitrary", ] [[package]] name = "cranelift-entity" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af1e0467700a3f4fccf5feddbaebdf8b0eb82535b06a9600c4bc5df40872e75d" +checksum = "29be04f931b73cdb9694874a295027471817f26f26d2f0ebe5454153176b6e3a" dependencies = [ "serde", "serde_derive", @@ -2792,9 +2810,9 @@ dependencies = [ [[package]] name = "cranelift-frontend" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6cb918ee2c23939262efd1b99d76a21212ac7bd35129582133e21a22a6ff0467" +checksum = "a07fd7393041d7faa2f37426f5dc7fc04003b70988810e8c063beefeff1cd8f9" dependencies = [ "cranelift-codegen", "log", @@ -2804,15 +2822,15 @@ dependencies = [ [[package]] name = "cranelift-isle" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "966e4cfb23cf6d7f1d285d53a912baaffc5f06bcd9c9b0a2d8c66a184fae534b" +checksum = "f341d7938caa6dff8149dac05bb2b53fc680323826b83b4cf175ab9f5139a3c9" [[package]] name = "cranelift-native" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bea803aadfc4aabdfae7c3870f1b1f6dd4332f4091859e9758ef5fca6bf8cc87" +checksum = "82af6066e6448d26eeabb7aa26a43f7ff79f8217b06bade4ee6ef230aecc8880" dependencies = [ "cranelift-codegen", "libc", @@ -2821,9 +2839,9 @@ dependencies = [ [[package]] name = "cranelift-wasm" -version = "0.106.1" +version = "0.107.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11d18a3572cd897555bba3621e568029417d8f5cc26aeede2d7cb0bad6afd916" +checksum = "2766fab7284a914a7f17f90ebe865c86453225fb8637ac31f123f5028fee69cd" dependencies = [ "cranelift-codegen", "cranelift-entity", @@ -5478,9 +5496,9 @@ dependencies = [ [[package]] name = "ginepro" -version = "0.7.0" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eedbff62a689be48f58f32571dbf3d60c4a73b39740141dfe7ac942536ea27f7" +checksum = "3b00ef897d4082727a53ea1111cd19bfa4ccdc476a5eb9f49087047113a43891" dependencies = [ "anyhow", "async-trait", @@ -6937,20 +6955,11 @@ dependencies = [ "pkg-config", ] -[[package]] -name = "mach" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa" -dependencies = [ - "libc", -] - [[package]] name = "mach2" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0d1830bcd151a6fc4aea1369af235b36c1528fe976b8ff678683c9995eade8" +checksum = "19b955cdeb2a02b9117f121ce63aa52d08ade45de53e48fe6a38b39c10f6f709" dependencies = [ "libc", ] @@ -7756,6 +7765,15 @@ name = "object" version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +dependencies = [ + "memchr", +] + +[[package]] +name = "object" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8dd6c0cdf9429bce006e1362bfce61fa1bfd8c898a643ed8d2b471934701d3d" dependencies = [ "crc32fast", "hashbrown 0.14.3", @@ -9306,9 +9324,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" +checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" dependencies = [ "cfg-if", "indoc", @@ -9324,9 +9342,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" +checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" dependencies = [ "once_cell", "target-lexicon", @@ -9334,9 +9352,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" +checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" dependencies = [ "libc", "pyo3-build-config", @@ -9344,9 +9362,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" +checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -9356,9 +9374,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.3" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" +checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" dependencies = [ "heck 0.4.1", "proc-macro2", @@ -10656,7 +10674,9 @@ version = "1.9.0-alpha" dependencies = [ "anyhow", "arrow-array 50.0.0", + "arrow-flight", "arrow-schema 50.0.0", + "arrow-udf-flight", "arrow-udf-js", "arrow-udf-js-deno", "arrow-udf-python", @@ -10676,6 +10696,7 @@ dependencies = [ "futures", "futures-async-stream", "futures-util", + "ginepro", "itertools 0.12.1", "linkme", "madsim-tokio", @@ -10685,15 +10706,16 @@ dependencies = [ "openssl", "parse-display", "paste", + "prometheus", "risingwave_common", "risingwave_common_estimate_size", "risingwave_expr_macro", "risingwave_pb", - "risingwave_udf", "smallvec", "static_assertions", "thiserror", "thiserror-ext", + "tonic 0.10.2", "tracing", "workspace-hack", "zstd 0.13.0", @@ -10759,6 +10781,7 @@ dependencies = [ "anyhow", "arc-swap", "arrow-schema 50.0.0", + "arrow-udf-flight", "arrow-udf-wasm", "assert_matches", "async-recursion", @@ -10818,7 +10841,6 @@ dependencies = [ "risingwave_rpc_client", "risingwave_sqlparser", "risingwave_storage", - "risingwave_udf", "risingwave_variables", "rw_futures_util", "serde", @@ -11588,28 +11610,6 @@ dependencies = [ "workspace-hack", ] -[[package]] -name = "risingwave_udf" -version = "0.1.0" -dependencies = [ - "arrow-array 50.0.0", - "arrow-flight", - "arrow-schema 50.0.0", - "arrow-select 50.0.0", - "cfg-or-panic", - "futures", - "futures-util", - "ginepro", - "madsim-tokio", - "madsim-tonic", - "prometheus", - "risingwave_common", - "static_assertions", - "thiserror", - "thiserror-ext", - "tracing", -] - [[package]] name = "risingwave_variables" version = "1.9.0-alpha" @@ -15308,9 +15308,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasi-common" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b53dfacdeacca15ee2a48a4aa0ec6a6d0da737676e465770c0585f79c04e638" +checksum = "63255d85e10627b07325d7cf4e5fe5a40fa4ff183569a0a67931be26d50ede07" dependencies = [ "anyhow", "bitflags 2.5.0", @@ -15406,9 +15406,18 @@ checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" [[package]] name = "wasm-encoder" -version = "0.201.0" +version = "0.202.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9c7d2731df60006819b013f64ccc2019691deccf6e11a1804bc850cd6748f1a" +checksum = "bfd106365a7f5f7aa3c1916a98cbb3ad477f5ff96ddb130285a91c6e7429e67a" +dependencies = [ + "leb128", +] + +[[package]] +name = "wasm-encoder" +version = "0.206.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d759312e1137f199096d80a70be685899cd7d3d09c572836bb2e9b69b4dc3b1e" dependencies = [ "leb128", ] @@ -15441,9 +15450,9 @@ dependencies = [ [[package]] name = "wasmparser" -version = "0.201.0" +version = "0.202.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84e5df6dba6c0d7fafc63a450f1738451ed7a0b52295d83e868218fa286bf708" +checksum = "d6998515d3cf3f8b980ef7c11b29a9b1017d4cf86b99ae93b546992df9931413" dependencies = [ "bitflags 2.5.0", "indexmap 2.0.0", @@ -15452,9 +15461,9 @@ dependencies = [ [[package]] name = "wasmprinter" -version = "0.201.0" +version = "0.202.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a67e66da702706ba08729a78e3c0079085f6bfcb1a62e4799e97bbf728c2c265" +checksum = "ab1cc9508685eef9502e787f4d4123745f5651a1e29aec047645d3cac1e2da7a" dependencies = [ "anyhow", "wasmparser", @@ -15462,9 +15471,9 @@ dependencies = [ [[package]] name = "wasmtime" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "516be5b58a8f75d39b01378516dcb0ff7b9bc39c7f1f10eec5b338d4916cf988" +checksum = "5a5990663c28d81015ddbb02a068ac1bf396a4ea296eba7125b2dfc7c00cb52e" dependencies = [ "addr2line", "anyhow", @@ -15479,7 +15488,7 @@ dependencies = [ "ittapi", "libc", "log", - "object", + "object 0.33.0", "once_cell", "paste", "rayon", @@ -15489,7 +15498,7 @@ dependencies = [ "serde_derive", "serde_json", "target-lexicon", - "wasm-encoder", + "wasm-encoder 0.202.0", "wasmparser", "wasmtime-cache", "wasmtime-component-macro", @@ -15508,18 +15517,18 @@ dependencies = [ [[package]] name = "wasmtime-asm-macros" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d22d88a92d69385f18143c946884bf6aaa9ec206ce54c85a2d320c1362b009" +checksum = "625ee94c72004f3ea0228989c9506596e469517d7d0ed66f7300d1067bdf1ca9" dependencies = [ "cfg-if", ] [[package]] name = "wasmtime-cache" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "068728a840223b56c964507550da671372e7e5c2f3a7856012b57482e3e979a7" +checksum = "98534bf28de232299e83eab33984a7a6c40c69534d6bd0ea216150b63d41a83a" dependencies = [ "anyhow", "base64 0.21.7", @@ -15537,9 +15546,9 @@ dependencies = [ [[package]] name = "wasmtime-component-macro" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "631244bac89c57ebe7283209d86fe175ad5929328e75f61bf9141895cafbf52d" +checksum = "64f84414a25ee3a624c8b77550f3fe7b5d8145bd3405ca58886ee6900abb6dc2" dependencies = [ "anyhow", "proc-macro2", @@ -15552,15 +15561,15 @@ dependencies = [ [[package]] name = "wasmtime-component-util" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82ad496ba0558f7602da5e9d4c201f35f7aefcca70f973ec916f3f0d0787ef74" +checksum = "78580bdb4e04c7da3bf98088559ca1d29382668536e4d5c7f2f966d79c390307" [[package]] name = "wasmtime-cranelift" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "961ab5ee4b17e627001b18069ee89ef906edbbd3f84955515f6aad5ab6d82299" +checksum = "b60df0ee08c6a536c765f69e9e8205273435b66d02dd401e938769a2622a6c1a" dependencies = [ "anyhow", "cfg-if", @@ -15572,36 +15581,19 @@ dependencies = [ "cranelift-wasm", "gimli", "log", - "object", + "object 0.33.0", "target-lexicon", "thiserror", "wasmparser", - "wasmtime-cranelift-shared", "wasmtime-environ", "wasmtime-versioned-export-macros", ] -[[package]] -name = "wasmtime-cranelift-shared" -version = "19.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc4db94596be14cd1f85844ce85470bf68acf235143098b9d9bf72b49e47b917" -dependencies = [ - "anyhow", - "cranelift-codegen", - "cranelift-control", - "cranelift-native", - "gimli", - "object", - "target-lexicon", - "wasmtime-environ", -] - [[package]] name = "wasmtime-environ" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420b13858ef27dfd116f1fdb0513e9593a307a632ade2ea58334b639a3d8d24e" +checksum = "64ffc1613db69ee47c96738861534f9a405e422a5aa00224fbf5d410b03fb445" dependencies = [ "anyhow", "bincode 1.3.3", @@ -15610,13 +15602,13 @@ dependencies = [ "gimli", "indexmap 2.0.0", "log", - "object", + "object 0.33.0", "rustc-demangle", "serde", "serde_derive", "target-lexicon", "thiserror", - "wasm-encoder", + "wasm-encoder 0.202.0", "wasmparser", "wasmprinter", "wasmtime-component-util", @@ -15625,9 +15617,9 @@ dependencies = [ [[package]] name = "wasmtime-fiber" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d37ff0e11a023019e34fe839c74a1c00880b989f4446176b6cc6da3b58e3ef2" +checksum = "f043514a23792761c5765f8ba61a4aa7d67f260c0c37494caabceb41d8ae81de" dependencies = [ "anyhow", "cc", @@ -15640,11 +15632,11 @@ dependencies = [ [[package]] name = "wasmtime-jit-debug" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b849f19ad1d4a8133ff05b82c438144f17fb49b08e5f7995f8c1e25cf35f390" +checksum = "9c0ca2ad8f5d2b37f507ef1c935687a690e84e9f325f5a2af9639440b43c1f0e" dependencies = [ - "object", + "object 0.33.0", "once_cell", "rustix 0.38.31", "wasmtime-versioned-export-macros", @@ -15652,9 +15644,9 @@ dependencies = [ [[package]] name = "wasmtime-jit-icache-coherence" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59c48eb4223d6556ffbf3decb146d0da124f1fd043f41c98b705252cb6a5c186" +checksum = "7a9f93a3289057b26dc75eb84d6e60d7694f7d169c7c09597495de6e016a13ff" dependencies = [ "cfg-if", "libc", @@ -15663,9 +15655,9 @@ dependencies = [ [[package]] name = "wasmtime-runtime" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fefac2cb5f5a6f365234a3584bf40bd2e45e7f6cd90a689d9b2afbb9881978f" +checksum = "c6332a2b0af4224c3ea57c857ad39acd2780ccc2b0c99ba1baa01864d90d7c94" dependencies = [ "anyhow", "cc", @@ -15674,34 +15666,34 @@ dependencies = [ "indexmap 2.0.0", "libc", "log", - "mach", + "mach2", "memfd", "memoffset", "paste", "psm", "rustix 0.38.31", "sptr", - "wasm-encoder", + "wasm-encoder 0.202.0", "wasmtime-asm-macros", "wasmtime-environ", "wasmtime-fiber", "wasmtime-jit-debug", + "wasmtime-slab", "wasmtime-versioned-export-macros", - "wasmtime-wmemcheck", "windows-sys 0.52.0", ] [[package]] name = "wasmtime-slab" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52d7b97b92df126fdbe994a53d2215828ec5ed5087535e6d4703b1fbd299f0e3" +checksum = "8b3655075824a374c536a2b2cc9283bb765fcdf3d58b58587862c48571ad81ef" [[package]] name = "wasmtime-types" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "509c88abb830819b259c49e2d4e4f22b555db066ba08ded0b76b071a2aa53ddf" +checksum = "b98cf64a242b0b9257604181ca28b28a5fcaa4c9ea1d396f76d1d2d1c5b40eef" dependencies = [ "cranelift-entity", "serde", @@ -15712,9 +15704,9 @@ dependencies = [ [[package]] name = "wasmtime-versioned-export-macros" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1d81c092a61ca1667013e2eb08fed7c6c53e496dbbaa32d5685dc5152b0a772" +checksum = "8561d9e2920db2a175213d557d71c2ac7695831ab472bbfafb9060cd1034684f" dependencies = [ "proc-macro2", "quote", @@ -15723,26 +15715,26 @@ dependencies = [ [[package]] name = "wasmtime-winch" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0958907880e37a2d3974f5b3574c23bf70aaf1fc6c1f716625bb50dac776f1a" +checksum = "a06b573d14ac846a0fb8c541d8fca6a64acf9a1d176176982472274ab1d2fa5d" dependencies = [ "anyhow", "cranelift-codegen", "gimli", - "object", + "object 0.33.0", "target-lexicon", "wasmparser", - "wasmtime-cranelift-shared", + "wasmtime-cranelift", "wasmtime-environ", "winch-codegen", ] [[package]] name = "wasmtime-wit-bindgen" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a593ddefd2f80617df6bea084b2e422d8969e924bc209642a794d57518f59587" +checksum = "595bc7bb3b0ff4aa00fab718c323ea552c3034d77abc821a35112552f2ea487a" dependencies = [ "anyhow", "heck 0.4.1", @@ -15750,12 +15742,6 @@ dependencies = [ "wit-parser", ] -[[package]] -name = "wasmtime-wmemcheck" -version = "19.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b77212b6874bbc86d220bb1d28632d0c11c6afe996c3e1ddcf746b1a6b4919b9" - [[package]] name = "wast" version = "35.0.2" @@ -15767,24 +15753,24 @@ dependencies = [ [[package]] name = "wast" -version = "201.0.0" +version = "206.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ef6e1ef34d7da3e2b374fd2b1a9c0227aff6cad596e1b24df9b58d0f6222faa" +checksum = "68586953ee4960b1f5d84ebf26df3b628b17e6173bc088e0acfbce431469795a" dependencies = [ "bumpalo", "leb128", "memchr", "unicode-width", - "wasm-encoder", + "wasm-encoder 0.206.0", ] [[package]] name = "wat" -version = "1.201.0" +version = "1.206.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453d5b37a45b98dee4f4cb68015fc73634d7883bbef1c65e6e9c78d454cf3f32" +checksum = "da4c6f2606276c6e991aebf441b2fc92c517807393f039992a3e0ad873efe4ad" dependencies = [ - "wast 201.0.0", + "wast 206.0.0", ] [[package]] @@ -15872,9 +15858,9 @@ checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" [[package]] name = "wiggle" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f093d8afdb09efaf2ed1037468bd4614308a762d215b6cafd60a7712993a8ffa" +checksum = "1b6552dda951239e219c329e5a768393664e8d120c5e0818487ac2633f173b1f" dependencies = [ "anyhow", "async-trait", @@ -15887,9 +15873,9 @@ dependencies = [ [[package]] name = "wiggle-generate" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47c7bccd5172ce8d853242f723e42c84b8c131b24fb07a1570f9045d99258616" +checksum = "da64cb31e0bfe8b1d2d13956ef9fd5c77545756a1a6ef0e6cfd44e8f1f207aed" dependencies = [ "anyhow", "heck 0.4.1", @@ -15902,9 +15888,9 @@ dependencies = [ [[package]] name = "wiggle-macro" -version = "19.0.1" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a69d087dee85991096fc0c6eaf4dcf4e17cd16a0594c33b8ab9e2d345234ef75" +checksum = "900b2416ef2ff2903ded6cf55d4a941fed601bf56a8c4874856d7a77c1891994" dependencies = [ "proc-macro2", "quote", @@ -15945,9 +15931,9 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "winch-codegen" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e72a6a7034793b874b85e428fd6d7b3ccccb98c326e33af3aa40cdf50d0c33da" +checksum = "fb23450977f9d4a23c02439cf6899340b2d68887b19465c5682740d9cc37d52e" dependencies = [ "anyhow", "cranelift-codegen", @@ -15956,6 +15942,7 @@ dependencies = [ "smallvec", "target-lexicon", "wasmparser", + "wasmtime-cranelift", "wasmtime-environ", ] @@ -16244,9 +16231,9 @@ dependencies = [ [[package]] name = "wit-parser" -version = "0.201.0" +version = "0.202.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "196d3ecfc4b759a8573bf86a9b3f8996b304b3732e4c7de81655f875f6efdca6" +checksum = "744237b488352f4f27bca05a10acb79474415951c450e52ebd0da784c1df2bcc" dependencies = [ "anyhow", "id-arena", diff --git a/Cargo.toml b/Cargo.toml index 69b2988cf0a4..d75b5b75f741 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ members = [ "src/expr/core", "src/expr/impl", "src/expr/macro", - "src/expr/udf", "src/frontend", "src/frontend/macro", "src/frontend/planner_test", @@ -139,10 +138,11 @@ arrow-flight = "50" arrow-select = "50" arrow-ord = "50" arrow-row = "50" -arrow-udf-js = "0.1" -arrow-udf-js-deno = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "23fe0dd" } -arrow-udf-wasm = { version = "0.2.1", features = ["build"] } -arrow-udf-python = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "6c32f71" } +arrow-udf-js = "0.2" +arrow-udf-js-deno = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "fa36365" } +arrow-udf-wasm = { version = "0.2.2", features = ["build"] } +arrow-udf-python = "0.1" +arrow-udf-flight = "0.1" arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" } arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" } arrow-cast-deltalake = { package = "arrow-cast", version = "48.0.1" } diff --git a/ci/scripts/build-other.sh b/ci/scripts/build-other.sh index 2311e5164fe7..65c50462f97a 100755 --- a/ci/scripts/build-other.sh +++ b/ci/scripts/build-other.sh @@ -16,9 +16,13 @@ cd java mvn -B package -Dmaven.test.skip=true mvn -B install -Dmaven.test.skip=true --pl java-binding-integration-test --am mvn dependency:copy-dependencies --no-transfer-progress --pl java-binding-integration-test -mvn -B test --pl udf cd .. +echo "--- Build Java UDF" +cd e2e_test/udf/java +mvn -B package +cd ../../.. + echo "--- Build rust binary for java binding integration test" cargo build -p risingwave_java_binding --bin data-chunk-payload-generator --bin data-chunk-payload-convert-generator @@ -30,9 +34,9 @@ tar --zstd -cf java-binding-integration-test.tar.zst bin java/java-binding-integ echo "--- Upload Java artifacts" cp java/connector-node/assembly/target/risingwave-connector-1.0.0.tar.gz ./risingwave-connector.tar.gz -cp java/udf-example/target/risingwave-udf-example.jar ./risingwave-udf-example.jar +cp e2e_test/udf/java/target/risingwave-udf-example.jar ./udf.jar cp e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm udf.wasm buildkite-agent artifact upload ./risingwave-connector.tar.gz -buildkite-agent artifact upload ./risingwave-udf-example.jar buildkite-agent artifact upload ./java-binding-integration-test.tar.zst +buildkite-agent artifact upload ./udf.jar buildkite-agent artifact upload ./udf.wasm diff --git a/ci/scripts/run-e2e-test.sh b/ci/scripts/run-e2e-test.sh index 5ce0b55f27e9..044193b71272 100755 --- a/ci/scripts/run-e2e-test.sh +++ b/ci/scripts/run-e2e-test.sh @@ -70,7 +70,7 @@ download-and-decompress-artifact e2e_test_generated ./ download-and-decompress-artifact risingwave_e2e_extended_mode_test-"$profile" target/debug/ mkdir -p e2e_test/udf/wasm/target/wasm32-wasi/release/ buildkite-agent artifact download udf.wasm e2e_test/udf/wasm/target/wasm32-wasi/release/ -buildkite-agent artifact download risingwave-udf-example.jar ./ +buildkite-agent artifact download udf.jar ./ mv target/debug/risingwave_e2e_extended_mode_test-"$profile" target/debug/risingwave_e2e_extended_mode_test chmod +x ./target/debug/risingwave_e2e_extended_mode_test @@ -105,6 +105,7 @@ echo "--- e2e, $mode, Apache Superset" sqllogictest -p 4566 -d dev './e2e_test/superset/*.slt' --junit "batch-${profile}" echo "--- e2e, $mode, external python udf" +python3 -m pip install --break-system-packages arrow-udf==0.2.1 python3 e2e_test/udf/test.py & sleep 1 sqllogictest -p 4566 -d dev './e2e_test/udf/external_udf.slt' @@ -117,7 +118,7 @@ sqllogictest -p 4566 -d dev './e2e_test/udf/always_retry_python.slt' # sqllogictest -p 4566 -d dev './e2e_test/udf/retry_python.slt' echo "--- e2e, $mode, external java udf" -java -jar risingwave-udf-example.jar & +java -jar udf.jar & sleep 1 sqllogictest -p 4566 -d dev './e2e_test/udf/external_udf.slt' pkill java diff --git a/ci/scripts/run-unit-test.sh b/ci/scripts/run-unit-test.sh index d9a723a34fa1..394cdb1a7826 100755 --- a/ci/scripts/run-unit-test.sh +++ b/ci/scripts/run-unit-test.sh @@ -5,11 +5,6 @@ set -euo pipefail REPO_ROOT=${PWD} -echo "+++ Run python UDF SDK unit tests" -cd "${REPO_ROOT}"/src/expr/udf/python -python3 -m pytest -cd "${REPO_ROOT}" - echo "+++ Run unit tests" # use tee to disable progress bar NEXTEST_PROFILE=ci cargo nextest run --features failpoints,sync_point --workspace --exclude risingwave_simulation diff --git a/e2e_test/error_ui/simple/main.slt b/e2e_test/error_ui/simple/main.slt index 8ef82e1f0d1c..6bcbbde608cf 100644 --- a/e2e_test/error_ui/simple/main.slt +++ b/e2e_test/error_ui/simple/main.slt @@ -13,8 +13,10 @@ create function int_42() returns int as int_42 using link '555.0.0.1:8815'; ---- db error: ERROR: Failed to run the query -Caused by: - Flight service error: invalid address: 555.0.0.1:8815, err: failed to parse address: http://555.0.0.1:8815: invalid IPv4 address +Caused by these errors (recent errors listed first): + 1: Expr error + 2: UDF error + 3: Flight service error: invalid address: 555.0.0.1:8815, err: failed to parse address: http://555.0.0.1:8815: invalid IPv4 address statement error diff --git a/e2e_test/udf/external_udf.slt b/e2e_test/udf/external_udf.slt index 096a605709d6..7a38506f8156 100644 --- a/e2e_test/udf/external_udf.slt +++ b/e2e_test/udf/external_udf.slt @@ -1,7 +1,7 @@ # Before running this test: # python3 e2e_test/udf/test.py # or: -# cd java/udf-example && mvn package && java -jar target/risingwave-udf-example.jar +# cd e2e_test/udf/java && mvn package && java -jar target/risingwave-udf-example.jar # Create a function. statement ok diff --git a/java/udf-example/README.md b/e2e_test/udf/java/README.md similarity index 100% rename from java/udf-example/README.md rename to e2e_test/udf/java/README.md diff --git a/java/udf-example/pom.xml b/e2e_test/udf/java/pom.xml similarity index 86% rename from java/udf-example/pom.xml rename to e2e_test/udf/java/pom.xml index 8bf51cd10812..7ecd7c54dca1 100644 --- a/java/udf-example/pom.xml +++ b/e2e_test/udf/java/pom.xml @@ -5,17 +5,9 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 - - - com.risingwave - risingwave-java-root - 0.1.0-SNAPSHOT - ../pom.xml - - com.risingwave risingwave-udf-example - 0.1.1-SNAPSHOT + 0.1.0-SNAPSHOT udf-example https://docs.risingwave.com/docs/current/udf-java @@ -31,7 +23,7 @@ com.risingwave risingwave-udf - 0.1.3-SNAPSHOT + 0.2.0 com.google.code.gson diff --git a/java/udf-example/src/main/java/com/example/UdfExample.java b/e2e_test/udf/java/src/main/java/com/example/UdfExample.java similarity index 99% rename from java/udf-example/src/main/java/com/example/UdfExample.java rename to e2e_test/udf/java/src/main/java/com/example/UdfExample.java index 883dc5035514..1702e244bf1f 100644 --- a/java/udf-example/src/main/java/com/example/UdfExample.java +++ b/e2e_test/udf/java/src/main/java/com/example/UdfExample.java @@ -33,7 +33,7 @@ public class UdfExample { public static void main(String[] args) throws IOException { - try (var server = new UdfServer("0.0.0.0", 8815)) { + try (var server = new UdfServer("localhost", 8815)) { server.addFunction("int_42", new Int42()); server.addFunction("float_to_decimal", new FloatToDecimal()); server.addFunction("sleep", new Sleep()); diff --git a/e2e_test/udf/requirements.txt b/e2e_test/udf/requirements.txt index 8642e2b1ec25..36688db1ed1e 100644 --- a/e2e_test/udf/requirements.txt +++ b/e2e_test/udf/requirements.txt @@ -1,2 +1,3 @@ flask -waitress \ No newline at end of file +waitress +arrow_udf==0.2.1 \ No newline at end of file diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py index 6195476a8000..4443a81a6e74 100644 --- a/e2e_test/udf/test.py +++ b/e2e_test/udf/test.py @@ -19,9 +19,7 @@ from typing import Iterator, List, Optional, Tuple, Any from decimal import Decimal -sys.path.append("src/expr/udf/python") # noqa - -from risingwave.udf import udf, udtf, UdfServer +from arrow_udf import udf, udtf, UdfServer @udf(input_types=[], result_type="INT") @@ -47,13 +45,21 @@ def gcd3(x: int, y: int, z: int) -> int: return gcd(gcd(x, y), z) -@udf(input_types=["BYTEA"], result_type="STRUCT") +@udf( + input_types=["BYTEA"], + result_type="STRUCT", +) def extract_tcp_info(tcp_packet: bytes): src_addr, dst_addr = struct.unpack("!4s4s", tcp_packet[12:20]) src_port, dst_port = struct.unpack("!HH", tcp_packet[20:24]) src_addr = socket.inet_ntoa(src_addr) dst_addr = socket.inet_ntoa(dst_addr) - return src_addr, dst_addr, src_port, dst_port + return { + "src_addr": src_addr, + "dst_addr": dst_addr, + "src_port": src_port, + "dst_port": dst_port, + } @udtf(input_types="INT", result_types="INT") @@ -84,7 +90,7 @@ def hex_to_dec(hex: Optional[str]) -> Optional[Decimal]: return dec -@udf(input_types=["FLOAT8"], result_type="DECIMAL") +@udf(input_types=["FLOAT64"], result_type="DECIMAL") def float_to_decimal(f: float) -> Decimal: return Decimal(f) @@ -120,21 +126,49 @@ def jsonb_array_identity(list: List[Any]) -> List[Any]: return list -@udf(input_types="STRUCT", result_type="STRUCT") +@udf( + input_types="STRUCT", + result_type="STRUCT", +) def jsonb_array_struct_identity(v: Tuple[List[Any], int]) -> Tuple[List[Any], int]: return v -ALL_TYPES = "BOOLEAN,SMALLINT,INT,BIGINT,FLOAT4,FLOAT8,DECIMAL,DATE,TIME,TIMESTAMP,INTERVAL,VARCHAR,BYTEA,JSONB".split( - "," -) + [ - "STRUCT" -] - - @udf( - input_types=ALL_TYPES, - result_type=f"struct<{','.join(ALL_TYPES)}>", + input_types=[ + "boolean", + "int16", + "int32", + "int64", + "float32", + "float64", + "decimal", + "date32", + "time64", + "timestamp", + "interval", + "string", + "binary", + "json", + "struct", + ], + result_type="""struct< + boolean: boolean, + int16: int16, + int32: int32, + int64: int64, + float32: float32, + float64: float64, + decimal: decimal, + date32: date32, + time64: time64, + timestamp: timestamp, + interval: interval, + string: string, + binary: binary, + json: json, + struct: struct, + >""", ) def return_all( bool, @@ -153,28 +187,60 @@ def return_all( jsonb, struct, ): - return ( - bool, - i16, - i32, - i64, - f32, - f64, - decimal, - date, - time, - timestamp, - interval, - varchar, - bytea, - jsonb, - struct, - ) + return { + "boolean": bool, + "int16": i16, + "int32": i32, + "int64": i64, + "float32": f32, + "float64": f64, + "decimal": decimal, + "date32": date, + "time64": time, + "timestamp": timestamp, + "interval": interval, + "string": varchar, + "binary": bytea, + "json": jsonb, + "struct": struct, + } @udf( - input_types=[t + "[]" for t in ALL_TYPES], - result_type=f"struct<{','.join(t + '[]' for t in ALL_TYPES)}>", + input_types=[ + "boolean[]", + "int16[]", + "int32[]", + "int64[]", + "float32[]", + "float64[]", + "decimal[]", + "date32[]", + "time64[]", + "timestamp[]", + "interval[]", + "string[]", + "binary[]", + "json[]", + "struct[]", + ], + result_type="""struct< + boolean: boolean[], + int16: int16[], + int32: int32[], + int64: int64[], + float32: float32[], + float64: float64[], + decimal: decimal[], + date32: date32[], + time64: time64[], + timestamp: timestamp[], + interval: interval[], + string: string[], + binary: binary[], + json: json[], + struct: struct[], + >""", ) def return_all_arrays( bool, @@ -193,23 +259,23 @@ def return_all_arrays( jsonb, struct, ): - return ( - bool, - i16, - i32, - i64, - f32, - f64, - decimal, - date, - time, - timestamp, - interval, - varchar, - bytea, - jsonb, - struct, - ) + return { + "boolean": bool, + "int16": i16, + "int32": i32, + "int64": i64, + "float32": f32, + "float64": f64, + "decimal": decimal, + "date32": date, + "time64": time, + "timestamp": timestamp, + "interval": interval, + "string": varchar, + "binary": bytea, + "json": jsonb, + "struct": struct, + } if __name__ == "__main__": diff --git a/e2e_test/udf/wasm/Cargo.toml b/e2e_test/udf/wasm/Cargo.toml index 250bd8132ca5..54c7da45b1af 100644 --- a/e2e_test/udf/wasm/Cargo.toml +++ b/e2e_test/udf/wasm/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] -arrow-udf = "0.2" +arrow-udf = "0.3" genawaiter = "0.99" rust_decimal = "1" serde_json = "1" diff --git a/java/dev.md b/java/dev.md index ac20c30fe69f..148fde173baa 100644 --- a/java/dev.md +++ b/java/dev.md @@ -56,9 +56,3 @@ Config with the following. It may work. "java.format.settings.profile": "Android" } ``` - -## Deploy UDF Library to Maven - -```sh -mvn clean deploy --pl udf --am -``` \ No newline at end of file diff --git a/java/pom.xml b/java/pom.xml index 922c62ead69e..f1ee457ef3b8 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -37,8 +37,6 @@ proto - udf - udf-example java-binding common-utils java-binding-integration-test @@ -572,4 +570,4 @@ https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ - + \ No newline at end of file diff --git a/java/udf/CHANGELOG.md b/java/udf/CHANGELOG.md deleted file mode 100644 index fb1f05578322..000000000000 --- a/java/udf/CHANGELOG.md +++ /dev/null @@ -1,39 +0,0 @@ -# Changelog - -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [Unreleased] - -## [0.1.3] - 2023-12-06 - -### Fixed - -- Fix decimal type output. - -## [0.1.2] - 2023-12-04 - -### Fixed - -- Fix index-out-of-bound error when string or string list is large. -- Fix memory leak. - -## [0.1.1] - 2023-12-03 - -### Added - -- Support struct in struct and struct[] in struct. - -### Changed - -- Bump Arrow version to 14. - -### Fixed - -- Fix unconstrained decimal type. - -## [0.1.0] - 2023-09-01 - -- Initial release. \ No newline at end of file diff --git a/java/udf/README.md b/java/udf/README.md deleted file mode 100644 index 200b897b8b89..000000000000 --- a/java/udf/README.md +++ /dev/null @@ -1,274 +0,0 @@ -# RisingWave Java UDF SDK - -This library provides a Java SDK for creating user-defined functions (UDF) in RisingWave. - -## Introduction - -RisingWave supports user-defined functions implemented as external functions. -With the RisingWave Java UDF SDK, users can define custom UDFs using Java and start a Java process as a UDF server. -RisingWave can then remotely access the UDF server to execute the defined functions. - -## Installation - -To install the RisingWave Java UDF SDK: - -```sh -git clone https://github.com/risingwavelabs/risingwave.git -cd risingwave/java/udf -mvn install -``` - -Or you can add the following dependency to your `pom.xml` file: - -```xml - - - com.risingwave - risingwave-udf - 0.1.0 - - -``` - - -## Creating a New Project - -> NOTE: You can also start from the [udf-example](../udf-example) project without creating the project from scratch. - -To create a new project using the RisingWave Java UDF SDK, follow these steps: - -```sh -mvn archetype:generate -DgroupId=com.example -DartifactId=udf-example -DarchetypeArtifactId=maven-archetype-quickstart -DarchetypeVersion=1.4 -DinteractiveMode=false -``` - -Configure your `pom.xml` file as follows: - -```xml - - - 4.0.0 - com.example - udf-example - 1.0-SNAPSHOT - - - - com.risingwave - risingwave-udf - 0.1.0 - - - -``` - -The `--add-opens` flag must be added when running unit tests through Maven: - -```xml - - - - org.apache.maven.plugins - maven-surefire-plugin - 3.0.0 - - --add-opens=java.base/java.nio=ALL-UNNAMED - - - - -``` - -## Scalar Functions - -A user-defined scalar function maps zero, one, or multiple scalar values to a new scalar value. - -In order to define a scalar function, one has to create a new class that implements the `ScalarFunction` -interface in `com.risingwave.functions` and implement exactly one evaluation method named `eval(...)`. -This method must be declared public and non-static. - -Any [data type](#data-types) listed in the data types section can be used as a parameter or return type of an evaluation method. - -Here's an example of a scalar function that calculates the greatest common divisor (GCD) of two integers: - -```java -import com.risingwave.functions.ScalarFunction; - -public class Gcd implements ScalarFunction { - public int eval(int a, int b) { - while (b != 0) { - int temp = b; - b = a % b; - a = temp; - } - return a; - } -} -``` - -> **NOTE:** Differences with Flink -> 1. The `ScalarFunction` is an interface instead of an abstract class. -> 2. Multiple overloaded `eval` methods are not supported. -> 3. Variable arguments such as `eval(Integer...)` are not supported. - -## Table Functions - -A user-defined table function maps zero, one, or multiple scalar values to one or multiple -rows (structured types). - -In order to define a table function, one has to create a new class that implements the `TableFunction` -interface in `com.risingwave.functions` and implement exactly one evaluation method named `eval(...)`. -This method must be declared public and non-static. - -The return type must be an `Iterator` of any [data type](#data-types) listed in the data types section. -Similar to scalar functions, input and output data types are automatically extracted using reflection. -This includes the generic argument T of the return value for determining an output data type. - -Here's an example of a table function that generates a series of integers: - -```java -import com.risingwave.functions.TableFunction; - -public class Series implements TableFunction { - public Iterator eval(int n) { - return java.util.stream.IntStream.range(0, n).iterator(); - } -} -``` - -> **NOTE:** Differences with Flink -> 1. The `TableFunction` is an interface instead of an abstract class. It has no generic arguments. -> 2. Instead of calling `collect` to emit a row, the `eval` method returns an `Iterator` of the output rows. -> 3. Multiple overloaded `eval` methods are not supported. -> 4. Variable arguments such as `eval(Integer...)` are not supported. -> 5. In SQL, table functions can be used in the `FROM` clause directly. `JOIN LATERAL TABLE` is not supported. - -## UDF Server - -To create a UDF server and register functions: - -```java -import com.risingwave.functions.UdfServer; - -public class App { - public static void main(String[] args) { - try (var server = new UdfServer("0.0.0.0", 8815)) { - // register functions - server.addFunction("gcd", new Gcd()); - server.addFunction("series", new Series()); - // start the server - server.start(); - server.awaitTermination(); - } catch (Exception e) { - e.printStackTrace(); - } - } -} -``` - -To run the UDF server, execute the following command: - -```sh -_JAVA_OPTIONS="--add-opens=java.base/java.nio=ALL-UNNAMED" mvn exec:java -Dexec.mainClass="com.example.App" -``` - -## Creating Functions in RisingWave - -```sql -create function gcd(int, int) returns int -as gcd using link 'http://localhost:8815'; - -create function series(int) returns table (x int) -as series using link 'http://localhost:8815'; -``` - -For more detailed information and examples, please refer to the official RisingWave [documentation](https://www.risingwave.dev/docs/current/user-defined-functions/#4-declare-your-functions-in-risingwave). - -## Using Functions in RisingWave - -Once the user-defined functions are created in RisingWave, you can use them in SQL queries just like any built-in functions. Here are a few examples: - -```sql -select gcd(25, 15); - -select * from series(10); -``` - -## Data Types - -The RisingWave Java UDF SDK supports the following data types: - -| SQL Type | Java Type | Notes | -| ---------------- | --------------------------------------- | ------------------ | -| BOOLEAN | boolean, Boolean | | -| SMALLINT | short, Short | | -| INT | int, Integer | | -| BIGINT | long, Long | | -| REAL | float, Float | | -| DOUBLE PRECISION | double, Double | | -| DECIMAL | BigDecimal | | -| DATE | java.time.LocalDate | | -| TIME | java.time.LocalTime | | -| TIMESTAMP | java.time.LocalDateTime | | -| INTERVAL | com.risingwave.functions.PeriodDuration | | -| VARCHAR | String | | -| BYTEA | byte[] | | -| JSONB | String | Use `@DataTypeHint("JSONB") String` as the type. See [example](#jsonb). | -| T[] | T'[] | `T` can be any of the above SQL types. `T'` should be the corresponding Java type.| -| STRUCT<> | user-defined class | Define a data class as the type. See [example](#struct-type). | -| ...others | | Not supported yet. | - -### JSONB - -```java -import com.google.gson.Gson; - -// Returns the i-th element of a JSON array. -public class JsonbAccess implements ScalarFunction { - static Gson gson = new Gson(); - - public @DataTypeHint("JSONB") String eval(@DataTypeHint("JSONB") String json, int index) { - if (json == null) - return null; - var array = gson.fromJson(json, Object[].class); - if (index >= array.length || index < 0) - return null; - var obj = array[index]; - return gson.toJson(obj); - } -} -``` - -```sql -create function jsonb_access(jsonb, int) returns jsonb -as jsonb_access using link 'http://localhost:8815'; -``` - -### Struct Type - -```java -// Split a socket address into host and port. -public static class IpPort implements ScalarFunction { - public static class SocketAddr { - public String host; - public short port; - } - - public SocketAddr eval(String addr) { - var socketAddr = new SocketAddr(); - var parts = addr.split(":"); - socketAddr.host = parts[0]; - socketAddr.port = Short.parseShort(parts[1]); - return socketAddr; - } -} -``` - -```sql -create function ip_port(varchar) returns struct -as ip_port using link 'http://localhost:8815'; -``` - -## Full Example - -You can checkout [udf-example](../udf-example) and use it as a template to create your own UDFs. diff --git a/java/udf/pom.xml b/java/udf/pom.xml deleted file mode 100644 index f747603ca842..000000000000 --- a/java/udf/pom.xml +++ /dev/null @@ -1,58 +0,0 @@ - - 4.0.0 - - com.risingwave - risingwave-udf - jar - 0.1.3-SNAPSHOT - - - risingwave-java-root - com.risingwave - 0.1.0-SNAPSHOT - ../pom.xml - - - RisingWave Java UDF SDK - https://docs.risingwave.com/docs/current/udf-java - - - - org.junit.jupiter - junit-jupiter-engine - 5.9.1 - test - - - org.apache.arrow - arrow-vector - 14.0.0 - - - org.apache.arrow - flight-core - 14.0.0 - - - org.slf4j - slf4j-api - 2.0.7 - - - org.slf4j - slf4j-simple - 2.0.7 - - - - - - kr.motd.maven - os-maven-plugin - 1.7.0 - - - - \ No newline at end of file diff --git a/java/udf/src/main/java/com/risingwave/functions/DataTypeHint.java b/java/udf/src/main/java/com/risingwave/functions/DataTypeHint.java deleted file mode 100644 index 7baf0fe4c611..000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/DataTypeHint.java +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.lang.annotation.*; - -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.METHOD, ElementType.FIELD, ElementType.PARAMETER}) -public @interface DataTypeHint { - String value(); -} diff --git a/java/udf/src/main/java/com/risingwave/functions/PeriodDuration.java b/java/udf/src/main/java/com/risingwave/functions/PeriodDuration.java deleted file mode 100644 index 6d704100f6f3..000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/PeriodDuration.java +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.time.Duration; -import java.time.Period; - -/** Combination of Period and Duration. */ -public class PeriodDuration extends org.apache.arrow.vector.PeriodDuration { - public PeriodDuration(Period period, Duration duration) { - super(period, duration); - } - - PeriodDuration(org.apache.arrow.vector.PeriodDuration base) { - super(base.getPeriod(), base.getDuration()); - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/ScalarFunction.java b/java/udf/src/main/java/com/risingwave/functions/ScalarFunction.java deleted file mode 100644 index 5f3fcaf28733..000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/ScalarFunction.java +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -/** - * Base interface for a user-defined scalar function. A user-defined scalar function maps zero, one, - * or multiple scalar values to a new scalar value. - * - *

The behavior of a {@link ScalarFunction} can be defined by implementing a custom evaluation - * method. An evaluation method must be declared publicly, not static, and named eval. - * Multiple overloaded methods named eval are not supported yet. - * - *

By default, input and output data types are automatically extracted using reflection. - * - *

The following examples show how to specify a scalar function: - * - *

{@code
- * // a function that accepts two INT arguments and computes a sum
- * class SumFunction implements ScalarFunction {
- *     public Integer eval(Integer a, Integer b) {
- *         return a + b;
- *     }
- * }
- *
- * // a function that returns a struct type
- * class StructFunction implements ScalarFunction {
- *     public static class KeyValue {
- *         public String key;
- *         public int value;
- *     }
- *
- *     public KeyValue eval(int a) {
- *         KeyValue kv = new KeyValue();
- *         kv.key = a.toString();
- *         kv.value = a;
- *         return kv;
- *     }
- * }
- * }
- */ -public interface ScalarFunction extends UserDefinedFunction {} diff --git a/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java deleted file mode 100644 index 5d837d3b370f..000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/ScalarFunctionBatch.java +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.lang.invoke.MethodHandle; -import java.util.Collections; -import java.util.Iterator; -import java.util.function.Function; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; - -/** Batch-processing wrapper over a user-defined scalar function. */ -class ScalarFunctionBatch extends UserDefinedFunctionBatch { - ScalarFunction function; - MethodHandle methodHandle; - Function[] processInputs; - - ScalarFunctionBatch(ScalarFunction function) { - this.function = function; - var method = Reflection.getEvalMethod(function); - this.methodHandle = Reflection.getMethodHandle(method); - this.inputSchema = TypeUtils.methodToInputSchema(method); - this.outputSchema = TypeUtils.methodToOutputSchema(method); - this.processInputs = TypeUtils.methodToProcessInputs(method); - } - - @Override - Iterator evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) { - var row = new Object[batch.getSchema().getFields().size() + 1]; - row[0] = this.function; - var outputValues = new Object[batch.getRowCount()]; - for (int i = 0; i < batch.getRowCount(); i++) { - for (int j = 0; j < row.length - 1; j++) { - var val = batch.getVector(j).getObject(i); - row[j + 1] = this.processInputs[j].apply(val); - } - try { - outputValues[i] = this.methodHandle.invokeWithArguments(row); - } catch (Throwable e) { - throw new RuntimeException(e); - } - } - var outputVector = - TypeUtils.createVector( - this.outputSchema.getFields().get(0), allocator, outputValues); - var outputBatch = VectorSchemaRoot.of(outputVector); - return Collections.singleton(outputBatch).iterator(); - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/TableFunction.java b/java/udf/src/main/java/com/risingwave/functions/TableFunction.java deleted file mode 100644 index ec5b9d214553..000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/TableFunction.java +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -/** - * Base interface for a user-defined table function. A user-defined table function maps zero, one, - * or multiple scalar values to zero, one, or multiple rows (or structured types). If an output - * record consists of only one field, the structured record can be omitted, and a scalar value can - * be emitted that will be implicitly wrapped into a row by the runtime. - * - *

The behavior of a {@link TableFunction} can be defined by implementing a custom evaluation - * method. An evaluation method must be declared publicly, not static, and named eval. - * The return type must be an Iterator. Multiple overloaded methods named eval are not - * supported yet. - * - *

By default, input and output data types are automatically extracted using reflection. - * - *

The following examples show how to specify a table function: - * - *

{@code
- * // a function that accepts an INT arguments and emits the range from 0 to the
- * // given number.
- * class Series implements TableFunction {
- *     public Iterator eval(int n) {
- *         return java.util.stream.IntStream.range(0, n).iterator();
- *     }
- * }
- *
- * // a function that accepts an String arguments and emits the words of the
- * // given string.
- * class Split implements TableFunction {
- *     public static class Row {
- *         public String word;
- *         public int length;
- *     }
- *
- *     public Iterator eval(String str) {
- *         return Stream.of(str.split(" ")).map(s -> {
- *             Row row = new Row();
- *             row.word = s;
- *             row.length = s.length();
- *             return row;
- *         }).iterator();
- *     }
- * }
- * }
- */ -public interface TableFunction extends UserDefinedFunction {} diff --git a/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java deleted file mode 100644 index a0e0608e6021..000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/TableFunctionBatch.java +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.lang.invoke.MethodHandle; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.function.Function; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; - -/** Batch-processing wrapper over a user-defined table function. */ -class TableFunctionBatch extends UserDefinedFunctionBatch { - TableFunction function; - MethodHandle methodHandle; - Function[] processInputs; - int chunkSize = 1024; - - TableFunctionBatch(TableFunction function) { - this.function = function; - var method = Reflection.getEvalMethod(function); - this.methodHandle = Reflection.getMethodHandle(method); - this.inputSchema = TypeUtils.methodToInputSchema(method); - this.outputSchema = TypeUtils.tableFunctionToOutputSchema(method); - this.processInputs = TypeUtils.methodToProcessInputs(method); - } - - @Override - Iterator evalBatch(VectorSchemaRoot batch, BufferAllocator allocator) { - var outputs = new ArrayList(); - var row = new Object[batch.getSchema().getFields().size() + 1]; - row[0] = this.function; - var indexes = new ArrayList(); - var values = new ArrayList(); - Runnable buildChunk = - () -> { - var fields = this.outputSchema.getFields(); - var indexVector = - TypeUtils.createVector(fields.get(0), allocator, indexes.toArray()); - var valueVector = - TypeUtils.createVector(fields.get(1), allocator, values.toArray()); - indexes.clear(); - values.clear(); - var outputBatch = VectorSchemaRoot.of(indexVector, valueVector); - outputs.add(outputBatch); - }; - for (int i = 0; i < batch.getRowCount(); i++) { - // prepare input row - for (int j = 0; j < row.length - 1; j++) { - var val = batch.getVector(j).getObject(i); - row[j + 1] = this.processInputs[j].apply(val); - } - // call function - Iterator iterator; - try { - iterator = (Iterator) this.methodHandle.invokeWithArguments(row); - } catch (Throwable e) { - throw new RuntimeException(e); - } - // push values - while (iterator.hasNext()) { - indexes.add(i); - values.add(iterator.next()); - // check if we need to flush - if (indexes.size() >= this.chunkSize) { - buildChunk.run(); - } - } - } - if (indexes.size() > 0) { - buildChunk.run(); - } - return outputs.iterator(); - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java b/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java deleted file mode 100644 index 06c2f79858c4..000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/TypeUtils.java +++ /dev/null @@ -1,505 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.lang.invoke.MethodHandles; -import java.lang.reflect.Array; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.lang.reflect.ParameterizedType; -import java.math.BigDecimal; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.LocalTime; -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.function.Function; -import java.util.stream.Collectors; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.*; -import org.apache.arrow.vector.complex.ListVector; -import org.apache.arrow.vector.complex.StructVector; -import org.apache.arrow.vector.types.*; -import org.apache.arrow.vector.types.pojo.*; - -class TypeUtils { - /** Convert a string to an Arrow type. */ - static Field stringToField(String typeStr, String name) { - typeStr = typeStr.toUpperCase(); - if (typeStr.equals("BOOLEAN") || typeStr.equals("BOOL")) { - return Field.nullable(name, new ArrowType.Bool()); - } else if (typeStr.equals("SMALLINT") || typeStr.equals("INT2")) { - return Field.nullable(name, new ArrowType.Int(16, true)); - } else if (typeStr.equals("INT") || typeStr.equals("INTEGER") || typeStr.equals("INT4")) { - return Field.nullable(name, new ArrowType.Int(32, true)); - } else if (typeStr.equals("BIGINT") || typeStr.equals("INT8")) { - return Field.nullable(name, new ArrowType.Int(64, true)); - } else if (typeStr.equals("FLOAT4") || typeStr.equals("REAL")) { - return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)); - } else if (typeStr.equals("FLOAT8") || typeStr.equals("DOUBLE PRECISION")) { - return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); - } else if (typeStr.equals("DECIMAL") || typeStr.equals("NUMERIC")) { - return Field.nullable(name, new ArrowType.LargeBinary()); - } else if (typeStr.equals("DATE")) { - return Field.nullable(name, new ArrowType.Date(DateUnit.DAY)); - } else if (typeStr.equals("TIME") || typeStr.equals("TIME WITHOUT TIME ZONE")) { - return Field.nullable(name, new ArrowType.Time(TimeUnit.MICROSECOND, 64)); - } else if (typeStr.equals("TIMESTAMP") || typeStr.equals("TIMESTAMP WITHOUT TIME ZONE")) { - return Field.nullable(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)); - } else if (typeStr.startsWith("INTERVAL")) { - return Field.nullable(name, new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO)); - } else if (typeStr.equals("VARCHAR")) { - return Field.nullable(name, new ArrowType.Utf8()); - } else if (typeStr.equals("JSONB")) { - return Field.nullable(name, new ArrowType.LargeUtf8()); - } else if (typeStr.equals("BYTEA")) { - return Field.nullable(name, new ArrowType.Binary()); - } else if (typeStr.endsWith("[]")) { - Field innerField = stringToField(typeStr.substring(0, typeStr.length() - 2), ""); - return new Field( - name, FieldType.nullable(new ArrowType.List()), Arrays.asList(innerField)); - } else if (typeStr.startsWith("STRUCT")) { - // extract "STRUCT" - var typeList = typeStr.substring(7, typeStr.length() - 1); - var fields = - Arrays.stream(typeList.split(",")) - .map(s -> stringToField(s.trim(), "")) - .collect(Collectors.toList()); - return new Field(name, FieldType.nullable(new ArrowType.Struct()), fields); - } else { - throw new IllegalArgumentException("Unsupported type: " + typeStr); - } - } - - /** - * Convert a Java class to an Arrow type. - * - * @param param The Java class. - * @param hint An optional DataTypeHint annotation. - * @param name The name of the field. - * @return The Arrow type. - */ - static Field classToField(Class param, DataTypeHint hint, String name) { - if (hint != null) { - return stringToField(hint.value(), name); - } else if (param == Boolean.class || param == boolean.class) { - return Field.nullable(name, new ArrowType.Bool()); - } else if (param == Short.class || param == short.class) { - return Field.nullable(name, new ArrowType.Int(16, true)); - } else if (param == Integer.class || param == int.class) { - return Field.nullable(name, new ArrowType.Int(32, true)); - } else if (param == Long.class || param == long.class) { - return Field.nullable(name, new ArrowType.Int(64, true)); - } else if (param == Float.class || param == float.class) { - return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)); - } else if (param == Double.class || param == double.class) { - return Field.nullable(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); - } else if (param == BigDecimal.class) { - return Field.nullable(name, new ArrowType.LargeBinary()); - } else if (param == LocalDate.class) { - return Field.nullable(name, new ArrowType.Date(DateUnit.DAY)); - } else if (param == LocalTime.class) { - return Field.nullable(name, new ArrowType.Time(TimeUnit.MICROSECOND, 64)); - } else if (param == LocalDateTime.class) { - return Field.nullable(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)); - } else if (param == PeriodDuration.class) { - return Field.nullable(name, new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO)); - } else if (param == String.class) { - return Field.nullable(name, new ArrowType.Utf8()); - } else if (param == byte[].class) { - return Field.nullable(name, new ArrowType.Binary()); - } else if (param.isArray()) { - var innerField = classToField(param.getComponentType(), null, ""); - return new Field( - name, FieldType.nullable(new ArrowType.List()), Arrays.asList(innerField)); - } else { - // struct type - var fields = new ArrayList(); - for (var field : param.getDeclaredFields()) { - var subhint = field.getAnnotation(DataTypeHint.class); - fields.add(classToField(field.getType(), subhint, field.getName())); - } - return new Field(name, FieldType.nullable(new ArrowType.Struct()), fields); - // TODO: more types - // throw new IllegalArgumentException("Unsupported type: " + param); - } - } - - /** Get the input schema from a Java method. */ - static Schema methodToInputSchema(Method method) { - var fields = new ArrayList(); - for (var param : method.getParameters()) { - var hint = param.getAnnotation(DataTypeHint.class); - fields.add(classToField(param.getType(), hint, param.getName())); - } - return new Schema(fields); - } - - /** Get the output schema of a scalar function from a Java method. */ - static Schema methodToOutputSchema(Method method) { - var type = method.getReturnType(); - var hint = method.getAnnotation(DataTypeHint.class); - return new Schema(Arrays.asList(classToField(type, hint, ""))); - } - - /** Get the output schema of a table function from a Java class. */ - static Schema tableFunctionToOutputSchema(Method method) { - var hint = method.getAnnotation(DataTypeHint.class); - var type = method.getReturnType(); - if (!Iterator.class.isAssignableFrom(type)) { - throw new IllegalArgumentException("Table function must return Iterator"); - } - var typeArguments = - ((ParameterizedType) method.getGenericReturnType()).getActualTypeArguments(); - type = (Class) typeArguments[0]; - var rowIndex = Field.nullable("row_index", new ArrowType.Int(32, true)); - return new Schema(Arrays.asList(rowIndex, classToField(type, hint, ""))); - } - - /** Return functions to process input values from a Java method. */ - static Function[] methodToProcessInputs(Method method) { - var schema = methodToInputSchema(method); - var params = method.getParameters(); - @SuppressWarnings("unchecked") - Function[] funcs = new Function[schema.getFields().size()]; - for (int i = 0; i < schema.getFields().size(); i++) { - funcs[i] = processFunc(schema.getFields().get(i), params[i].getType()); - } - return funcs; - } - - /** Create an Arrow vector from an array of values. */ - static FieldVector createVector(Field field, BufferAllocator allocator, Object[] values) { - var vector = field.createVector(allocator); - fillVector(vector, values); - return vector; - } - - /** Fill an Arrow vector with an array of values. */ - static void fillVector(FieldVector fieldVector, Object[] values) { - if (fieldVector instanceof BitVector) { - var vector = (BitVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (boolean) values[i] ? 1 : 0); - } - } - } else if (fieldVector instanceof SmallIntVector) { - var vector = (SmallIntVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (short) values[i]); - } - } - } else if (fieldVector instanceof IntVector) { - var vector = (IntVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (int) values[i]); - } - } - } else if (fieldVector instanceof BigIntVector) { - var vector = (BigIntVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (long) values[i]); - } - } - } else if (fieldVector instanceof Float4Vector) { - var vector = (Float4Vector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (float) values[i]); - } - } - } else if (fieldVector instanceof Float8Vector) { - var vector = (Float8Vector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (double) values[i]); - } - } - } else if (fieldVector instanceof LargeVarBinaryVector) { - var vector = (LargeVarBinaryVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - // use `toPlainString` to avoid scientific notation - vector.set(i, ((BigDecimal) values[i]).toPlainString().getBytes()); - } - } - } else if (fieldVector instanceof DateDayVector) { - var vector = (DateDayVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (int) ((LocalDate) values[i]).toEpochDay()); - } - } - } else if (fieldVector instanceof TimeMicroVector) { - var vector = (TimeMicroVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, ((LocalTime) values[i]).toNanoOfDay() / 1000); - } - } - } else if (fieldVector instanceof TimeStampMicroVector) { - var vector = (TimeStampMicroVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, timestampToMicros((LocalDateTime) values[i])); - } - } - } else if (fieldVector instanceof IntervalMonthDayNanoVector) { - var vector = (IntervalMonthDayNanoVector) fieldVector; - vector.allocateNew(values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - var pd = (PeriodDuration) values[i]; - var months = (int) pd.getPeriod().toTotalMonths(); - var days = pd.getPeriod().getDays(); - var nanos = pd.getDuration().toNanos(); - vector.set(i, months, days, nanos); - } - } - } else if (fieldVector instanceof VarCharVector) { - var vector = (VarCharVector) fieldVector; - int totalBytes = 0; - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - totalBytes += ((String) values[i]).length(); - } - } - vector.allocateNew(totalBytes, values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, ((String) values[i]).getBytes()); - } - } - } else if (fieldVector instanceof LargeVarCharVector) { - var vector = (LargeVarCharVector) fieldVector; - int totalBytes = 0; - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - totalBytes += ((String) values[i]).length(); - } - } - vector.allocateNew(totalBytes, values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, ((String) values[i]).getBytes()); - } - } - } else if (fieldVector instanceof VarBinaryVector) { - var vector = (VarBinaryVector) fieldVector; - int totalBytes = 0; - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - totalBytes += ((byte[]) values[i]).length; - } - } - vector.allocateNew(totalBytes, values.length); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.set(i, (byte[]) values[i]); - } - } - } else if (fieldVector instanceof ListVector) { - var vector = (ListVector) fieldVector; - vector.allocateNew(); - // flatten the `values` - var flattenLength = 0; - for (int i = 0; i < values.length; i++) { - if (values[i] == null) { - continue; - } - var len = Array.getLength(values[i]); - vector.startNewValue(i); - vector.endValue(i, len); - flattenLength += len; - } - var flattenValues = new Object[flattenLength]; - var ii = 0; - for (var list : values) { - if (list == null) { - continue; - } - var length = Array.getLength(list); - for (int i = 0; i < length; i++) { - flattenValues[ii++] = Array.get(list, i); - } - } - // fill the inner vector - fillVector(vector.getDataVector(), flattenValues); - } else if (fieldVector instanceof StructVector) { - var vector = (StructVector) fieldVector; - vector.allocateNew(); - var lookup = MethodHandles.lookup(); - // get class of the first non-null value - Class valueClass = null; - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - valueClass = values[i].getClass(); - break; - } - } - for (var field : vector.getField().getChildren()) { - // extract field from values - var subvalues = new Object[values.length]; - if (valueClass != null) { - try { - var javaField = valueClass.getDeclaredField(field.getName()); - var varHandle = lookup.unreflectVarHandle(javaField); - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - subvalues[i] = varHandle.get(values[i]); - } - } - } catch (NoSuchFieldException | IllegalAccessException e) { - throw new RuntimeException(e); - } - } - var subvector = vector.getChild(field.getName()); - fillVector(subvector, subvalues); - } - for (int i = 0; i < values.length; i++) { - if (values[i] != null) { - vector.setIndexDefined(i); - } - } - } else { - throw new IllegalArgumentException("Unsupported type: " + fieldVector.getClass()); - } - fieldVector.setValueCount(values.length); - } - - static long timestampToMicros(LocalDateTime timestamp) { - var date = timestamp.toLocalDate().toEpochDay(); - var time = timestamp.toLocalTime().toNanoOfDay(); - return date * 24 * 3600 * 1000 * 1000 + time / 1000; - } - - /** Return a function that converts the object get from input array to the correct type. */ - static Function processFunc(Field field, Class targetClass) { - var inner = processFunc0(field, targetClass); - return obj -> obj == null ? null : inner.apply(obj); - } - - static Function processFunc0(Field field, Class targetClass) { - if (field.getType() instanceof ArrowType.Utf8 && targetClass == String.class) { - // object is org.apache.arrow.vector.util.Text - return obj -> obj.toString(); - } else if (field.getType() instanceof ArrowType.LargeUtf8 && targetClass == String.class) { - // object is org.apache.arrow.vector.util.Text - return obj -> obj.toString(); - } else if (field.getType() instanceof ArrowType.LargeBinary - && targetClass == BigDecimal.class) { - // object is byte[] - return obj -> new BigDecimal(new String((byte[]) obj)); - } else if (field.getType() instanceof ArrowType.Date && targetClass == LocalDate.class) { - // object is Integer - return obj -> LocalDate.ofEpochDay((int) obj); - } else if (field.getType() instanceof ArrowType.Time && targetClass == LocalTime.class) { - // object is Long - return obj -> LocalTime.ofNanoOfDay((long) obj * 1000); - } else if (field.getType() instanceof ArrowType.Interval - && targetClass == PeriodDuration.class) { - // object is arrow PeriodDuration - return obj -> new PeriodDuration((org.apache.arrow.vector.PeriodDuration) obj); - } else if (field.getType() instanceof ArrowType.List) { - // object is List - var subfield = field.getChildren().get(0); - var subfunc = processFunc(subfield, targetClass.getComponentType()); - if (subfield.getType() instanceof ArrowType.Bool) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Boolean[]::new); - } else if (subfield.getType().equals(new ArrowType.Int(16, true))) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Short[]::new); - } else if (subfield.getType().equals(new ArrowType.Int(32, true))) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Integer[]::new); - } else if (subfield.getType().equals(new ArrowType.Int(64, true))) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Long[]::new); - } else if (subfield.getType() - .equals(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Float[]::new); - } else if (subfield.getType() - .equals(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { - return obj -> ((List) obj).stream().map(subfunc).toArray(Double[]::new); - } else if (subfield.getType() instanceof ArrowType.LargeBinary) { - return obj -> ((List) obj).stream().map(subfunc).toArray(BigDecimal[]::new); - } else if (subfield.getType() instanceof ArrowType.Date) { - return obj -> ((List) obj).stream().map(subfunc).toArray(LocalDate[]::new); - } else if (subfield.getType() instanceof ArrowType.Time) { - return obj -> ((List) obj).stream().map(subfunc).toArray(LocalTime[]::new); - } else if (subfield.getType() instanceof ArrowType.Timestamp) { - return obj -> ((List) obj).stream().map(subfunc).toArray(LocalDateTime[]::new); - } else if (subfield.getType() instanceof ArrowType.Interval) { - return obj -> ((List) obj).stream().map(subfunc).toArray(PeriodDuration[]::new); - } else if (subfield.getType() instanceof ArrowType.Utf8) { - return obj -> ((List) obj).stream().map(subfunc).toArray(String[]::new); - } else if (subfield.getType() instanceof ArrowType.LargeUtf8) { - return obj -> ((List) obj).stream().map(subfunc).toArray(String[]::new); - } else if (subfield.getType() instanceof ArrowType.Binary) { - return obj -> ((List) obj).stream().map(subfunc).toArray(byte[][]::new); - } else if (subfield.getType() instanceof ArrowType.Struct) { - return obj -> { - var list = (List) obj; - Object array = Array.newInstance(targetClass.getComponentType(), list.size()); - for (int i = 0; i < list.size(); i++) { - Array.set(array, i, subfunc.apply(list.get(i))); - } - return array; - }; - } - throw new IllegalArgumentException("Unsupported type: " + subfield.getType()); - } else if (field.getType() instanceof ArrowType.Struct) { - // object is org.apache.arrow.vector.util.JsonStringHashMap - var subfields = field.getChildren(); - @SuppressWarnings("unchecked") - Function[] subfunc = new Function[subfields.size()]; - for (int i = 0; i < subfields.size(); i++) { - subfunc[i] = processFunc(subfields.get(i), targetClass.getFields()[i].getType()); - } - return obj -> { - var map = (AbstractMap) obj; - try { - var row = targetClass.getDeclaredConstructor().newInstance(); - for (int i = 0; i < subfields.size(); i++) { - var field0 = targetClass.getFields()[i]; - var val = subfunc[i].apply(map.get(field0.getName())); - field0.set(row, val); - } - return row; - } catch (InstantiationException - | IllegalAccessException - | InvocationTargetException - | NoSuchMethodException e) { - throw new RuntimeException(e); - } - }; - } - return Function.identity(); - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java b/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java deleted file mode 100644 index 692d898acaf8..000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/UdfProducer.java +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import org.apache.arrow.flight.*; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorLoader; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.Schema; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -class UdfProducer extends NoOpFlightProducer { - - private BufferAllocator allocator; - private HashMap functions = new HashMap<>(); - private static final Logger logger = LoggerFactory.getLogger(UdfServer.class); - - UdfProducer(BufferAllocator allocator) { - this.allocator = allocator; - } - - void addFunction(String name, UserDefinedFunction function) throws IllegalArgumentException { - UserDefinedFunctionBatch udf; - if (function instanceof ScalarFunction) { - udf = new ScalarFunctionBatch((ScalarFunction) function); - } else if (function instanceof TableFunction) { - udf = new TableFunctionBatch((TableFunction) function); - } else { - throw new IllegalArgumentException( - "Unknown function type: " + function.getClass().getName()); - } - if (functions.containsKey(name)) { - throw new IllegalArgumentException("Function already exists: " + name); - } - functions.put(name, udf); - } - - @Override - public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { - try { - var functionName = descriptor.getPath().get(0); - var udf = functions.get(functionName); - if (udf == null) { - throw new IllegalArgumentException("Unknown function: " + functionName); - } - var fields = new ArrayList(); - fields.addAll(udf.getInputSchema().getFields()); - fields.addAll(udf.getOutputSchema().getFields()); - var fullSchema = new Schema(fields); - var inputLen = udf.getInputSchema().getFields().size(); - - return new FlightInfo(fullSchema, descriptor, Collections.emptyList(), 0, inputLen); - } catch (Exception e) { - logger.error("Error occurred during getFlightInfo", e); - throw e; - } - } - - @Override - public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { - try (var allocator = this.allocator.newChildAllocator("exchange", 0, Long.MAX_VALUE)) { - var functionName = reader.getDescriptor().getPath().get(0); - logger.debug("call function: " + functionName); - - var udf = this.functions.get(functionName); - try (var root = VectorSchemaRoot.create(udf.getOutputSchema(), allocator)) { - var loader = new VectorLoader(root); - writer.start(root); - while (reader.next()) { - try (var input = reader.getRoot()) { - var outputBatches = udf.evalBatch(input, allocator); - while (outputBatches.hasNext()) { - try (var outputRoot = outputBatches.next()) { - var unloader = new VectorUnloader(outputRoot); - try (var outputBatch = unloader.getRecordBatch()) { - loader.load(outputBatch); - } - } - writer.putNext(); - } - } - } - writer.completed(); - } - } catch (Exception e) { - logger.error("Error occurred during UDF execution", e); - writer.error(e); - } - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/UdfServer.java b/java/udf/src/main/java/com/risingwave/functions/UdfServer.java deleted file mode 100644 index 66f2a8d3bb0d..000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/UdfServer.java +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.io.IOException; -import org.apache.arrow.flight.*; -import org.apache.arrow.memory.RootAllocator; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** A server that exposes user-defined functions over Apache Arrow Flight. */ -public class UdfServer implements AutoCloseable { - - private FlightServer server; - private UdfProducer producer; - private static final Logger logger = LoggerFactory.getLogger(UdfServer.class); - - public UdfServer(String host, int port) { - var location = Location.forGrpcInsecure(host, port); - var allocator = new RootAllocator(); - this.producer = new UdfProducer(allocator); - this.server = FlightServer.builder(allocator, location, this.producer).build(); - } - - /** - * Add a user-defined function to the server. - * - * @param name the name of the function - * @param udf the function to add - * @throws IllegalArgumentException if a function with the same name already exists - */ - public void addFunction(String name, UserDefinedFunction udf) throws IllegalArgumentException { - logger.info("added function: " + name); - this.producer.addFunction(name, udf); - } - - /** - * Start the server. - * - * @throws IOException if the server fails to start - */ - public void start() throws IOException { - this.server.start(); - logger.info("listening on " + this.server.getLocation().toSocketAddress()); - } - - /** - * Get the port the server is listening on. - * - * @return the port number - */ - public int getPort() { - return this.server.getPort(); - } - - /** - * Wait for the server to terminate. - * - * @throws InterruptedException if the thread is interrupted while waiting - */ - public void awaitTermination() throws InterruptedException { - this.server.awaitTermination(); - } - - /** Close the server. */ - public void close() throws InterruptedException { - this.server.close(); - } -} diff --git a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunction.java b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunction.java deleted file mode 100644 index 3db6f1714cd8..000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunction.java +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -/** - * Base interface for all user-defined functions. - * - * @see ScalarFunction - * @see TableFunction - */ -public interface UserDefinedFunction {} diff --git a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java b/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java deleted file mode 100644 index e2c513a7954a..000000000000 --- a/java/udf/src/main/java/com/risingwave/functions/UserDefinedFunctionBatch.java +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; -import java.util.ArrayList; -import java.util.Iterator; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.types.pojo.Schema; - -/** Base class for a batch-processing user-defined function. */ -abstract class UserDefinedFunctionBatch { - protected Schema inputSchema; - protected Schema outputSchema; - - /** Get the input schema of the function. */ - Schema getInputSchema() { - return inputSchema; - } - - /** Get the output schema of the function. */ - Schema getOutputSchema() { - return outputSchema; - } - - /** - * Evaluate the function by processing a batch of input data. - * - * @param batch the input data batch to process - * @param allocator the allocator to use for allocating output data - * @return an iterator over the output data batches - */ - abstract Iterator evalBatch( - VectorSchemaRoot batch, BufferAllocator allocator); -} - -/** Utility class for reflection. */ -class Reflection { - /** Get the method named eval. */ - static Method getEvalMethod(UserDefinedFunction obj) { - var methods = new ArrayList(); - for (Method method : obj.getClass().getDeclaredMethods()) { - if (method.getName().equals("eval")) { - methods.add(method); - } - } - if (methods.size() != 1) { - throw new IllegalArgumentException( - "Exactly one eval method must be defined for class " - + obj.getClass().getName()); - } - var method = methods.get(0); - if (Modifier.isStatic(method.getModifiers())) { - throw new IllegalArgumentException( - "The eval method should not be static for class " + obj.getClass().getName()); - } - return method; - } - - /** Get the method handle of the given method. */ - static MethodHandle getMethodHandle(Method method) { - var lookup = MethodHandles.lookup(); - try { - return lookup.unreflect(method); - } catch (IllegalAccessException e) { - throw new IllegalArgumentException( - "The eval method must be public for class " - + method.getDeclaringClass().getName()); - } - } -} diff --git a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java b/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java deleted file mode 100644 index 5722efa1dd70..000000000000 --- a/java/udf/src/test/java/com/risingwave/functions/TestUdfServer.java +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.io.IOException; -import java.math.BigDecimal; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.LocalTime; -import java.util.Iterator; -import java.util.stream.IntStream; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.*; -import org.apache.arrow.vector.complex.StructVector; -import org.apache.arrow.vector.types.Types.MinorType; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -/** Unit test for UDF server. */ -public class TestUdfServer { - private static UdfClient client; - private static UdfServer server; - private static BufferAllocator allocator = new RootAllocator(); - - @BeforeAll - public static void setup() throws IOException { - server = new UdfServer("localhost", 0); - server.addFunction("gcd", new Gcd()); - server.addFunction("return_all", new ReturnAll()); - server.addFunction("series", new Series()); - server.start(); - - client = new UdfClient("localhost", server.getPort()); - } - - @AfterAll - public static void teardown() throws InterruptedException { - client.close(); - server.close(); - } - - public static class Gcd implements ScalarFunction { - public int eval(int a, int b) { - while (b != 0) { - int temp = b; - b = a % b; - a = temp; - } - return a; - } - } - - @Test - public void gcd() throws Exception { - var c0 = new IntVector("", allocator); - c0.allocateNew(1); - c0.set(0, 15); - c0.setValueCount(1); - - var c1 = new IntVector("", allocator); - c1.allocateNew(1); - c1.set(0, 12); - c1.setValueCount(1); - - var input = VectorSchemaRoot.of(c0, c1); - - try (var stream = client.call("gcd", input)) { - var output = stream.getRoot(); - assertTrue(stream.next()); - assertEquals("3", output.contentToTSVString().trim()); - } - } - - public static class ReturnAll implements ScalarFunction { - public static class Row { - public Boolean bool; - public Short i16; - public Integer i32; - public Long i64; - public Float f32; - public Double f64; - public BigDecimal decimal; - public LocalDate date; - public LocalTime time; - public LocalDateTime timestamp; - public PeriodDuration interval; - public String str; - public byte[] bytes; - public @DataTypeHint("JSONB") String jsonb; - public Struct struct; - } - - public static class Struct { - public Integer f1; - public Integer f2; - - public String toString() { - return String.format("(%d, %d)", f1, f2); - } - } - - public Row eval( - Boolean bool, - Short i16, - Integer i32, - Long i64, - Float f32, - Double f64, - BigDecimal decimal, - LocalDate date, - LocalTime time, - LocalDateTime timestamp, - PeriodDuration interval, - String str, - byte[] bytes, - @DataTypeHint("JSONB") String jsonb, - Struct struct) { - var row = new Row(); - row.bool = bool; - row.i16 = i16; - row.i32 = i32; - row.i64 = i64; - row.f32 = f32; - row.f64 = f64; - row.decimal = decimal; - row.date = date; - row.time = time; - row.timestamp = timestamp; - row.interval = interval; - row.str = str; - row.bytes = bytes; - row.jsonb = jsonb; - row.struct = struct; - return row; - } - } - - @Test - public void all_types() throws Exception { - var c0 = new BitVector("", allocator); - c0.allocateNew(2); - c0.set(0, 1); - c0.setValueCount(2); - - var c1 = new SmallIntVector("", allocator); - c1.allocateNew(2); - c1.set(0, 1); - c1.setValueCount(2); - - var c2 = new IntVector("", allocator); - c2.allocateNew(2); - c2.set(0, 1); - c2.setValueCount(2); - - var c3 = new BigIntVector("", allocator); - c3.allocateNew(2); - c3.set(0, 1); - c3.setValueCount(2); - - var c4 = new Float4Vector("", allocator); - c4.allocateNew(2); - c4.set(0, 1); - c4.setValueCount(2); - - var c5 = new Float8Vector("", allocator); - c5.allocateNew(2); - c5.set(0, 1); - c5.setValueCount(2); - - var c6 = new LargeVarBinaryVector("", allocator); - c6.allocateNew(2); - c6.set(0, "1.234".getBytes()); - c6.setValueCount(2); - - var c7 = new DateDayVector("", allocator); - c7.allocateNew(2); - c7.set(0, (int) LocalDate.of(2023, 1, 1).toEpochDay()); - c7.setValueCount(2); - - var c8 = new TimeMicroVector("", allocator); - c8.allocateNew(2); - c8.set(0, LocalTime.of(1, 2, 3).toNanoOfDay() / 1000); - c8.setValueCount(2); - - var c9 = new TimeStampMicroVector("", allocator); - c9.allocateNew(2); - var ts = LocalDateTime.of(2023, 1, 1, 1, 2, 3); - c9.set( - 0, - ts.toLocalDate().toEpochDay() * 24 * 3600 * 1000000 - + ts.toLocalTime().toNanoOfDay() / 1000); - c9.setValueCount(2); - - var c10 = - new IntervalMonthDayNanoVector( - "", - FieldType.nullable(MinorType.INTERVALMONTHDAYNANO.getType()), - allocator); - c10.allocateNew(2); - c10.set(0, 1000, 2000, 3000); - c10.setValueCount(2); - - var c11 = new VarCharVector("", allocator); - c11.allocateNew(2); - c11.set(0, "string".getBytes()); - c11.setValueCount(2); - - var c12 = new VarBinaryVector("", allocator); - c12.allocateNew(2); - c12.set(0, "bytes".getBytes()); - c12.setValueCount(2); - - var c13 = new LargeVarCharVector("", allocator); - c13.allocateNew(2); - c13.set(0, "{ key: 1 }".getBytes()); - c13.setValueCount(2); - - var c14 = - new StructVector( - "", allocator, FieldType.nullable(ArrowType.Struct.INSTANCE), null); - c14.allocateNew(); - var f1 = c14.addOrGet("f1", FieldType.nullable(MinorType.INT.getType()), IntVector.class); - var f2 = c14.addOrGet("f2", FieldType.nullable(MinorType.INT.getType()), IntVector.class); - f1.allocateNew(2); - f2.allocateNew(2); - f1.set(0, 1); - f2.set(0, 2); - c14.setIndexDefined(0); - c14.setValueCount(2); - - var input = - VectorSchemaRoot.of( - c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14); - - try (var stream = client.call("return_all", input)) { - var output = stream.getRoot(); - assertTrue(stream.next()); - assertEquals( - "{\"bool\":true,\"i16\":1,\"i32\":1,\"i64\":1,\"f32\":1.0,\"f64\":1.0,\"decimal\":\"MS4yMzQ=\",\"date\":19358,\"time\":3723000000,\"timestamp\":[2023,1,1,1,2,3],\"interval\":{\"period\":\"P1000M2000D\",\"duration\":0.000003000},\"str\":\"string\",\"bytes\":\"Ynl0ZXM=\",\"jsonb\":\"{ key: 1 }\",\"struct\":{\"f1\":1,\"f2\":2}}\n{}", - output.contentToTSVString().trim()); - } - } - - public static class Series implements TableFunction { - public Iterator eval(int n) { - return IntStream.range(0, n).iterator(); - } - } - - @Test - public void series() throws Exception { - var c0 = new IntVector("", allocator); - c0.allocateNew(3); - c0.set(0, 0); - c0.set(1, 1); - c0.set(2, 2); - c0.setValueCount(3); - - var input = VectorSchemaRoot.of(c0); - - try (var stream = client.call("series", input)) { - var output = stream.getRoot(); - assertTrue(stream.next()); - assertEquals("row_index\t\n1\t0\n2\t0\n2\t1\n", output.contentToTSVString()); - } - } -} diff --git a/java/udf/src/test/java/com/risingwave/functions/UdfClient.java b/java/udf/src/test/java/com/risingwave/functions/UdfClient.java deleted file mode 100644 index 12728bf64fbe..000000000000 --- a/java/udf/src/test/java/com/risingwave/functions/UdfClient.java +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.functions; - -import org.apache.arrow.flight.*; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; - -public class UdfClient implements AutoCloseable { - - private FlightClient client; - - public UdfClient(String host, int port) { - var allocator = new RootAllocator(); - var location = Location.forGrpcInsecure(host, port); - this.client = FlightClient.builder(allocator, location).build(); - } - - public void close() throws InterruptedException { - this.client.close(); - } - - public FlightInfo getFlightInfo(String functionName) { - var descriptor = FlightDescriptor.command(functionName.getBytes()); - return client.getInfo(descriptor); - } - - public FlightStream call(String functionName, VectorSchemaRoot root) { - var descriptor = FlightDescriptor.path(functionName); - var readerWriter = client.doExchange(descriptor); - var writer = readerWriter.getWriter(); - var reader = readerWriter.getReader(); - - writer.start(root); - writer.putNext(); - writer.completed(); - return reader; - } -} diff --git a/src/common/src/array/arrow/arrow_udf.rs b/src/common/src/array/arrow/arrow_udf.rs index e2f9e39ad385..5a44ef143961 100644 --- a/src/common/src/array/arrow/arrow_udf.rs +++ b/src/common/src/array/arrow/arrow_udf.rs @@ -29,68 +29,99 @@ use crate::array::{ArrayError, ArrayImpl, DataType, DecimalArray, JsonbArray}; #[path = "./arrow_impl.rs"] mod arrow_impl; -/// Arrow conversion for the current version of UDF. This is in use but will be deprecated soon. -/// -/// In the current version of UDF protocol, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types. -pub struct UdfArrowConvert; +/// Arrow conversion for UDF. +#[derive(Default, Debug)] +pub struct UdfArrowConvert { + /// Whether the UDF talks in legacy mode. + /// + /// If true, decimal and jsonb types are mapped to Arrow `LargeBinary` and `LargeUtf8` types. + /// Otherwise, they are mapped to Arrow extension types. + /// See . + pub legacy: bool, +} impl ToArrow for UdfArrowConvert { - // Decimal values are stored as ASCII text representation in a large binary array. fn decimal_to_arrow( &self, _data_type: &arrow_schema::DataType, array: &DecimalArray, ) -> Result { - Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) + if self.legacy { + // Decimal values are stored as ASCII text representation in a large binary array. + Ok(Arc::new(arrow_array::LargeBinaryArray::from(array))) + } else { + Ok(Arc::new(arrow_array::StringArray::from(array))) + } } - // JSON values are stored as text representation in a large string array. fn jsonb_to_arrow(&self, array: &JsonbArray) -> Result { - Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + if self.legacy { + // JSON values are stored as text representation in a large string array. + Ok(Arc::new(arrow_array::LargeStringArray::from(array))) + } else { + Ok(Arc::new(arrow_array::StringArray::from(array))) + } } fn jsonb_type_to_arrow(&self, name: &str) -> arrow_schema::Field { - arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true) + if self.legacy { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeUtf8, true) + } else { + arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true) + .with_metadata([("ARROW:extension:name".into(), "arrowudf.json".into())].into()) + } } fn decimal_type_to_arrow(&self, name: &str) -> arrow_schema::Field { - arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true) + if self.legacy { + arrow_schema::Field::new(name, arrow_schema::DataType::LargeBinary, true) + } else { + arrow_schema::Field::new(name, arrow_schema::DataType::Utf8, true) + .with_metadata([("ARROW:extension:name".into(), "arrowudf.decimal".into())].into()) + } } } impl FromArrow for UdfArrowConvert { fn from_large_utf8(&self) -> Result { - Ok(DataType::Jsonb) + if self.legacy { + Ok(DataType::Jsonb) + } else { + Ok(DataType::Varchar) + } } fn from_large_binary(&self) -> Result { - Ok(DataType::Decimal) + if self.legacy { + Ok(DataType::Decimal) + } else { + Ok(DataType::Bytea) + } } fn from_large_utf8_array( &self, array: &arrow_array::LargeStringArray, ) -> Result { - Ok(ArrayImpl::Jsonb(array.try_into()?)) + if self.legacy { + Ok(ArrayImpl::Jsonb(array.try_into()?)) + } else { + Ok(ArrayImpl::Utf8(array.into())) + } } fn from_large_binary_array( &self, array: &arrow_array::LargeBinaryArray, ) -> Result { - Ok(ArrayImpl::Decimal(array.try_into()?)) + if self.legacy { + Ok(ArrayImpl::Decimal(array.try_into()?)) + } else { + Ok(ArrayImpl::Bytea(array.into())) + } } } -/// Arrow conversion for the next version of UDF. This is unused for now. -/// -/// In the next version of UDF protocol, decimal and jsonb types will be mapped to Arrow extension types. -/// See . -pub struct NewUdfArrowConvert; - -impl ToArrow for NewUdfArrowConvert {} -impl FromArrow for NewUdfArrowConvert {} - #[cfg(test)] mod tests { use std::sync::Arc; @@ -104,7 +135,7 @@ mod tests { // Empty array - risingwave to arrow conversion. let test_arr = StructArray::new(StructType::empty(), vec![], Bitmap::ones(0)); assert_eq!( - UdfArrowConvert + UdfArrowConvert::default() .struct_to_arrow( &arrow_schema::DataType::Struct(arrow_schema::Fields::empty()), &test_arr @@ -117,7 +148,7 @@ mod tests { // Empty array - arrow to risingwave conversion. let test_arr_2 = arrow_array::StructArray::from(vec![]); assert_eq!( - UdfArrowConvert + UdfArrowConvert::default() .from_struct_array(&test_arr_2) .unwrap() .len(), @@ -146,7 +177,7 @@ mod tests { ), ]) .unwrap(); - let actual_risingwave_struct_array = UdfArrowConvert + let actual_risingwave_struct_array = UdfArrowConvert::default() .from_struct_array(&test_arrow_struct_array) .unwrap() .into_struct(); @@ -168,8 +199,10 @@ mod tests { fn list() { let array = ListArray::from_iter([None, Some(vec![0, -127, 127, 50]), Some(vec![0; 0])]); let data_type = arrow_schema::DataType::new_list(arrow_schema::DataType::Int32, true); - let arrow = UdfArrowConvert.list_to_arrow(&data_type, &array).unwrap(); - let rw_array = UdfArrowConvert + let arrow = UdfArrowConvert::default() + .list_to_arrow(&data_type, &array) + .unwrap(); + let rw_array = UdfArrowConvert::default() .from_list_array(arrow.as_any().downcast_ref().unwrap()) .unwrap(); assert_eq!(rw_array.as_list(), &array); diff --git a/src/common/src/array/data_chunk.rs b/src/common/src/array/data_chunk.rs index cdb012a3185c..6d5b2247979c 100644 --- a/src/common/src/array/data_chunk.rs +++ b/src/common/src/array/data_chunk.rs @@ -249,6 +249,7 @@ impl DataChunk { Self::new(columns, Bitmap::ones(cardinality)) } + /// Scatter a compacted chunk to a new chunk with the given visibility. pub fn uncompact(self, vis: Bitmap) -> Self { let mut uncompact_builders: Vec<_> = self .columns diff --git a/src/expr/core/Cargo.toml b/src/expr/core/Cargo.toml index 3f5ca590026d..c811d81b3465 100644 --- a/src/expr/core/Cargo.toml +++ b/src/expr/core/Cargo.toml @@ -22,7 +22,9 @@ embedded-python-udf = ["arrow-udf-python"] [dependencies] anyhow = "1" arrow-array = { workspace = true } +arrow-flight = "50" arrow-schema = { workspace = true } +arrow-udf-flight = { workspace = true } arrow-udf-js = { workspace = true } arrow-udf-js-deno = { workspace = true, optional = true } arrow-udf-python = { workspace = true, optional = true } @@ -44,6 +46,7 @@ enum-as-inner = "0.6" futures = "0.3" futures-async-stream = { workspace = true } futures-util = "0.3" +ginepro = "0.7" itertools = { workspace = true } linkme = { version = "0.3", features = ["used_linker"] } md5 = "0.7" @@ -52,11 +55,11 @@ num-traits = "0.2" openssl = { version = "0.10", features = ["vendored"] } parse-display = "0.9" paste = "1" +prometheus = "0.13" risingwave_common = { workspace = true } risingwave_common_estimate_size = { workspace = true } risingwave_expr_macro = { path = "../macro" } risingwave_pb = { workspace = true } -risingwave_udf = { workspace = true } smallvec = "1" static_assertions = "1" thiserror = "1" @@ -65,6 +68,7 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [ "rt-multi-thread", "macros", ] } +tonic = "0.10" tracing = "0.1" zstd = { version = "0.13", default-features = false } diff --git a/src/expr/core/src/error.rs b/src/expr/core/src/error.rs index 6688824093d2..08562b3a973b 100644 --- a/src/expr/core/src/error.rs +++ b/src/expr/core/src/error.rs @@ -99,7 +99,7 @@ pub enum ExprError { Udf( #[from] #[backtrace] - risingwave_udf::Error, + Box, ), #[error("not a constant")] @@ -119,6 +119,10 @@ pub enum ExprError { #[error("error in cryptography: {0}")] Cryptography(Box), + + /// Function error message returned by UDF. + #[error("{0}")] + Custom(String), } #[derive(Debug)] @@ -152,6 +156,12 @@ impl From for ExprError { } } +impl From for ExprError { + fn from(err: arrow_udf_flight::Error) -> Self { + Self::Udf(Box::new(err)) + } +} + /// A collection of multiple errors. #[derive(Error, Debug)] pub struct MultiExprError(Box<[ExprError]>); @@ -178,6 +188,12 @@ impl From> for MultiExprError { } } +impl FromIterator for MultiExprError { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + impl IntoIterator for MultiExprError { type IntoIter = std::vec::IntoIter; type Item = ExprError; diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index b9103b62649e..54d3006dc303 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -18,7 +18,9 @@ use std::sync::{Arc, LazyLock, Weak}; use std::time::Duration; use anyhow::{Context, Error}; +use arrow_array::RecordBatch; use arrow_schema::{Fields, Schema, SchemaRef}; +use arrow_udf_flight::Client as FlightClient; use arrow_udf_js::{CallMode as JsCallMode, Runtime as JsRuntime}; #[cfg(feature = "embedded-deno-udf")] use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; @@ -27,19 +29,25 @@ use arrow_udf_python::{CallMode as PythonCallMode, Runtime as PythonRuntime}; use arrow_udf_wasm::Runtime as WasmRuntime; use await_tree::InstrumentAwait; use cfg_or_panic::cfg_or_panic; +use ginepro::{LoadBalancedChannel, ResolutionStrategy}; use moka::sync::Cache; +use prometheus::{ + exponential_buckets, register_histogram_vec_with_registry, + register_int_counter_vec_with_registry, HistogramVec, IntCounter, IntCounterVec, Registry, +}; use risingwave_common::array::arrow::{FromArrow, ToArrow, UdfArrowConvert}; -use risingwave_common::array::{ArrayRef, DataChunk}; +use risingwave_common::array::{Array, ArrayRef, DataChunk}; +use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; +use risingwave_common::util::addr::HostAddr; use risingwave_expr::expr_context::FRAGMENT_ID; use risingwave_pb::expr::ExprNode; -use risingwave_udf::metrics::GLOBAL_METRICS; -use risingwave_udf::ArrowFlightUdfClient; +use thiserror_ext::AsReport; use super::{BoxedExpression, Build}; use crate::expr::Expression; -use crate::{bail, Result}; +use crate::{bail, ExprError, Result}; #[derive(Debug)] pub struct UserDefinedFunction { @@ -49,6 +57,8 @@ pub struct UserDefinedFunction { arg_schema: SchemaRef, imp: UdfImpl, identifier: String, + link: Option, + arrow_convert: UdfArrowConvert, span: await_tree::Span, /// Number of remaining successful calls until retry is enabled. /// This parameter is designed to prevent continuous retry on every call, which would increase delay. @@ -68,7 +78,7 @@ const INITIAL_RETRY_COUNT: u8 = 16; #[derive(Debug)] pub enum UdfImpl { - External(Arc), + External(Arc), Wasm(Arc), JavaScript(JsRuntime), #[cfg(feature = "embedded-python-udf")] @@ -115,7 +125,9 @@ impl Expression for UserDefinedFunction { impl UserDefinedFunction { async fn eval_inner(&self, input: &DataChunk) -> Result { // this will drop invisible rows - let arrow_input = UdfArrowConvert.to_record_batch(self.arg_schema.clone(), input)?; + let arrow_input = self + .arrow_convert + .to_record_batch(self.arg_schema.clone(), input)?; // metrics let metrics = &*GLOBAL_METRICS; @@ -123,10 +135,6 @@ impl UserDefinedFunction { let fragment_id = FRAGMENT_ID::try_with(ToOwned::to_owned) .unwrap_or(0) .to_string(); - let addr = match &self.imp { - UdfImpl::External(client) => client.get_addr(), - _ => "", - }; let language = match &self.imp { UdfImpl::Wasm(_) => "wasm", UdfImpl::JavaScript(_) => "javascript(quickjs)", @@ -136,7 +144,12 @@ impl UserDefinedFunction { UdfImpl::Deno(_) => "javascript(deno)", UdfImpl::External(_) => "external", }; - let labels: &[&str; 4] = &[addr, language, &self.identifier, fragment_id.as_str()]; + let labels: &[&str; 4] = &[ + self.link.as_deref().unwrap_or(""), + language, + &self.identifier, + fragment_id.as_str(), + ]; metrics .udf_input_chunk_rows .with_label_values(labels) @@ -164,28 +177,27 @@ impl UserDefinedFunction { UdfImpl::External(client) => { let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); let result = if self.always_retry_on_network_error { - client - .call_with_always_retry_on_network_error( - &self.identifier, - arrow_input, - &fragment_id, - ) - .instrument_await(self.span.clone()) - .await + call_with_always_retry_on_network_error( + client, + &self.identifier, + &arrow_input, + &metrics.udf_retry_count.with_label_values(labels), + ) + .instrument_await(self.span.clone()) + .await } else { let result = if disable_retry_count != 0 { client - .call(&self.identifier, arrow_input) + .call(&self.identifier, &arrow_input) .instrument_await(self.span.clone()) .await } else { - client - .call_with_retry(&self.identifier, arrow_input) + call_with_retry(client, &self.identifier, &arrow_input) .instrument_await(self.span.clone()) .await }; let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed); - let connection_error = matches!(&result, Err(e) if e.is_connection_error()); + let connection_error = matches!(&result, Err(e) if is_connection_error(e)); if connection_error && disable_retry_count != INITIAL_RETRY_COUNT { // reset count on connection error self.disable_retry_count @@ -223,7 +235,7 @@ impl UserDefinedFunction { ); } - let output = UdfArrowConvert.from_record_batch(&arrow_output)?; + let output = self.arrow_convert.from_record_batch(&arrow_output)?; let output = output.uncompact(input.visibility().clone()); let Some(array) = output.columns().first() else { @@ -237,10 +249,72 @@ impl UserDefinedFunction { ); } + // handle optional error column + if let Some(errors) = output.columns().get(1) { + if errors.data_type() != DataType::Varchar { + bail!( + "UDF returned errors column with invalid type: {:?}", + errors.data_type() + ); + } + let errors = errors + .as_utf8() + .iter() + .filter_map(|msg| msg.map(|s| ExprError::Custom(s.into()))) + .collect(); + return Err(crate::ExprError::Multiple(array.clone(), errors)); + } + Ok(array.clone()) } } +/// Call a function, retry up to 5 times / 3s if connection is broken. +async fn call_with_retry( + client: &FlightClient, + id: &str, + input: &RecordBatch, +) -> Result { + let mut backoff = Duration::from_millis(100); + for i in 0..5 { + match client.call(id, input).await { + Err(err) if is_connection_error(&err) && i != 4 => { + tracing::error!(error = %err.as_report(), "UDF connection error. retry..."); + } + ret => return ret, + } + tokio::time::sleep(backoff).await; + backoff *= 2; + } + unreachable!() +} + +/// Always retry on connection error +async fn call_with_always_retry_on_network_error( + client: &FlightClient, + id: &str, + input: &RecordBatch, + retry_count: &IntCounter, +) -> Result { + let mut backoff = Duration::from_millis(100); + loop { + match client.call(id, input).await { + Err(err) if is_tonic_error(&err) => { + tracing::error!(error = %err.as_report(), "UDF tonic error. retry..."); + } + ret => { + if ret.is_err() { + tracing::error!(error = %ret.as_ref().unwrap_err().as_report(), "UDF error. exiting..."); + } + return ret; + } + } + retry_count.inc(); + tokio::time::sleep(backoff).await; + backoff *= 2; + } +} + impl Build for UserDefinedFunction { fn build( prost: &ExprNode, @@ -248,11 +322,7 @@ impl Build for UserDefinedFunction { ) -> Result { let return_type = DataType::from(prost.get_return_type().unwrap()); let udf = prost.get_rex_node().unwrap().as_udf().unwrap(); - - let arrow_return_type = UdfArrowConvert - .to_arrow_field("", &return_type)? - .data_type() - .clone(); + let mut arrow_convert = UdfArrowConvert::default(); #[cfg(not(feature = "embedded-deno-udf"))] let runtime = "quickjs"; @@ -271,6 +341,11 @@ impl Build for UserDefinedFunction { let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice()) .context("failed to decompress wasm binary")?; let runtime = get_or_create_wasm_runtime(&wasm_binary)?; + // backward compatibility + // see for details + if runtime.abi_version().0 <= 2 { + arrow_convert = UdfArrowConvert { legacy: true }; + } UdfImpl::Wasm(runtime) } "javascript" if runtime != "deno" => { @@ -283,7 +358,7 @@ impl Build for UserDefinedFunction { ); rt.add_function( identifier, - arrow_return_type, + arrow_convert.to_arrow_field("", &return_type)?, JsCallMode::CalledOnNullInput, &body, )?; @@ -324,7 +399,7 @@ impl Build for UserDefinedFunction { futures::executor::block_on(rt.add_function( identifier, - arrow_return_type, + arrow_convert.to_arrow_field("", &return_type)?, DenoCallMode::CalledOnNullInput, &body, ))?; @@ -337,7 +412,7 @@ impl Build for UserDefinedFunction { let body = udf.get_body()?; rt.add_function( identifier, - arrow_return_type, + arrow_convert.to_arrow_field("", &return_type)?, PythonCallMode::CalledOnNullInput, body, )?; @@ -346,7 +421,13 @@ impl Build for UserDefinedFunction { #[cfg(not(madsim))] _ => { let link = udf.get_link()?; - UdfImpl::External(get_or_create_flight_client(link)?) + let client = get_or_create_flight_client(link)?; + // backward compatibility + // see for details + if client.protocol_version() == 1 { + arrow_convert = UdfArrowConvert { legacy: true }; + } + UdfImpl::External(client) } #[cfg(madsim)] l => panic!("UDF language {l:?} is not supported on madsim"), @@ -355,7 +436,7 @@ impl Build for UserDefinedFunction { let arg_schema = Arc::new(Schema::new( udf.arg_types .iter() - .map(|t| UdfArrowConvert.to_arrow_field("", &DataType::from(t))) + .map(|t| arrow_convert.to_arrow_field("", &DataType::from(t))) .try_collect::()?, )); @@ -366,6 +447,8 @@ impl Build for UserDefinedFunction { arg_schema, imp, identifier: identifier.clone(), + link: udf.link.clone(), + arrow_convert, span: format!("udf_call({})", identifier).into(), disable_retry_count: AtomicU8::new(0), always_retry_on_network_error: udf.always_retry_on_network_error, @@ -373,12 +456,12 @@ impl Build for UserDefinedFunction { } } -#[cfg(not(madsim))] +#[cfg_or_panic(not(madsim))] /// Get or create a client for the given UDF service. /// /// There is a global cache for clients, so that we can reuse the same client for the same service. -pub(crate) fn get_or_create_flight_client(link: &str) -> Result> { - static CLIENTS: LazyLock>>> = +pub fn get_or_create_flight_client(link: &str) -> Result> { + static CLIENTS: LazyLock>>> = LazyLock::new(Default::default); let mut clients = CLIENTS.lock().unwrap(); if let Some(client) = clients.get(link).and_then(|c| c.upgrade()) { @@ -386,12 +469,58 @@ pub(crate) fn get_or_create_flight_client(link: &str) -> Result + }) + })?); + clients.insert(link.to_owned(), Arc::downgrade(&client)); Ok(client) } } +/// Connect to a UDF service and return a tonic `Channel`. +async fn connect_tonic(mut addr: &str) -> Result { + // Interval between two successive probes of the UDF DNS. + const DNS_PROBE_INTERVAL_SECS: u64 = 5; + // Timeout duration for performing an eager DNS resolution. + const EAGER_DNS_RESOLVE_TIMEOUT_SECS: u64 = 5; + const REQUEST_TIMEOUT_SECS: u64 = 5; + const CONNECT_TIMEOUT_SECS: u64 = 5; + + if let Some(s) = addr.strip_prefix("http://") { + addr = s; + } + if let Some(s) = addr.strip_prefix("https://") { + addr = s; + } + let host_addr = addr.parse::().map_err(|e| { + arrow_udf_flight::Error::Service(format!( + "invalid address: {}, err: {}", + addr, + e.as_report() + )) + })?; + let channel = LoadBalancedChannel::builder((host_addr.host.clone(), host_addr.port)) + .dns_probe_interval(std::time::Duration::from_secs(DNS_PROBE_INTERVAL_SECS)) + .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS)) + .connect_timeout(Duration::from_secs(CONNECT_TIMEOUT_SECS)) + .resolution_strategy(ResolutionStrategy::Eager { + timeout: tokio::time::Duration::from_secs(EAGER_DNS_RESOLVE_TIMEOUT_SECS), + }) + .channel() + .await + .map_err(|e| { + arrow_udf_flight::Error::Service(format!( + "failed to create LoadBalancedChannel, address: {}, err: {}", + host_addr, + e.as_report() + )) + })?; + Ok(channel.into()) +} + /// Get or create a wasm runtime. /// /// Runtimes returned by this function are cached inside for at least 60 seconds. @@ -413,3 +542,109 @@ pub fn get_or_create_wasm_runtime(binary: &[u8]) -> Result> { RUNTIMES.insert(md5, runtime.clone()); Ok(runtime) } + +/// Returns true if the arrow flight error is caused by a connection error. +fn is_connection_error(err: &arrow_udf_flight::Error) -> bool { + match err { + // Connection refused + arrow_udf_flight::Error::Tonic(status) if status.code() == tonic::Code::Unavailable => true, + _ => false, + } +} + +fn is_tonic_error(err: &arrow_udf_flight::Error) -> bool { + matches!( + err, + arrow_udf_flight::Error::Tonic(_) + | arrow_udf_flight::Error::Flight(arrow_flight::error::FlightError::Tonic(_)) + ) +} + +/// Monitor metrics for UDF. +#[derive(Debug, Clone)] +struct Metrics { + /// Number of successful UDF calls. + udf_success_count: IntCounterVec, + /// Number of failed UDF calls. + udf_failure_count: IntCounterVec, + /// Total number of retried UDF calls. + udf_retry_count: IntCounterVec, + /// Input chunk rows of UDF calls. + udf_input_chunk_rows: HistogramVec, + /// The latency of UDF calls in seconds. + udf_latency: HistogramVec, + /// Total number of input rows of UDF calls. + udf_input_rows: IntCounterVec, + /// Total number of input bytes of UDF calls. + udf_input_bytes: IntCounterVec, +} + +/// Global UDF metrics. +static GLOBAL_METRICS: LazyLock = LazyLock::new(|| Metrics::new(&GLOBAL_METRICS_REGISTRY)); + +impl Metrics { + fn new(registry: &Registry) -> Self { + let labels = &["link", "language", "name", "fragment_id"]; + let udf_success_count = register_int_counter_vec_with_registry!( + "udf_success_count", + "Total number of successful UDF calls", + labels, + registry + ) + .unwrap(); + let udf_failure_count = register_int_counter_vec_with_registry!( + "udf_failure_count", + "Total number of failed UDF calls", + labels, + registry + ) + .unwrap(); + let udf_retry_count = register_int_counter_vec_with_registry!( + "udf_retry_count", + "Total number of retried UDF calls", + labels, + registry + ) + .unwrap(); + let udf_input_chunk_rows = register_histogram_vec_with_registry!( + "udf_input_chunk_rows", + "Input chunk rows of UDF calls", + labels, + exponential_buckets(1.0, 2.0, 10).unwrap(), // 1 to 1024 + registry + ) + .unwrap(); + let udf_latency = register_histogram_vec_with_registry!( + "udf_latency", + "The latency(s) of UDF calls", + labels, + exponential_buckets(0.000001, 2.0, 30).unwrap(), // 1us to 1000s + registry + ) + .unwrap(); + let udf_input_rows = register_int_counter_vec_with_registry!( + "udf_input_rows", + "Total number of input rows of UDF calls", + labels, + registry + ) + .unwrap(); + let udf_input_bytes = register_int_counter_vec_with_registry!( + "udf_input_bytes", + "Total number of input bytes of UDF calls", + labels, + registry + ) + .unwrap(); + + Metrics { + udf_success_count, + udf_failure_count, + udf_retry_count, + udf_input_chunk_rows, + udf_latency, + udf_input_rows, + udf_input_bytes, + } + } +} diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 6dbb3906f561..9188ced21d11 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -51,7 +51,7 @@ use risingwave_common::types::{DataType, Datum}; pub use self::build::*; pub use self::expr_input_ref::InputRefExpression; pub use self::expr_literal::LiteralExpression; -pub use self::expr_udf::get_or_create_wasm_runtime; +pub use self::expr_udf::{get_or_create_flight_client, get_or_create_wasm_runtime}; pub use self::value::{ValueImpl, ValueRef}; pub use self::wrapper::*; pub use super::{ExprError, Result}; diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index 4362ff27b57b..bf8354df0ea4 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -23,7 +23,6 @@ use arrow_udf_js_deno::{CallMode as DenoCallMode, Runtime as DenoRuntime}; #[cfg(feature = "embedded-python-udf")] use arrow_udf_python::{CallMode as PythonCallMode, Runtime as PythonRuntime}; use cfg_or_panic::cfg_or_panic; -use futures_util::stream; use risingwave_common::array::arrow::{FromArrow, ToArrow, UdfArrowConvert}; use risingwave_common::array::{DataChunk, I32Array}; use risingwave_common::bail; @@ -38,6 +37,7 @@ pub struct UserDefinedTableFunction { return_type: DataType, client: UdfImpl, identifier: String, + arrow_convert: UdfArrowConvert, #[allow(dead_code)] chunk_size: usize, } @@ -61,10 +61,7 @@ impl UdfImpl { match self { UdfImpl::External(client) => { #[for_await] - for res in client - .call_stream(identifier, stream::once(async { input })) - .await? - { + for res in client.call_table_function(identifier, &input).await? { yield res?; } } @@ -110,8 +107,9 @@ impl UserDefinedTableFunction { // compact the input chunk and record the row mapping let visible_rows = direct_input.visibility().iter_ones().collect::>(); // this will drop invisible rows - let arrow_input = - UdfArrowConvert.to_record_batch(self.arg_schema.clone(), &direct_input)?; + let arrow_input = self + .arrow_convert + .to_record_batch(self.arg_schema.clone(), &direct_input)?; // call UDTF #[for_await] @@ -119,7 +117,7 @@ impl UserDefinedTableFunction { .client .call_table_function(&self.identifier, arrow_input) { - let output = UdfArrowConvert.from_record_batch(&res?)?; + let output = self.arrow_convert.from_record_batch(&res?)?; self.check_output(&output)?; // we send the compacted input to UDF, so we need to map the row indices back to the @@ -179,21 +177,9 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result()?, - )); - let identifier = udtf.get_identifier()?; let return_type = DataType::from(prost.get_return_type()?); - let arrow_return_type = UdfArrowConvert - .to_arrow_field("", &return_type)? - .data_type() - .clone(); - #[cfg(not(feature = "embedded-deno-udf"))] let runtime = "quickjs"; @@ -203,12 +189,18 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result "quickjs", }; + let mut arrow_convert = UdfArrowConvert::default(); + let client = match udtf.language.as_str() { "wasm" | "rust" => { let compressed_wasm_binary = udtf.get_compressed_binary()?; let wasm_binary = zstd::stream::decode_all(compressed_wasm_binary.as_slice()) .context("failed to decompress wasm binary")?; let runtime = crate::expr::expr_udf::get_or_create_wasm_runtime(&wasm_binary)?; + // backward compatibility + if runtime.abi_version().0 <= 2 { + arrow_convert = UdfArrowConvert { legacy: true }; + } UdfImpl::Wasm(runtime) } "javascript" if runtime != "deno" => { @@ -221,7 +213,7 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result Result Result Result { let link = udtf.get_link()?; - UdfImpl::External(crate::expr::expr_udf::get_or_create_flight_client(link)?) + let client = crate::expr::expr_udf::get_or_create_flight_client(link)?; + // backward compatibility + // see for details + if client.protocol_version() == 1 { + arrow_convert = UdfArrowConvert { legacy: true }; + } + UdfImpl::External(client) } }; + let arg_schema = Arc::new(Schema::new( + udtf.arg_types + .iter() + .map(|t| arrow_convert.to_arrow_field("", &DataType::from(t))) + .try_collect::()?, + )); + Ok(UserDefinedTableFunction { children: prost.args.iter().map(expr_build_from_prost).try_collect()?, return_type, arg_schema, client, identifier: identifier.clone(), + arrow_convert, chunk_size, } .boxed()) diff --git a/src/expr/udf/Cargo.toml b/src/expr/udf/Cargo.toml deleted file mode 100644 index b17ad7acadfc..000000000000 --- a/src/expr/udf/Cargo.toml +++ /dev/null @@ -1,35 +0,0 @@ -[package] -name = "risingwave_udf" -version = "0.1.0" -edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[package.metadata.cargo-machete] -ignored = ["workspace-hack"] - -[package.metadata.cargo-udeps.ignore] -normal = ["workspace-hack"] - -[dependencies] -arrow-array = { workspace = true } -arrow-flight = { workspace = true } -arrow-schema = { workspace = true } -arrow-select = { workspace = true } -cfg-or-panic = "0.2" -futures = "0.3" -futures-util = "0.3.28" -ginepro = "0.7.0" -prometheus = "0.13" -risingwave_common = { workspace = true } -static_assertions = "1" -thiserror = "1" -thiserror-ext = { workspace = true } -tokio = { version = "0.2", package = "madsim-tokio", features = [ - "rt", - "macros", -] } -tonic = { workspace = true } -tracing = "0.1" - -[lints] -workspace = true diff --git a/src/expr/udf/README-js.md b/src/expr/udf/README-js.md deleted file mode 100644 index 902bce4ef52e..000000000000 --- a/src/expr/udf/README-js.md +++ /dev/null @@ -1,83 +0,0 @@ -# Use UDFs in JavaScript - -This article provides a step-by-step guide for defining JavaScript functions in RisingWave. - -JavaScript code is inlined in `CREATE FUNCTION` statement and then run on the embedded QuickJS virtual machine in RisingWave. It does not support access to external networks and is limited to computational tasks only. -Compared to other languages, JavaScript UDFs offer the easiest way to define UDFs in RisingWave. - -## Define your functions - -You can use the `CREATE FUNCTION` statement to create JavaScript UDFs. The syntax is as follows: - -```sql -CREATE FUNCTION function_name ( arg_name arg_type [, ...] ) - [ RETURNS return_type | RETURNS TABLE ( column_name column_type [, ...] ) ] - LANGUAGE javascript - AS [ $$ function_body $$ | 'function_body' ]; -``` - -The argument names you define can be used in the function body. For example: - -```sql -CREATE FUNCTION gcd(a int, b int) RETURNS int LANGUAGE javascript AS $$ - if(a == null || b == null) { - return null; - } - while (b != 0) { - let t = b; - b = a % b; - a = t; - } - return a; -$$; -``` - -The correspondence between SQL types and JavaScript types can be found in the [appendix table](#appendix-type-mapping). You need to ensure that the type of the return value is either `null` or consistent with the type in the `RETURNS` clause. - -If the function you define returns a table, you need to use the `yield` statement to return the data of each row. For example: - -```sql -CREATE FUNCTION series(n int) RETURNS TABLE (x int) LANGUAGE javascript AS $$ - for(let i = 0; i < n; i++) { - yield i; - } -$$; -``` - -## Use your functions - -Once the UDFs are created in RisingWave, you can use them in SQL queries just like any built-in functions. For example: - -```sql -SELECT gcd(25, 15); -SELECT * from series(5); -``` - -## Appendix: Type Mapping - -The following table shows the type mapping between SQL and JavaScript: - -| SQL Type | JS Type | Note | -| --------------------- | ------------- | --------------------- | -| boolean | boolean | | -| smallint | number | | -| int | number | | -| bigint | number | | -| real | number | | -| double precision | number | | -| decimal | BigDecimal | | -| date | | not supported yet | -| time | | not supported yet | -| timestamp | | not supported yet | -| timestamptz | | not supported yet | -| interval | | not supported yet | -| varchar | string | | -| bytea | Uint8Array | | -| jsonb | null, boolean, number, string, array or object | `JSON.parse(string)` | -| smallint[] | Int16Array | | -| int[] | Int32Array | | -| bigint[] | BigInt64Array | | -| real[] | Float32Array | | -| double precision[] | Float64Array | | -| others[] | array | | -| struct<..> | object | | diff --git a/src/expr/udf/README.md b/src/expr/udf/README.md deleted file mode 100644 index d9428cc54724..000000000000 --- a/src/expr/udf/README.md +++ /dev/null @@ -1,118 +0,0 @@ -# Use UDFs in Rust - -This article provides a step-by-step guide for defining Rust functions in RisingWave. - -Rust functions are compiled into WebAssembly modules and then run on the embedded WebAssembly virtual machine in RisingWave. Compared to Python and Java, Rust UDFs offer **higher performance** (near native) and are **managed by the RisingWave kernel**, eliminating the need for additional maintenance. However, since they run embedded in the kernel, for security reasons, Rust UDFs currently **do not support access to external networks and are limited to computational tasks only**, with restricted CPU and memory resources. Therefore, we recommend using Rust UDFs for **computationally intensive tasks**, such as packet parsing and format conversion. - -## Prerequisites - -- Ensure that you have [Rust toolchain](https://rustup.rs) (stable channel) installed on your computer. -- Ensure that the Rust standard library for `wasm32-wasi` target is installed: - ```shell - rustup target add wasm32-wasi - ``` - -## 1. Create a project - -Create a Rust project named `udf`: - -```shell -cargo new --lib udf -cd udf -``` - -Add the following lines to `Cargo.toml`: - -```toml -[lib] -crate-type = ["cdylib"] - -[dependencies] -arrow-udf = "0.1" -``` - -## 2. Define your functions - -In `src/lib.rs`, define your functions using the `function` macro: - -```rust -use arrow_udf::function; - -// define a scalar function -#[function("gcd(int, int) -> int")] -fn gcd(mut x: i32, mut y: i32) -> i32 { - while y != 0 { - (x, y) = (y, x % y); - } - x -} - -// define a table function -#[function("series(int) -> setof int")] -fn series(n: i32) -> impl Iterator { - 0..n -} -``` - -You can find more usages in the [documentation](https://docs.rs/arrow_udf/0.1.0/arrow_udf/attr.function.html) and more examples in the [tests](https://github.com/risingwavelabs/arrow-udf/blob/main/arrow-udf/tests/tests.rs). - -Currently we only support a limited set of data types. `timestamptz` and complex array types are not supported yet. - -## 3. Build the project - -Build your functions into a WebAssembly module: - -```shell -cargo build --release --target wasm32-wasi -``` - -You can find the generated WASM module at `target/wasm32-wasi/release/udf.wasm`. - -Optional: It is recommended to strip the binary to reduce its size: - -```shell -# Install wasm-tools -cargo install wasm-tools - -# Strip the binary -wasm-tools strip ./target/wasm32-wasi/release/udf.wasm > udf.wasm -``` - -## 4. Declare your functions in RisingWave - -In RisingWave, use the `CREATE FUNCTION` command to declare the functions you defined. - -There are two ways to load the WASM module: - -1. The WASM binary can be embedded in the SQL statement using the base64 encoding. -You can use the following shell script to encode the binary and generate the SQL statement: - ```shell - encoded=$(base64 -i udf.wasm) - sql="CREATE FUNCTION gcd(int, int) RETURNS int LANGUAGE wasm USING BASE64 '$encoded';" - echo "$sql" > create_function.sql - ``` - When created successfully, the WASM binary will be automatically uploaded to the object store. - -2. The WASM binary can be loaded from the object store. - ```sql - CREATE FUNCTION gcd(int, int) RETURNS int - LANGUAGE wasm USING LINK 's3://bucket/path/to/udf.wasm'; - - CREATE FUNCTION series(int) RETURNS TABLE (x int) - LANGUAGE wasm USING LINK 's3://bucket/path/to/udf.wasm'; - ``` - - Or if you run RisingWave locally, you can use the local file system: - ```sql - CREATE FUNCTION gcd(int, int) RETURNS int - LANGUAGE wasm USING LINK 'fs://path/to/udf.wasm'; - ``` - -## 5. Use your functions in RisingWave - -Once the UDFs are created in RisingWave, you can use them in SQL queries just like any built-in functions. For example: - -```sql -SELECT gcd(25, 15); -SELECT series(5); -``` diff --git a/src/expr/udf/examples/client.rs b/src/expr/udf/examples/client.rs deleted file mode 100644 index 92f93ae13614..000000000000 --- a/src/expr/udf/examples/client.rs +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use arrow_array::{Int32Array, RecordBatch}; -use arrow_schema::{DataType, Field, Schema}; -use risingwave_udf::ArrowFlightUdfClient; - -#[tokio::main] -async fn main() { - let addr = "http://localhost:8815"; - let client = ArrowFlightUdfClient::connect(addr).await.unwrap(); - - // build `RecordBatch` to send (equivalent to our `DataChunk`) - let array1 = Arc::new(Int32Array::from_iter(vec![1, 6, 10])); - let array2 = Arc::new(Int32Array::from_iter(vec![3, 4, 15])); - let array3 = Arc::new(Int32Array::from_iter(vec![6, 8, 3])); - let input2_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ]); - let input3_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - ]); - let output_schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]); - - // check function - client - .check("gcd", &input2_schema, &output_schema) - .await - .unwrap(); - client - .check("gcd3", &input3_schema, &output_schema) - .await - .unwrap(); - - let input2 = RecordBatch::try_new( - Arc::new(input2_schema), - vec![array1.clone(), array2.clone()], - ) - .unwrap(); - - let output = client - .call("gcd", input2) - .await - .expect("failed to call function"); - - println!("{:?}", output); - - let input3 = RecordBatch::try_new( - Arc::new(input3_schema), - vec![array1.clone(), array2.clone(), array3.clone()], - ) - .unwrap(); - - let output = client - .call("gcd3", input3) - .await - .expect("failed to call function"); - - println!("{:?}", output); -} diff --git a/src/expr/udf/python/.gitignore b/src/expr/udf/python/.gitignore deleted file mode 100644 index 75b18b1dc191..000000000000 --- a/src/expr/udf/python/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -/dist -/risingwave.egg-info diff --git a/src/expr/udf/python/CHANGELOG.md b/src/expr/udf/python/CHANGELOG.md deleted file mode 100644 index a20411e69d83..000000000000 --- a/src/expr/udf/python/CHANGELOG.md +++ /dev/null @@ -1,37 +0,0 @@ -# Changelog - -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [Unreleased] - -## [0.1.1] - 2023-12-06 - -### Fixed - -- Fix decimal type output. - -## [0.1.0] - 2023-12-01 - -### Fixed - -- Fix unconstrained decimal type. - -## [0.0.12] - 2023-11-28 - -### Changed - -- Change the default struct field name to `f{i}`. - -### Fixed - -- Fix parsing nested struct type. - - -## [0.0.11] - 2023-11-06 - -### Fixed - -- Hook SIGTERM to stop the UDF server gracefully. diff --git a/src/expr/udf/python/README.md b/src/expr/udf/python/README.md deleted file mode 100644 index d1655be05350..000000000000 --- a/src/expr/udf/python/README.md +++ /dev/null @@ -1,112 +0,0 @@ -# RisingWave Python UDF SDK - -This library provides a Python SDK for creating user-defined functions (UDF) in [RisingWave](https://www.risingwave.com/). - -For a detailed guide on how to use Python UDF in RisingWave, please refer to [this doc](https://docs.risingwave.com/docs/current/udf-python/). - -## Introduction - -RisingWave supports user-defined functions implemented as external functions. -With the RisingWave Python UDF SDK, users can define custom UDFs using Python and start a Python process as a UDF server. -RisingWave can then remotely access the UDF server to execute the defined functions. - -## Installation - -```sh -pip install risingwave -``` - -## Usage - -Define functions in a Python file: - -```python -# udf.py -from risingwave.udf import udf, udtf, UdfServer -import struct -import socket - -# Define a scalar function -@udf(input_types=['INT', 'INT'], result_type='INT') -def gcd(x, y): - while y != 0: - (x, y) = (y, x % y) - return x - -# Define a scalar function that returns multiple values (within a struct) -@udf(input_types=['BYTEA'], result_type='STRUCT') -def extract_tcp_info(tcp_packet: bytes): - src_addr, dst_addr = struct.unpack('!4s4s', tcp_packet[12:20]) - src_port, dst_port = struct.unpack('!HH', tcp_packet[20:24]) - src_addr = socket.inet_ntoa(src_addr) - dst_addr = socket.inet_ntoa(dst_addr) - return src_addr, dst_addr, src_port, dst_port - -# Define a table function -@udtf(input_types='INT', result_types='INT') -def series(n): - for i in range(n): - yield i - -# Start a UDF server -if __name__ == '__main__': - server = UdfServer(location="0.0.0.0:8815") - server.add_function(gcd) - server.add_function(series) - server.serve() -``` - -Start the UDF server: - -```sh -python3 udf.py -``` - -To create functions in RisingWave, use the following syntax: - -```sql -create function ( [, ...] ) - [ returns | returns table ( [, ...] ) ] - as using link ''; -``` - -- The `as` parameter specifies the function name defined in the UDF server. -- The `link` parameter specifies the address of the UDF server. - -For example: - -```sql -create function gcd(int, int) returns int -as gcd using link 'http://localhost:8815'; - -create function series(int) returns table (x int) -as series using link 'http://localhost:8815'; - -select gcd(25, 15); - -select * from series(10); -``` - -## Data Types - -The RisingWave Python UDF SDK supports the following data types: - -| SQL Type | Python Type | Notes | -| ---------------- | ----------------------------- | ------------------ | -| BOOLEAN | bool | | -| SMALLINT | int | | -| INT | int | | -| BIGINT | int | | -| REAL | float | | -| DOUBLE PRECISION | float | | -| DECIMAL | decimal.Decimal | | -| DATE | datetime.date | | -| TIME | datetime.time | | -| TIMESTAMP | datetime.datetime | | -| INTERVAL | MonthDayNano / (int, int, int) | Fields can be obtained by `months()`, `days()` and `nanoseconds()` from `MonthDayNano` | -| VARCHAR | str | | -| BYTEA | bytes | | -| JSONB | any | | -| T[] | list[T] | | -| STRUCT<> | tuple | | -| ...others | | Not supported yet. | diff --git a/src/expr/udf/python/publish.md b/src/expr/udf/python/publish.md deleted file mode 100644 index 0bc22d713906..000000000000 --- a/src/expr/udf/python/publish.md +++ /dev/null @@ -1,19 +0,0 @@ -# How to publish this library - -Install the build tool: - -```sh -pip3 install build -``` - -Build the library: - -```sh -python3 -m build -``` - -Upload the library to PyPI: - -```sh -twine upload dist/* -``` diff --git a/src/expr/udf/python/pyproject.toml b/src/expr/udf/python/pyproject.toml deleted file mode 100644 index b53535516836..000000000000 --- a/src/expr/udf/python/pyproject.toml +++ /dev/null @@ -1,20 +0,0 @@ -[build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "risingwave" -version = "0.1.1" -authors = [{ name = "RisingWave Labs" }] -description = "RisingWave Python API" -readme = "README.md" -license = { text = "Apache Software License" } -classifiers = [ - "Programming Language :: Python", - "License :: OSI Approved :: Apache Software License", -] -requires-python = ">=3.8" -dependencies = ["pyarrow"] - -[project.optional-dependencies] -test = ["pytest"] diff --git a/src/expr/udf/python/risingwave/__init__.py b/src/expr/udf/python/risingwave/__init__.py deleted file mode 100644 index 3d60f2f96d02..000000000000 --- a/src/expr/udf/python/risingwave/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 RisingWave Labs -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/expr/udf/python/risingwave/test_udf.py b/src/expr/udf/python/risingwave/test_udf.py deleted file mode 100644 index e3c2029d3d1f..000000000000 --- a/src/expr/udf/python/risingwave/test_udf.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright 2024 RisingWave Labs -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from multiprocessing import Process -import pytest -from risingwave.udf import udf, UdfServer, _to_data_type -import pyarrow as pa -import pyarrow.flight as flight -import time -import datetime -from typing import Any - - -def flight_server(): - server = UdfServer(location="localhost:8815") - server.add_function(add) - server.add_function(wait) - server.add_function(wait_concurrent) - server.add_function(return_all) - return server - - -def flight_client(): - client = flight.FlightClient(("localhost", 8815)) - return client - - -# Define a scalar function -@udf(input_types=["INT", "INT"], result_type="INT") -def add(x, y): - return x + y - - -@udf(input_types=["INT"], result_type="INT") -def wait(x): - time.sleep(0.01) - return 0 - - -@udf(input_types=["INT"], result_type="INT", io_threads=32) -def wait_concurrent(x): - time.sleep(0.01) - return 0 - - -@udf( - input_types=[ - "BOOLEAN", - "SMALLINT", - "INT", - "BIGINT", - "FLOAT4", - "FLOAT8", - "DECIMAL", - "DATE", - "TIME", - "TIMESTAMP", - "INTERVAL", - "VARCHAR", - "BYTEA", - "JSONB", - ], - result_type="""struct< - BOOLEAN, - SMALLINT, - INT, - BIGINT, - FLOAT4, - FLOAT8, - DECIMAL, - DATE, - TIME, - TIMESTAMP, - INTERVAL, - VARCHAR, - BYTEA, - JSONB - >""", -) -def return_all( - bool, - i16, - i32, - i64, - f32, - f64, - decimal, - date, - time, - timestamp, - interval, - varchar, - bytea, - jsonb, -): - return ( - bool, - i16, - i32, - i64, - f32, - f64, - decimal, - date, - time, - timestamp, - interval, - varchar, - bytea, - jsonb, - ) - - -def test_simple(): - LEN = 64 - data = pa.Table.from_arrays( - [pa.array(range(0, LEN)), pa.array(range(0, LEN))], names=["x", "y"] - ) - - batches = data.to_batches(max_chunksize=512) - - with flight_client() as client, flight_server() as server: - flight_info = flight.FlightDescriptor.for_path(b"add") - writer, reader = client.do_exchange(descriptor=flight_info) - with writer: - writer.begin(schema=data.schema) - for batch in batches: - writer.write_batch(batch) - writer.done_writing() - - chunk = reader.read_chunk() - assert len(chunk.data) == LEN - assert chunk.data.column("output").equals( - pa.array(range(0, LEN * 2, 2), type=pa.int32()) - ) - - -def test_io_concurrency(): - LEN = 64 - data = pa.Table.from_arrays([pa.array(range(0, LEN))], names=["x"]) - batches = data.to_batches(max_chunksize=512) - - with flight_client() as client, flight_server() as server: - # Single-threaded function takes a long time - flight_info = flight.FlightDescriptor.for_path(b"wait") - writer, reader = client.do_exchange(descriptor=flight_info) - with writer: - writer.begin(schema=data.schema) - for batch in batches: - writer.write_batch(batch) - writer.done_writing() - start_time = time.time() - - total_len = 0 - for chunk in reader: - total_len += len(chunk.data) - - assert total_len == LEN - - elapsed_time = time.time() - start_time # ~0.64s - assert elapsed_time > 0.5 - - # Multi-threaded I/O bound function will take a much shorter time - flight_info = flight.FlightDescriptor.for_path(b"wait_concurrent") - writer, reader = client.do_exchange(descriptor=flight_info) - with writer: - writer.begin(schema=data.schema) - for batch in batches: - writer.write_batch(batch) - writer.done_writing() - start_time = time.time() - - total_len = 0 - for chunk in reader: - total_len += len(chunk.data) - - assert total_len == LEN - - elapsed_time = time.time() - start_time - assert elapsed_time < 0.25 - - -def test_all_types(): - arrays = [ - pa.array([True], type=pa.bool_()), - pa.array([1], type=pa.int16()), - pa.array([1], type=pa.int32()), - pa.array([1], type=pa.int64()), - pa.array([1], type=pa.float32()), - pa.array([1], type=pa.float64()), - pa.array(["12345678901234567890.1234567890"], type=pa.large_binary()), - pa.array([datetime.date(2023, 6, 1)], type=pa.date32()), - pa.array([datetime.time(1, 2, 3, 456789)], type=pa.time64("us")), - pa.array( - [datetime.datetime(2023, 6, 1, 1, 2, 3, 456789)], - type=pa.timestamp("us"), - ), - pa.array([(1, 2, 3)], type=pa.month_day_nano_interval()), - pa.array(["string"], type=pa.string()), - pa.array(["bytes"], type=pa.binary()), - pa.array(['{ "key": 1 }'], type=pa.large_string()), - ] - batch = pa.RecordBatch.from_arrays(arrays, names=["" for _ in arrays]) - - with flight_client() as client, flight_server() as server: - flight_info = flight.FlightDescriptor.for_path(b"return_all") - writer, reader = client.do_exchange(descriptor=flight_info) - with writer: - writer.begin(schema=batch.schema) - writer.write_batch(batch) - writer.done_writing() - - chunk = reader.read_chunk() - assert [v.as_py() for _, v in chunk.data.column(0)[0].items()] == [ - True, - 1, - 1, - 1, - 1.0, - 1.0, - b"12345678901234567890.1234567890", - datetime.date(2023, 6, 1), - datetime.time(1, 2, 3, 456789), - datetime.datetime(2023, 6, 1, 1, 2, 3, 456789), - (1, 2, 3), - "string", - b"bytes", - '{"key": 1}', - ] diff --git a/src/expr/udf/python/risingwave/udf.py b/src/expr/udf/python/risingwave/udf.py deleted file mode 100644 index aad53e25e0c9..000000000000 --- a/src/expr/udf/python/risingwave/udf.py +++ /dev/null @@ -1,552 +0,0 @@ -# Copyright 2024 RisingWave Labs -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import * -import pyarrow as pa -import pyarrow.flight -import pyarrow.parquet -import inspect -import traceback -import json -from concurrent.futures import ThreadPoolExecutor -import concurrent -from decimal import Decimal -import signal - - -class UserDefinedFunction: - """ - Base interface for user-defined function. - """ - - _name: str - _input_schema: pa.Schema - _result_schema: pa.Schema - _io_threads: Optional[int] - _executor: Optional[ThreadPoolExecutor] - - def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: - """ - Apply the function on a batch of inputs. - """ - return iter([]) - - -class ScalarFunction(UserDefinedFunction): - """ - Base interface for user-defined scalar function. A user-defined scalar functions maps zero, one, - or multiple scalar values to a new scalar value. - """ - - def __init__(self, *args, **kwargs): - self._io_threads = kwargs.pop("io_threads") - self._executor = ( - ThreadPoolExecutor(max_workers=self._io_threads) - if self._io_threads is not None - else None - ) - super().__init__(*args, **kwargs) - - def eval(self, *args) -> Any: - """ - Method which defines the logic of the scalar function. - """ - pass - - def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: - # parse value from json string for jsonb columns - inputs = [[v.as_py() for v in array] for array in batch] - inputs = [ - _process_func(pa.list_(type), False)(array) - for array, type in zip(inputs, self._input_schema.types) - ] - if self._executor is not None: - # evaluate the function for each row - tasks = [ - self._executor.submit(self._func, *[col[i] for col in inputs]) - for i in range(batch.num_rows) - ] - column = [ - future.result() for future in concurrent.futures.as_completed(tasks) - ] - else: - # evaluate the function for each row - column = [ - self.eval(*[col[i] for col in inputs]) for i in range(batch.num_rows) - ] - - column = _process_func(pa.list_(self._result_schema.types[0]), True)(column) - - array = pa.array(column, type=self._result_schema.types[0]) - yield pa.RecordBatch.from_arrays([array], schema=self._result_schema) - - -def _process_func(type: pa.DataType, output: bool) -> Callable: - """Return a function to process input or output value.""" - if pa.types.is_list(type): - func = _process_func(type.value_type, output) - return lambda array: [(func(v) if v is not None else None) for v in array] - - if pa.types.is_struct(type): - funcs = [_process_func(field.type, output) for field in type] - if output: - return lambda tup: tuple( - (func(v) if v is not None else None) for v, func in zip(tup, funcs) - ) - else: - # the input value of struct type is a dict - # we convert it into tuple here - return lambda map: tuple( - (func(v) if v is not None else None) - for v, func in zip(map.values(), funcs) - ) - - if type.equals(JSONB): - if output: - return lambda v: json.dumps(v) - else: - return lambda v: json.loads(v) - - if type.equals(UNCONSTRAINED_DECIMAL): - if output: - - def decimal_to_str(v): - if not isinstance(v, Decimal): - raise ValueError(f"Expected Decimal, got {v}") - # use `f` format to avoid scientific notation, e.g. `1e10` - return format(v, "f").encode("utf-8") - - return decimal_to_str - else: - return lambda v: Decimal(v.decode("utf-8")) - - return lambda v: v - - -class TableFunction(UserDefinedFunction): - """ - Base interface for user-defined table function. A user-defined table functions maps zero, one, - or multiple scalar values to a new table value. - """ - - BATCH_SIZE = 1024 - - def eval(self, *args) -> Iterator: - """ - Method which defines the logic of the table function. - """ - yield - - def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: - class RecordBatchBuilder: - """A utility class for constructing Arrow RecordBatch by row.""" - - schema: pa.Schema - columns: List[List] - - def __init__(self, schema: pa.Schema): - self.schema = schema - self.columns = [[] for _ in self.schema.types] - - def len(self) -> int: - """Returns the number of rows in the RecordBatch being built.""" - return len(self.columns[0]) - - def append(self, index: int, value: Any): - """Appends a new row to the RecordBatch being built.""" - self.columns[0].append(index) - self.columns[1].append(value) - - def build(self) -> pa.RecordBatch: - """Builds the RecordBatch from the accumulated data and clears the state.""" - # Convert the columns to arrow arrays - arrays = [ - pa.array(col, type) - for col, type in zip(self.columns, self.schema.types) - ] - # Reset columns - self.columns = [[] for _ in self.schema.types] - return pa.RecordBatch.from_arrays(arrays, schema=self.schema) - - builder = RecordBatchBuilder(self._result_schema) - - # Iterate through rows in the input RecordBatch - for row_index in range(batch.num_rows): - row = tuple(column[row_index].as_py() for column in batch) - for result in self.eval(*row): - builder.append(row_index, result) - if builder.len() == self.BATCH_SIZE: - yield builder.build() - if builder.len() != 0: - yield builder.build() - - -class UserDefinedScalarFunctionWrapper(ScalarFunction): - """ - Base Wrapper for Python user-defined scalar function. - """ - - _func: Callable - - def __init__(self, func, input_types, result_type, name=None, io_threads=None): - self._func = func - self._input_schema = pa.schema( - zip( - inspect.getfullargspec(func)[0], - [_to_data_type(t) for t in _to_list(input_types)], - ) - ) - self._result_schema = pa.schema([("output", _to_data_type(result_type))]) - self._name = name or ( - func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ - ) - super().__init__(io_threads=io_threads) - - def __call__(self, *args): - return self._func(*args) - - def eval(self, *args): - return self._func(*args) - - -class UserDefinedTableFunctionWrapper(TableFunction): - """ - Base Wrapper for Python user-defined table function. - """ - - _func: Callable - - def __init__(self, func, input_types, result_types, name=None): - self._func = func - self._name = name or ( - func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ - ) - self._input_schema = pa.schema( - zip( - inspect.getfullargspec(func)[0], - [_to_data_type(t) for t in _to_list(input_types)], - ) - ) - self._result_schema = pa.schema( - [ - ("row_index", pa.int32()), - ( - self._name, - pa.struct([("", _to_data_type(t)) for t in result_types]) - if isinstance(result_types, list) - else _to_data_type(result_types), - ), - ] - ) - - def __call__(self, *args): - return self._func(*args) - - def eval(self, *args): - return self._func(*args) - - -def _to_list(x): - if isinstance(x, list): - return x - else: - return [x] - - -def udf( - input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], - result_type: Union[str, pa.DataType], - name: Optional[str] = None, - io_threads: Optional[int] = None, -) -> Callable: - """ - Annotation for creating a user-defined scalar function. - - Parameters: - - input_types: A list of strings or Arrow data types that specifies the input data types. - - result_type: A string or an Arrow data type that specifies the return value type. - - name: An optional string specifying the function name. If not provided, the original name will be used. - - io_threads: Number of I/O threads used per data chunk for I/O bound functions. - - Example: - ``` - @udf(input_types=['INT', 'INT'], result_type='INT') - def gcd(x, y): - while y != 0: - (x, y) = (y, x % y) - return x - ``` - - I/O bound Example: - ``` - @udf(input_types=['INT'], result_type='INT', io_threads=64) - def external_api(x): - response = requests.get(my_endpoint + '?param=' + x) - return response["data"] - ``` - """ - - if io_threads is not None and io_threads > 1: - return lambda f: UserDefinedScalarFunctionWrapper( - f, input_types, result_type, name, io_threads=io_threads - ) - else: - return lambda f: UserDefinedScalarFunctionWrapper( - f, input_types, result_type, name - ) - - -def udtf( - input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], - result_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], - name: Optional[str] = None, -) -> Callable: - """ - Annotation for creating a user-defined table function. - - Parameters: - - input_types: A list of strings or Arrow data types that specifies the input data types. - - result_types A list of strings or Arrow data types that specifies the return value types. - - name: An optional string specifying the function name. If not provided, the original name will be used. - - Example: - ``` - @udtf(input_types='INT', result_types='INT') - def series(n): - for i in range(n): - yield i - ``` - """ - - return lambda f: UserDefinedTableFunctionWrapper(f, input_types, result_types, name) - - -class UdfServer(pa.flight.FlightServerBase): - """ - A server that provides user-defined functions to clients. - - Example: - ``` - server = UdfServer(location="0.0.0.0:8815") - server.add_function(my_udf) - server.serve() - ``` - """ - - # UDF server based on Apache Arrow Flight protocol. - # Reference: https://arrow.apache.org/cookbook/py/flight.html#simple-parquet-storage-service-with-arrow-flight - - _location: str - _functions: Dict[str, UserDefinedFunction] - - def __init__(self, location="0.0.0.0:8815", **kwargs): - super(UdfServer, self).__init__("grpc://" + location, **kwargs) - self._location = location - self._functions = {} - - def get_flight_info(self, context, descriptor): - """Return the result schema of a function.""" - udf = self._functions[descriptor.path[0].decode("utf-8")] - # return the concatenation of input and output schema - full_schema = pa.schema(list(udf._input_schema) + list(udf._result_schema)) - # we use `total_records` to indicate the number of input arguments - return pa.flight.FlightInfo( - schema=full_schema, - descriptor=descriptor, - endpoints=[], - total_records=len(udf._input_schema), - total_bytes=0, - ) - - def add_function(self, udf: UserDefinedFunction): - """Add a function to the server.""" - name = udf._name - if name in self._functions: - raise ValueError("Function already exists: " + name) - - input_types = ",".join( - [_data_type_to_string(t) for t in udf._input_schema.types] - ) - if isinstance(udf, TableFunction): - output_type = udf._result_schema.types[-1] - if isinstance(output_type, pa.StructType): - output_type = ",".join( - f"field_{i} {_data_type_to_string(field.type)}" - for i, field in enumerate(output_type) - ) - output_type = f"TABLE({output_type})" - else: - output_type = _data_type_to_string(output_type) - output_type = f"TABLE(output {output_type})" - else: - output_type = _data_type_to_string(udf._result_schema.types[-1]) - - sql = f"CREATE FUNCTION {name}({input_types}) RETURNS {output_type} AS '{name}' USING LINK 'http://{self._location}';" - print(f"added function: {name}, corresponding SQL:\n{sql}\n") - self._functions[name] = udf - - def do_exchange(self, context, descriptor, reader, writer): - """Call a function from the client.""" - udf = self._functions[descriptor.path[0].decode("utf-8")] - writer.begin(udf._result_schema) - try: - for batch in reader: - # print(pa.Table.from_batches([batch.data])) - for output_batch in udf.eval_batch(batch.data): - writer.write_batch(output_batch) - except Exception as e: - print(traceback.print_exc()) - raise e - - def serve(self): - """ - Block until the server shuts down. - - This method only returns if shutdown() is called or a signal (SIGINT, SIGTERM) received. - """ - print( - "Note: You can use arbitrary function names and struct field names in CREATE FUNCTION statements." - f"\n\nlistening on {self._location}" - ) - signal.signal(signal.SIGTERM, lambda s, f: self.shutdown()) - super(UdfServer, self).serve() - - -def _to_data_type(t: Union[str, pa.DataType]) -> pa.DataType: - """ - Convert a SQL data type string or `pyarrow.DataType` to `pyarrow.DataType`. - """ - if isinstance(t, str): - return _string_to_data_type(t) - else: - return t - - -# we use `large_binary` to represent unconstrained decimal type -UNCONSTRAINED_DECIMAL = pa.large_binary() -JSONB = pa.large_string() - - -def _string_to_data_type(type_str: str): - """ - Convert a SQL data type string to `pyarrow.DataType`. - """ - type_str = type_str.upper() - if type_str.endswith("[]"): - return pa.list_(_string_to_data_type(type_str[:-2])) - elif type_str in ("BOOLEAN", "BOOL"): - return pa.bool_() - elif type_str in ("SMALLINT", "INT2"): - return pa.int16() - elif type_str in ("INT", "INTEGER", "INT4"): - return pa.int32() - elif type_str in ("BIGINT", "INT8"): - return pa.int64() - elif type_str in ("FLOAT4", "REAL"): - return pa.float32() - elif type_str in ("FLOAT8", "DOUBLE PRECISION"): - return pa.float64() - elif type_str.startswith("DECIMAL") or type_str.startswith("NUMERIC"): - if type_str == "DECIMAL" or type_str == "NUMERIC": - return UNCONSTRAINED_DECIMAL - rest = type_str[8:-1] # remove "DECIMAL(" and ")" - if "," in rest: - precision, scale = rest.split(",") - return pa.decimal128(int(precision), int(scale)) - else: - return pa.decimal128(int(rest), 0) - elif type_str in ("DATE"): - return pa.date32() - elif type_str in ("TIME", "TIME WITHOUT TIME ZONE"): - return pa.time64("us") - elif type_str in ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE"): - return pa.timestamp("us") - elif type_str.startswith("INTERVAL"): - return pa.month_day_nano_interval() - elif type_str in ("VARCHAR"): - return pa.string() - elif type_str in ("JSONB"): - return JSONB - elif type_str in ("BYTEA"): - return pa.binary() - elif type_str.startswith("STRUCT"): - # extract 'STRUCT, ...>' - type_list = type_str[7:-1] # strip "STRUCT<>" - fields = [] - elements = [] - start = 0 - depth = 0 - for i, c in enumerate(type_list): - if c == "<": - depth += 1 - elif c == ">": - depth -= 1 - elif c == "," and depth == 0: - type_str = type_list[start:i].strip() - fields.append(pa.field("", _string_to_data_type(type_str))) - start = i + 1 - type_str = type_list[start:].strip() - fields.append(pa.field("", _string_to_data_type(type_str))) - return pa.struct(fields) - - raise ValueError(f"Unsupported type: {type_str}") - - -def _data_type_to_string(t: pa.DataType) -> str: - """ - Convert a `pyarrow.DataType` to a SQL data type string. - """ - if isinstance(t, pa.ListType): - return _data_type_to_string(t.value_type) + "[]" - elif t.equals(pa.bool_()): - return "BOOLEAN" - elif t.equals(pa.int16()): - return "SMALLINT" - elif t.equals(pa.int32()): - return "INT" - elif t.equals(pa.int64()): - return "BIGINT" - elif t.equals(pa.float32()): - return "FLOAT4" - elif t.equals(pa.float64()): - return "FLOAT8" - elif t.equals(UNCONSTRAINED_DECIMAL): - return "DECIMAL" - elif pa.types.is_decimal(t): - return f"DECIMAL({t.precision},{t.scale})" - elif t.equals(pa.date32()): - return "DATE" - elif t.equals(pa.time64("us")): - return "TIME" - elif t.equals(pa.timestamp("us")): - return "TIMESTAMP" - elif t.equals(pa.month_day_nano_interval()): - return "INTERVAL" - elif t.equals(pa.string()): - return "VARCHAR" - elif t.equals(JSONB): - return "JSONB" - elif t.equals(pa.binary()): - return "BYTEA" - elif isinstance(t, pa.StructType): - return ( - "STRUCT<" - + ",".join( - f"f{i+1} {_data_type_to_string(field.type)}" - for i, field in enumerate(t) - ) - + ">" - ) - else: - raise ValueError(f"Unsupported type: {t}") diff --git a/src/expr/udf/python/risingwave/udf/health_check.py b/src/expr/udf/python/risingwave/udf/health_check.py deleted file mode 100644 index ad2d38681a6c..000000000000 --- a/src/expr/udf/python/risingwave/udf/health_check.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2024 RisingWave Labs -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pyarrow.flight import FlightClient -import sys - - -def check_udf_service_available(addr: str) -> bool: - """Check if the UDF service is available at the given address.""" - try: - client = FlightClient(f"grpc://{addr}") - client.wait_for_available() - return True - except Exception as e: - print(f"Error connecting to RisingWave UDF service: {str(e)}") - return False - - -if __name__ == "__main__": - if len(sys.argv) != 2: - print("usage: python3 health_check.py ") - sys.exit(1) - - server_address = sys.argv[1] - if check_udf_service_available(server_address): - print("OK") - else: - print("unavailable") - exit(-1) diff --git a/src/expr/udf/src/error.rs b/src/expr/udf/src/error.rs deleted file mode 100644 index fc6733052b13..000000000000 --- a/src/expr/udf/src/error.rs +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use arrow_flight::error::FlightError; -use thiserror::Error; -use thiserror_ext::{Box, Construct}; - -/// A specialized `Result` type for UDF operations. -pub type Result = std::result::Result; - -/// The error type for UDF operations. -#[derive(Error, Debug, Box, Construct)] -#[thiserror_ext(newtype(name = Error))] -pub enum ErrorInner { - #[error("failed to send requests to UDF service: {0}")] - Tonic(#[from] tonic::Status), - - #[error("failed to call UDF: {0}")] - Flight(#[from] FlightError), - - #[error("type mismatch: {0}")] - TypeMismatch(String), - - #[error("arrow error: {0}")] - Arrow(#[from] arrow_schema::ArrowError), - - #[error("UDF unsupported: {0}")] - // TODO(error-handling): should prefer use error types than strings. - Unsupported(String), - - #[error("UDF service returned no data")] - NoReturned, - - #[error("Flight service error: {0}")] - ServiceError(String), -} - -impl Error { - /// Returns true if the error is caused by a connection error. - pub fn is_connection_error(&self) -> bool { - match self.inner() { - // Connection refused - ErrorInner::Tonic(status) if status.code() == tonic::Code::Unavailable => true, - _ => false, - } - } - - pub fn is_tonic_error(&self) -> bool { - matches!( - self.inner(), - ErrorInner::Tonic(_) | ErrorInner::Flight(FlightError::Tonic(_)) - ) - } -} - -static_assertions::const_assert_eq!(std::mem::size_of::(), 8); diff --git a/src/expr/udf/src/external.rs b/src/expr/udf/src/external.rs deleted file mode 100644 index 7560638b0398..000000000000 --- a/src/expr/udf/src/external.rs +++ /dev/null @@ -1,260 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::str::FromStr; -use std::time::Duration; - -use arrow_array::RecordBatch; -use arrow_flight::decode::FlightRecordBatchStream; -use arrow_flight::encode::FlightDataEncoderBuilder; -use arrow_flight::error::FlightError; -use arrow_flight::flight_service_client::FlightServiceClient; -use arrow_flight::{FlightData, FlightDescriptor}; -use arrow_schema::Schema; -use cfg_or_panic::cfg_or_panic; -use futures_util::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; -use ginepro::{LoadBalancedChannel, ResolutionStrategy}; -use risingwave_common::util::addr::HostAddr; -use thiserror_ext::AsReport; -use tokio::time::Duration as TokioDuration; -use tonic::transport::Channel; - -use crate::metrics::GLOBAL_METRICS; -use crate::{Error, Result}; - -// Interval between two successive probes of the UDF DNS. -const DNS_PROBE_INTERVAL_SECS: u64 = 5; -// Timeout duration for performing an eager DNS resolution. -const EAGER_DNS_RESOLVE_TIMEOUT_SECS: u64 = 5; -const REQUEST_TIMEOUT_SECS: u64 = 5; -const CONNECT_TIMEOUT_SECS: u64 = 5; - -/// Client for external function service based on Arrow Flight. -#[derive(Debug)] -pub struct ArrowFlightUdfClient { - client: FlightServiceClient, - addr: String, -} - -// TODO: support UDF in simulation -#[cfg_or_panic(not(madsim))] -impl ArrowFlightUdfClient { - /// Connect to a UDF service. - pub async fn connect(addr: &str) -> Result { - Self::connect_inner( - addr, - ResolutionStrategy::Eager { - timeout: TokioDuration::from_secs(EAGER_DNS_RESOLVE_TIMEOUT_SECS), - }, - ) - .await - } - - /// Connect to a UDF service lazily (i.e. only when the first request is sent). - pub fn connect_lazy(addr: &str) -> Result { - Self::connect_inner(addr, ResolutionStrategy::Lazy) - .now_or_never() - .unwrap() - } - - async fn connect_inner( - mut addr: &str, - resolution_strategy: ResolutionStrategy, - ) -> Result { - if addr.starts_with("http://") { - addr = addr.strip_prefix("http://").unwrap(); - } - if addr.starts_with("https://") { - addr = addr.strip_prefix("https://").unwrap(); - } - let host_addr = HostAddr::from_str(addr).map_err(|e| { - Error::service_error(format!("invalid address: {}, err: {}", addr, e.as_report())) - })?; - let channel = LoadBalancedChannel::builder((host_addr.host.clone(), host_addr.port)) - .dns_probe_interval(std::time::Duration::from_secs(DNS_PROBE_INTERVAL_SECS)) - .timeout(Duration::from_secs(REQUEST_TIMEOUT_SECS)) - .connect_timeout(Duration::from_secs(CONNECT_TIMEOUT_SECS)) - .resolution_strategy(resolution_strategy) - .channel() - .await - .map_err(|e| { - Error::service_error(format!( - "failed to create LoadBalancedChannel, address: {}, err: {}", - host_addr, - e.as_report() - )) - })?; - let client = FlightServiceClient::new(channel.into()); - Ok(Self { - client, - addr: addr.into(), - }) - } - - /// Check if the function is available and the schema is match. - pub async fn check(&self, id: &str, args: &Schema, returns: &Schema) -> Result<()> { - let descriptor = FlightDescriptor::new_path(vec![id.into()]); - - let response = self.client.clone().get_flight_info(descriptor).await?; - - // check schema - let info = response.into_inner(); - let input_num = info.total_records as usize; - let full_schema = Schema::try_from(info).map_err(|e| { - FlightError::DecodeError(format!("Error decoding schema: {}", e.as_report())) - })?; - if input_num > full_schema.fields.len() { - return Err(Error::service_error(format!( - "function {:?} schema info not consistency: input_num: {}, total_fields: {}", - id, - input_num, - full_schema.fields.len() - ))); - } - - let (input_fields, return_fields) = full_schema.fields.split_at(input_num); - let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect(); - let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect(); - let expect_input_types: Vec<_> = args.fields.iter().map(|f| f.data_type()).collect(); - let expect_result_types: Vec<_> = returns.fields.iter().map(|f| f.data_type()).collect(); - if !data_types_match(&expect_input_types, &actual_input_types) { - return Err(Error::type_mismatch(format!( - "function: {:?}, expect arguments: {:?}, actual: {:?}", - id, expect_input_types, actual_input_types - ))); - } - if !data_types_match(&expect_result_types, &actual_result_types) { - return Err(Error::type_mismatch(format!( - "function: {:?}, expect return: {:?}, actual: {:?}", - id, expect_result_types, actual_result_types - ))); - } - Ok(()) - } - - /// Call a function. - pub async fn call(&self, id: &str, input: RecordBatch) -> Result { - self.call_internal(id, input).await - } - - async fn call_internal(&self, id: &str, input: RecordBatch) -> Result { - let mut output_stream = self - .call_stream_internal(id, stream::once(async { input })) - .await?; - let mut batches = vec![]; - while let Some(batch) = output_stream.next().await { - batches.push(batch?); - } - Ok(arrow_select::concat::concat_batches( - output_stream.schema().ok_or_else(Error::no_returned)?, - batches.iter(), - )?) - } - - /// Call a function, retry up to 5 times / 3s if connection is broken. - pub async fn call_with_retry(&self, id: &str, input: RecordBatch) -> Result { - let mut backoff = Duration::from_millis(100); - for i in 0..5 { - match self.call(id, input.clone()).await { - Err(err) if err.is_connection_error() && i != 4 => { - tracing::error!(error = %err.as_report(), "UDF connection error. retry..."); - } - ret => return ret, - } - tokio::time::sleep(backoff).await; - backoff *= 2; - } - unreachable!() - } - - /// Always retry on connection error - pub async fn call_with_always_retry_on_network_error( - &self, - id: &str, - input: RecordBatch, - fragment_id: &str, - ) -> Result { - let mut backoff = Duration::from_millis(100); - let metrics = &*GLOBAL_METRICS; - let labels: &[&str; 4] = &[&self.addr, "external", id, fragment_id]; - loop { - match self.call(id, input.clone()).await { - Err(err) if err.is_tonic_error() => { - tracing::error!(error = %err.as_report(), "UDF tonic error. retry..."); - } - ret => { - if ret.is_err() { - tracing::error!(error = %ret.as_ref().unwrap_err().as_report(), "UDF error. exiting..."); - } - return ret; - } - } - metrics.udf_retry_count.with_label_values(labels).inc(); - tokio::time::sleep(backoff).await; - backoff *= 2; - } - } - - /// Call a function with streaming input and output. - #[panic_return = "Result>"] - pub async fn call_stream( - &self, - id: &str, - inputs: impl Stream + Send + 'static, - ) -> Result> + Send + 'static> { - Ok(self - .call_stream_internal(id, inputs) - .await? - .map_err(|e| e.into())) - } - - async fn call_stream_internal( - &self, - id: &str, - inputs: impl Stream + Send + 'static, - ) -> Result { - let descriptor = FlightDescriptor::new_path(vec![id.into()]); - let flight_data_stream = - FlightDataEncoderBuilder::new() - .build(inputs.map(Ok)) - .map(move |res| FlightData { - // TODO: fill descriptor only for the first message - flight_descriptor: Some(descriptor.clone()), - ..res.unwrap() - }); - - // call `do_exchange` on Flight server - let response = self.client.clone().do_exchange(flight_data_stream).await?; - - // decode response - let stream = response.into_inner(); - Ok(FlightRecordBatchStream::new_from_flight_data( - // convert tonic::Status to FlightError - stream.map_err(|e| e.into()), - )) - } - - pub fn get_addr(&self) -> &str { - &self.addr - } -} - -/// Check if two list of data types match, ignoring field names. -fn data_types_match(a: &[&arrow_schema::DataType], b: &[&arrow_schema::DataType]) -> bool { - if a.len() != b.len() { - return false; - } - #[allow(clippy::disallowed_methods)] - a.iter().zip(b.iter()).all(|(a, b)| a.equals_datatype(b)) -} diff --git a/src/expr/udf/src/lib.rs b/src/expr/udf/src/lib.rs deleted file mode 100644 index ddd8cf1bdeab..000000000000 --- a/src/expr/udf/src/lib.rs +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#![feature(error_generic_member_access)] -#![feature(lazy_cell)] - -mod error; -mod external; -pub mod metrics; - -pub use error::{Error, Result}; -pub use external::ArrowFlightUdfClient; -pub use metrics::GLOBAL_METRICS; diff --git a/src/expr/udf/src/metrics.rs b/src/expr/udf/src/metrics.rs deleted file mode 100644 index 50ef1b068307..000000000000 --- a/src/expr/udf/src/metrics.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::LazyLock; - -use prometheus::{ - exponential_buckets, register_histogram_vec_with_registry, - register_int_counter_vec_with_registry, HistogramVec, IntCounterVec, Registry, -}; -use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY; - -/// Monitor metrics for UDF. -#[derive(Debug, Clone)] -pub struct Metrics { - /// Number of successful UDF calls. - pub udf_success_count: IntCounterVec, - /// Number of failed UDF calls. - pub udf_failure_count: IntCounterVec, - /// Total number of retried UDF calls. - pub udf_retry_count: IntCounterVec, - /// Input chunk rows of UDF calls. - pub udf_input_chunk_rows: HistogramVec, - /// The latency of UDF calls in seconds. - pub udf_latency: HistogramVec, - /// Total number of input rows of UDF calls. - pub udf_input_rows: IntCounterVec, - /// Total number of input bytes of UDF calls. - pub udf_input_bytes: IntCounterVec, -} - -/// Global UDF metrics. -pub static GLOBAL_METRICS: LazyLock = - LazyLock::new(|| Metrics::new(&GLOBAL_METRICS_REGISTRY)); - -impl Metrics { - fn new(registry: &Registry) -> Self { - let labels = &["link", "language", "name", "fragment_id"]; - let udf_success_count = register_int_counter_vec_with_registry!( - "udf_success_count", - "Total number of successful UDF calls", - labels, - registry - ) - .unwrap(); - let udf_failure_count = register_int_counter_vec_with_registry!( - "udf_failure_count", - "Total number of failed UDF calls", - labels, - registry - ) - .unwrap(); - let udf_retry_count = register_int_counter_vec_with_registry!( - "udf_retry_count", - "Total number of retried UDF calls", - labels, - registry - ) - .unwrap(); - let udf_input_chunk_rows = register_histogram_vec_with_registry!( - "udf_input_chunk_rows", - "Input chunk rows of UDF calls", - labels, - exponential_buckets(1.0, 2.0, 10).unwrap(), // 1 to 1024 - registry - ) - .unwrap(); - let udf_latency = register_histogram_vec_with_registry!( - "udf_latency", - "The latency(s) of UDF calls", - labels, - exponential_buckets(0.000001, 2.0, 30).unwrap(), // 1us to 1000s - registry - ) - .unwrap(); - let udf_input_rows = register_int_counter_vec_with_registry!( - "udf_input_rows", - "Total number of input rows of UDF calls", - labels, - registry - ) - .unwrap(); - let udf_input_bytes = register_int_counter_vec_with_registry!( - "udf_input_bytes", - "Total number of input bytes of UDF calls", - labels, - registry - ) - .unwrap(); - - Metrics { - udf_success_count, - udf_failure_count, - udf_retry_count, - udf_input_chunk_rows, - udf_latency, - udf_input_rows, - udf_input_bytes, - } - } -} diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index 3cba7afe8266..6a09101dc3e5 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -18,6 +18,7 @@ normal = ["workspace-hack"] anyhow = "1" arc-swap = "1" arrow-schema = { workspace = true } +arrow-udf-flight = { workspace = true } arrow-udf-wasm = { workspace = true } async-recursion = "1.1.0" async-trait = "0.1" @@ -72,7 +73,6 @@ risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } risingwave_sqlparser = { workspace = true } risingwave_storage = { workspace = true } -risingwave_udf = { workspace = true } risingwave_variables = { workspace = true } rw_futures_util = { workspace = true } serde = { version = "1", features = ["derive"] } diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index a23632fbd62c..471145a12a3e 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::{anyhow, Context}; +use anyhow::Context; use arrow_schema::Fields; use bytes::Bytes; use itertools::Itertools; @@ -20,11 +20,10 @@ use pgwire::pg_response::StatementType; use risingwave_common::array::arrow::{ToArrow, UdfArrowConvert}; use risingwave_common::catalog::FunctionId; use risingwave_common::types::DataType; -use risingwave_expr::expr::get_or_create_wasm_runtime; +use risingwave_expr::expr::{get_or_create_flight_client, get_or_create_wasm_runtime}; use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; use risingwave_pb::catalog::Function; use risingwave_sqlparser::ast::{CreateFunctionBody, ObjectName, OperateFunctionArg}; -use risingwave_udf::ArrowFlightUdfClient; use super::*; use crate::catalog::CatalogError; @@ -167,13 +166,12 @@ pub async fn handle_create_function( // check UDF server { - let client = ArrowFlightUdfClient::connect(&l) - .await - .map_err(|e| anyhow!(e))?; - /// A helper function to create a unnamed field from data type. - fn to_field(data_type: &DataType) -> Result { - Ok(UdfArrowConvert.to_arrow_field("", data_type)?) - } + let client = get_or_create_flight_client(&l)?; + let convert = UdfArrowConvert { + legacy: client.protocol_version() == 1, + }; + // A helper function to create a unnamed field from data type. + let to_field = |data_type| convert.to_arrow_field("", data_type); let args = arrow_schema::Schema::new( arg_types .iter() @@ -183,15 +181,29 @@ pub async fn handle_create_function( let returns = arrow_schema::Schema::new(match kind { Kind::Scalar(_) => vec![to_field(&return_type)?], Kind::Table(_) => vec![ - arrow_schema::Field::new("row_index", arrow_schema::DataType::Int32, true), + arrow_schema::Field::new("row", arrow_schema::DataType::Int32, true), to_field(&return_type)?, ], _ => unreachable!(), }); - client - .check(&identifier, &args, &returns) + let function = client + .get(&identifier) .await .context("failed to check UDF signature")?; + if !data_types_match(&function.args, &args) { + return Err(ErrorCode::InvalidParameterValue(format!( + "argument type mismatch, expect: {:?}, actual: {:?}", + args, function.args, + )) + .into()); + } + if !data_types_match(&function.returns, &returns) { + return Err(ErrorCode::InvalidParameterValue(format!( + "return type mismatch, expect: {:?}, actual: {:?}", + returns, function.returns, + )) + .into()); + } } link = Some(l); } @@ -276,6 +288,7 @@ pub async fn handle_create_function( let wasm_binary = tokio::task::spawn_blocking(move || { let mut opts = arrow_udf_wasm::build::BuildOpts::default(); + opts.arrow_udf_version = Some("0.3".to_string()); opts.script = script; // use a fixed tempdir to reuse the build cache opts.tempdir = Some(std::env::temp_dir().join("risingwave-rust-udf")); @@ -309,6 +322,13 @@ pub async fn handle_create_function( } }; let runtime = get_or_create_wasm_runtime(&wasm_binary)?; + if runtime.abi_version().0 <= 2 { + return Err(ErrorCode::InvalidParameterValue( + "legacy arrow-udf is no longer supported. please update arrow-udf to 0.3+" + .to_string(), + ) + .into()); + } let identifier_v1 = wasm_identifier_v1( &function_name, &arg_types, @@ -457,13 +477,13 @@ fn wasm_identifier_v1( fn datatype_name(ty: &DataType) -> String { match ty { DataType::Boolean => "boolean".to_string(), - DataType::Int16 => "int2".to_string(), - DataType::Int32 => "int4".to_string(), - DataType::Int64 => "int8".to_string(), - DataType::Float32 => "float4".to_string(), - DataType::Float64 => "float8".to_string(), - DataType::Date => "date".to_string(), - DataType::Time => "time".to_string(), + DataType::Int16 => "int16".to_string(), + DataType::Int32 => "int32".to_string(), + DataType::Int64 => "int64".to_string(), + DataType::Float32 => "float32".to_string(), + DataType::Float64 => "float64".to_string(), + DataType::Date => "date32".to_string(), + DataType::Time => "time64".to_string(), DataType::Timestamp => "timestamp".to_string(), DataType::Timestamptz => "timestamptz".to_string(), DataType::Interval => "interval".to_string(), @@ -471,8 +491,8 @@ fn datatype_name(ty: &DataType) -> String { DataType::Jsonb => "json".to_string(), DataType::Serial => "serial".to_string(), DataType::Int256 => "int256".to_string(), - DataType::Bytea => "bytea".to_string(), - DataType::Varchar => "varchar".to_string(), + DataType::Bytea => "binary".to_string(), + DataType::Varchar => "string".to_string(), DataType::List(inner) => format!("{}[]", datatype_name(inner)), DataType::Struct(s) => format!( "struct<{}>", @@ -482,3 +502,15 @@ fn datatype_name(ty: &DataType) -> String { ), } } + +/// Check if two list of data types match, ignoring field names. +fn data_types_match(a: &arrow_schema::Schema, b: &arrow_schema::Schema) -> bool { + if a.fields().len() != b.fields().len() { + return false; + } + #[allow(clippy::disallowed_methods)] + a.fields() + .iter() + .zip(b.fields()) + .all(|(a, b)| a.data_type().equals_datatype(b.data_type())) +}