From eafc3795668349879fb46112d7b17dc82364f09b Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 22 Sep 2023 14:13:57 +0800 Subject: [PATCH 1/8] feat(java-binding): generate jni method signature with macro --- Cargo.lock | 2 + src/java_binding/Cargo.toml | 1 + src/java_binding/make-java-binding.toml | 4 +- src/java_binding/src/lib.rs | 16 +- src/jni_core/Cargo.toml | 1 + src/jni_core/src/jvm_runtime.rs | 201 +++----------------- src/jni_core/src/lib.rs | 34 +++- src/jni_core/src/macros.rs | 238 ++++++++++++++++++++++++ 8 files changed, 310 insertions(+), 187 deletions(-) create mode 100644 src/jni_core/src/macros.rs diff --git a/Cargo.lock b/Cargo.lock index cd3441db7b2e2..60757fdf9d104 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7201,6 +7201,7 @@ dependencies = [ name = "risingwave_java_binding" version = "0.1.0" dependencies = [ + "jni", "prost", "risingwave_common", "risingwave_expr", @@ -7219,6 +7220,7 @@ dependencies = [ "itertools 0.11.0", "jni", "madsim-tokio", + "paste", "prost", "risingwave_common", "risingwave_expr", diff --git a/src/java_binding/Cargo.toml b/src/java_binding/Cargo.toml index d8d90693f44a6..2da34aa53af77 100644 --- a/src/java_binding/Cargo.toml +++ b/src/java_binding/Cargo.toml @@ -10,6 +10,7 @@ ignored = ["workspace-hack"] normal = ["workspace-hack"] [dependencies] +jni = "0.21.1" prost = "0.11" risingwave_common = { workspace = true } risingwave_jni_core = { workspace = true } diff --git a/src/java_binding/make-java-binding.toml b/src/java_binding/make-java-binding.toml index 3be65ec2158a6..957cec9c762f5 100644 --- a/src/java_binding/make-java-binding.toml +++ b/src/java_binding/make-java-binding.toml @@ -15,7 +15,7 @@ script = ''' #!/usr/bin/env bash set -ex cd java -mvn install --no-transfer-progress --pl java-binding-integration-test --am -DskipTests=true +mvn install --no-transfer-progress --pl java-binding-integration-test --am -DskipTests=true -Dmaven.javadoc.skip=true mvn dependency:copy-dependencies --no-transfer-progress --pl java-binding-integration-test ''' @@ -109,7 +109,7 @@ RISINGWAVE_ROOT=$(git rev-parse --show-toplevel) cd ${RISINGWAVE_ROOT}/java -mvn install --pl java-binding-benchmark --am -DskipTests=true +mvn install --pl java-binding-benchmark --am -DskipTests=true -Dmaven.javadoc.skip=true mvn dependency:copy-dependencies --pl java-binding-benchmark diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index aa7e564ed1ace..6edf4d29ce557 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -12,4 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub use risingwave_jni_core::*; +#![feature(result_option_inspect)] + +use std::ffi::c_void; + +use jni::sys::{jint, JNI_VERSION_1_2}; +use jni::JavaVM; +use risingwave_jni_core::register_native_method_for_jvm; + +#[no_mangle] +#[allow(non_snake_case)] +pub extern "system" fn JNI_OnLoad(jvm: JavaVM, _reserved: *mut c_void) -> jint { + let _ = register_native_method_for_jvm(&jvm) + .inspect_err(|_e| eprintln!("unable to register native method")); + JNI_VERSION_1_2 +} diff --git a/src/jni_core/Cargo.toml b/src/jni_core/Cargo.toml index c8bba371c8dea..40195108b1039 100644 --- a/src/jni_core/Cargo.toml +++ b/src/jni_core/Cargo.toml @@ -14,6 +14,7 @@ bytes = "1" futures = { version = "0.3", default-features = false, features = ["alloc"] } itertools = "0.11" jni = "0.21.1" +paste = "1" prost = "0.11" risingwave_common = { workspace = true } risingwave_hummock_sdk = { workspace = true } diff --git a/src/jni_core/src/jvm_runtime.rs b/src/jni_core/src/jvm_runtime.rs index 5559abd3ffa3f..ec632fc612508 100644 --- a/src/jni_core/src/jvm_runtime.rs +++ b/src/jni_core/src/jvm_runtime.rs @@ -91,12 +91,12 @@ pub static JVM: LazyLock> = LazyLock::new(|| { tracing::info!("initialize JVM successfully"); - register_native_method_for_jvm(&jvm); + register_native_method_for_jvm(&jvm).unwrap(); Ok(jvm) }); -fn register_native_method_for_jvm(jvm: &JavaVM) { +pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::Error> { let mut env = jvm .attach_current_thread() .inspect_err(|e| tracing::error!("jvm attach thread error: {:?}", e)) @@ -106,179 +106,30 @@ fn register_native_method_for_jvm(jvm: &JavaVM) { .find_class("com/risingwave/java/binding/Binding") .inspect_err(|e| tracing::error!("jvm find class error: {:?}", e)) .unwrap(); - env.register_native_methods( - binding_class, - &[ - NativeMethod { - name: JNIString::from("vnodeCount"), - sig: JNIString::from("()I"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_vnodeCount as *mut c_void, - }, - #[cfg(not(madsim))] - NativeMethod { - name: JNIString::from("hummockIteratorNew"), - sig: JNIString::from("([B)J"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_hummockIteratorNew - as *mut c_void, - }, - #[cfg(not(madsim))] - NativeMethod { - name: JNIString::from("hummockIteratorNext"), - sig: JNIString::from("(J)J"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_hummockIteratorNext - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("hummockIteratorClose"), - sig: JNIString::from("(J)V"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_hummockIteratorClose - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetKey"), - sig: JNIString::from("(J)[B"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetKey as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetOp"), - sig: JNIString::from("(J)I"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetOp as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowIsNull"), - sig: JNIString::from("(JI)Z"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowIsNull as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetInt16Value"), - sig: JNIString::from("(JI)S"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetInt16Value - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetInt32Value"), - sig: JNIString::from("(JI)I"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetInt32Value - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetInt64Value"), - sig: JNIString::from("(JI)J"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetInt64Value - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetFloatValue"), - sig: JNIString::from("(JI)F"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetFloatValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetDoubleValue"), - sig: JNIString::from("(JI)D"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetDoubleValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetBooleanValue"), - sig: JNIString::from("(JI)Z"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetBooleanValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetStringValue"), - sig: JNIString::from("(JI)Ljava/lang/String;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetStringValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetTimestampValue"), - sig: JNIString::from("(JI)Ljava/sql/Timestamp;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetTimestampValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetDecimalValue"), - sig: JNIString::from("(JI)Ljava/math/BigDecimal;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetDecimalValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetTimeValue"), - sig: JNIString::from("(JI)Ljava/sql/Time;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetTimeValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetDateValue"), - sig: JNIString::from("(JI)Ljava/sql/Date;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetDateValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetIntervalValue"), - sig: JNIString::from("(JI)Ljava/lang/String;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetIntervalValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetJsonbValue"), - sig: JNIString::from("(JI)Ljava/lang/String;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetJsonbValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetByteaValue"), - sig: JNIString::from("(JI)[B"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetByteaValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetArrayValue"), - sig: JNIString::from("(JILjava/lang/Class;)Ljava/lang/Object;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetArrayValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowClose"), - sig: JNIString::from("(J)V"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowClose as *mut c_void, - }, - NativeMethod { - name: JNIString::from("streamChunkIteratorNew"), - sig: JNIString::from("([B)J"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorNew - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("streamChunkIteratorNext"), - sig: JNIString::from("(J)J"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorNext - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("streamChunkIteratorClose"), - sig: JNIString::from("(J)V"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorClose - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("streamChunkIteratorFromPretty"), - sig: JNIString::from("(Ljava/lang/String;)J"), - fn_ptr: - crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorFromPretty - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("sendCdcSourceMsgToChannel"), - sig: JNIString::from("(J[B)Z"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel - as *mut c_void, - }, - ], - ) - .inspect_err(|e| tracing::error!("jvm register native methods error: {:?}", e)) - .unwrap(); + use crate::*; + macro_rules! gen_native_method_array { + () => {{ + $crate::for_all_native_methods! {gen_native_method_array} + }}; + ({$({ $func_name:ident, $sig:expr }),*}) => { + [ + $( + { + let fn_ptr = paste::paste! {[ ]} as *mut c_void; + NativeMethod { + name: JNIString::from(stringify! {$func_name}), + sig: JNIString::from($sig), + fn_ptr, + } + }, + )* + + ] + } + } + env.register_native_methods(binding_class, &gen_native_method_array!()) + .inspect_err(|e| tracing::error!("jvm register native methods error: {:?}", e))?; tracing::info!("register native methods for jvm successfully"); + Ok(()) } diff --git a/src/jni_core/src/lib.rs b/src/jni_core/src/lib.rs index 6fa6f2f10e991..86ac65b8b4f0c 100644 --- a/src/jni_core/src/lib.rs +++ b/src/jni_core/src/lib.rs @@ -20,6 +20,7 @@ pub mod hummock_iterator; pub mod jvm_runtime; +mod macros; pub mod stream_chunk_iterator; use std::backtrace::Backtrace; @@ -51,6 +52,7 @@ use thiserror::Error; use tokio::runtime::Runtime; use tokio::sync::mpsc::Sender; +pub use crate::jvm_runtime::register_native_method_for_jvm; use crate::stream_chunk_iterator::{StreamChunkIterator, StreamChunkRow}; pub type GetEventStreamJniSender = Sender; @@ -296,30 +298,44 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_vnodeCount( VirtualNode::COUNT as jint } -#[cfg(not(madsim))] #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorNew<'a>( env: EnvParam<'a>, read_plan: JByteArray<'a>, ) -> Pointer<'static, HummockJavaBindingIterator> { execute_and_catch(env, move |env| { - let read_plan = Message::decode(to_guarded_slice(&read_plan, env)?.deref())?; - let iter = RUNTIME.block_on(HummockJavaBindingIterator::new(read_plan))?; - Ok(iter.into()) + #[cfg(madsim)] + { + unreachable!() + } + + #[cfg(not(madsim))] + { + let read_plan = Message::decode(to_guarded_slice(&read_plan, env)?.deref())?; + let iter = RUNTIME.block_on(HummockJavaBindingIterator::new(read_plan))?; + Ok(iter.into()) + } }) } -#[cfg(not(madsim))] #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorNext<'a>( env: EnvParam<'a>, mut pointer: Pointer<'a, HummockJavaBindingIterator>, ) -> Pointer<'static, JavaBindingRow> { execute_and_catch(env, move |_env| { - let iter = pointer.as_mut(); - match RUNTIME.block_on(iter.next())? { - None => Ok(Pointer::null()), - Some(row) => Ok(JavaBindingRow::with_keyed(row, iter.class_cache.clone()).into()), + #[cfg(madsim)] + { + unreachable!() + } + + #[cfg(not(madsim))] + { + let iter = pointer.as_mut(); + match RUNTIME.block_on(iter.next())? { + None => Ok(Pointer::null()), + Some(row) => Ok(JavaBindingRow::with_keyed(row, iter.class_cache.clone()).into()), + } } }) } diff --git a/src/jni_core/src/macros.rs b/src/jni_core/src/macros.rs new file mode 100644 index 0000000000000..8d64e7ce8b422 --- /dev/null +++ b/src/jni_core/src/macros.rs @@ -0,0 +1,238 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[macro_export] +macro_rules! gen_class_name { + ($last:ident) => { + stringify! {$last} + }; + ($first:ident . $($rest:ident).+) => { + concat! {stringify! {$first}, "/", gen_class_name! {$($rest).+} } + } +} + +#[macro_export] +macro_rules! gen_jni_sig_inner { + ($(public)? static native $($rest:tt)*) => { + gen_jni_sig_inner! { $($rest)* } + }; + ($($ret:tt).+ $func_name:ident($($args:tt)*)) => { + concat! {"(", gen_jni_sig_inner!{$($args)*}, ")", gen_jni_sig_inner! {$($ret).+} } + }; + (boolean) => { + "Z" + }; + (byte) => { + "B" + }; + (char) => { + "C" + }; + (short) => { + "S" + }; + (int) => { + "I" + }; + (long) => { + "J" + }; + (float) => { + "F" + }; + (double) => { + "D" + }; + (void) => { + "V" + }; + (String) => { + gen_jni_sig_inner! { java.lang.String } + }; + (Object) => { + gen_jni_sig_inner! { java.lang.Object } + }; + (Class) => { + gen_jni_sig_inner! { java.lang.Class } + }; + ($($class_part:ident).+) => { + concat! {"L", gen_class_name! {$($class_part).+}, ";"} + }; + ($($class_part:ident).+ $(.)? [] $($param_name:ident)? $(,$($rest:tt)*)?) => { + concat! { "[", gen_jni_sig_inner! {$($class_part).+}, gen_jni_sig_inner! {$($($rest)*)?}} + }; + ($($class_part:ident).+ $($param_name:ident)? $(,$($rest:tt)*)?) => { + concat! { gen_jni_sig_inner! {$($class_part).+}, gen_jni_sig_inner! {$($($rest)*)?}} + }; + () => { + "" + }; + ($($invalid:tt)*) => { + compile_error!(concat!("unsupported type {{", stringify!($($invalid)*), "}}")) + }; +} + +#[macro_export] +macro_rules! gen_jni_sig { + ($($input:tt)*) => {{ + // this macro only provide with a expression context + gen_jni_sig_inner! {$($input)*} + }} +} + +#[macro_export] +macro_rules! for_all_plain_native_methods { + ($macro:path $(,$args:tt)*) => { + $macro! { + { + public static native int vnodeCount(); + static native long hummockIteratorNew(byte[] readPlan); + static native long hummockIteratorNext(long pointer); + static native void hummockIteratorClose(long pointer); + static native byte.[] rowGetKey(long pointer); + static native int rowGetOp(long pointer); + static native boolean rowIsNull(long pointer, int index); + static native short rowGetInt16Value(long pointer, int index); + static native int rowGetInt32Value(long pointer, int index); + static native long rowGetInt64Value(long pointer, int index); + static native float rowGetFloatValue(long pointer, int index); + static native double rowGetDoubleValue(long pointer, int index); + static native boolean rowGetBooleanValue(long pointer, int index); + static native String rowGetStringValue(long pointer, int index); + static native java.sql.Timestamp rowGetTimestampValue(long pointer, int index); + static native java.math.BigDecimal rowGetDecimalValue(long pointer, int index); + static native java.sql.Time rowGetTimeValue(long pointer, int index); + static native java.sql.Date rowGetDateValue(long pointer, int index); + static native String rowGetIntervalValue(long pointer, int index); + static native String rowGetJsonbValue(long pointer, int index); + static native byte.[] rowGetByteaValue(long pointer, int index); + static native Object rowGetArrayValue(long pointer, int index, Class clazz); + static native void rowClose(long pointer); + static native long streamChunkIteratorNew(byte[] streamChunkPayload); + static native long streamChunkIteratorNext(long pointer); + static native void streamChunkIteratorClose(long pointer); + static native long streamChunkIteratorFromPretty(String str); + public static native boolean sendCdcSourceMsgToChannel(long channelPtr, byte[] msg); + } + $(,$args)* + } + }; +} + +#[macro_export] +macro_rules! for_all_native_methods { + ( + { + $($(public)? static native $($ret:tt).+ $func_name:ident($($args:tt)*);)* + }, + $macro:path + $(,$extra_args:tt)* + ) => { + $macro! { + { + $( + { $func_name, {gen_jni_sig! {$($ret).+ $func_name($($args)*)}}} + ),* + } + $(,$extra_args)* + } + }; + ($macro:path $(,$args:tt)*) => { + $crate::for_all_plain_native_methods! { + $crate::for_all_native_methods, + $macro + $(,$args)* + } + }; +} + +#[cfg(test)] +mod tests { + #[test] + fn test_gen_jni_sig() { + assert_eq!(gen_jni_sig!(int), "I"); + assert_eq!(gen_jni_sig!(boolean f(int, short, byte[])), "(IS[B)Z"); + assert_eq!( + gen_jni_sig!(boolean f(int, short, byte[], java.lang.String)), + "(IS[BLjava/lang/String;)Z" + ); + assert_eq!( + gen_jni_sig!(boolean f(int, java.lang.String)), + "(ILjava/lang/String;)Z" + ); + assert_eq!(gen_jni_sig!(public static native int vnodeCount()), "()I"); + assert_eq!( + gen_jni_sig!(long hummockIteratorNew(byte[] readPlan)), + "([B)J" + ); + assert_eq!(gen_jni_sig!(long hummockIteratorNext(long pointer)), "(J)J"); + assert_eq!( + gen_jni_sig!(void hummockIteratorClose(long pointer)), + "(J)V" + ); + assert_eq!(gen_jni_sig!(byte.[] rowGetKey(long pointer)), "(J)[B"); + assert_eq!( + gen_jni_sig!(java.sql.Timestamp rowGetTimestampValue(long pointer, int index)), + "(JI)Ljava/sql/Timestamp;" + ); + assert_eq!( + gen_jni_sig!(String rowGetStringValue(long pointer, int index)), + "(JI)Ljava/lang/String;" + ); + assert_eq!( + gen_jni_sig!(static native Object rowGetArrayValue(long pointer, int index, Class clazz)), + "(JILjava/lang/Class;)Ljava/lang/Object;" + ); + } + + #[test] + fn test_for_all_gen() { + macro_rules! gen_array { + (test) => {{ + for_all_native_methods! { + { + public static native int vnodeCount(); + static native long hummockIteratorNew(byte[] readPlan); + public static native byte.[] rowGetKey(long pointer); + }, + gen_array + } + }}; + (all) => {{ + for_all_native_methods! { + gen_array + } + }}; + ({$({ $func_name:ident, $sig:expr }),*}) => {{ + [ + $( + (stringify! {$func_name}, $sig), + )* + ] + }}; +} + let sig: [(_, _); 3] = gen_array!(test); + assert_eq!( + sig, + [ + ("vnodeCount", "()I"), + ("hummockIteratorNew", "([B)J"), + ("rowGetKey", "(J)[B") + ] + ); + + let sig = gen_array!(all); + assert!(!sig.is_empty()); + } +} From 605f3bfc3d4c99a47b0d67eeb8c74f76b46e3c53 Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 22 Sep 2023 15:39:08 +0800 Subject: [PATCH 2/8] pass args types and return type in for_all_native_methods --- src/jni_core/src/jvm_runtime.rs | 6 +++--- src/jni_core/src/macros.rs | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/jni_core/src/jvm_runtime.rs b/src/jni_core/src/jvm_runtime.rs index ec632fc612508..bac9f8901c444 100644 --- a/src/jni_core/src/jvm_runtime.rs +++ b/src/jni_core/src/jvm_runtime.rs @@ -111,19 +111,19 @@ pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::E () => {{ $crate::for_all_native_methods! {gen_native_method_array} }}; - ({$({ $func_name:ident, $sig:expr }),*}) => { + ({$({ $func_name:ident, {$($ret:tt).+}, {$($args:tt)*} }),*}) => { [ $( { let fn_ptr = paste::paste! {[ ]} as *mut c_void; + let sig = $crate::gen_jni_sig! { $($ret).+ ($($args)*)}; NativeMethod { name: JNIString::from(stringify! {$func_name}), - sig: JNIString::from($sig), + sig: JNIString::from(sig), fn_ptr, } }, )* - ] } } diff --git a/src/jni_core/src/macros.rs b/src/jni_core/src/macros.rs index 8d64e7ce8b422..29bb750c7f326 100644 --- a/src/jni_core/src/macros.rs +++ b/src/jni_core/src/macros.rs @@ -27,7 +27,7 @@ macro_rules! gen_jni_sig_inner { ($(public)? static native $($rest:tt)*) => { gen_jni_sig_inner! { $($rest)* } }; - ($($ret:tt).+ $func_name:ident($($args:tt)*)) => { + ($($ret:tt).+ $($func_name:ident)? ($($args:tt)*)) => { concat! {"(", gen_jni_sig_inner!{$($args)*}, ")", gen_jni_sig_inner! {$($ret).+} } }; (boolean) => { @@ -142,7 +142,7 @@ macro_rules! for_all_native_methods { $macro! { { $( - { $func_name, {gen_jni_sig! {$($ret).+ $func_name($($args)*)}}} + { $func_name, {$($ret).+}, {$($args)*}} ),* } $(,$extra_args)* @@ -214,10 +214,10 @@ mod tests { gen_array } }}; - ({$({ $func_name:ident, $sig:expr }),*}) => {{ + ({$({ $func_name:ident, {$($ret:tt).+}, {$($args:tt)*} }),*}) => {{ [ $( - (stringify! {$func_name}, $sig), + (stringify! {$func_name}, gen_jni_sig! { $($ret).+ ($($args)*)}), )* ] }}; From d207f572a42d052cbba606959ff35854bee88126 Mon Sep 17 00:00:00 2001 From: William Wen Date: Sat, 23 Sep 2023 14:06:45 +0800 Subject: [PATCH 3/8] avoid dot in array return type --- src/jni_core/src/jvm_runtime.rs | 4 +- src/jni_core/src/macros.rs | 152 ++++++++++++++++++++++++++++---- 2 files changed, 138 insertions(+), 18 deletions(-) diff --git a/src/jni_core/src/jvm_runtime.rs b/src/jni_core/src/jvm_runtime.rs index bac9f8901c444..62630528961dc 100644 --- a/src/jni_core/src/jvm_runtime.rs +++ b/src/jni_core/src/jvm_runtime.rs @@ -111,12 +111,12 @@ pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::E () => {{ $crate::for_all_native_methods! {gen_native_method_array} }}; - ({$({ $func_name:ident, {$($ret:tt).+}, {$($args:tt)*} }),*}) => { + ({$({ $func_name:ident, {$($ret:tt)+}, {$($args:tt)*} }),*}) => { [ $( { let fn_ptr = paste::paste! {[ ]} as *mut c_void; - let sig = $crate::gen_jni_sig! { $($ret).+ ($($args)*)}; + let sig = $crate::gen_jni_sig! { $($ret)+ ($($args)*)}; NativeMethod { name: JNIString::from(stringify! {$func_name}), sig: JNIString::from(sig), diff --git a/src/jni_core/src/macros.rs b/src/jni_core/src/macros.rs index 29bb750c7f326..bdb5c60ec3f82 100644 --- a/src/jni_core/src/macros.rs +++ b/src/jni_core/src/macros.rs @@ -27,9 +27,12 @@ macro_rules! gen_jni_sig_inner { ($(public)? static native $($rest:tt)*) => { gen_jni_sig_inner! { $($rest)* } }; - ($($ret:tt).+ $($func_name:ident)? ($($args:tt)*)) => { + ($($ret:ident).+ $($func_name:ident)? ($($args:tt)*)) => { concat! {"(", gen_jni_sig_inner!{$($args)*}, ")", gen_jni_sig_inner! {$($ret).+} } }; + ($($ret:ident).+ [] $($func_name:ident)? ($($args:tt)*)) => { + concat! {"(", gen_jni_sig_inner!{$($args)*}, ")", gen_jni_sig_inner! {$($ret).+ []} } + }; (boolean) => { "Z" }; @@ -72,6 +75,9 @@ macro_rules! gen_jni_sig_inner { ($($class_part:ident).+ $(.)? [] $($param_name:ident)? $(,$($rest:tt)*)?) => { concat! { "[", gen_jni_sig_inner! {$($class_part).+}, gen_jni_sig_inner! {$($($rest)*)?}} }; + (Class $(< ? >)? $($param_name:ident)? $(,$($rest:tt)*)?) => { + concat! { gen_jni_sig_inner! { Class }, gen_jni_sig_inner! {$($($rest)*)?}} + }; ($($class_part:ident).+ $($param_name:ident)? $(,$($rest:tt)*)?) => { concat! { gen_jni_sig_inner! {$($class_part).+}, gen_jni_sig_inner! {$($($rest)*)?}} }; @@ -97,32 +103,69 @@ macro_rules! for_all_plain_native_methods { $macro! { { public static native int vnodeCount(); + + // hummock iterator method + // Return a pointer to the iterator static native long hummockIteratorNew(byte[] readPlan); + + // return a pointer to the next row static native long hummockIteratorNext(long pointer); + + // Since the underlying rust does not have garbage collection, we will have to manually call + // close on the iterator to release the iterator instance pointed by the pointer. static native void hummockIteratorClose(long pointer); - static native byte.[] rowGetKey(long pointer); + + // row method + static native byte[] rowGetKey(long pointer); + static native int rowGetOp(long pointer); + static native boolean rowIsNull(long pointer, int index); + static native short rowGetInt16Value(long pointer, int index); + static native int rowGetInt32Value(long pointer, int index); + static native long rowGetInt64Value(long pointer, int index); + static native float rowGetFloatValue(long pointer, int index); + static native double rowGetDoubleValue(long pointer, int index); + static native boolean rowGetBooleanValue(long pointer, int index); + static native String rowGetStringValue(long pointer, int index); + static native java.sql.Timestamp rowGetTimestampValue(long pointer, int index); + static native java.math.BigDecimal rowGetDecimalValue(long pointer, int index); + static native java.sql.Time rowGetTimeValue(long pointer, int index); + static native java.sql.Date rowGetDateValue(long pointer, int index); + static native String rowGetIntervalValue(long pointer, int index); + static native String rowGetJsonbValue(long pointer, int index); - static native byte.[] rowGetByteaValue(long pointer, int index); - static native Object rowGetArrayValue(long pointer, int index, Class clazz); + + static native byte[] rowGetByteaValue(long pointer, int index); + + // TODO: object or object array? + static native Object rowGetArrayValue(long pointer, int index, Class clazz); + + // Since the underlying rust does not have garbage collection, we will have to manually call + // close on the row to release the row instance pointed by the pointer. static native void rowClose(long pointer); + + // stream chunk iterator method static native long streamChunkIteratorNew(byte[] streamChunkPayload); + static native long streamChunkIteratorNext(long pointer); + static native void streamChunkIteratorClose(long pointer); + static native long streamChunkIteratorFromPretty(String str); + public static native boolean sendCdcSourceMsgToChannel(long channelPtr, byte[] msg); } $(,$args)* @@ -130,31 +173,108 @@ macro_rules! for_all_plain_native_methods { }; } +#[macro_export] +macro_rules! for_single_native_method { + ( + {$($ret:tt).+ $func_name:ident ($($args:tt)*)}, + $macro:path + $(,$extra_args:tt)* + ) => { + $macro! { + $func_name, + {$($ret).+}, + {$($args)*} + } + }; + ( + {$($ret:tt).+ [] $func_name:ident ($($args:tt)*)}, + $macro:path + $(,$extra_args:tt)* + ) => { + $macro! { + $func_name, + {$($ret).+ []}, + {$($args)*} + } + }; +} + #[macro_export] macro_rules! for_all_native_methods { + ( + {$($input:tt)*}, + $macro:path + $(,$extra_args:tt)* + ) => {{ + $crate::for_all_native_methods! { + {$($input)*}, + {}, + $macro + $(,$extra_args)* + } + }}; ( { - $($(public)? static native $($ret:tt).+ $func_name:ident($($args:tt)*);)* + $(public)? static native $($ret:tt).+ $func_name:ident($($args:tt)*); $($rest:tt)* + }, + { + $({$prev_func_name:ident, {$($prev_ret:tt)*}, {$($prev_args:tt)*}})* + }, + $macro:path + $(,$extra_args:tt)* + ) => { + $crate::for_all_native_methods! { + {$($rest)*}, + { + $({$prev_func_name, {$($prev_ret)*}, {$($prev_args)*}})* + {$func_name, {$($ret).+}, {$($args)*}} + }, + $macro + $(,$extra_args)* + } + }; + ( + { + $(public)? static native $($ret:tt).+ [] $func_name:ident($($args:tt)*); $($rest:tt)* + }, + { + $({$prev_func_name:ident, {$($prev_ret:tt)*}, {$($prev_args:tt)*}})* + }, + $macro:path + $(,$extra_args:tt)* + ) => { + $crate::for_all_native_methods! { + {$($rest)*}, + { + $({$prev_func_name, {$($prev_ret)*}, {$($prev_args)*}})* + {$func_name, {$($ret).+ []}, {$($args)*}} + }, + $macro + $(,$extra_args)* + } + }; + ( + {}, + { + $({$func_name:ident, {$($ret:tt)*}, {$($args:tt)*}})* }, $macro:path $(,$extra_args:tt)* ) => { $macro! { { - $( - { $func_name, {$($ret).+}, {$($args)*}} - ),* + $({$func_name, {$($ret)*}, {$($args)*}}),* } $(,$extra_args)* } }; - ($macro:path $(,$args:tt)*) => { + ($macro:path $(,$args:tt)*) => {{ $crate::for_all_plain_native_methods! { $crate::for_all_native_methods, $macro $(,$args)* } - }; + }}; } #[cfg(test)] @@ -181,7 +301,7 @@ mod tests { gen_jni_sig!(void hummockIteratorClose(long pointer)), "(J)V" ); - assert_eq!(gen_jni_sig!(byte.[] rowGetKey(long pointer)), "(J)[B"); + assert_eq!(gen_jni_sig!(byte[] rowGetKey(long pointer)), "(J)[B"); assert_eq!( gen_jni_sig!(java.sql.Timestamp rowGetTimestampValue(long pointer, int index)), "(JI)Ljava/sql/Timestamp;" @@ -204,7 +324,7 @@ mod tests { { public static native int vnodeCount(); static native long hummockIteratorNew(byte[] readPlan); - public static native byte.[] rowGetKey(long pointer); + public static native byte[] rowGetKey(long pointer); }, gen_array } @@ -214,14 +334,14 @@ mod tests { gen_array } }}; - ({$({ $func_name:ident, {$($ret:tt).+}, {$($args:tt)*} }),*}) => {{ + ({$({ $func_name:ident, {$($ret:tt)+}, {$($args:tt)*} }),*}) => {{ [ $( - (stringify! {$func_name}, gen_jni_sig! { $($ret).+ ($($args)*)}), + (stringify! {$func_name}, gen_jni_sig! { $($ret)+ ($($args)*)}), )* ] - }}; -} + }}; + } let sig: [(_, _); 3] = gen_array!(test); assert_eq!( sig, From ae836982b4cd9b53f0f7ee1c4bdc79840fc84da2 Mon Sep 17 00:00:00 2001 From: William Wen Date: Sat, 23 Sep 2023 15:43:08 +0800 Subject: [PATCH 4/8] feat(java-binding): store java binding row in iter --- .../connector/api/sink/ArraySinkRow.java | 3 - .../connector/api/sink/SinkRow.java | 2 +- .../com/risingwave/connector/FileSink.java | 41 ++- .../deserializer/StreamChunkDeserializer.java | 17 - .../risingwave/connector/DeltaLakeSink.java | 39 ++- .../java/com/risingwave/connector/EsSink.java | 7 +- .../AppendOnlyIcebergSinkWriter.java | 103 +++--- .../connector/UpsertIcebergSinkWriter.java | 95 +++--- .../com/risingwave/connector/JDBCSink.java | 25 +- .../java/binding/StreamchunkBenchmark.java | 11 +- .../java/binding/HummockReadDemo.java | 11 +- .../java/binding/StreamChunkDemo.java | 11 +- .../com/risingwave/java/binding/BaseRow.java | 44 +-- .../com/risingwave/java/binding/Binding.java | 61 ++-- .../java/binding/HummockIterator.java | 8 +- .../com/risingwave/java/binding/KeyedRow.java | 2 +- .../java/binding/StreamChunkIterator.java | 10 +- .../java/binding/StreamChunkRow.java | 2 +- src/java_binding/make-java-binding.toml | 2 +- src/jni_core/src/hummock_iterator.rs | 32 +- src/jni_core/src/lib.rs | 301 +++++++++--------- src/jni_core/src/macros.rs | 61 ++-- src/jni_core/src/stream_chunk_iterator.rs | 48 +-- 23 files changed, 407 insertions(+), 529 deletions(-) diff --git a/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/sink/ArraySinkRow.java b/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/sink/ArraySinkRow.java index e443a7d3e286e..9140558c41cd6 100644 --- a/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/sink/ArraySinkRow.java +++ b/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/sink/ArraySinkRow.java @@ -39,7 +39,4 @@ public Data.Op getOp() { public int size() { return values.length; } - - @Override - public void close() throws Exception {} } diff --git a/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/sink/SinkRow.java b/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/sink/SinkRow.java index 0ae0aa3facf7e..dcddfc07479b6 100644 --- a/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/sink/SinkRow.java +++ b/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/sink/SinkRow.java @@ -16,7 +16,7 @@ import com.risingwave.proto.Data; -public interface SinkRow extends AutoCloseable { +public interface SinkRow { Object get(int index); Data.Op getOp(); diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/FileSink.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/FileSink.java index 5ea2db204a06c..0959b389e55ca 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/FileSink.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/FileSink.java @@ -60,28 +60,25 @@ public FileSink(FileSinkConfig config, TableSchema tableSchema) { @Override public void write(Iterator rows) { while (rows.hasNext()) { - try (SinkRow row = rows.next()) { - switch (row.getOp()) { - case INSERT: - String buf = - new Gson() - .toJson( - IntStream.range(0, row.size()) - .mapToObj(row::get) - .toArray()); - try { - sinkWriter.write(buf + System.lineSeparator()); - } catch (IOException e) { - throw INTERNAL.withCause(e).asRuntimeException(); - } - break; - default: - throw UNIMPLEMENTED - .withDescription("unsupported operation: " + row.getOp()) - .asRuntimeException(); - } - } catch (Exception e) { - throw new RuntimeException(e); + SinkRow row = rows.next(); + switch (row.getOp()) { + case INSERT: + String buf = + new Gson() + .toJson( + IntStream.range(0, row.size()) + .mapToObj(row::get) + .toArray()); + try { + sinkWriter.write(buf + System.lineSeparator()); + } catch (IOException e) { + throw INTERNAL.withCause(e).asRuntimeException(); + } + break; + default: + throw UNIMPLEMENTED + .withDescription("unsupported operation: " + row.getOp()) + .asRuntimeException(); } } } diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/deserializer/StreamChunkDeserializer.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/deserializer/StreamChunkDeserializer.java index a8175bb2f738d..ab9a9068fabb9 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/deserializer/StreamChunkDeserializer.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/deserializer/StreamChunkDeserializer.java @@ -251,14 +251,12 @@ public CloseableIterator deserialize( static class StreamChunkRowWrapper implements SinkRow { - private boolean isClosed; private final StreamChunkRow inner; private final ValueGetter[] valueGetters; StreamChunkRowWrapper(StreamChunkRow inner, ValueGetter[] valueGetters) { this.inner = inner; this.valueGetters = valueGetters; - this.isClosed = false; } @Override @@ -275,14 +273,6 @@ public Data.Op getOp() { public int size() { return valueGetters.length; } - - @Override - public void close() { - if (!isClosed) { - this.isClosed = true; - inner.close(); - } - } } static class StreamChunkIteratorWrapper implements CloseableIterator { @@ -299,13 +289,6 @@ public StreamChunkIteratorWrapper(StreamChunkIterator iter, ValueGetter[] valueG @Override public void close() { iter.close(); - try { - if (row != null) { - row.close(); - } - } catch (Exception e) { - throw new RuntimeException(e); - } } @Override diff --git a/java/connector-node/risingwave-sink-deltalake/src/main/java/com/risingwave/connector/DeltaLakeSink.java b/java/connector-node/risingwave-sink-deltalake/src/main/java/com/risingwave/connector/DeltaLakeSink.java index 1b3a7c28d97a9..413edeb10df81 100644 --- a/java/connector-node/risingwave-sink-deltalake/src/main/java/com/risingwave/connector/DeltaLakeSink.java +++ b/java/connector-node/risingwave-sink-deltalake/src/main/java/com/risingwave/connector/DeltaLakeSink.java @@ -75,27 +75,24 @@ public void write(Iterator rows) { } } while (rows.hasNext()) { - try (SinkRow row = rows.next()) { - switch (row.getOp()) { - case INSERT: - GenericRecord record = new GenericData.Record(this.sinkSchema); - for (int i = 0; i < this.sinkSchema.getFields().size(); i++) { - record.put(i, row.get(i)); - } - try { - this.parquetWriter.write(record); - this.numOutputRows += 1; - } catch (IOException ioException) { - throw INTERNAL.withCause(ioException).asRuntimeException(); - } - break; - default: - throw UNIMPLEMENTED - .withDescription("unsupported operation: " + row.getOp()) - .asRuntimeException(); - } - } catch (Exception e) { - throw new RuntimeException(e); + SinkRow row = rows.next(); + switch (row.getOp()) { + case INSERT: + GenericRecord record = new GenericData.Record(this.sinkSchema); + for (int i = 0; i < this.sinkSchema.getFields().size(); i++) { + record.put(i, row.get(i)); + } + try { + this.parquetWriter.write(record); + this.numOutputRows += 1; + } catch (IOException ioException) { + throw INTERNAL.withCause(ioException).asRuntimeException(); + } + break; + default: + throw UNIMPLEMENTED + .withDescription("unsupported operation: " + row.getOp()) + .asRuntimeException(); } } } diff --git a/java/connector-node/risingwave-sink-es-7/src/main/java/com/risingwave/connector/EsSink.java b/java/connector-node/risingwave-sink-es-7/src/main/java/com/risingwave/connector/EsSink.java index 7c1727f4a82f3..e2903970128f8 100644 --- a/java/connector-node/risingwave-sink-es-7/src/main/java/com/risingwave/connector/EsSink.java +++ b/java/connector-node/risingwave-sink-es-7/src/main/java/com/risingwave/connector/EsSink.java @@ -254,11 +254,8 @@ private void writeRow(SinkRow row) { @Override public void write(Iterator rows) { while (rows.hasNext()) { - try (SinkRow row = rows.next()) { - writeRow(row); - } catch (Exception e) { - throw new RuntimeException(e); - } + SinkRow row = rows.next(); + writeRow(row); } } diff --git a/java/connector-node/risingwave-sink-iceberg/src/main/java/com/risingwave/connector/AppendOnlyIcebergSinkWriter.java b/java/connector-node/risingwave-sink-iceberg/src/main/java/com/risingwave/connector/AppendOnlyIcebergSinkWriter.java index 6a6aad0a460e0..6b60eedd23d37 100644 --- a/java/connector-node/risingwave-sink-iceberg/src/main/java/com/risingwave/connector/AppendOnlyIcebergSinkWriter.java +++ b/java/connector-node/risingwave-sink-iceberg/src/main/java/com/risingwave/connector/AppendOnlyIcebergSinkWriter.java @@ -55,62 +55,57 @@ public AppendOnlyIcebergSinkWriter( @Override public void write(Iterator rows) { while (rows.hasNext()) { - try (SinkRow row = rows.next()) { - switch (row.getOp()) { - case INSERT: - Record record = GenericRecord.create(rowSchema); - if (row.size() != tableSchema.getColumnNames().length) { - throw INTERNAL.withDescription("row values do not match table schema") + SinkRow row = rows.next(); + switch (row.getOp()) { + case INSERT: + Record record = GenericRecord.create(rowSchema); + if (row.size() != tableSchema.getColumnNames().length) { + throw INTERNAL.withDescription("row values do not match table schema") + .asRuntimeException(); + } + for (int i = 0; i < rowSchema.columns().size(); i++) { + record.set(i, row.get(i)); + } + PartitionKey partitionKey = + new PartitionKey(icebergTable.spec(), icebergTable.schema()); + partitionKey.partition(record); + DataWriter dataWriter; + if (dataWriterMap.containsKey(partitionKey)) { + dataWriter = dataWriterMap.get(partitionKey); + } else { + try { + String filename = fileFormat.addExtension(UUID.randomUUID().toString()); + OutputFile outputFile = + icebergTable + .io() + .newOutputFile( + icebergTable.location() + + "/data/" + + icebergTable + .spec() + .partitionToPath(partitionKey) + + "/" + + filename); + dataWriter = + Parquet.writeData(outputFile) + .schema(rowSchema) + .withSpec(icebergTable.spec()) + .withPartition(partitionKey) + .createWriterFunc(GenericParquetWriter::buildWriter) + .overwrite() + .build(); + } catch (Exception e) { + throw INTERNAL.withDescription("failed to create dataWriter") .asRuntimeException(); } - for (int i = 0; i < rowSchema.columns().size(); i++) { - record.set(i, row.get(i)); - } - PartitionKey partitionKey = - new PartitionKey(icebergTable.spec(), icebergTable.schema()); - partitionKey.partition(record); - DataWriter dataWriter; - if (dataWriterMap.containsKey(partitionKey)) { - dataWriter = dataWriterMap.get(partitionKey); - } else { - try { - String filename = - fileFormat.addExtension(UUID.randomUUID().toString()); - OutputFile outputFile = - icebergTable - .io() - .newOutputFile( - icebergTable.location() - + "/data/" - + icebergTable - .spec() - .partitionToPath( - partitionKey) - + "/" - + filename); - dataWriter = - Parquet.writeData(outputFile) - .schema(rowSchema) - .withSpec(icebergTable.spec()) - .withPartition(partitionKey) - .createWriterFunc(GenericParquetWriter::buildWriter) - .overwrite() - .build(); - } catch (Exception e) { - throw INTERNAL.withDescription("failed to create dataWriter") - .asRuntimeException(); - } - dataWriterMap.put(partitionKey, dataWriter); - } - dataWriter.write(record); - break; - default: - throw UNIMPLEMENTED - .withDescription("unsupported operation: " + row.getOp()) - .asRuntimeException(); - } - } catch (Exception e) { - throw new RuntimeException(e); + dataWriterMap.put(partitionKey, dataWriter); + } + dataWriter.write(record); + break; + default: + throw UNIMPLEMENTED + .withDescription("unsupported operation: " + row.getOp()) + .asRuntimeException(); } } } diff --git a/java/connector-node/risingwave-sink-iceberg/src/main/java/com/risingwave/connector/UpsertIcebergSinkWriter.java b/java/connector-node/risingwave-sink-iceberg/src/main/java/com/risingwave/connector/UpsertIcebergSinkWriter.java index 10fca804acf64..e1d649f028bf8 100644 --- a/java/connector-node/risingwave-sink-iceberg/src/main/java/com/risingwave/connector/UpsertIcebergSinkWriter.java +++ b/java/connector-node/risingwave-sink-iceberg/src/main/java/com/risingwave/connector/UpsertIcebergSinkWriter.java @@ -142,57 +142,52 @@ private List> getKeyFromRow(SinkRow row) { @Override public void write(Iterator rows) { while (rows.hasNext()) { - try (SinkRow row = rows.next()) { - if (row.size() != tableSchema.getColumnNames().length) { - throw Status.FAILED_PRECONDITION - .withDescription("row values do not match table schema") - .asRuntimeException(); - } - Record record = newRecord(rowSchema, row); - PartitionKey partitionKey = - new PartitionKey(icebergTable.spec(), icebergTable.schema()); - partitionKey.partition(record); - SinkRowMap sinkRowMap; - if (sinkRowMapByPartition.containsKey(partitionKey)) { - sinkRowMap = sinkRowMapByPartition.get(partitionKey); - } else { - sinkRowMap = new SinkRowMap(); - sinkRowMapByPartition.put(partitionKey, sinkRowMap); - } - switch (row.getOp()) { - case INSERT: - sinkRowMap.insert(getKeyFromRow(row), newRecord(rowSchema, row)); - break; - case DELETE: - sinkRowMap.delete(getKeyFromRow(row), newRecord(deleteRowSchema, row)); - break; - case UPDATE_DELETE: - if (updateBufferExists) { - throw Status.FAILED_PRECONDITION - .withDescription( - "an UPDATE_INSERT should precede an UPDATE_DELETE") - .asRuntimeException(); - } - sinkRowMap.delete(getKeyFromRow(row), newRecord(deleteRowSchema, row)); - updateBufferExists = true; - break; - case UPDATE_INSERT: - if (!updateBufferExists) { - throw Status.FAILED_PRECONDITION - .withDescription( - "an UPDATE_INSERT should precede an UPDATE_DELETE") - .asRuntimeException(); - } - sinkRowMap.insert(getKeyFromRow(row), newRecord(rowSchema, row)); - updateBufferExists = false; - break; - default: - throw UNIMPLEMENTED - .withDescription("unsupported operation: " + row.getOp()) + SinkRow row = rows.next(); + if (row.size() != tableSchema.getColumnNames().length) { + throw Status.FAILED_PRECONDITION + .withDescription("row values do not match table schema") + .asRuntimeException(); + } + Record record = newRecord(rowSchema, row); + PartitionKey partitionKey = + new PartitionKey(icebergTable.spec(), icebergTable.schema()); + partitionKey.partition(record); + SinkRowMap sinkRowMap; + if (sinkRowMapByPartition.containsKey(partitionKey)) { + sinkRowMap = sinkRowMapByPartition.get(partitionKey); + } else { + sinkRowMap = new SinkRowMap(); + sinkRowMapByPartition.put(partitionKey, sinkRowMap); + } + switch (row.getOp()) { + case INSERT: + sinkRowMap.insert(getKeyFromRow(row), newRecord(rowSchema, row)); + break; + case DELETE: + sinkRowMap.delete(getKeyFromRow(row), newRecord(deleteRowSchema, row)); + break; + case UPDATE_DELETE: + if (updateBufferExists) { + throw Status.FAILED_PRECONDITION + .withDescription("an UPDATE_INSERT should precede an UPDATE_DELETE") .asRuntimeException(); - } - } catch (Exception e) { - throw new RuntimeException(e); + } + sinkRowMap.delete(getKeyFromRow(row), newRecord(deleteRowSchema, row)); + updateBufferExists = true; + break; + case UPDATE_INSERT: + if (!updateBufferExists) { + throw Status.FAILED_PRECONDITION + .withDescription("an UPDATE_INSERT should precede an UPDATE_DELETE") + .asRuntimeException(); + } + sinkRowMap.insert(getKeyFromRow(row), newRecord(rowSchema, row)); + updateBufferExists = false; + break; + default: + throw UNIMPLEMENTED + .withDescription("unsupported operation: " + row.getOp()) + .asRuntimeException(); } } } diff --git a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java index fe23c7db5d846..db09572c30db8 100644 --- a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java +++ b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java @@ -219,22 +219,19 @@ public void write(Iterator rows) { PreparedStatement insertStatement = null; while (rows.hasNext()) { - try (SinkRow row = rows.next()) { - if (row.getOp() == Data.Op.UPDATE_DELETE) { - updateFlag = true; - continue; - } - if (config.isUpsertSink()) { - if (row.getOp() == Data.Op.DELETE) { - deleteStatement = prepareDeleteStatement(row); - } else { - upsertStatement = prepareUpsertStatement(row); - } + SinkRow row = rows.next(); + if (row.getOp() == Data.Op.UPDATE_DELETE) { + updateFlag = true; + continue; + } + if (config.isUpsertSink()) { + if (row.getOp() == Data.Op.DELETE) { + deleteStatement = prepareDeleteStatement(row); } else { - insertStatement = prepareInsertStatement(row); + upsertStatement = prepareUpsertStatement(row); } - } catch (Exception e) { - throw new RuntimeException(e); + } else { + insertStatement = prepareInsertStatement(row); } } diff --git a/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/StreamchunkBenchmark.java b/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/StreamchunkBenchmark.java index 8741044f7b34e..88d70ac59ce0e 100644 --- a/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/StreamchunkBenchmark.java +++ b/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/StreamchunkBenchmark.java @@ -57,13 +57,12 @@ public void getValue(StreamChunkRow row) { public void streamchunkTest() { int count = 0; while (true) { - try (StreamChunkRow row = iter.next()) { - if (row == null) { - break; - } - count += 1; - getValue(row); + StreamChunkRow row = iter.next(); + if (row == null) { + break; } + count += 1; + getValue(row); } if (count != loopTime) { throw new RuntimeException( diff --git a/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/HummockReadDemo.java b/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/HummockReadDemo.java index 9f4038cf3f9a3..f1996bb96f43d 100644 --- a/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/HummockReadDemo.java +++ b/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/HummockReadDemo.java @@ -72,13 +72,12 @@ public static void main(String[] args) { try (HummockIterator iter = new HummockIterator(readPlan)) { int count = 0; while (true) { - try (KeyedRow row = iter.next()) { - if (row == null) { - break; - } - count += 1; - validateRow(row); + KeyedRow row = iter.next(); + if (row == null) { + break; } + count += 1; + validateRow(row); } int expectedCount = 30000; if (count != expectedCount) { diff --git a/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/StreamChunkDemo.java b/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/StreamChunkDemo.java index 0cc6977de2f0c..ad59a74e4c20c 100644 --- a/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/StreamChunkDemo.java +++ b/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/StreamChunkDemo.java @@ -25,13 +25,12 @@ public static void main(String[] args) throws IOException { try (StreamChunkIterator iter = new StreamChunkIterator(payload)) { int count = 0; while (true) { - try (StreamChunkRow row = iter.next()) { - if (row == null) { - break; - } - count += 1; - validateRow(row); + StreamChunkRow row = iter.next(); + if (row == null) { + break; } + count += 1; + validateRow(row); } int expectedCount = 30000; if (count != expectedCount) { diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/BaseRow.java b/java/java-binding/src/main/java/com/risingwave/java/binding/BaseRow.java index a12978d92e995..d9fb28115b68c 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/BaseRow.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/BaseRow.java @@ -14,75 +14,73 @@ package com.risingwave.java.binding; -public class BaseRow implements AutoCloseable { +public class BaseRow { protected final long pointer; - private boolean isClosed; protected BaseRow(long pointer) { this.pointer = pointer; - this.isClosed = false; } public boolean isNull(int index) { - return Binding.rowIsNull(pointer, index); + return Binding.iteratorIsNull(pointer, index); } public short getShort(int index) { - return Binding.rowGetInt16Value(pointer, index); + return Binding.iteratorGetInt16Value(pointer, index); } public int getInt(int index) { - return Binding.rowGetInt32Value(pointer, index); + return Binding.iteratorGetInt32Value(pointer, index); } public long getLong(int index) { - return Binding.rowGetInt64Value(pointer, index); + return Binding.iteratorGetInt64Value(pointer, index); } public float getFloat(int index) { - return Binding.rowGetFloatValue(pointer, index); + return Binding.iteratorGetFloatValue(pointer, index); } public double getDouble(int index) { - return Binding.rowGetDoubleValue(pointer, index); + return Binding.iteratorGetDoubleValue(pointer, index); } public boolean getBoolean(int index) { - return Binding.rowGetBooleanValue(pointer, index); + return Binding.iteratorGetBooleanValue(pointer, index); } public String getString(int index) { - return Binding.rowGetStringValue(pointer, index); + return Binding.iteratorGetStringValue(pointer, index); } public java.sql.Timestamp getTimestamp(int index) { - return Binding.rowGetTimestampValue(pointer, index); + return Binding.iteratorGetTimestampValue(pointer, index); } public java.sql.Time getTime(int index) { - return Binding.rowGetTimeValue(pointer, index); + return Binding.iteratorGetTimeValue(pointer, index); } public java.math.BigDecimal getDecimal(int index) { - return Binding.rowGetDecimalValue(pointer, index); + return Binding.iteratorGetDecimalValue(pointer, index); } public java.sql.Date getDate(int index) { - return Binding.rowGetDateValue(pointer, index); + return Binding.iteratorGetDateValue(pointer, index); } // string representation of interval: "2 mons 3 days 00:00:00.000004" or "P1Y2M3DT4H5M6.789123S" public String getInterval(int index) { - return Binding.rowGetIntervalValue(pointer, index); + return Binding.iteratorGetIntervalValue(pointer, index); } // string representation of jsonb: '{"key": "value"}' public String getJsonb(int index) { - return Binding.rowGetJsonbValue(pointer, index); + return Binding.iteratorGetJsonbValue(pointer, index); } public byte[] getBytea(int index) { - return Binding.rowGetByteaValue(pointer, index); + return Binding.iteratorGetByteaValue(pointer, index); } /** @@ -92,16 +90,8 @@ public byte[] getBytea(int index) { * Object[] elements) */ public Object[] getArray(int index, Class clazz) { - var val = Binding.rowGetArrayValue(pointer, index, clazz); + var val = Binding.iteratorGetArrayValue(pointer, index, clazz); assert (val instanceof Object[]); return (Object[]) val; } - - @Override - public void close() { - if (!isClosed) { - isClosed = true; - Binding.rowClose(pointer); - } - } } diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java index 4a79033b147a8..282efc1b99f08 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java @@ -30,65 +30,52 @@ public class Binding { // hummock iterator method // Return a pointer to the iterator - static native long hummockIteratorNew(byte[] readPlan); + static native long iteratorNewHummock(byte[] readPlan); - // return a pointer to the next row - static native long hummockIteratorNext(long pointer); + static native boolean iteratorNext(long pointer); - // Since the underlying rust does not have garbage collection, we will have to manually call - // close on the iterator to release the iterator instance pointed by the pointer. - static native void hummockIteratorClose(long pointer); + static native void iteratorClose(long pointer); - // row method - static native byte[] rowGetKey(long pointer); + static native long iteratorNewFromStreamChunkPayload(byte[] streamChunkPayload); - static native int rowGetOp(long pointer); + static native long iteratorNewFromStreamChunkPretty(String str); - static native boolean rowIsNull(long pointer, int index); + static native byte[] iteratorGetKey(long pointer); - static native short rowGetInt16Value(long pointer, int index); + static native int iteratorGetOp(long pointer); - static native int rowGetInt32Value(long pointer, int index); + static native boolean iteratorIsNull(long pointer, int index); - static native long rowGetInt64Value(long pointer, int index); + static native short iteratorGetInt16Value(long pointer, int index); - static native float rowGetFloatValue(long pointer, int index); + static native int iteratorGetInt32Value(long pointer, int index); - static native double rowGetDoubleValue(long pointer, int index); + static native long iteratorGetInt64Value(long pointer, int index); - static native boolean rowGetBooleanValue(long pointer, int index); + static native float iteratorGetFloatValue(long pointer, int index); - static native String rowGetStringValue(long pointer, int index); + static native double iteratorGetDoubleValue(long pointer, int index); - static native java.sql.Timestamp rowGetTimestampValue(long pointer, int index); + static native boolean iteratorGetBooleanValue(long pointer, int index); - static native java.math.BigDecimal rowGetDecimalValue(long pointer, int index); + static native String iteratorGetStringValue(long pointer, int index); - static native java.sql.Time rowGetTimeValue(long pointer, int index); + static native java.sql.Timestamp iteratorGetTimestampValue(long pointer, int index); - static native java.sql.Date rowGetDateValue(long pointer, int index); + static native java.math.BigDecimal iteratorGetDecimalValue(long pointer, int index); - static native String rowGetIntervalValue(long pointer, int index); + static native java.sql.Time iteratorGetTimeValue(long pointer, int index); - static native String rowGetJsonbValue(long pointer, int index); + static native java.sql.Date iteratorGetDateValue(long pointer, int index); - static native byte[] rowGetByteaValue(long pointer, int index); + static native String iteratorGetIntervalValue(long pointer, int index); - // TODO: object or object array? - static native Object rowGetArrayValue(long pointer, int index, Class clazz); - - // Since the underlying rust does not have garbage collection, we will have to manually call - // close on the row to release the row instance pointed by the pointer. - static native void rowClose(long pointer); - - // stream chunk iterator method - static native long streamChunkIteratorNew(byte[] streamChunkPayload); + static native String iteratorGetJsonbValue(long pointer, int index); - static native long streamChunkIteratorNext(long pointer); + static native byte[] iteratorGetByteaValue(long pointer, int index); - static native void streamChunkIteratorClose(long pointer); - - static native long streamChunkIteratorFromPretty(String str); + // TODO: object or object array? + static native Object iteratorGetArrayValue(long pointer, int index, Class clazz); public static native boolean sendCdcSourceMsgToChannel(long channelPtr, byte[] msg); } diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/HummockIterator.java b/java/java-binding/src/main/java/com/risingwave/java/binding/HummockIterator.java index ced034fd649d9..cf88068ddf615 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/HummockIterator.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/HummockIterator.java @@ -21,13 +21,13 @@ public class HummockIterator implements AutoCloseable { private boolean isClosed; public HummockIterator(ReadPlan readPlan) { - this.pointer = Binding.hummockIteratorNew(readPlan.toByteArray()); + this.pointer = Binding.iteratorNewHummock(readPlan.toByteArray()); this.isClosed = false; } public KeyedRow next() { - long pointer = Binding.hummockIteratorNext(this.pointer); - if (pointer == 0) { + boolean hasNext = Binding.iteratorNext(this.pointer); + if (!hasNext) { return null; } return new KeyedRow(pointer); @@ -37,7 +37,7 @@ public KeyedRow next() { public void close() { if (!isClosed) { isClosed = true; - Binding.hummockIteratorClose(pointer); + Binding.iteratorClose(pointer); } } } diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/KeyedRow.java b/java/java-binding/src/main/java/com/risingwave/java/binding/KeyedRow.java index 6bbfdaafebabc..8f1e0b0117ac4 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/KeyedRow.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/KeyedRow.java @@ -20,6 +20,6 @@ public KeyedRow(long pointer) { } public byte[] getKey() { - return Binding.rowGetKey(pointer); + return Binding.iteratorGetKey(pointer); } } diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/StreamChunkIterator.java b/java/java-binding/src/main/java/com/risingwave/java/binding/StreamChunkIterator.java index 89693befff700..5b300872bed51 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/StreamChunkIterator.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/StreamChunkIterator.java @@ -19,7 +19,7 @@ public class StreamChunkIterator implements AutoCloseable { private boolean isClosed; public StreamChunkIterator(byte[] streamChunkPayload) { - this.pointer = Binding.streamChunkIteratorNew(streamChunkPayload); + this.pointer = Binding.iteratorNewFromStreamChunkPayload(streamChunkPayload); this.isClosed = false; } @@ -30,13 +30,13 @@ public StreamChunkIterator(byte[] streamChunkPayload) { * 40" */ public StreamChunkIterator(String str) { - this.pointer = Binding.streamChunkIteratorFromPretty(str); + this.pointer = Binding.iteratorNewFromStreamChunkPretty(str); this.isClosed = false; } public StreamChunkRow next() { - long pointer = Binding.streamChunkIteratorNext(this.pointer); - if (pointer == 0) { + boolean hasNext = Binding.iteratorNext(this.pointer); + if (!hasNext) { return null; } return new StreamChunkRow(pointer); @@ -46,7 +46,7 @@ public StreamChunkRow next() { public void close() { if (!isClosed) { isClosed = true; - Binding.streamChunkIteratorClose(pointer); + Binding.iteratorClose(pointer); } } } diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/StreamChunkRow.java b/java/java-binding/src/main/java/com/risingwave/java/binding/StreamChunkRow.java index 401d3d98f766d..2825d62a0b0ca 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/StreamChunkRow.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/StreamChunkRow.java @@ -22,6 +22,6 @@ public StreamChunkRow(long pointer) { } public Data.Op getOp() { - return Data.Op.forNumber(Binding.rowGetOp(pointer)); + return Data.Op.forNumber(Binding.iteratorGetOp(pointer)); } } diff --git a/src/java_binding/make-java-binding.toml b/src/java_binding/make-java-binding.toml index 957cec9c762f5..9390785754a4e 100644 --- a/src/java_binding/make-java-binding.toml +++ b/src/java_binding/make-java-binding.toml @@ -113,6 +113,6 @@ mvn install --pl java-binding-benchmark --am -DskipTests=true -Dmaven.javadoc.sk mvn dependency:copy-dependencies --pl java-binding-benchmark -java -cp "java-binding-benchmark/target/dependency/*:java-binding-benchmark/target/java-binding-benchmark-1.0-SNAPSHOT.jar" \ +java -cp "java-binding-benchmark/target/dependency/*:java-binding-benchmark/target/java-binding-benchmark-0.1.0-SNAPSHOT.jar" \ com.risingwave.java.binding.BenchmarkRunner ''' diff --git a/src/jni_core/src/hummock_iterator.rs b/src/jni_core/src/hummock_iterator.rs index 7395a0f82273f..5917d08582998 100644 --- a/src/jni_core/src/hummock_iterator.rs +++ b/src/jni_core/src/hummock_iterator.rs @@ -48,22 +48,6 @@ fn select_all_vnode_stream( pub struct HummockJavaBindingIterator { row_serde: EitherSerde, stream: SelectAllIterStream, - pub class_cache: Arc, -} - -pub struct KeyedRow { - key: Bytes, - row: OwnedRow, -} - -impl KeyedRow { - pub fn key(&self) -> &[u8] { - self.key.as_ref() - } - - pub fn row(&self) -> &OwnedRow { - &self.row - } } impl HummockJavaBindingIterator { @@ -136,24 +120,20 @@ impl HummockJavaBindingIterator { .into() }; - Ok(Self { - row_serde, - stream, - class_cache: Default::default(), - }) + Ok(Self { row_serde, stream }) } - pub async fn next(&mut self) -> StorageResult> { + pub async fn next(&mut self) -> StorageResult> { let item = self.stream.try_next().await?; Ok(match item { - Some((key, value)) => Some(KeyedRow { - key: key.user_key.table_key.0, - row: OwnedRow::new( + Some((key, value)) => Some(( + key.user_key.table_key.0, + OwnedRow::new( self.row_serde .deserialize(&value) .map_err(StorageError::DeserializeRow)?, ), - }), + )), None => None, }) } diff --git a/src/jni_core/src/lib.rs b/src/jni_core/src/lib.rs index 86ac65b8b4f0c..2113553a8e06a 100644 --- a/src/jni_core/src/lib.rs +++ b/src/jni_core/src/lib.rs @@ -27,9 +27,9 @@ use std::backtrace::Backtrace; use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; use std::slice::from_raw_parts; -use std::sync::{Arc, LazyLock, OnceLock}; +use std::sync::{LazyLock, OnceLock}; -use hummock_iterator::{HummockJavaBindingIterator, KeyedRow}; +use bytes::Bytes; use jni::objects::{ AutoElements, GlobalRef, JByteArray, JClass, JMethodID, JObject, JStaticMethodID, JString, JValue, JValueGen, JValueOwned, ReleaseMode, @@ -47,13 +47,15 @@ use risingwave_common::test_prelude::StreamChunkTestExt; use risingwave_common::types::ScalarRefImpl; use risingwave_common::util::panic::rw_catch_unwind; use risingwave_pb::connector_service::GetEventStreamResponse; +use risingwave_pb::data::Op; use risingwave_storage::error::StorageError; use thiserror::Error; use tokio::runtime::Runtime; use tokio::sync::mpsc::Sender; +use crate::hummock_iterator::HummockJavaBindingIterator; pub use crate::jvm_runtime::register_native_method_for_jvm; -use crate::stream_chunk_iterator::{StreamChunkIterator, StreamChunkRow}; +use crate::stream_chunk_iterator::{into_iter, StreamChunkRowIterator}; pub type GetEventStreamJniSender = Sender; static RUNTIME: LazyLock = LazyLock::new(|| tokio::runtime::Runtime::new().unwrap()); @@ -144,15 +146,6 @@ impl From for Pointer<'static, T> { } } -impl Pointer<'static, T> { - fn null() -> Self { - Pointer { - pointer: 0, - _phantom: PhantomData, - } - } -} - impl<'a, T> Pointer<'a, T> { fn as_ref(&self) -> &'a T { debug_assert!(self.pointer != 0); @@ -229,12 +222,8 @@ where } } -pub enum JavaBindingRowInner { - Keyed(KeyedRow), - StreamChunk(StreamChunkRow), -} #[derive(Default)] -pub struct JavaClassMethodCache { +struct JavaClassMethodCache { big_decimal_ctor: OnceLock<(GlobalRef, JMethodID)>, timestamp_ctor: OnceLock<(GlobalRef, JMethodID)>, @@ -242,67 +231,65 @@ pub struct JavaClassMethodCache { time_ctor: OnceLock<(GlobalRef, JStaticMethodID)>, } -pub struct JavaBindingRow { - inner: JavaBindingRowInner, - class_cache: Arc, +enum JavaBindingIteratorInner { + Hummock(HummockJavaBindingIterator), + StreamChunk(StreamChunkRowIterator), } -impl JavaBindingRow { - fn with_stream_chunk( - underlying: StreamChunkRow, - class_cache: Arc, - ) -> Self { - Self { - inner: JavaBindingRowInner::StreamChunk(underlying), - class_cache, - } - } +enum RowExtra { + Op(Op), + Key(Bytes), +} - fn with_keyed(underlying: KeyedRow, class_cache: Arc) -> Self { - Self { - inner: JavaBindingRowInner::Keyed(underlying), - class_cache, +impl RowExtra { + fn as_op(&self) -> Op { + match self { + RowExtra::Op(op) => *op, + RowExtra::Key(_) => unreachable!("should be op"), } } - fn as_keyed(&self) -> &KeyedRow { - match &self.inner { - JavaBindingRowInner::Keyed(r) => r, - _ => unreachable!("can only call as_keyed for KeyedRow"), + fn as_key(&self) -> &Bytes { + match self { + RowExtra::Key(key) => key, + RowExtra::Op(_) => unreachable!("should be key"), } } +} - fn as_stream_chunk(&self) -> &StreamChunkRow { - match &self.inner { - JavaBindingRowInner::StreamChunk(r) => r, - _ => unreachable!("can only call as_stream_chunk for StreamChunkRow"), - } - } +struct RowCursor { + row: OwnedRow, + extra: RowExtra, +} + +struct JavaBindingIterator { + inner: JavaBindingIteratorInner, + cursor: Option, + class_cache: JavaClassMethodCache, } -impl Deref for JavaBindingRow { +impl Deref for JavaBindingIterator { type Target = OwnedRow; fn deref(&self) -> &Self::Target { - match &self.inner { - JavaBindingRowInner::Keyed(r) => r.row(), - JavaBindingRowInner::StreamChunk(r) => r.row(), - } + &self + .cursor + .as_ref() + .expect("should exist when call row methods") + .row } } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_vnodeCount( - _env: EnvParam<'_>, -) -> jint { +extern "system" fn Java_com_risingwave_java_binding_Binding_vnodeCount(_env: EnvParam<'_>) -> jint { VirtualNode::COUNT as jint } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorNew<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorNewHummock<'a>( env: EnvParam<'a>, read_plan: JByteArray<'a>, -) -> Pointer<'static, HummockJavaBindingIterator> { +) -> Pointer<'static, JavaBindingIterator> { execute_and_catch(env, move |env| { #[cfg(madsim)] { @@ -313,16 +300,21 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorN { let read_plan = Message::decode(to_guarded_slice(&read_plan, env)?.deref())?; let iter = RUNTIME.block_on(HummockJavaBindingIterator::new(read_plan))?; + let iter = JavaBindingIterator { + inner: JavaBindingIteratorInner::Hummock(iter), + cursor: None, + class_cache: Default::default(), + }; Ok(iter.into()) } }) } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorNext<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorNext<'a>( env: EnvParam<'a>, - mut pointer: Pointer<'a, HummockJavaBindingIterator>, -) -> Pointer<'static, JavaBindingRow> { + mut pointer: Pointer<'a, JavaBindingIterator>, +) -> jboolean { execute_and_catch(env, move |_env| { #[cfg(madsim)] { @@ -332,101 +324,130 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorN #[cfg(not(madsim))] { let iter = pointer.as_mut(); - match RUNTIME.block_on(iter.next())? { - None => Ok(Pointer::null()), - Some(row) => Ok(JavaBindingRow::with_keyed(row, iter.class_cache.clone()).into()), + match &mut iter.inner { + JavaBindingIteratorInner::Hummock(ref mut hummock_iter) => { + match RUNTIME.block_on(hummock_iter.next())? { + None => { + iter.cursor = None; + Ok(JNI_FALSE) + } + Some((key, row)) => { + iter.cursor = Some(RowCursor { + row, + extra: RowExtra::Key(key), + }); + Ok(JNI_TRUE) + } + } + } + JavaBindingIteratorInner::StreamChunk(ref mut stream_chunk_iter) => { + match stream_chunk_iter.next() { + None => { + iter.cursor = None; + Ok(JNI_FALSE) + } + Some((op, row)) => { + iter.cursor = Some(RowCursor { + row, + extra: RowExtra::Op(op), + }); + Ok(JNI_TRUE) + } + } + } } } }) } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorClose( - _env: EnvParam<'_>, - pointer: Pointer<'_, HummockJavaBindingIterator>, +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorClose<'a>( + _env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingIterator>, ) { - pointer.drop(); + pointer.drop() } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorNew<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorNewFromStreamChunkPayload< + 'a, +>( env: EnvParam<'a>, stream_chunk_payload: JByteArray<'a>, -) -> Pointer<'static, StreamChunkIterator> { +) -> Pointer<'static, JavaBindingIterator> { execute_and_catch(env, move |env| { let prost_stream_chumk = Message::decode(to_guarded_slice(&stream_chunk_payload, env)?.deref())?; - let iter = StreamChunkIterator::new(StreamChunk::from_protobuf(&prost_stream_chumk)?); + let iter = into_iter(StreamChunk::from_protobuf(&prost_stream_chumk)?); + let iter = JavaBindingIterator { + inner: JavaBindingIteratorInner::StreamChunk(iter), + cursor: None, + class_cache: Default::default(), + }; Ok(iter.into()) }) } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorFromPretty< - 'a, ->( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorNewFromStreamChunkPretty<'a>( env: EnvParam<'a>, str: JString<'a>, -) -> Pointer<'static, StreamChunkIterator> { +) -> Pointer<'static, JavaBindingIterator> { execute_and_catch(env, move |env: &mut EnvParam<'_>| { - let iter = StreamChunkIterator::new(StreamChunk::from_pretty( + let iter = into_iter(StreamChunk::from_pretty( env.get_string(&str) .expect("cannot get java string") .to_str() .unwrap(), )); + let iter = JavaBindingIterator { + inner: JavaBindingIteratorInner::StreamChunk(iter), + cursor: None, + class_cache: Default::default(), + }; Ok(iter.into()) }) } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorNext<'a>( - env: EnvParam<'a>, - mut pointer: Pointer<'a, StreamChunkIterator>, -) -> Pointer<'static, JavaBindingRow> { - execute_and_catch(env, move |_env| { - let iter = pointer.as_mut(); - match iter.next() { - None => Ok(Pointer::null()), - Some(row) => { - Ok(JavaBindingRow::with_stream_chunk(row, iter.class_cache.clone()).into()) - } - } - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorClose( - _env: EnvParam<'_>, - pointer: Pointer<'_, StreamChunkIterator>, -) { - pointer.drop(); -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetKey<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetKey<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, ) -> JByteArray<'a> { execute_and_catch(env, move |env: &mut EnvParam<'_>| { - Ok(env.byte_array_from_slice(pointer.as_ref().as_keyed().key())?) + Ok(env.byte_array_from_slice( + pointer + .as_ref() + .cursor + .as_ref() + .expect("should exists when call get key") + .extra + .as_key() + .as_ref(), + )?) }) } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetOp<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetOp<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, ) -> jint { execute_and_catch(env, move |_env| { - Ok(pointer.as_ref().as_stream_chunk().op() as jint) + Ok(pointer + .as_ref() + .cursor + .as_ref() + .expect("should exist when call get op") + .extra + .as_op() as jint) }) } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowIsNull<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorIsNull<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> jboolean { execute_and_catch(env, move |_env| { @@ -435,9 +456,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowIsNull<'a>( } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt16Value<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt16Value<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> jshort { execute_and_catch(env, move |_env| { @@ -450,9 +471,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt16Value } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt32Value<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt32Value<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> jint { execute_and_catch(env, move |_env| { @@ -465,9 +486,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt32Value } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt64Value<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetInt64Value<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> jlong { execute_and_catch(env, move |_env| { @@ -480,9 +501,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt64Value } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetFloatValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetFloatValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> jfloat { execute_and_catch(env, move |_env| { @@ -496,9 +517,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetFloatValue } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDoubleValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDoubleValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> jdouble { execute_and_catch(env, move |_env| { @@ -512,9 +533,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDoubleValu } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetBooleanValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetBooleanValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> jboolean { execute_and_catch(env, move |_env| { @@ -523,9 +544,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetBooleanVal } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetStringValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetStringValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> JString<'a> { execute_and_catch(env, move |env: &mut EnvParam<'a>| { @@ -534,9 +555,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetStringValu } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetIntervalValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetIntervalValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> JString<'a> { execute_and_catch(env, move |env: &mut EnvParam<'a>| { @@ -551,9 +572,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetIntervalVa } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetJsonbValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetJsonbValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> JString<'a> { execute_and_catch(env, move |env: &mut EnvParam<'_>| { @@ -568,9 +589,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetJsonbValue } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetTimestampValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetTimestampValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> JObject<'a> { execute_and_catch(env, move |env: &mut EnvParam<'_>| { @@ -600,9 +621,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetTimestampV } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDecimalValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDecimalValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> JObject<'a> { execute_and_catch(env, move |env: &mut EnvParam<'_>| { @@ -637,9 +658,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDecimalVal } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDateValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetDateValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> JObject<'a> { execute_and_catch(env, move |env: &mut EnvParam<'_>| { @@ -683,9 +704,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDateValue< } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetTimeValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetTimeValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> JObject<'a> { execute_and_catch(env, move |env: &mut EnvParam<'_>| { @@ -729,9 +750,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetTimeValue< } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetByteaValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetByteaValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, ) -> JByteArray<'a> { execute_and_catch(env, move |env: &mut EnvParam<'_>| { @@ -745,9 +766,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetByteaValue } #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetArrayValue<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetArrayValue<'a>( env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, + pointer: Pointer<'a, JavaBindingIterator>, idx: jint, class: JClass<'a>, ) -> JObject<'a> { @@ -835,20 +856,12 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetArrayValue }) } -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>( - _env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, -) { - pointer.drop() -} - /// Send messages to the channel received by `CdcSplitReader`. /// If msg is null, just check whether the channel is closed. /// Return true if sending is successful, otherwise, return false so that caller can stop /// gracefully. #[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel<'a>( +extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel<'a>( env: EnvParam<'a>, channel: Pointer<'a, GetEventStreamJniSender>, msg: JByteArray<'a>, @@ -886,7 +899,7 @@ mod tests { use risingwave_expr::vector_op::cast::literal_parsing; /// make sure that the [`ScalarRefImpl::Int64`] received by - /// [`Java_com_risingwave_java_binding_Binding_rowGetTimestampValue`] + /// [`Java_com_risingwave_java_binding_Binding_iteratorGetTimestampValue`] /// is of type [`DataType::Timestamptz`] stored in microseconds #[test] fn test_timestamptz_to_i64() { diff --git a/src/jni_core/src/macros.rs b/src/jni_core/src/macros.rs index bdb5c60ec3f82..9dd83b945e8ad 100644 --- a/src/jni_core/src/macros.rs +++ b/src/jni_core/src/macros.rs @@ -106,65 +106,52 @@ macro_rules! for_all_plain_native_methods { // hummock iterator method // Return a pointer to the iterator - static native long hummockIteratorNew(byte[] readPlan); + static native long iteratorNewHummock(byte[] readPlan); - // return a pointer to the next row - static native long hummockIteratorNext(long pointer); + static native boolean iteratorNext(long pointer); - // Since the underlying rust does not have garbage collection, we will have to manually call - // close on the iterator to release the iterator instance pointed by the pointer. - static native void hummockIteratorClose(long pointer); + static native void iteratorClose(long pointer); - // row method - static native byte[] rowGetKey(long pointer); + static native long iteratorNewFromStreamChunkPayload(byte[] streamChunkPayload); - static native int rowGetOp(long pointer); + static native long iteratorNewFromStreamChunkPretty(String str); - static native boolean rowIsNull(long pointer, int index); + static native byte[] iteratorGetKey(long pointer); - static native short rowGetInt16Value(long pointer, int index); + static native int iteratorGetOp(long pointer); - static native int rowGetInt32Value(long pointer, int index); + static native boolean iteratorIsNull(long pointer, int index); - static native long rowGetInt64Value(long pointer, int index); + static native short iteratorGetInt16Value(long pointer, int index); - static native float rowGetFloatValue(long pointer, int index); + static native int iteratorGetInt32Value(long pointer, int index); - static native double rowGetDoubleValue(long pointer, int index); + static native long iteratorGetInt64Value(long pointer, int index); - static native boolean rowGetBooleanValue(long pointer, int index); + static native float iteratorGetFloatValue(long pointer, int index); - static native String rowGetStringValue(long pointer, int index); + static native double iteratorGetDoubleValue(long pointer, int index); - static native java.sql.Timestamp rowGetTimestampValue(long pointer, int index); + static native boolean iteratorGetBooleanValue(long pointer, int index); - static native java.math.BigDecimal rowGetDecimalValue(long pointer, int index); + static native String iteratorGetStringValue(long pointer, int index); - static native java.sql.Time rowGetTimeValue(long pointer, int index); + static native java.sql.Timestamp iteratorGetTimestampValue(long pointer, int index); - static native java.sql.Date rowGetDateValue(long pointer, int index); + static native java.math.BigDecimal iteratorGetDecimalValue(long pointer, int index); - static native String rowGetIntervalValue(long pointer, int index); + static native java.sql.Time iteratorGetTimeValue(long pointer, int index); - static native String rowGetJsonbValue(long pointer, int index); + static native java.sql.Date iteratorGetDateValue(long pointer, int index); - static native byte[] rowGetByteaValue(long pointer, int index); + static native String iteratorGetIntervalValue(long pointer, int index); - // TODO: object or object array? - static native Object rowGetArrayValue(long pointer, int index, Class clazz); - - // Since the underlying rust does not have garbage collection, we will have to manually call - // close on the row to release the row instance pointed by the pointer. - static native void rowClose(long pointer); - - // stream chunk iterator method - static native long streamChunkIteratorNew(byte[] streamChunkPayload); + static native String iteratorGetJsonbValue(long pointer, int index); - static native long streamChunkIteratorNext(long pointer); + static native byte[] iteratorGetByteaValue(long pointer, int index); - static native void streamChunkIteratorClose(long pointer); - - static native long streamChunkIteratorFromPretty(String str); + // TODO: object or object array? + static native Object iteratorGetArrayValue(long pointer, int index, Class clazz); public static native boolean sendCdcSourceMsgToChannel(long channelPtr, byte[] msg); } diff --git a/src/jni_core/src/stream_chunk_iterator.rs b/src/jni_core/src/stream_chunk_iterator.rs index d62117a0aa108..49d096d30339e 100644 --- a/src/jni_core/src/stream_chunk_iterator.rs +++ b/src/jni_core/src/stream_chunk_iterator.rs @@ -12,51 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use itertools::Itertools; use risingwave_common::array::StreamChunk; use risingwave_common::row::{OwnedRow, Row}; use risingwave_pb::data::Op; -pub struct StreamChunkRow { - op: Op, - row: OwnedRow, -} - -impl StreamChunkRow { - pub fn op(&self) -> Op { - self.op - } - - pub fn row(&self) -> &OwnedRow { - &self.row - } -} - -type StreamChunkRowIterator = impl Iterator + 'static; - -pub struct StreamChunkIterator { - iter: StreamChunkRowIterator, - pub class_cache: Arc, -} - -impl StreamChunkIterator { - pub(crate) fn new(stream_chunk: StreamChunk) -> Self { - Self { - iter: stream_chunk - .rows() - .map(|(op, row_ref)| StreamChunkRow { - op: op.to_protobuf(), - row: row_ref.to_owned_row(), - }) - .collect_vec() - .into_iter(), - class_cache: Default::default(), - } - } +pub(crate) type StreamChunkRowIterator = impl Iterator + 'static; - pub(crate) fn next(&mut self) -> Option { - self.iter.next() - } +pub(crate) fn into_iter(stream_chunk: StreamChunk) -> StreamChunkRowIterator { + stream_chunk + .rows() + .map(|(op, row_ref)| (op.to_protobuf(), row_ref.to_owned_row())) + .collect_vec() + .into_iter() } From 26e6306d32ac3fcf766aff4eaf6e63ca1c134b57 Mon Sep 17 00:00:00 2001 From: William Wen Date: Sun, 24 Sep 2023 19:27:23 +0800 Subject: [PATCH 5/8] chore(java-binding): refine stream chunk benchmark --- .../java/binding/ArrayListBenchmark.java | 4 +-- .../java/binding/StreamchunkBenchmark.java | 36 ++++++++++++------- src/java_binding/make-java-binding.toml | 9 ++--- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/ArrayListBenchmark.java b/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/ArrayListBenchmark.java index 6540033371d34..f0edca505ca4a 100644 --- a/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/ArrayListBenchmark.java +++ b/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/ArrayListBenchmark.java @@ -20,8 +20,8 @@ import java.util.concurrent.TimeUnit; import org.openjdk.jmh.annotations.*; -@Warmup(iterations = 10, time = 1, timeUnit = TimeUnit.MILLISECONDS) -@Measurement(iterations = 20, time = 1, timeUnit = TimeUnit.MILLISECONDS) +@Warmup(iterations = 2, time = 1, timeUnit = TimeUnit.MILLISECONDS, batchSize = 10) +@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.MILLISECONDS, batchSize = 10) @Fork(value = 1) @BenchmarkMode(org.openjdk.jmh.annotations.Mode.AverageTime) @OutputTimeUnit(TimeUnit.MICROSECONDS) diff --git a/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/StreamchunkBenchmark.java b/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/StreamchunkBenchmark.java index 8741044f7b34e..628d1405c8d81 100644 --- a/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/StreamchunkBenchmark.java +++ b/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/StreamchunkBenchmark.java @@ -16,31 +16,37 @@ package com.risingwave.java.binding; +import java.util.ArrayList; +import java.util.Iterator; import java.util.concurrent.TimeUnit; import org.openjdk.jmh.annotations.*; -@Warmup(iterations = 10, time = 1, timeUnit = TimeUnit.MILLISECONDS) -@Measurement(iterations = 20, time = 1, timeUnit = TimeUnit.MILLISECONDS) +@Warmup(iterations = 2, time = 1, timeUnit = TimeUnit.MILLISECONDS, batchSize = 10) +@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.MILLISECONDS, batchSize = 10) @Fork(value = 1) @BenchmarkMode(org.openjdk.jmh.annotations.Mode.AverageTime) -@OutputTimeUnit(TimeUnit.MICROSECONDS) +@OutputTimeUnit(TimeUnit.MILLISECONDS) @State(org.openjdk.jmh.annotations.Scope.Benchmark) public class StreamchunkBenchmark { @Param({"100", "1000", "10000"}) - static int loopTime; + int loopTime; - String str; - StreamChunkIterator iter; + Iterator iterOfIter; - @Setup(Level.Invocation) + @Setup(Level.Iteration) public void setup() { - str = "i i I f F B i"; - for (int i = 0; i < loopTime; i++) { - String b = i % 2 == 0 ? "f" : "t"; - String n = i % 2 == 0 ? "." : "1"; - str += String.format("\n + %d %d %d %d.0 %d.0 %s %s", i, i, i, i, i, b, n); + var iterList = new ArrayList(); + for (int iterI = 0; iterI < 10; iterI++) { + String str = "i i I f F B i"; + for (int i = 0; i < loopTime; i++) { + String b = i % 2 == 0 ? "f" : "t"; + String n = i % 2 == 0 ? "." : "1"; + str += String.format("\n + %d %d %d %d.0 %d.0 %s %s", i, i, i, i, i, b, n); + } + var iter = new StreamChunkIterator(str); + iterList.add(iter); } - iter = new StreamChunkIterator(str); + iterOfIter = iterList.iterator(); } public void getValue(StreamChunkRow row) { @@ -55,6 +61,10 @@ public void getValue(StreamChunkRow row) { @Benchmark public void streamchunkTest() { + if (!iterOfIter.hasNext()) { + throw new RuntimeException("too few prepared iter"); + } + var iter = iterOfIter.next(); int count = 0; while (true) { try (StreamChunkRow row = iter.next()) { diff --git a/src/java_binding/make-java-binding.toml b/src/java_binding/make-java-binding.toml index 3be65ec2158a6..de4f405487eab 100644 --- a/src/java_binding/make-java-binding.toml +++ b/src/java_binding/make-java-binding.toml @@ -15,7 +15,7 @@ script = ''' #!/usr/bin/env bash set -ex cd java -mvn install --no-transfer-progress --pl java-binding-integration-test --am -DskipTests=true +mvn install --no-transfer-progress --pl java-binding-integration-test --am -DskipTests=true -Dmaven.javadoc.skip mvn dependency:copy-dependencies --no-transfer-progress --pl java-binding-integration-test ''' @@ -98,9 +98,6 @@ cd ${RISINGWAVE_ROOT}/java [tasks.run-java-binding-stream-chunk-benchmark] category = "RiseDev - Java Binding" description = "Run the java binding stream chunk benchmark" -dependencies = [ - "build-java-binding", -] script = ''' #!/usr/bin/env bash set -ex @@ -109,10 +106,10 @@ RISINGWAVE_ROOT=$(git rev-parse --show-toplevel) cd ${RISINGWAVE_ROOT}/java -mvn install --pl java-binding-benchmark --am -DskipTests=true +mvn install --pl java-binding-benchmark --am -DskipTests=true -Dmaven.javadoc.skip mvn dependency:copy-dependencies --pl java-binding-benchmark -java -cp "java-binding-benchmark/target/dependency/*:java-binding-benchmark/target/java-binding-benchmark-1.0-SNAPSHOT.jar" \ +java -cp "java-binding-benchmark/target/dependency/*:java-binding-benchmark/target/java-binding-benchmark-0.1.0-SNAPSHOT.jar" \ com.risingwave.java.binding.BenchmarkRunner ''' From 7cb86cc9c778558741a987bdca15a918827a5659 Mon Sep 17 00:00:00 2001 From: William Wen Date: Sun, 24 Sep 2023 19:38:20 +0800 Subject: [PATCH 6/8] refine array list bench --- .../risingwave/java/binding/ArrayListBenchmark.java | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/ArrayListBenchmark.java b/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/ArrayListBenchmark.java index f0edca505ca4a..c05cf23d2c582 100644 --- a/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/ArrayListBenchmark.java +++ b/java/java-binding-benchmark/src/main/java/com/risingwave/java/binding/ArrayListBenchmark.java @@ -30,8 +30,6 @@ public class ArrayListBenchmark { @Param({"100", "1000", "10000"}) static int loopTime; - ArrayList> data = new ArrayList<>(); - public ArrayList getRow(int index) { short v1 = (short) index; int v2 = (int) index; @@ -61,17 +59,10 @@ public void getValue(ArrayList rowData) { Integer mayNull = (Integer) rowData.get(6); } - @Setup - public void setup() { - for (int i = 0; i < loopTime; i++) { - data.add(getRow(i)); - } - } - @Benchmark public void arrayListTest() { for (int i = 0; i < loopTime; i++) { - getValue(data.get(i)); + getValue(getRow(i)); } } } From 83e2ab8953d6ed952c4d5ee40286bacbd486f3e8 Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 25 Sep 2023 17:14:48 +0800 Subject: [PATCH 7/8] use cfg_or_panic --- Cargo.lock | 1 + src/jni_core/Cargo.toml | 1 + src/jni_core/src/lib.rs | 33 ++++++++++----------------------- 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0f70660226977..4c1203236a92e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7207,6 +7207,7 @@ name = "risingwave_jni_core" version = "0.1.0" dependencies = [ "bytes", + "cfg-or-panic", "futures", "itertools 0.11.0", "jni", diff --git a/src/jni_core/Cargo.toml b/src/jni_core/Cargo.toml index 40195108b1039..bc70ef8a73114 100644 --- a/src/jni_core/Cargo.toml +++ b/src/jni_core/Cargo.toml @@ -11,6 +11,7 @@ normal = ["workspace-hack"] [dependencies] bytes = "1" +cfg-or-panic = "0.2" futures = { version = "0.3", default-features = false, features = ["alloc"] } itertools = "0.11" jni = "0.21.1" diff --git a/src/jni_core/src/lib.rs b/src/jni_core/src/lib.rs index 86ac65b8b4f0c..9ebce58b426a7 100644 --- a/src/jni_core/src/lib.rs +++ b/src/jni_core/src/lib.rs @@ -29,6 +29,7 @@ use std::ops::{Deref, DerefMut}; use std::slice::from_raw_parts; use std::sync::{Arc, LazyLock, OnceLock}; +use cfg_or_panic::cfg_or_panic; use hummock_iterator::{HummockJavaBindingIterator, KeyedRow}; use jni::objects::{ AutoElements, GlobalRef, JByteArray, JClass, JMethodID, JObject, JStaticMethodID, JString, @@ -298,44 +299,30 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_vnodeCount( VirtualNode::COUNT as jint } +#[cfg_or_panic(not(madsim))] #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorNew<'a>( env: EnvParam<'a>, read_plan: JByteArray<'a>, ) -> Pointer<'static, HummockJavaBindingIterator> { execute_and_catch(env, move |env| { - #[cfg(madsim)] - { - unreachable!() - } - - #[cfg(not(madsim))] - { - let read_plan = Message::decode(to_guarded_slice(&read_plan, env)?.deref())?; - let iter = RUNTIME.block_on(HummockJavaBindingIterator::new(read_plan))?; - Ok(iter.into()) - } + let read_plan = Message::decode(to_guarded_slice(&read_plan, env)?.deref())?; + let iter = RUNTIME.block_on(HummockJavaBindingIterator::new(read_plan))?; + Ok(iter.into()) }) } +#[cfg_or_panic(not(madsim))] #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorNext<'a>( env: EnvParam<'a>, mut pointer: Pointer<'a, HummockJavaBindingIterator>, ) -> Pointer<'static, JavaBindingRow> { execute_and_catch(env, move |_env| { - #[cfg(madsim)] - { - unreachable!() - } - - #[cfg(not(madsim))] - { - let iter = pointer.as_mut(); - match RUNTIME.block_on(iter.next())? { - None => Ok(Pointer::null()), - Some(row) => Ok(JavaBindingRow::with_keyed(row, iter.class_cache.clone()).into()), - } + let iter = pointer.as_mut(); + match RUNTIME.block_on(iter.next())? { + None => Ok(Pointer::null()), + Some(row) => Ok(JavaBindingRow::with_keyed(row, iter.class_cache.clone()).into()), } }) } From 5dcaeb5bc8f6a09b71ddeb0a6062ff436da716b1 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 26 Sep 2023 15:37:13 +0800 Subject: [PATCH 8/8] fix compile and set utf8 encoding --- .../src/main/java/com/risingwave/connector/EsSink.java | 6 +++++- java/pom.xml | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/java/connector-node/risingwave-sink-es-7/src/main/java/com/risingwave/connector/EsSink.java b/java/connector-node/risingwave-sink-es-7/src/main/java/com/risingwave/connector/EsSink.java index 11ec5b4451c62..f9c266f0af117 100644 --- a/java/connector-node/risingwave-sink-es-7/src/main/java/com/risingwave/connector/EsSink.java +++ b/java/connector-node/risingwave-sink-es-7/src/main/java/com/risingwave/connector/EsSink.java @@ -277,7 +277,11 @@ private void writeRow(SinkRow row) throws JsonMappingException, JsonProcessingEx public void write(Iterator rows) { while (rows.hasNext()) { SinkRow row = rows.next(); - writeRow(row); + try { + writeRow(row); + } catch (Exception ex) { + throw new RuntimeException(ex); + } } } diff --git a/java/pom.xml b/java/pom.xml index 28d7a688a5aef..e72e831b798e5 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -60,6 +60,8 @@ 11 11 1.0.0 + UTF-8 + UTF-8 3.21.1 1.53.0 2.10