From 2904555899385af135796cc98072d1e31e359559 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Tue, 22 Aug 2023 14:41:28 +0800 Subject: [PATCH 01/23] test jni rs --- Cargo.lock | 23 ++++++ .../connector/ConnectorService.java | 3 + src/compute/Cargo.toml | 1 + src/compute/src/lib.rs | 77 +++++++++++++++++++ 4 files changed, 104 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 981a31a412f71..d4001e914048c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3639,6 +3639,16 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +[[package]] +name = "java-locator" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90003f2fd9c52f212c21d8520f1128da0080bad6fff16b68fe6e7f2f0c3780c2" +dependencies = [ + "glob", + "lazy_static", +] + [[package]] name = "jni" version = "0.21.1" @@ -3648,7 +3658,9 @@ dependencies = [ "cesu8", "cfg-if", "combine", + "java-locator", "jni-sys", + "libloading", "log", "thiserror", "walkdir", @@ -3801,6 +3813,16 @@ dependencies = [ "rle-decode-fast", ] +[[package]] +name = "libloading" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" +dependencies = [ + "cfg-if", + "winapi", +] + [[package]] name = "libm" version = "0.2.7" @@ -6420,6 +6442,7 @@ dependencies = [ "futures-async-stream", "hyper", "itertools 0.11.0", + "jni", "madsim-tokio", "madsim-tonic", "maplit", diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/ConnectorService.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/ConnectorService.java index 810fd9d0f26f4..4c68ff45b88d0 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/ConnectorService.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/ConnectorService.java @@ -37,6 +37,9 @@ public static void main(String[] args) throws Exception { CommandLineParser parser = new DefaultParser(); CommandLine cmd = parser.parse(options, args); + java.lang.Thread.currentThread() + .setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader()); + // Quoted from the debezium document: // > Your application should always properly stop the engine to ensure graceful and complete // > shutdown and that each source record is sent to the application exactly one time. diff --git a/src/compute/Cargo.toml b/src/compute/Cargo.toml index 70aaf895e7b73..4cca8235d8e20 100644 --- a/src/compute/Cargo.toml +++ b/src/compute/Cargo.toml @@ -23,6 +23,7 @@ either = "1" futures = { version = "0.3", default-features = false, features = ["alloc"] } futures-async-stream = { workspace = true } hyper = "0.14" +jni = { version = "0.21.1", features = ["invocation"] } itertools = "0.11" maplit = "1.0.2" pprof = { version = "0.12", features = ["flamegraph"] } diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index 937e236564f48..da56b0fe93761 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -31,6 +31,7 @@ pub mod rpc; pub mod server; pub mod telemetry; +use std::fs; use clap::{Parser, ValueEnum}; use risingwave_common::config::{AsyncStackTraceOption, OverrideConfig}; use risingwave_common::util::resource_util::cpu::total_cpu_available; @@ -186,7 +187,11 @@ fn validate_opts(opts: &ComputeNodeOpts) { } use std::future::Future; +use std::path::Path; use std::pin::Pin; +use jni::{InitArgsBuilder, JavaVM, JNIVersion}; +use jni::objects::{JObject, JValue}; +use jni::sys::jint; use crate::server::compute_node_serve; @@ -218,12 +223,84 @@ pub fn start( let (join_handle_vec, _shutdown_send) = compute_node_serve(listen_addr, advertise_addr, opts, registry).await; + tokio::task::spawn_blocking(move || { + run_jvm(); + }); + + for join_handle in join_handle_vec { join_handle.await.unwrap(); } }) } +fn run_jvm() { + + let dir_path = "/Users/dylan/Desktop/workspace/risingwave/.risingwave/bin/connector-node/libs/"; + + let dir = Path::new(dir_path); + + if !dir.is_dir() { + println!("{} is not a directory", dir_path); + return; + } + + let mut class_vec = vec![]; + + if let Ok(entries) = fs::read_dir(dir) { + for entry in entries { + if let Ok(entry) = entry { + if let Some(name) = entry.path().file_name() { + println!("{:?}", name); + class_vec.push(String::from( dir_path.to_owned() + name.to_str().to_owned().unwrap())); + } + } + } + } else { + println!("failed to read directory {}", dir_path); + } + + // Build the VM properties + let jvm_args = InitArgsBuilder::new() + // Pass the JNI API version (default is 8) + .version(JNIVersion::V8) + // You can additionally pass any JVM options (standard, like a system property, + // or VM-specific). + // Here we enable some extra JNI checks useful during development + // .option("-Xcheck:jni") + .option("-ea") + .option(format!("-Djava.class.path={}", class_vec.join(":")) ) + .option("-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=9111") + .build() + .unwrap(); + + // Create a new VM + let jvm = match JavaVM::new(jvm_args) { + Err(err) => { + panic!("{:?}", err) + }, + Ok(jvm) => jvm, + }; + + // Attach the current thread to call into Java — see extra options in + // "Attaching Native Threads" section. + // + // This method returns the guard that will detach the current thread when dropped, + // also freeing any local references created in it + let mut env = jvm.attach_current_thread_as_daemon().unwrap(); + + // Call Java Math#abs(-10) + let x = JValue::from(-10); + let val: jint = env.call_static_method("java/lang/Math", "abs", "(I)I", &[x]).unwrap() + .i().unwrap(); + + assert_eq!(val, 10); + let string_class = env.find_class("java/lang/String").unwrap(); + let jarray = env.new_object_array(0, string_class, JObject::null()).unwrap(); + + let _ = env.call_static_method("com/risingwave/connector/ConnectorService", "main", "([Ljava/lang/String;)V", &[JValue::Object(&jarray)]).inspect_err(|e| eprintln!("{:?}", e)); +} + fn default_total_memory_bytes() -> usize { total_memory_available_bytes() } From 2831bad7c0d163fcedfe88c02cc4677377a4a3cf Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Wed, 23 Aug 2023 20:57:01 +0800 Subject: [PATCH 02/23] CHANNEL POOL --- Cargo.lock | 2 + java/com_risingwave_java_binding_Binding.h | 80 ++++++++++++ .../source/core/DbzCdcEngineRunner.java | 27 ++++ .../source/core/JniSourceHandler.java | 104 +++++++++++++++ .../source/core/SourceHandlerFactory.java | 20 +++ .../com/risingwave/java/binding/Binding.java | 2 + src/common/Cargo.toml | 1 + src/common/src/jvm_runtime.rs | 67 ++++++++++ src/common/src/lib.rs | 1 + src/compute/src/lib.rs | 80 ++++-------- src/connector/Cargo.toml | 1 + src/connector/src/source/cdc/source/reader.rs | 122 +++++++++++++++++- src/java_binding/src/lib.rs | 69 +++++++++- 13 files changed, 519 insertions(+), 57 deletions(-) create mode 100644 java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java create mode 100644 src/common/src/jvm_runtime.rs diff --git a/Cargo.lock b/Cargo.lock index d4001e914048c..e19e3ab2882e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6296,6 +6296,7 @@ dependencies = [ "hytra", "itertools 0.11.0", "itoa", + "jni", "libc", "lru 0.7.6", "mach2", @@ -6502,6 +6503,7 @@ dependencies = [ "google-cloud-pubsub", "icelake", "itertools 0.11.0", + "jni", "madsim-rdkafka", "madsim-tokio", "madsim-tonic", diff --git a/java/com_risingwave_java_binding_Binding.h b/java/com_risingwave_java_binding_Binding.h index bd03892223a6d..0f33228d6797b 100644 --- a/java/com_risingwave_java_binding_Binding.h +++ b/java/com_risingwave_java_binding_Binding.h @@ -119,6 +119,70 @@ JNIEXPORT jboolean JNICALL Java_com_risingwave_java_binding_Binding_rowGetBoolea JNIEXPORT jstring JNICALL Java_com_risingwave_java_binding_Binding_rowGetStringValue (JNIEnv *, jclass, jlong, jint); +/* + * Class: com_risingwave_java_binding_Binding + * Method: rowGetTimestampValue + * Signature: (JI)Ljava/sql/Timestamp; + */ +JNIEXPORT jobject JNICALL Java_com_risingwave_java_binding_Binding_rowGetTimestampValue + (JNIEnv *, jclass, jlong, jint); + +/* + * Class: com_risingwave_java_binding_Binding + * Method: rowGetDecimalValue + * Signature: (JI)Ljava/math/BigDecimal; + */ +JNIEXPORT jobject JNICALL Java_com_risingwave_java_binding_Binding_rowGetDecimalValue + (JNIEnv *, jclass, jlong, jint); + +/* + * Class: com_risingwave_java_binding_Binding + * Method: rowGetTimeValue + * Signature: (JI)Ljava/sql/Time; + */ +JNIEXPORT jobject JNICALL Java_com_risingwave_java_binding_Binding_rowGetTimeValue + (JNIEnv *, jclass, jlong, jint); + +/* + * Class: com_risingwave_java_binding_Binding + * Method: rowGetDateValue + * Signature: (JI)Ljava/sql/Date; + */ +JNIEXPORT jobject JNICALL Java_com_risingwave_java_binding_Binding_rowGetDateValue + (JNIEnv *, jclass, jlong, jint); + +/* + * Class: com_risingwave_java_binding_Binding + * Method: rowGetIntervalValue + * Signature: (JI)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_com_risingwave_java_binding_Binding_rowGetIntervalValue + (JNIEnv *, jclass, jlong, jint); + +/* + * Class: com_risingwave_java_binding_Binding + * Method: rowGetJsonbValue + * Signature: (JI)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_com_risingwave_java_binding_Binding_rowGetJsonbValue + (JNIEnv *, jclass, jlong, jint); + +/* + * Class: com_risingwave_java_binding_Binding + * Method: rowGetByteaValue + * Signature: (JI)[B + */ +JNIEXPORT jbyteArray JNICALL Java_com_risingwave_java_binding_Binding_rowGetByteaValue + (JNIEnv *, jclass, jlong, jint); + +/* + * Class: com_risingwave_java_binding_Binding + * Method: rowGetArrayValue + * Signature: (JILjava/lang/Class;)Ljava/lang/Object; + */ +JNIEXPORT jobject JNICALL Java_com_risingwave_java_binding_Binding_rowGetArrayValue + (JNIEnv *, jclass, jlong, jint, jclass); + /* * Class: com_risingwave_java_binding_Binding * Method: rowClose @@ -151,6 +215,22 @@ JNIEXPORT jlong JNICALL Java_com_risingwave_java_binding_Binding_streamChunkIter JNIEXPORT void JNICALL Java_com_risingwave_java_binding_Binding_streamChunkIteratorClose (JNIEnv *, jclass, jlong); +/* + * Class: com_risingwave_java_binding_Binding + * Method: streamChunkIteratorFromPretty + * Signature: (Ljava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_com_risingwave_java_binding_Binding_streamChunkIteratorFromPretty + (JNIEnv *, jclass, jstring); + +/* + * Class: com_risingwave_java_binding_Binding + * Method: sendMsgToChannel + * Signature: (ILjava/lang/Object;)V + */ +JNIEXPORT void JNICALL Java_com_risingwave_java_binding_Binding_sendMsgToChannel + (JNIEnv *, jclass, jint, jobject); + #ifdef __cplusplus } #endif diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java index e9fef6e869c04..f69f5774f3d84 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java @@ -70,6 +70,33 @@ public static CdcEngineRunner newCdcEngineRunner( return runner; } + public static CdcEngineRunner newCdcEngineRunnerV2(DbzConnectorConfig config) { + DbzCdcEngineRunner runner = null; + try { + var sourceId = config.getSourceId(); + var engine = + new DbzCdcEngine( + config.getSourceId(), + config.getResolvedDebeziumProps(), + (success, message, error) -> { + if (!success) { + LOG.error( + "engine#{} terminated with error. message: {}", + sourceId, + message, + error); + } else { + LOG.info("engine#{} stopped normally. {}", sourceId, message); + } + }); + + runner = new DbzCdcEngineRunner(engine); + } catch (Exception e) { + LOG.error("failed to create the CDC engine", e); + } + return runner; + } + /** Start to run the cdc engine */ public void start() { if (isRunning()) { diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java new file mode 100644 index 0000000000000..dfcf0f7e75971 --- /dev/null +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java @@ -0,0 +1,104 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.connector.source.core; + +import com.risingwave.connector.api.source.CdcEngineRunner; +import com.risingwave.connector.source.common.DbzConnectorConfig; +import com.risingwave.java.binding.Binding; +import com.risingwave.metrics.ConnectorNodeMetrics; +import io.grpc.Context; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** handler for starting a debezium source connectors */ +public class JniSourceHandler { + static final Logger LOG = LoggerFactory.getLogger(DbzSourceHandler.class); + + private final DbzConnectorConfig config; + + public JniSourceHandler(DbzConnectorConfig config) { + this.config = config; + } + + class OnReadyHandler implements Runnable { + private final CdcEngineRunner runner; + private final int channelId; + + public OnReadyHandler(CdcEngineRunner runner, int channelId) { + this.runner = runner; + this.channelId = channelId; + } + + @Override + public void run() { + while (runner.isRunning()) { + try { + if (Context.current().isCancelled()) { + LOG.info( + "Engine#{}: Connection broken detected, stop the engine", + config.getSourceId()); + runner.stop(); + return; + } + + // check whether the send queue has room for new messages + // Thread will block on the channel to get output from engine + var resp = + runner.getEngine().getOutputChannel().poll(500, TimeUnit.MILLISECONDS); + if (resp != null) { + ConnectorNodeMetrics.incSourceRowsReceived( + config.getSourceType().toString(), + String.valueOf(config.getSourceId()), + resp.getEventsCount()); + LOG.debug( + "Engine#{}: emit one chunk {} events to network ", + config.getSourceId(), + resp.getEventsCount()); + + Binding.sendMsgToChannel(channelId, resp); + } + } catch (Exception e) { + LOG.error("Poll engine output channel fail. ", e); + } + } + } + } + + public void start(int channelId) { + var runner = DbzCdcEngineRunner.newCdcEngineRunnerV2(config); + if (runner == null) { + return; + } + + try { + // Start the engine + runner.start(); + LOG.info("Start consuming events of table {}", config.getSourceId()); + + final OnReadyHandler onReadyHandler = new OnReadyHandler(runner, channelId); + + onReadyHandler.run(); + + } catch (Throwable t) { + LOG.error("Cdc engine failed.", t); + try { + runner.stop(); + } catch (Exception e) { + LOG.warn("Failed to stop Engine#{}", config.getSourceId(), e); + } + } + } +} diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java index b60bcb4f7da5a..b01c7d642df0a 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java @@ -39,4 +39,24 @@ public static SourceHandler createSourceHandler( source, sourceId, startOffset, mutableUserProps, snapshotDone); return new DbzSourceHandler(config); } + + public static void startJniSourceHandler( + SourceTypeE source, + long sourceId, + String startOffset, + Map userProps, + boolean snapshotDone, + int channelId) { + // For jni.rs + java.lang.Thread.currentThread() + .setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader()); + // userProps extracted from grpc request, underlying implementation is UnmodifiableMap + Map mutableUserProps = new HashMap<>(userProps); + mutableUserProps.put("source.id", Long.toString(sourceId)); + var config = + new DbzConnectorConfig( + source, sourceId, startOffset, mutableUserProps, snapshotDone); + JniSourceHandler hanlder = new JniSourceHandler(config); + hanlder.start(channelId); + } } diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java index 3f05768ec74b8..b25ad37c8df15 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java @@ -84,4 +84,6 @@ public class Binding { static native void streamChunkIteratorClose(long pointer); static native long streamChunkIteratorFromPretty(String str); + + public static native void sendMsgToChannel(int channelId, Object msg); } diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index d667b846e9f3a..7cac255f5249b 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -98,6 +98,7 @@ tracing-subscriber = "0.3.17" twox-hash = "1" url = "2" uuid = "1.4.1" +jni = { version = "0.21.1", features = ["invocation"] } [target.'cfg(not(madsim))'.dependencies] workspace-hack = { path = "../workspace-hack" } diff --git a/src/common/src/jvm_runtime.rs b/src/common/src/jvm_runtime.rs new file mode 100644 index 0000000000000..792ed74acf4df --- /dev/null +++ b/src/common/src/jvm_runtime.rs @@ -0,0 +1,67 @@ +use core::option::Option::Some; +use core::result::Result::{Err, Ok}; +use std::collections::HashMap; +use risingwave_pb::connector_service::GetEventStreamResponse; +use std::fs; +use std::path::Path; +use std::sync::{Arc, LazyLock, RwLock}; +use std::sync::atomic::AtomicI32; +use jni::{InitArgsBuilder, JavaVM, JNIVersion}; +use tokio::sync::mpsc::UnboundedSender; + +pub static JNI_CHANNEL_POOL: LazyLock>>> = LazyLock::new(|| { + RwLock::new(HashMap::new()) +}); + +pub static CHANNEL_ID_GEN: LazyLock> = LazyLock::new(|| { + Arc::new(AtomicI32::new(0)) +}); + +pub static JVM: LazyLock> = LazyLock::new(|| { + let dir_path = "/Users/dylan/Desktop/workspace/risingwave/.risingwave/bin/connector-node/libs/"; + + let dir = Path::new(dir_path); + + if !dir.is_dir() { + panic!("{} is not a directory", dir_path); + } + + let mut class_vec = vec![]; + + if let Ok(entries) = fs::read_dir(dir) { + for entry in entries { + if let Ok(entry) = entry { + if let Some(name) = entry.path().file_name() { + println!("{:?}", name); + class_vec.push(String::from( dir_path.to_owned() + name.to_str().to_owned().unwrap())); + } + } + } + } else { + println!("failed to read directory {}", dir_path); + } + + // Build the VM properties + let jvm_args = InitArgsBuilder::new() + // Pass the JNI API version (default is 8) + .version(JNIVersion::V8) + // You can additionally pass any JVM options (standard, like a system property, + // or VM-specific). + // Here we enable some extra JNI checks useful during development + // .option("-Xcheck:jni") + .option("-ea") + .option(format!("-Djava.class.path={}", class_vec.join(":")) ) + .option("-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=9111") + .build() + .unwrap(); + + // Create a new VM + let jvm = match JavaVM::new(jvm_args) { + Err(err) => { + panic!("{:?}", err) + }, + Ok(jvm) => jvm, + }; + + Arc::new(jvm) +}); diff --git a/src/common/src/lib.rs b/src/common/src/lib.rs index 1eff2e813f1d6..22d046b9377f0 100644 --- a/src/common/src/lib.rs +++ b/src/common/src/lib.rs @@ -72,6 +72,7 @@ pub mod metrics; pub mod test_utils; pub mod types; pub mod vnode_mapping; +pub mod jvm_runtime; pub mod test_prelude { pub use super::array::{DataChunkTestExt, StreamChunkTestExt}; diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index da56b0fe93761..c43475a0b633a 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -31,6 +31,7 @@ pub mod rpc; pub mod server; pub mod telemetry; +use std::collections::HashMap; use std::fs; use clap::{Parser, ValueEnum}; use risingwave_common::config::{AsyncStackTraceOption, OverrideConfig}; @@ -187,11 +188,13 @@ fn validate_opts(opts: &ComputeNodeOpts) { } use std::future::Future; +use std::ops::Deref; use std::path::Path; use std::pin::Pin; -use jni::{InitArgsBuilder, JavaVM, JNIVersion}; -use jni::objects::{JObject, JValue}; -use jni::sys::jint; +use jni::{InitArgsBuilder, JavaVM, JNIEnv, JNIVersion}; +use jni::objects::{JObject, JString, JValue}; +use jni::sys::{jint, jobject}; +use risingwave_common::jvm_runtime::JVM; use crate::server::compute_node_serve; @@ -234,67 +237,34 @@ pub fn start( }) } -fn run_jvm() { - - let dir_path = "/Users/dylan/Desktop/workspace/risingwave/.risingwave/bin/connector-node/libs/"; - - let dir = Path::new(dir_path); - - if !dir.is_dir() { - println!("{} is not a directory", dir_path); - return; - } - - let mut class_vec = vec![]; - - if let Ok(entries) = fs::read_dir(dir) { - for entry in entries { - if let Ok(entry) = entry { - if let Some(name) = entry.path().file_name() { - println!("{:?}", name); - class_vec.push(String::from( dir_path.to_owned() + name.to_str().to_owned().unwrap())); - } - } - } - } else { - println!("failed to read directory {}", dir_path); +fn rust_hashmap_to_java_hashmap<'a>(env: &'a mut JNIEnv, hashmap: &HashMap<&str, &str>) -> Result, String> { + let hashmap_class = "java/util/HashMap"; + let hashmap_constructor_signature = "()V"; + let hashmap_put_signature = "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; + + let map = env.new_object(hashmap_class, hashmap_constructor_signature, &[]).unwrap(); + for (key, value) in hashmap.iter() { + let key = env.new_string(*key).unwrap(); + let value = env.new_string(*value).unwrap(); + let args = [ + JValue::Object(&key), + JValue::Object(&value), + ]; + env.call_method(&map, "put", hashmap_put_signature, &args).unwrap(); } + Ok(map) +} - // Build the VM properties - let jvm_args = InitArgsBuilder::new() - // Pass the JNI API version (default is 8) - .version(JNIVersion::V8) - // You can additionally pass any JVM options (standard, like a system property, - // or VM-specific). - // Here we enable some extra JNI checks useful during development - // .option("-Xcheck:jni") - .option("-ea") - .option(format!("-Djava.class.path={}", class_vec.join(":")) ) - .option("-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=9111") - .build() - .unwrap(); - - // Create a new VM - let jvm = match JavaVM::new(jvm_args) { - Err(err) => { - panic!("{:?}", err) - }, - Ok(jvm) => jvm, - }; - - // Attach the current thread to call into Java — see extra options in - // "Attaching Native Threads" section. - // - // This method returns the guard that will detach the current thread when dropped, - // also freeing any local references created in it - let mut env = jvm.attach_current_thread_as_daemon().unwrap(); +fn run_jvm() { + let mut env = JVM.attach_current_thread_as_daemon().unwrap(); // Call Java Math#abs(-10) let x = JValue::from(-10); let val: jint = env.call_static_method("java/lang/Math", "abs", "(I)I", &[x]).unwrap() .i().unwrap(); assert_eq!(val, 10); + let string_class = env.find_class("java/lang/String").unwrap(); let jarray = env.new_object_array(0, string_class, JObject::null()).unwrap(); diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index d549b9c613164..367fc04f3b32c 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -53,6 +53,7 @@ itertools = "0.11" maplit = "1.0.2" moka = { version = "0.11", features = ["future"] } nexmark = { version = "0.2", features = ["serde"] } +jni = { version = "0.21.1", features = ["invocation"] } num-bigint = "0.4" opendal = "0.39" parking_lot = "0.12" diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 200c91a8a5051..053fdcbc4c04f 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -12,12 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; +use std::fs; +use std::path::Path; use std::str::FromStr; +use std::sync::{Arc, LazyLock, RwLock}; +use std::sync::atomic::{AtomicUsize, Ordering}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::{pin_mut, StreamExt, TryStreamExt}; use futures_async_stream::try_stream; +use jni::{InitArgsBuilder, JavaVM, JNIVersion}; +use jni::objects::{JObject, JValue}; +use jni::sys::jint; +use tokio::sync::mpsc; +use tokio::sync::mpsc::UnboundedSender; +use risingwave_common::jvm_runtime::{CHANNEL_ID_GEN, JNI_CHANNEL_POOL, JVM}; use risingwave_common::util::addr::HostAddr; use risingwave_pb::connector_service::GetEventStreamResponse; @@ -32,6 +43,7 @@ use crate::source::{ impl_common_split_reader_logic!(CdcSplitReader, CdcProperties); + pub struct CdcSplitReader { source_id: u64, start_offset: Option, @@ -96,7 +108,7 @@ impl SplitReader for CdcSplitReader { impl CdcSplitReader { #[try_stream(boxed, ok = Vec, error = anyhow::Error)] - async fn into_data_stream(self) { + async fn ____into_data_stream(self) { let cdc_client = self.source_ctx.connector_client.clone().ok_or_else(|| { anyhow!("connector node endpoint not specified or unable to connect to connector node") })?; @@ -153,4 +165,112 @@ impl CdcSplitReader { } } } + + #[try_stream(boxed, ok = Vec, error = anyhow::Error)] + async fn into_data_stream(self) { + // rewrite the hostname and port for the split + let mut properties = self.conn_props.props.clone(); + + // For citus, we need to rewrite the table.name to capture sharding tables + if self.server_addr.is_some() { + let addr = self.server_addr.unwrap(); + let host_addr = HostAddr::from_str(&addr) + .map_err(|err| anyhow!("invalid server address for cdc split. {}", err))?; + properties.insert("hostname".to_string(), host_addr.host); + properties.insert("port".to_string(), host_addr.port.to_string()); + // rewrite table name with suffix to capture all shards in the split + let mut table_name = properties + .remove("table.name") + .ok_or_else(|| anyhow!("missing field 'table.name'"))?; + table_name.push_str("_[0-9]+"); + properties.insert("table.name".into(), table_name); + } + + let (tx, mut rx) = mpsc::unbounded_channel(); + let channel_id = CHANNEL_ID_GEN.fetch_add(1, Ordering::Relaxed); + { + let mut guard = JNI_CHANNEL_POOL.write().unwrap(); + guard.insert(channel_id, tx); + } + + let source_type = self.conn_props.get_source_type_pb()?; + + // let cdc_stream = cdc_client + // .start_source_stream( + // self.source_id, + // self.conn_props.get_source_type_pb()?, + // self.start_offset, + // properties, + // self.snapshot_done, + // ) + // .await + // .inspect_err(|err| tracing::error!("connector node start stream error: {}", err))?; + // pin_mut!(cdc_stream); + + tokio::task::spawn_blocking(move || { + let mut env = JVM.attach_current_thread_as_daemon().unwrap(); + + env.find_class("com/risingwave/proto/ConnectorServiceProto$SourceType").inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + let source_type_arg = JValue::from(source_type as i32); + let st = env.call_static_method("com/risingwave/proto/ConnectorServiceProto$SourceType", "forNumber", "(I)Lcom/risingwave/proto/ConnectorServiceProto$SourceType;", &[source_type_arg]).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + let st = env.call_static_method("com/risingwave/connector/api/source/SourceTypeE", "valueOf", "(Lcom/risingwave/proto/ConnectorServiceProto$SourceType;)Lcom/risingwave/connector/api/source/SourceTypeE;", &[(&st).into()]).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + + let source_id_arg = JValue::from(self.source_id as i64); + + + let source_type = env.find_class("com/risingwave/connector/api/source/SourceTypeE").unwrap(); + let string_class = env.find_class("java/lang/String").unwrap(); + let start_offset = match self.start_offset { + Some(start_offset) => { + let start_offset = env.new_string(start_offset).unwrap(); + env.call_method(start_offset, "toString", "()Ljava/lang/String;", &[]).unwrap() + }, + None => { + jni::objects::JValueGen::Object(JObject::null()) + } + }; + + let mut user_prop = properties; + + let hashmap_class = "java/util/HashMap"; + let hashmap_constructor_signature = "()V"; + let hashmap_put_signature = "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; + + let java_map = env.new_object(hashmap_class, hashmap_constructor_signature, &[]).unwrap(); + for (key, value) in user_prop.iter() { + let key = env.new_string(key.to_string()).unwrap(); + let value = env.new_string(value.to_string()).unwrap(); + let args = [ + JValue::Object(&key), + JValue::Object(&value), + ]; + env.call_method(&java_map, "put", hashmap_put_signature, &args).unwrap(); + } + + let snapshot_done = JValue::from(self.snapshot_done); + + let channel_id = JValue::from(channel_id as i32); + + let _ = env.call_static_method( + "com/risingwave/connector/source/core/SourceHandlerFactory", + "startJniSourceHandler", + "(Lcom/risingwave/connector/api/source/SourceTypeE;JLjava/lang/String;Ljava/util/Map;ZI)V", + &[(&st).into(), source_id_arg, (&start_offset).into(), JValue::Object(&java_map), snapshot_done, channel_id]).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + + println!("call jni cdc start source success"); + }); + + while let Some(GetEventStreamResponse { events, .. }) = rx.recv().await { + if events.is_empty() { + continue; + } + let mut msgs = Vec::with_capacity(events.len()); + for event in events { + msgs.push(SourceMessage::from(event)); + } + yield msgs; + } + } } + + diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index 6da90e6931ad9..7b48705b0b32b 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -33,7 +33,7 @@ use jni::objects::{ JValue, JValueGen, JValueOwned, ReleaseMode, }; use jni::signature::ReturnType; -use jni::sys::{jboolean, jbyte, jdouble, jfloat, jint, jlong, jshort, jsize, jvalue}; +use jni::sys::{jboolean, jbyte, jdouble, jfloat, jint, jlong, jobject, jshort, jsize, jvalue}; use jni::JNIEnv; use prost::{DecodeError, Message}; use risingwave_common::array::{ArrayError, StreamChunk}; @@ -45,6 +45,9 @@ use risingwave_common::util::panic::rw_catch_unwind; use risingwave_storage::error::StorageError; use thiserror::Error; use tokio::runtime::Runtime; +use risingwave_common::jvm_runtime::JNI_CHANNEL_POOL; +use risingwave_pb::connector_service::{CdcMessage, GetEventStreamResponse}; +use risingwave_pb::hummock::GetEpochResponse; use crate::stream_chunk_iterator::{StreamChunkIterator, StreamChunkRow}; @@ -816,6 +819,70 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>( pointer.drop() } +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( + mut env: EnvParam<'a>, + channel_id: jint, + mut msg: JObject<'a>, +) { + let guard = JNI_CHANNEL_POOL.read().unwrap(); + println!("JNI_CHANNEL_POOL len = {}", guard.len()); + let channel = guard.get(&channel_id).unwrap(); + + let source_id = env.env.call_method(&mut msg, "getSourceId", "()L", &[]).unwrap(); + let source_id = source_id.j().unwrap(); + + let events_list = env.env.call_method(&mut msg, "getEventsList", "()Ljava/util/List;", &[]).unwrap(); + let mut events_list = match events_list { + JValueGen::Object(obj) => obj, + _ => unreachable!() + }; + + + let size = env.env.call_method(&mut events_list, "size", "()I", &[]).unwrap().i().unwrap(); + let mut events = Vec::with_capacity(size as usize); + for i in 0..size { + let java_element = env.call_method(&mut events_list, "get", "(I)Ljava/lang/Object;", &[JValue::from(i as i32)]).unwrap(); + let mut java_element = match java_element { + JValueGen::Object(obj) => obj, + _ => unreachable!() + }; + let payload = env.call_method(&mut java_element, "getPayload", "()Ljava/lang/String;", &[]).unwrap(); + let mut payload = match payload { + JValueGen::Object(obj) => obj, + _ => unreachable!() + }; + let payload: String = env.get_string(&JString::from(payload)).unwrap().into(); + + let partition = env.call_method(&mut java_element, "getPartition", "()Ljava/lang/String;", &[]).unwrap(); + let mut partition = match partition { + JValueGen::Object(obj) => obj, + _ => unreachable!() + }; + let partition: String = env.get_string(&JString::from(partition)).unwrap().into(); + + let offset = env.call_method(&mut java_element, "getOffset", "()Ljava/lang/String;", &[]).unwrap(); + let mut offset = match offset { + JValueGen::Object(obj) => obj, + _ => unreachable!() + }; + let offset: String = env.get_string(&JString::from(offset)).unwrap().into(); + + println!("channel_id = {:?}, source_id = {:?}, payload = {:?}, partition = {:?}, offset = {:?}", channel_id, source_id, payload, partition, offset); + events.push(CdcMessage { + payload, + partition, + offset, + }) + } + let get_event_stream_response = GetEventStreamResponse { + source_id: source_id as u64, + events, + }; + let _ = channel.send(get_event_stream_response); +} + + #[cfg(test)] mod tests { use risingwave_common::types::{DataType, Timestamptz}; From 8e3fc67a1d74211cda5d6e3f2f53a65b9c426620 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Thu, 24 Aug 2023 20:13:12 +0800 Subject: [PATCH 03/23] support jni create cdc source --- Cargo.lock | 1 + Cargo.toml | 1 + java/com_risingwave_java_binding_Binding.h | 4 +- .../source/core/JniSourceHandler.java | 39 ++-- .../source/core/SourceHandlerFactory.java | 4 +- .../com/risingwave/java/binding/Binding.java | 6 +- src/common/src/jvm_runtime.rs | 22 ++- src/compute/Cargo.toml | 1 + src/compute/src/lib.rs | 169 +++++++++++++++++- src/connector/src/source/cdc/source/reader.rs | 51 +++--- src/java_binding/src/lib.rs | 26 +-- 11 files changed, 253 insertions(+), 71 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e19e3ab2882e2..a028b9ca6c117 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6455,6 +6455,7 @@ dependencies = [ "risingwave_common_service", "risingwave_connector", "risingwave_hummock_sdk", + "risingwave_java_binding", "risingwave_pb", "risingwave_rpc_client", "risingwave_source", diff --git a/Cargo.toml b/Cargo.toml index dc685bdd20f76..29009fa696316 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -108,6 +108,7 @@ risingwave_stream = { path = "./src/stream" } risingwave_test_runner = { path = "./src/test_runner" } risingwave_udf = { path = "./src/udf" } risingwave_variables = { path = "./src/utils/variables" } +risingwave_java_binding = { path = "./src/java_binding" } [profile.dev] lto = 'off' diff --git a/java/com_risingwave_java_binding_Binding.h b/java/com_risingwave_java_binding_Binding.h index 0f33228d6797b..c2c235921756f 100644 --- a/java/com_risingwave_java_binding_Binding.h +++ b/java/com_risingwave_java_binding_Binding.h @@ -226,10 +226,10 @@ JNIEXPORT jlong JNICALL Java_com_risingwave_java_binding_Binding_streamChunkIter /* * Class: com_risingwave_java_binding_Binding * Method: sendMsgToChannel - * Signature: (ILjava/lang/Object;)V + * Signature: (JLjava/lang/Object;)V */ JNIEXPORT void JNICALL Java_com_risingwave_java_binding_Binding_sendMsgToChannel - (JNIEnv *, jclass, jint, jobject); + (JNIEnv *, jclass, jlong, jobject); #ifdef __cplusplus } diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java index dfcf0f7e75971..2a88001184d6e 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java @@ -18,7 +18,6 @@ import com.risingwave.connector.source.common.DbzConnectorConfig; import com.risingwave.java.binding.Binding; import com.risingwave.metrics.ConnectorNodeMetrics; -import io.grpc.Context; import java.util.concurrent.TimeUnit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,49 +34,57 @@ public JniSourceHandler(DbzConnectorConfig config) { class OnReadyHandler implements Runnable { private final CdcEngineRunner runner; - private final int channelId; + private final long channelPtr; - public OnReadyHandler(CdcEngineRunner runner, int channelId) { + public OnReadyHandler(CdcEngineRunner runner, long channelPtr) { this.runner = runner; - this.channelId = channelId; + this.channelPtr = channelPtr; } @Override public void run() { while (runner.isRunning()) { try { - if (Context.current().isCancelled()) { - LOG.info( - "Engine#{}: Connection broken detected, stop the engine", - config.getSourceId()); - runner.stop(); - return; - } + LOG.info("Engine#{}: loop step 1 ", config.getSourceId()); + // if (Context.current().isCancelled()) { + // LOG.info( + // "Engine#{}: Connection broken detected, stop + // the engine", + // config.getSourceId()); + // runner.stop(); + // return; + // } + + LOG.info("Engine#{}: loop step 2 ", config.getSourceId()); // check whether the send queue has room for new messages // Thread will block on the channel to get output from engine var resp = runner.getEngine().getOutputChannel().poll(500, TimeUnit.MILLISECONDS); + LOG.info("Engine#{}: loop step 3 ", config.getSourceId()); if (resp != null) { ConnectorNodeMetrics.incSourceRowsReceived( config.getSourceType().toString(), String.valueOf(config.getSourceId()), resp.getEventsCount()); - LOG.debug( + LOG.info( "Engine#{}: emit one chunk {} events to network ", config.getSourceId(), resp.getEventsCount()); - Binding.sendMsgToChannel(channelId, resp); + Binding.sendMsgToChannel(channelPtr, resp); + Thread.sleep(10000); } - } catch (Exception e) { + + LOG.info("Engine#{}: loop step 4 ", config.getSourceId()); + } catch (Throwable e) { LOG.error("Poll engine output channel fail. ", e); } } } } - public void start(int channelId) { + public void start(long channelPtr) { var runner = DbzCdcEngineRunner.newCdcEngineRunnerV2(config); if (runner == null) { return; @@ -88,7 +95,7 @@ public void start(int channelId) { runner.start(); LOG.info("Start consuming events of table {}", config.getSourceId()); - final OnReadyHandler onReadyHandler = new OnReadyHandler(runner, channelId); + final OnReadyHandler onReadyHandler = new OnReadyHandler(runner, channelPtr); onReadyHandler.run(); diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java index b01c7d642df0a..8974861d29e18 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java @@ -46,7 +46,7 @@ public static void startJniSourceHandler( String startOffset, Map userProps, boolean snapshotDone, - int channelId) { + long channelPtr) { // For jni.rs java.lang.Thread.currentThread() .setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader()); @@ -57,6 +57,6 @@ public static void startJniSourceHandler( new DbzConnectorConfig( source, sourceId, startOffset, mutableUserProps, snapshotDone); JniSourceHandler hanlder = new JniSourceHandler(config); - hanlder.start(channelId); + hanlder.start(channelPtr); } } diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java index b25ad37c8df15..93f10a1f829f8 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java @@ -14,11 +14,11 @@ package com.risingwave.java.binding; -import io.questdb.jar.jni.JarJniLoader; +// import io.questdb.jar.jni.JarJniLoader; public class Binding { static { - JarJniLoader.loadLib(Binding.class, "/risingwave/jni", "risingwave_java_binding"); + // JarJniLoader.loadLib(Binding.class, "/risingwave/jni", "risingwave_java_binding"); } public static native int vnodeCount(); @@ -85,5 +85,5 @@ public class Binding { static native long streamChunkIteratorFromPretty(String str); - public static native void sendMsgToChannel(int channelId, Object msg); + public static native void sendMsgToChannel(long channelPtr, Object msg); } diff --git a/src/common/src/jvm_runtime.rs b/src/common/src/jvm_runtime.rs index 792ed74acf4df..94f07bdf53c39 100644 --- a/src/common/src/jvm_runtime.rs +++ b/src/common/src/jvm_runtime.rs @@ -7,15 +7,7 @@ use std::path::Path; use std::sync::{Arc, LazyLock, RwLock}; use std::sync::atomic::AtomicI32; use jni::{InitArgsBuilder, JavaVM, JNIVersion}; -use tokio::sync::mpsc::UnboundedSender; - -pub static JNI_CHANNEL_POOL: LazyLock>>> = LazyLock::new(|| { - RwLock::new(HashMap::new()) -}); - -pub static CHANNEL_ID_GEN: LazyLock> = LazyLock::new(|| { - Arc::new(AtomicI32::new(0)) -}); +use tokio::sync::mpsc::{Sender, UnboundedSender}; pub static JVM: LazyLock> = LazyLock::new(|| { let dir_path = "/Users/dylan/Desktop/workspace/risingwave/.risingwave/bin/connector-node/libs/"; @@ -65,3 +57,15 @@ pub static JVM: LazyLock> = LazyLock::new(|| { Arc::new(jvm) }); + + +pub struct MyPtr { + pub ptr: Sender, + pub num: u64, +} + +impl Drop for MyPtr { + fn drop(&mut self) { + println!("drop MyPtr, num = {}", self.num); + } +} \ No newline at end of file diff --git a/src/compute/Cargo.toml b/src/compute/Cargo.toml index 4cca8235d8e20..607cb394f4a7e 100644 --- a/src/compute/Cargo.toml +++ b/src/compute/Cargo.toml @@ -38,6 +38,7 @@ risingwave_rpc_client = { workspace = true } risingwave_source = { workspace = true } risingwave_storage = { workspace = true } risingwave_stream = { workspace = true } +risingwave_java_binding = { workspace = true } serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index c43475a0b633a..fafb5c4fcc1d3 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -32,6 +32,7 @@ pub mod server; pub mod telemetry; use std::collections::HashMap; +use std::ffi::c_void; use std::fs; use clap::{Parser, ValueEnum}; use risingwave_common::config::{AsyncStackTraceOption, OverrideConfig}; @@ -188,13 +189,16 @@ fn validate_opts(opts: &ComputeNodeOpts) { } use std::future::Future; -use std::ops::Deref; +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; use std::path::Path; use std::pin::Pin; -use jni::{InitArgsBuilder, JavaVM, JNIEnv, JNIVersion}; -use jni::objects::{JObject, JString, JValue}; -use jni::sys::{jint, jobject}; -use risingwave_common::jvm_runtime::JVM; +use jni::{InitArgsBuilder, JavaVM, JNIEnv, JNIVersion, NativeMethod}; +use jni::objects::{JClass, JObject, JString, JValue, JValueGen}; +use jni::strings::JNIString; +use jni::sys::{jint, jlong, jobject}; +use risingwave_common::jvm_runtime::{JVM, MyPtr}; +use risingwave_pb::connector_service::{CdcMessage, GetEventStreamResponse}; use crate::server::compute_node_serve; @@ -255,6 +259,148 @@ fn rust_hashmap_to_java_hashmap<'a>(env: &'a mut JNIEnv, hashmap: &HashMap<&str, Ok(map) } +#[repr(C)] +pub struct EnvParam<'a> { + env: JNIEnv<'a>, + class: JClass<'a>, +} + +impl<'a> Deref for EnvParam<'a> { + type Target = JNIEnv<'a>; + + fn deref(&self) -> &Self::Target { + &self.env + } +} + +impl<'a> DerefMut for EnvParam<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.env + } +} + +impl<'a> EnvParam<'a> { + pub fn get_class(&self) -> &JClass<'a> { + &self.class + } +} + +#[repr(transparent)] +pub struct Pointer<'a, T> { + pointer: jlong, + _phantom: PhantomData<&'a T>, +} + +impl<'a, T> Default for Pointer<'a, T> { + fn default() -> Self { + Self { + pointer: 0, + _phantom: Default::default(), + } + } +} + +impl From for Pointer<'static, T> { + fn from(value: T) -> Self { + Pointer { + pointer: Box::into_raw(Box::new(value)) as jlong, + _phantom: PhantomData, + } + } +} + +impl Pointer<'static, T> { + fn null() -> Self { + Pointer { + pointer: 0, + _phantom: PhantomData, + } + } +} + +impl<'a, T> Pointer<'a, T> { + fn as_ref(&self) -> &'a T { + debug_assert!(self.pointer != 0); + unsafe { &*(self.pointer as *const T) } + } + + fn as_mut(&mut self) -> &'a mut T { + debug_assert!(self.pointer != 0); + unsafe { &mut *(self.pointer as *mut T) } + } + + fn drop(self) { + debug_assert!(self.pointer != 0); + unsafe { drop(Box::from_raw(self.pointer as *mut T)) } + } +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( + mut env: EnvParam<'a>, + channel: Pointer<'a, MyPtr>, + mut msg: JObject<'a>, +) { + + println!("channel_ptr = {}, num = {}", channel.pointer, channel.as_ref().num); + // let channel: &mut UnboundedSender = unsafe { &mut *(channel_ptr.pointer as *mut UnboundedSender) }; + + let source_id = env.env.call_method(&mut msg, "getSourceId", "()J", &[]).unwrap(); + let source_id = source_id.j().unwrap(); + + let events_list = env.env.call_method(&mut msg, "getEventsList", "()Ljava/util/List;", &[]).unwrap(); + let mut events_list = match events_list { + JValueGen::Object(obj) => obj, + _ => unreachable!() + }; + + + let size = env.env.call_method(&mut events_list, "size", "()I", &[]).unwrap().i().unwrap(); + let mut events = Vec::with_capacity(size as usize); + for i in 0..size { + let java_element = env.call_method(&mut events_list, "get", "(I)Ljava/lang/Object;", &[JValue::from(i as i32)]).unwrap(); + let mut java_element = match java_element { + JValueGen::Object(obj) => obj, + _ => unreachable!() + }; + let payload = env.call_method(&mut java_element, "getPayload", "()Ljava/lang/String;", &[]).unwrap(); + let mut payload = match payload { + JValueGen::Object(obj) => obj, + _ => unreachable!() + }; + let payload: String = env.get_string(&JString::from(payload)).unwrap().into(); + + let partition = env.call_method(&mut java_element, "getPartition", "()Ljava/lang/String;", &[]).unwrap(); + let mut partition = match partition { + JValueGen::Object(obj) => obj, + _ => unreachable!() + }; + let partition: String = env.get_string(&JString::from(partition)).unwrap().into(); + + let offset = env.call_method(&mut java_element, "getOffset", "()Ljava/lang/String;", &[]).unwrap(); + let mut offset = match offset { + JValueGen::Object(obj) => obj, + _ => unreachable!() + }; + let offset: String = env.get_string(&JString::from(offset)).unwrap().into(); + + println!("source_id = {:?}, payload = {:?}, partition = {:?}, offset = {:?}", source_id, payload, partition, offset); + events.push(CdcMessage { + payload, + partition, + offset, + }) + } + let get_event_stream_response = GetEventStreamResponse { + source_id: source_id as u64, + events, + }; + println!("before send"); + let _ = channel.as_ref().ptr.blocking_send(get_event_stream_response).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + println!("send successfully"); +} + + fn run_jvm() { let mut env = JVM.attach_current_thread_as_daemon().unwrap(); @@ -268,6 +414,19 @@ fn run_jvm() { let string_class = env.find_class("java/lang/String").unwrap(); let jarray = env.new_object_array(0, string_class, JObject::null()).unwrap(); + let fn_ptr = Java_com_risingwave_java_binding_Binding_sendMsgToChannel as *const fn ( + EnvParam<'static>, + Pointer<'static, MyPtr>, + JObject<'static> + ); + + let binding_class = env.find_class("com/risingwave/java/binding/Binding").unwrap(); + env.register_native_methods(binding_class, &[NativeMethod { + name: JNIString::from("sendMsgToChannel"), + sig: JNIString::from("(JLjava/lang/Object;)V"), + fn_ptr: fn_ptr as *mut c_void, + }]).unwrap(); + let _ = env.call_static_method("com/risingwave/connector/ConnectorService", "main", "([Ljava/lang/String;)V", &[JValue::Object(&jarray)]).inspect_err(|e| eprintln!("{:?}", e)); } diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 053fdcbc4c04f..c3a4e337840ae 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -14,10 +14,12 @@ use std::collections::HashMap; use std::fs; +use std::mem::forget; use std::path::Path; use std::str::FromStr; use std::sync::{Arc, LazyLock, RwLock}; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -27,8 +29,8 @@ use jni::{InitArgsBuilder, JavaVM, JNIVersion}; use jni::objects::{JObject, JValue}; use jni::sys::jint; use tokio::sync::mpsc; -use tokio::sync::mpsc::UnboundedSender; -use risingwave_common::jvm_runtime::{CHANNEL_ID_GEN, JNI_CHANNEL_POOL, JVM}; +use tokio::time::sleep; +use risingwave_common::jvm_runtime::{JVM, MyPtr}; use risingwave_common::util::addr::HostAddr; use risingwave_pb::connector_service::GetEventStreamResponse; @@ -186,26 +188,15 @@ impl CdcSplitReader { properties.insert("table.name".into(), table_name); } - let (tx, mut rx) = mpsc::unbounded_channel(); - let channel_id = CHANNEL_ID_GEN.fetch_add(1, Ordering::Relaxed); - { - let mut guard = JNI_CHANNEL_POOL.write().unwrap(); - guard.insert(channel_id, tx); - } + let (tx, mut rx) = mpsc::channel(1024); + + let tx: Box = Box::new(MyPtr { + ptr: tx, + num: 123456, + }); let source_type = self.conn_props.get_source_type_pb()?; - // let cdc_stream = cdc_client - // .start_source_stream( - // self.source_id, - // self.conn_props.get_source_type_pb()?, - // self.start_offset, - // properties, - // self.snapshot_done, - // ) - // .await - // .inspect_err(|err| tracing::error!("connector node start stream error: {}", err))?; - // pin_mut!(cdc_stream); tokio::task::spawn_blocking(move || { let mut env = JVM.attach_current_thread_as_daemon().unwrap(); @@ -249,18 +240,34 @@ impl CdcSplitReader { let snapshot_done = JValue::from(self.snapshot_done); - let channel_id = JValue::from(channel_id as i32); + let channel_ptr = Box::into_raw(tx) as i64; + println!("channel_ptr = {}", channel_ptr); + let channel_ptr = JValue::from(channel_ptr); let _ = env.call_static_method( "com/risingwave/connector/source/core/SourceHandlerFactory", "startJniSourceHandler", - "(Lcom/risingwave/connector/api/source/SourceTypeE;JLjava/lang/String;Ljava/util/Map;ZI)V", - &[(&st).into(), source_id_arg, (&start_offset).into(), JValue::Object(&java_map), snapshot_done, channel_id]).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + "(Lcom/risingwave/connector/api/source/SourceTypeE;JLjava/lang/String;Ljava/util/Map;ZJ)V", + &[(&st).into(), source_id_arg, (&start_offset).into(), JValue::Object(&java_map), snapshot_done, channel_ptr]).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); println!("call jni cdc start source success"); }); + // loop { + // let GetEventStreamResponse { events, .. } = rx.recv().unwrap(); + // println!("recieve events {:?}", events.len()); + // if events.is_empty() { + // continue; + // } + // let mut msgs = Vec::with_capacity(events.len()); + // for event in events { + // msgs.push(SourceMessage::from(event)); + // } + // yield msgs; + // } + while let Some(GetEventStreamResponse { events, .. }) = rx.recv().await { + println!("recieve events {:?}", events.len()); if events.is_empty() { continue; } diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index 7b48705b0b32b..aab24b770631b 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -17,9 +17,10 @@ #![feature(lazy_cell)] #![feature(once_cell_try)] #![feature(type_alias_impl_trait)] +#![feature(result_option_inspect)] -mod hummock_iterator; -mod stream_chunk_iterator; +pub mod hummock_iterator; +pub mod stream_chunk_iterator; use std::backtrace::Backtrace; use std::marker::PhantomData; @@ -45,9 +46,9 @@ use risingwave_common::util::panic::rw_catch_unwind; use risingwave_storage::error::StorageError; use thiserror::Error; use tokio::runtime::Runtime; -use risingwave_common::jvm_runtime::JNI_CHANNEL_POOL; +use tokio::sync::mpsc::UnboundedSender; +use risingwave_common::jvm_runtime::MyPtr; use risingwave_pb::connector_service::{CdcMessage, GetEventStreamResponse}; -use risingwave_pb::hummock::GetEpochResponse; use crate::stream_chunk_iterator::{StreamChunkIterator, StreamChunkRow}; @@ -822,14 +823,14 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>( #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( mut env: EnvParam<'a>, - channel_id: jint, + channel: Pointer<'a, MyPtr>, mut msg: JObject<'a>, ) { - let guard = JNI_CHANNEL_POOL.read().unwrap(); - println!("JNI_CHANNEL_POOL len = {}", guard.len()); - let channel = guard.get(&channel_id).unwrap(); - let source_id = env.env.call_method(&mut msg, "getSourceId", "()L", &[]).unwrap(); + println!("channel_ptr = {}, num = {}", channel.pointer, channel.as_ref().num); + // let channel: &mut UnboundedSender = unsafe { &mut *(channel_ptr.pointer as *mut UnboundedSender) }; + + let source_id = env.env.call_method(&mut msg, "getSourceId", "()J", &[]).unwrap(); let source_id = source_id.j().unwrap(); let events_list = env.env.call_method(&mut msg, "getEventsList", "()Ljava/util/List;", &[]).unwrap(); @@ -868,7 +869,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel }; let offset: String = env.get_string(&JString::from(offset)).unwrap().into(); - println!("channel_id = {:?}, source_id = {:?}, payload = {:?}, partition = {:?}, offset = {:?}", channel_id, source_id, payload, partition, offset); + println!("source_id = {:?}, payload = {:?}, partition = {:?}, offset = {:?}", source_id, payload, partition, offset); events.push(CdcMessage { payload, partition, @@ -879,10 +880,11 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel source_id: source_id as u64, events, }; - let _ = channel.send(get_event_stream_response); + println!("before send"); + let _ = channel.as_ref().ptr.blocking_send(get_event_stream_response).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + println!("send successfully"); } - #[cfg(test)] mod tests { use risingwave_common::types::{DataType, Timestamptz}; From bea170fc12bf9d24f006d28ca96ba4bd1c2e3232 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Fri, 25 Aug 2023 14:40:45 +0800 Subject: [PATCH 04/23] fmt --- .../source/core/JniSourceHandler.java | 15 --------- src/common/src/jvm_runtime.rs | 8 ++--- src/compute/src/lib.rs | 31 +++--------------- src/connector/src/source/cdc/source/reader.rs | 32 +------------------ src/java_binding/src/lib.rs | 9 +++--- 5 files changed, 13 insertions(+), 82 deletions(-) diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java index 2a88001184d6e..0864d25cd6eaf 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java @@ -45,23 +45,10 @@ public OnReadyHandler(CdcEngineRunner runner, long channelPtr) { public void run() { while (runner.isRunning()) { try { - LOG.info("Engine#{}: loop step 1 ", config.getSourceId()); - // if (Context.current().isCancelled()) { - // LOG.info( - // "Engine#{}: Connection broken detected, stop - // the engine", - // config.getSourceId()); - // runner.stop(); - // return; - // } - - LOG.info("Engine#{}: loop step 2 ", config.getSourceId()); - // check whether the send queue has room for new messages // Thread will block on the channel to get output from engine var resp = runner.getEngine().getOutputChannel().poll(500, TimeUnit.MILLISECONDS); - LOG.info("Engine#{}: loop step 3 ", config.getSourceId()); if (resp != null) { ConnectorNodeMetrics.incSourceRowsReceived( config.getSourceType().toString(), @@ -75,8 +62,6 @@ public void run() { Binding.sendMsgToChannel(channelPtr, resp); Thread.sleep(10000); } - - LOG.info("Engine#{}: loop step 4 ", config.getSourceId()); } catch (Throwable e) { LOG.error("Poll engine output channel fail. ", e); } diff --git a/src/common/src/jvm_runtime.rs b/src/common/src/jvm_runtime.rs index 94f07bdf53c39..5a33a5783e4b7 100644 --- a/src/common/src/jvm_runtime.rs +++ b/src/common/src/jvm_runtime.rs @@ -1,16 +1,14 @@ use core::option::Option::Some; use core::result::Result::{Err, Ok}; -use std::collections::HashMap; use risingwave_pb::connector_service::GetEventStreamResponse; use std::fs; use std::path::Path; -use std::sync::{Arc, LazyLock, RwLock}; -use std::sync::atomic::AtomicI32; +use std::sync::{Arc, LazyLock}; use jni::{InitArgsBuilder, JavaVM, JNIVersion}; -use tokio::sync::mpsc::{Sender, UnboundedSender}; +use tokio::sync::mpsc::Sender; pub static JVM: LazyLock> = LazyLock::new(|| { - let dir_path = "/Users/dylan/Desktop/workspace/risingwave/.risingwave/bin/connector-node/libs/"; + let dir_path = ".risingwave/bin/connector-node/libs/"; let dir = Path::new(dir_path); diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index fafb5c4fcc1d3..9859190732fbc 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -31,9 +31,7 @@ pub mod rpc; pub mod server; pub mod telemetry; -use std::collections::HashMap; use std::ffi::c_void; -use std::fs; use clap::{Parser, ValueEnum}; use risingwave_common::config::{AsyncStackTraceOption, OverrideConfig}; use risingwave_common::util::resource_util::cpu::total_cpu_available; @@ -191,12 +189,11 @@ fn validate_opts(opts: &ComputeNodeOpts) { use std::future::Future; use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; -use std::path::Path; use std::pin::Pin; -use jni::{InitArgsBuilder, JavaVM, JNIEnv, JNIVersion, NativeMethod}; +use jni::{JNIEnv, NativeMethod}; use jni::objects::{JClass, JObject, JString, JValue, JValueGen}; use jni::strings::JNIString; -use jni::sys::{jint, jlong, jobject}; +use jni::sys::{jint, jlong}; use risingwave_common::jvm_runtime::{JVM, MyPtr}; use risingwave_pb::connector_service::{CdcMessage, GetEventStreamResponse}; @@ -241,24 +238,6 @@ pub fn start( }) } -fn rust_hashmap_to_java_hashmap<'a>(env: &'a mut JNIEnv, hashmap: &HashMap<&str, &str>) -> Result, String> { - let hashmap_class = "java/util/HashMap"; - let hashmap_constructor_signature = "()V"; - let hashmap_put_signature = "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; - - let map = env.new_object(hashmap_class, hashmap_constructor_signature, &[]).unwrap(); - for (key, value) in hashmap.iter() { - let key = env.new_string(*key).unwrap(); - let value = env.new_string(*value).unwrap(); - let args = [ - JValue::Object(&key), - JValue::Object(&value), - ]; - env.call_method(&map, "put", hashmap_put_signature, &args).unwrap(); - } - Ok(map) -} - #[repr(C)] pub struct EnvParam<'a> { env: JNIEnv<'a>, @@ -364,21 +343,21 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel _ => unreachable!() }; let payload = env.call_method(&mut java_element, "getPayload", "()Ljava/lang/String;", &[]).unwrap(); - let mut payload = match payload { + let payload = match payload { JValueGen::Object(obj) => obj, _ => unreachable!() }; let payload: String = env.get_string(&JString::from(payload)).unwrap().into(); let partition = env.call_method(&mut java_element, "getPartition", "()Ljava/lang/String;", &[]).unwrap(); - let mut partition = match partition { + let partition = match partition { JValueGen::Object(obj) => obj, _ => unreachable!() }; let partition: String = env.get_string(&JString::from(partition)).unwrap().into(); let offset = env.call_method(&mut java_element, "getOffset", "()Ljava/lang/String;", &[]).unwrap(); - let mut offset = match offset { + let offset = match offset { JValueGen::Object(obj) => obj, _ => unreachable!() }; diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index c3a4e337840ae..8972ca07d929a 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -12,24 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; -use std::fs; -use std::mem::forget; -use std::path::Path; use std::str::FromStr; -use std::sync::{Arc, LazyLock, RwLock}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::time::Duration; use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::{pin_mut, StreamExt, TryStreamExt}; use futures_async_stream::try_stream; -use jni::{InitArgsBuilder, JavaVM, JNIVersion}; use jni::objects::{JObject, JValue}; -use jni::sys::jint; use tokio::sync::mpsc; -use tokio::time::sleep; use risingwave_common::jvm_runtime::{JVM, MyPtr}; use risingwave_common::util::addr::HostAddr; use risingwave_pb::connector_service::GetEventStreamResponse; @@ -208,9 +198,6 @@ impl CdcSplitReader { let source_id_arg = JValue::from(self.source_id as i64); - - let source_type = env.find_class("com/risingwave/connector/api/source/SourceTypeE").unwrap(); - let string_class = env.find_class("java/lang/String").unwrap(); let start_offset = match self.start_offset { Some(start_offset) => { let start_offset = env.new_string(start_offset).unwrap(); @@ -221,14 +208,12 @@ impl CdcSplitReader { } }; - let mut user_prop = properties; - let hashmap_class = "java/util/HashMap"; let hashmap_constructor_signature = "()V"; let hashmap_put_signature = "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; let java_map = env.new_object(hashmap_class, hashmap_constructor_signature, &[]).unwrap(); - for (key, value) in user_prop.iter() { + for (key, value) in properties.iter() { let key = env.new_string(key.to_string()).unwrap(); let value = env.new_string(value.to_string()).unwrap(); let args = [ @@ -249,23 +234,8 @@ impl CdcSplitReader { "startJniSourceHandler", "(Lcom/risingwave/connector/api/source/SourceTypeE;JLjava/lang/String;Ljava/util/Map;ZJ)V", &[(&st).into(), source_id_arg, (&start_offset).into(), JValue::Object(&java_map), snapshot_done, channel_ptr]).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); - - println!("call jni cdc start source success"); }); - // loop { - // let GetEventStreamResponse { events, .. } = rx.recv().unwrap(); - // println!("recieve events {:?}", events.len()); - // if events.is_empty() { - // continue; - // } - // let mut msgs = Vec::with_capacity(events.len()); - // for event in events { - // msgs.push(SourceMessage::from(event)); - // } - // yield msgs; - // } - while let Some(GetEventStreamResponse { events, .. }) = rx.recv().await { println!("recieve events {:?}", events.len()); if events.is_empty() { diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index aab24b770631b..241031daf8214 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -34,7 +34,7 @@ use jni::objects::{ JValue, JValueGen, JValueOwned, ReleaseMode, }; use jni::signature::ReturnType; -use jni::sys::{jboolean, jbyte, jdouble, jfloat, jint, jlong, jobject, jshort, jsize, jvalue}; +use jni::sys::{jboolean, jbyte, jdouble, jfloat, jint, jlong, jshort, jsize, jvalue}; use jni::JNIEnv; use prost::{DecodeError, Message}; use risingwave_common::array::{ArrayError, StreamChunk}; @@ -46,7 +46,6 @@ use risingwave_common::util::panic::rw_catch_unwind; use risingwave_storage::error::StorageError; use thiserror::Error; use tokio::runtime::Runtime; -use tokio::sync::mpsc::UnboundedSender; use risingwave_common::jvm_runtime::MyPtr; use risingwave_pb::connector_service::{CdcMessage, GetEventStreamResponse}; @@ -849,21 +848,21 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel _ => unreachable!() }; let payload = env.call_method(&mut java_element, "getPayload", "()Ljava/lang/String;", &[]).unwrap(); - let mut payload = match payload { + let payload = match payload { JValueGen::Object(obj) => obj, _ => unreachable!() }; let payload: String = env.get_string(&JString::from(payload)).unwrap().into(); let partition = env.call_method(&mut java_element, "getPartition", "()Ljava/lang/String;", &[]).unwrap(); - let mut partition = match partition { + let partition = match partition { JValueGen::Object(obj) => obj, _ => unreachable!() }; let partition: String = env.get_string(&JString::from(partition)).unwrap().into(); let offset = env.call_method(&mut java_element, "getOffset", "()Ljava/lang/String;", &[]).unwrap(); - let mut offset = match offset { + let offset = match offset { JValueGen::Object(obj) => obj, _ => unreachable!() }; From 768776df417698aab5940572de265b7774c9f1a6 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Fri, 25 Aug 2023 15:02:45 +0800 Subject: [PATCH 05/23] fmt --- .../source/core/JniSourceHandler.java | 1 - src/common/src/jvm_runtime.rs | 12 +----------- src/compute/src/lib.rs | 19 ++++--------------- src/connector/src/source/cdc/source/reader.rs | 7 ++----- src/java_binding/src/lib.rs | 10 +++------- 5 files changed, 10 insertions(+), 39 deletions(-) diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java index 0864d25cd6eaf..c21487e7777a4 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java @@ -60,7 +60,6 @@ public void run() { resp.getEventsCount()); Binding.sendMsgToChannel(channelPtr, resp); - Thread.sleep(10000); } } catch (Throwable e) { LOG.error("Poll engine output channel fail. ", e); diff --git a/src/common/src/jvm_runtime.rs b/src/common/src/jvm_runtime.rs index 5a33a5783e4b7..faa45ddac3018 100644 --- a/src/common/src/jvm_runtime.rs +++ b/src/common/src/jvm_runtime.rs @@ -56,14 +56,4 @@ pub static JVM: LazyLock> = LazyLock::new(|| { Arc::new(jvm) }); - -pub struct MyPtr { - pub ptr: Sender, - pub num: u64, -} - -impl Drop for MyPtr { - fn drop(&mut self) { - println!("drop MyPtr, num = {}", self.num); - } -} \ No newline at end of file +pub type MyJniSender = Sender; diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index 9859190732fbc..4bd12d6a28b28 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -194,7 +194,7 @@ use jni::{JNIEnv, NativeMethod}; use jni::objects::{JClass, JObject, JString, JValue, JValueGen}; use jni::strings::JNIString; use jni::sys::{jint, jlong}; -use risingwave_common::jvm_runtime::{JVM, MyPtr}; +use risingwave_common::jvm_runtime::{JVM, MyJniSender}; use risingwave_pb::connector_service::{CdcMessage, GetEventStreamResponse}; use crate::server::compute_node_serve; @@ -317,13 +317,9 @@ impl<'a, T> Pointer<'a, T> { #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( mut env: EnvParam<'a>, - channel: Pointer<'a, MyPtr>, + channel: Pointer<'a, MyJniSender>, mut msg: JObject<'a>, ) { - - println!("channel_ptr = {}, num = {}", channel.pointer, channel.as_ref().num); - // let channel: &mut UnboundedSender = unsafe { &mut *(channel_ptr.pointer as *mut UnboundedSender) }; - let source_id = env.env.call_method(&mut msg, "getSourceId", "()J", &[]).unwrap(); let source_id = source_id.j().unwrap(); @@ -375,7 +371,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel events, }; println!("before send"); - let _ = channel.as_ref().ptr.blocking_send(get_event_stream_response).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + let _ = channel.as_ref().blocking_send(get_event_stream_response).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); println!("send successfully"); } @@ -383,19 +379,12 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel fn run_jvm() { let mut env = JVM.attach_current_thread_as_daemon().unwrap(); - // Call Java Math#abs(-10) - let x = JValue::from(-10); - let val: jint = env.call_static_method("java/lang/Math", "abs", "(I)I", &[x]).unwrap() - .i().unwrap(); - - assert_eq!(val, 10); - let string_class = env.find_class("java/lang/String").unwrap(); let jarray = env.new_object_array(0, string_class, JObject::null()).unwrap(); let fn_ptr = Java_com_risingwave_java_binding_Binding_sendMsgToChannel as *const fn ( EnvParam<'static>, - Pointer<'static, MyPtr>, + Pointer<'static, MyJniSender>, JObject<'static> ); diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 8972ca07d929a..a887935e2af8c 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -20,7 +20,7 @@ use futures::{pin_mut, StreamExt, TryStreamExt}; use futures_async_stream::try_stream; use jni::objects::{JObject, JValue}; use tokio::sync::mpsc; -use risingwave_common::jvm_runtime::{JVM, MyPtr}; +use risingwave_common::jvm_runtime::{JVM, MyJniSender}; use risingwave_common::util::addr::HostAddr; use risingwave_pb::connector_service::GetEventStreamResponse; @@ -180,10 +180,7 @@ impl CdcSplitReader { let (tx, mut rx) = mpsc::channel(1024); - let tx: Box = Box::new(MyPtr { - ptr: tx, - num: 123456, - }); + let tx: Box = Box::new(tx); let source_type = self.conn_props.get_source_type_pb()?; diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index 241031daf8214..4570471195b2d 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -46,7 +46,7 @@ use risingwave_common::util::panic::rw_catch_unwind; use risingwave_storage::error::StorageError; use thiserror::Error; use tokio::runtime::Runtime; -use risingwave_common::jvm_runtime::MyPtr; +use risingwave_common::jvm_runtime::MyJniSender; use risingwave_pb::connector_service::{CdcMessage, GetEventStreamResponse}; use crate::stream_chunk_iterator::{StreamChunkIterator, StreamChunkRow}; @@ -822,13 +822,9 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>( #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( mut env: EnvParam<'a>, - channel: Pointer<'a, MyPtr>, + channel: Pointer<'a, MyJniSender>, mut msg: JObject<'a>, ) { - - println!("channel_ptr = {}, num = {}", channel.pointer, channel.as_ref().num); - // let channel: &mut UnboundedSender = unsafe { &mut *(channel_ptr.pointer as *mut UnboundedSender) }; - let source_id = env.env.call_method(&mut msg, "getSourceId", "()J", &[]).unwrap(); let source_id = source_id.j().unwrap(); @@ -880,7 +876,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel events, }; println!("before send"); - let _ = channel.as_ref().ptr.blocking_send(get_event_stream_response).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + let _ = channel.as_ref().blocking_send(get_event_stream_response).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); println!("send successfully"); } From 03664e9a00779936ad07c91b3c53a33814a90c53 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Fri, 25 Aug 2023 15:25:59 +0800 Subject: [PATCH 06/23] fmt --- src/compute/src/lib.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index 4bd12d6a28b28..0f378c189a104 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -188,8 +188,11 @@ fn validate_opts(opts: &ComputeNodeOpts) { use std::future::Future; use std::marker::PhantomData; +use std::mem::transmute; use std::ops::{Deref, DerefMut}; use std::pin::Pin; +use std::ptr; +use std::ptr::null; use jni::{JNIEnv, NativeMethod}; use jni::objects::{JClass, JObject, JString, JValue, JValueGen}; use jni::strings::JNIString; @@ -382,17 +385,13 @@ fn run_jvm() { let string_class = env.find_class("java/lang/String").unwrap(); let jarray = env.new_object_array(0, string_class, JObject::null()).unwrap(); - let fn_ptr = Java_com_risingwave_java_binding_Binding_sendMsgToChannel as *const fn ( - EnvParam<'static>, - Pointer<'static, MyJniSender>, - JObject<'static> - ); + let fn_ptr = Java_com_risingwave_java_binding_Binding_sendMsgToChannel as *mut c_void; let binding_class = env.find_class("com/risingwave/java/binding/Binding").unwrap(); env.register_native_methods(binding_class, &[NativeMethod { name: JNIString::from("sendMsgToChannel"), sig: JNIString::from("(JLjava/lang/Object;)V"), - fn_ptr: fn_ptr as *mut c_void, + fn_ptr, }]).unwrap(); let _ = env.call_static_method("com/risingwave/connector/ConnectorService", "main", "([Ljava/lang/String;)V", &[JValue::Object(&jarray)]).inspect_err(|e| eprintln!("{:?}", e)); From 37e530e3e73828184322b9568cfa28674170bc1a Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Fri, 25 Aug 2023 18:48:22 +0800 Subject: [PATCH 07/23] remove cdylib --- src/compute/src/lib.rs | 344 ++++++++++++++++++++---------------- src/java_binding/Cargo.toml | 2 - src/java_binding/src/lib.rs | 5 + 3 files changed, 200 insertions(+), 151 deletions(-) diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index 0f378c189a104..ae5b2e443c1d9 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -32,11 +32,18 @@ pub mod server; pub mod telemetry; use std::ffi::c_void; +use std::future::Future; +use std::pin::Pin; use clap::{Parser, ValueEnum}; +use jni::NativeMethod; +use jni::objects::{JObject, JValue}; +use jni::strings::JNIString; use risingwave_common::config::{AsyncStackTraceOption, OverrideConfig}; use risingwave_common::util::resource_util::cpu::total_cpu_available; use risingwave_common::util::resource_util::memory::total_memory_available_bytes; use serde::{Deserialize, Serialize}; +use risingwave_common::jvm_runtime::JVM; +use risingwave_java_binding::run_this_func_to_get_valid_ptr_from_java_binding; /// Command-line arguments for compute-node. #[derive(Parser, Clone, Debug, OverrideConfig)] @@ -186,20 +193,6 @@ fn validate_opts(opts: &ComputeNodeOpts) { } } -use std::future::Future; -use std::marker::PhantomData; -use std::mem::transmute; -use std::ops::{Deref, DerefMut}; -use std::pin::Pin; -use std::ptr; -use std::ptr::null; -use jni::{JNIEnv, NativeMethod}; -use jni::objects::{JClass, JObject, JString, JValue, JValueGen}; -use jni::strings::JNIString; -use jni::sys::{jint, jlong}; -use risingwave_common::jvm_runtime::{JVM, MyJniSender}; -use risingwave_pb::connector_service::{CdcMessage, GetEventStreamResponse}; - use crate::server::compute_node_serve; /// Start compute node @@ -241,158 +234,211 @@ pub fn start( }) } -#[repr(C)] -pub struct EnvParam<'a> { - env: JNIEnv<'a>, - class: JClass<'a>, -} +fn run_jvm() { + let mut env = JVM.attach_current_thread_as_daemon().unwrap(); + let string_class = env.find_class("java/lang/String").unwrap(); + let jarray = env.new_object_array(0, string_class, JObject::null()).unwrap(); -impl<'a> Deref for EnvParam<'a> { - type Target = JNIEnv<'a>; + run_this_func_to_get_valid_ptr_from_java_binding(); - fn deref(&self) -> &Self::Target { - &self.env - } -} + let binding_class = env.find_class("com/risingwave/java/binding/Binding").unwrap(); + env.register_native_methods(binding_class, &[ + NativeMethod { + name: JNIString::from("vnodeCount"), + sig: JNIString::from("()I"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_vnodeCount as *mut c_void, + }, -impl<'a> DerefMut for EnvParam<'a> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.env - } -} -impl<'a> EnvParam<'a> { - pub fn get_class(&self) -> &JClass<'a> { - &self.class - } -} + NativeMethod { + name: JNIString::from("hummockIteratorNew"), + sig: JNIString::from("([B)J"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_hummockIteratorNew as *mut c_void, + }, -#[repr(transparent)] -pub struct Pointer<'a, T> { - pointer: jlong, - _phantom: PhantomData<&'a T>, -} -impl<'a, T> Default for Pointer<'a, T> { - fn default() -> Self { - Self { - pointer: 0, - _phantom: Default::default(), - } - } -} + NativeMethod { + name: JNIString::from("hummockIteratorNext"), + sig: JNIString::from("(J)J"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_hummockIteratorNext as *mut c_void, + }, -impl From for Pointer<'static, T> { - fn from(value: T) -> Self { - Pointer { - pointer: Box::into_raw(Box::new(value)) as jlong, - _phantom: PhantomData, - } - } -} -impl Pointer<'static, T> { - fn null() -> Self { - Pointer { - pointer: 0, - _phantom: PhantomData, - } - } -} + NativeMethod { + name: JNIString::from("hummockIteratorClose"), + sig: JNIString::from("(J)V"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_hummockIteratorClose as *mut c_void, + }, -impl<'a, T> Pointer<'a, T> { - fn as_ref(&self) -> &'a T { - debug_assert!(self.pointer != 0); - unsafe { &*(self.pointer as *const T) } - } - fn as_mut(&mut self) -> &'a mut T { - debug_assert!(self.pointer != 0); - unsafe { &mut *(self.pointer as *mut T) } - } + NativeMethod { + name: JNIString::from("rowGetKey"), + sig: JNIString::from("(J)[B"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetKey as *mut c_void, + }, - fn drop(self) { - debug_assert!(self.pointer != 0); - unsafe { drop(Box::from_raw(self.pointer as *mut T)) } - } -} -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( - mut env: EnvParam<'a>, - channel: Pointer<'a, MyJniSender>, - mut msg: JObject<'a>, -) { - let source_id = env.env.call_method(&mut msg, "getSourceId", "()J", &[]).unwrap(); - let source_id = source_id.j().unwrap(); - - let events_list = env.env.call_method(&mut msg, "getEventsList", "()Ljava/util/List;", &[]).unwrap(); - let mut events_list = match events_list { - JValueGen::Object(obj) => obj, - _ => unreachable!() - }; - - - let size = env.env.call_method(&mut events_list, "size", "()I", &[]).unwrap().i().unwrap(); - let mut events = Vec::with_capacity(size as usize); - for i in 0..size { - let java_element = env.call_method(&mut events_list, "get", "(I)Ljava/lang/Object;", &[JValue::from(i as i32)]).unwrap(); - let mut java_element = match java_element { - JValueGen::Object(obj) => obj, - _ => unreachable!() - }; - let payload = env.call_method(&mut java_element, "getPayload", "()Ljava/lang/String;", &[]).unwrap(); - let payload = match payload { - JValueGen::Object(obj) => obj, - _ => unreachable!() - }; - let payload: String = env.get_string(&JString::from(payload)).unwrap().into(); - - let partition = env.call_method(&mut java_element, "getPartition", "()Ljava/lang/String;", &[]).unwrap(); - let partition = match partition { - JValueGen::Object(obj) => obj, - _ => unreachable!() - }; - let partition: String = env.get_string(&JString::from(partition)).unwrap().into(); - - let offset = env.call_method(&mut java_element, "getOffset", "()Ljava/lang/String;", &[]).unwrap(); - let offset = match offset { - JValueGen::Object(obj) => obj, - _ => unreachable!() - }; - let offset: String = env.get_string(&JString::from(offset)).unwrap().into(); - - println!("source_id = {:?}, payload = {:?}, partition = {:?}, offset = {:?}", source_id, payload, partition, offset); - events.push(CdcMessage { - payload, - partition, - offset, - }) - } - let get_event_stream_response = GetEventStreamResponse { - source_id: source_id as u64, - events, - }; - println!("before send"); - let _ = channel.as_ref().blocking_send(get_event_stream_response).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); - println!("send successfully"); -} + NativeMethod { + name: JNIString::from("rowGetOp"), + sig: JNIString::from("(J)I"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetOp as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowIsNull"), + sig: JNIString::from("(JI)Z"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowIsNull as *mut c_void, + }, -fn run_jvm() { - let mut env = JVM.attach_current_thread_as_daemon().unwrap(); - let string_class = env.find_class("java/lang/String").unwrap(); - let jarray = env.new_object_array(0, string_class, JObject::null()).unwrap(); - let fn_ptr = Java_com_risingwave_java_binding_Binding_sendMsgToChannel as *mut c_void; + NativeMethod { + name: JNIString::from("rowGetInt16Value"), + sig: JNIString::from("(JI)S"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetInt16Value as *mut c_void, + }, - let binding_class = env.find_class("com/risingwave/java/binding/Binding").unwrap(); - env.register_native_methods(binding_class, &[NativeMethod { - name: JNIString::from("sendMsgToChannel"), - sig: JNIString::from("(JLjava/lang/Object;)V"), - fn_ptr, - }]).unwrap(); + + NativeMethod { + name: JNIString::from("rowGetInt32Value"), + sig: JNIString::from("(JI)I"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetInt32Value as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetInt64Value"), + sig: JNIString::from("(JI)J"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetInt64Value as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetFloatValue"), + sig: JNIString::from("(JI)F"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetFloatValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetDoubleValue"), + sig: JNIString::from("(JI)D"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetDoubleValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetBooleanValue"), + sig: JNIString::from("(JI)Z"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetBooleanValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetStringValue"), + sig: JNIString::from("(JI)Ljava/lang/String;"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetStringValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetTimestampValue"), + sig: JNIString::from("(JI)Ljava/sql/Timestamp;"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetTimestampValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetDecimalValue"), + sig: JNIString::from("(JI)Ljava/math/BigDecimal;"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetDecimalValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetTimeValue"), + sig: JNIString::from("(JI)Ljava/sql/Time;"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetTimeValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetDateValue"), + sig: JNIString::from("(JI)Ljava/sql/Date;"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetDateValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetIntervalValue"), + sig: JNIString::from("(JI)Ljava/lang/String;"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetIntervalValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetJsonbValue"), + sig: JNIString::from("(JI)Ljava/lang/String;"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetJsonbValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetByteaValue"), + sig: JNIString::from("(JI)[B"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetByteaValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowGetArrayValue"), + sig: JNIString::from("(JILjava/lang/Class;)Ljava/lang/Object;"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetArrayValue as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("rowClose"), + sig: JNIString::from("(J)V"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowClose as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("streamChunkIteratorNew"), + sig: JNIString::from("([B)J"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_streamChunkIteratorNew as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("streamChunkIteratorNext"), + sig: JNIString::from("(J)J"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_streamChunkIteratorNext as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("streamChunkIteratorClose"), + sig: JNIString::from("(J)V"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_streamChunkIteratorClose as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("streamChunkIteratorFromPretty"), + sig: JNIString::from("(Ljava/lang/String;)J"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_streamChunkIteratorFromPretty as *mut c_void, + }, + + + NativeMethod { + name: JNIString::from("sendMsgToChannel"), + sig: JNIString::from("(JLjava/lang/Object;)V"), + fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_sendMsgToChannel as *mut c_void, + }, + + ]).unwrap(); let _ = env.call_static_method("com/risingwave/connector/ConnectorService", "main", "([Ljava/lang/String;)V", &[JValue::Object(&jarray)]).inspect_err(|e| eprintln!("{:?}", e)); } diff --git a/src/java_binding/Cargo.toml b/src/java_binding/Cargo.toml index 9eda6a43e5bb2..38ee45081d72e 100644 --- a/src/java_binding/Cargo.toml +++ b/src/java_binding/Cargo.toml @@ -37,8 +37,6 @@ tracing = "0.1" [dev-dependencies] risingwave_expr = { workspace = true } -[lib] -crate-type = ["cdylib"] [[bin]] name = "data-chunk-payload-generator" diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index 4570471195b2d..4c77177702c1f 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -819,6 +819,11 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>( pointer.drop() } +pub fn run_this_func_to_get_valid_ptr_from_java_binding() { + println!("run_this_func_to_get_valid_ptr_from_java_binding") +} + + #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( mut env: EnvParam<'a>, From 89d421f816d62eccc1935e28ad4550ad4e4a9eeb Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Fri, 25 Aug 2023 19:06:01 +0800 Subject: [PATCH 08/23] fmt --- .../source/core/SourceHandlerFactory.java | 4 +- src/common/Cargo.toml | 2 +- src/common/src/jvm_runtime.rs | 19 +++-- src/common/src/lib.rs | 2 +- src/compute/Cargo.toml | 4 +- src/compute/src/lib.rs | 25 +++++-- src/connector/Cargo.toml | 2 +- src/connector/src/source/cdc/source/reader.rs | 45 ++++++------ src/java_binding/Cargo.toml | 1 - src/java_binding/src/lib.rs | 70 ++++++++++++++----- 10 files changed, 111 insertions(+), 63 deletions(-) diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java index 8974861d29e18..ec51d3123e448 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java @@ -56,7 +56,7 @@ public static void startJniSourceHandler( var config = new DbzConnectorConfig( source, sourceId, startOffset, mutableUserProps, snapshotDone); - JniSourceHandler hanlder = new JniSourceHandler(config); - hanlder.start(channelPtr); + JniSourceHandler handler = new JniSourceHandler(config); + handler.start(channelPtr); } } diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index 7cac255f5249b..5869c2a6e1634 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -46,6 +46,7 @@ humantime = "2.1" hytra = { workspace = true } itertools = "0.11" itoa = "1.0" +jni = { version = "0.21.1", features = ["invocation"] } lru = { git = "https://github.com/risingwavelabs/lru-rs.git", rev = "cb2d7c7" } memcomparable = { version = "0.2", features = ["decimal"] } num-integer = "0.1" @@ -98,7 +99,6 @@ tracing-subscriber = "0.3.17" twox-hash = "1" url = "2" uuid = "1.4.1" -jni = { version = "0.21.1", features = ["invocation"] } [target.'cfg(not(madsim))'.dependencies] workspace-hack = { path = "../workspace-hack" } diff --git a/src/common/src/jvm_runtime.rs b/src/common/src/jvm_runtime.rs index faa45ddac3018..d61376e550186 100644 --- a/src/common/src/jvm_runtime.rs +++ b/src/common/src/jvm_runtime.rs @@ -1,10 +1,11 @@ use core::option::Option::Some; use core::result::Result::{Err, Ok}; -use risingwave_pb::connector_service::GetEventStreamResponse; use std::fs; use std::path::Path; use std::sync::{Arc, LazyLock}; -use jni::{InitArgsBuilder, JavaVM, JNIVersion}; + +use jni::{InitArgsBuilder, JNIVersion, JavaVM}; +use risingwave_pb::connector_service::GetEventStreamResponse; use tokio::sync::mpsc::Sender; pub static JVM: LazyLock> = LazyLock::new(|| { @@ -19,12 +20,10 @@ pub static JVM: LazyLock> = LazyLock::new(|| { let mut class_vec = vec![]; if let Ok(entries) = fs::read_dir(dir) { - for entry in entries { - if let Ok(entry) = entry { - if let Some(name) = entry.path().file_name() { - println!("{:?}", name); - class_vec.push(String::from( dir_path.to_owned() + name.to_str().to_owned().unwrap())); - } + for entry in entries.flatten() { + if let Some(name) = entry.path().file_name() { + println!("{:?}", name); + class_vec.push(dir_path.to_owned() + name.to_str().to_owned().unwrap()); } } } else { @@ -40,7 +39,7 @@ pub static JVM: LazyLock> = LazyLock::new(|| { // Here we enable some extra JNI checks useful during development // .option("-Xcheck:jni") .option("-ea") - .option(format!("-Djava.class.path={}", class_vec.join(":")) ) + .option(format!("-Djava.class.path={}", class_vec.join(":"))) .option("-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=9111") .build() .unwrap(); @@ -49,7 +48,7 @@ pub static JVM: LazyLock> = LazyLock::new(|| { let jvm = match JavaVM::new(jvm_args) { Err(err) => { panic!("{:?}", err) - }, + } Ok(jvm) => jvm, }; diff --git a/src/common/src/lib.rs b/src/common/src/lib.rs index 22d046b9377f0..b7f5d99e4da85 100644 --- a/src/common/src/lib.rs +++ b/src/common/src/lib.rs @@ -68,11 +68,11 @@ pub mod telemetry; pub mod transaction; pub mod format; +pub mod jvm_runtime; pub mod metrics; pub mod test_utils; pub mod types; pub mod vnode_mapping; -pub mod jvm_runtime; pub mod test_prelude { pub use super::array::{DataChunkTestExt, StreamChunkTestExt}; diff --git a/src/compute/Cargo.toml b/src/compute/Cargo.toml index 607cb394f4a7e..5a3f57cb4df4f 100644 --- a/src/compute/Cargo.toml +++ b/src/compute/Cargo.toml @@ -23,8 +23,8 @@ either = "1" futures = { version = "0.3", default-features = false, features = ["alloc"] } futures-async-stream = { workspace = true } hyper = "0.14" -jni = { version = "0.21.1", features = ["invocation"] } itertools = "0.11" +jni = { version = "0.21.1", features = ["invocation"] } maplit = "1.0.2" pprof = { version = "0.12", features = ["flamegraph"] } prometheus = { version = "0.13" } @@ -33,12 +33,12 @@ risingwave_common = { workspace = true } risingwave_common_service = { workspace = true } risingwave_connector = { workspace = true } risingwave_hummock_sdk = { workspace = true } +risingwave_java_binding = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } risingwave_source = { workspace = true } risingwave_storage = { workspace = true } risingwave_stream = { workspace = true } -risingwave_java_binding = { workspace = true } serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index ae5b2e443c1d9..9235d19c56ace 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -34,16 +34,17 @@ pub mod telemetry; use std::ffi::c_void; use std::future::Future; use std::pin::Pin; + use clap::{Parser, ValueEnum}; -use jni::NativeMethod; use jni::objects::{JObject, JValue}; use jni::strings::JNIString; +use jni::NativeMethod; use risingwave_common::config::{AsyncStackTraceOption, OverrideConfig}; +use risingwave_common::jvm_runtime::JVM; use risingwave_common::util::resource_util::cpu::total_cpu_available; use risingwave_common::util::resource_util::memory::total_memory_available_bytes; -use serde::{Deserialize, Serialize}; -use risingwave_common::jvm_runtime::JVM; use risingwave_java_binding::run_this_func_to_get_valid_ptr_from_java_binding; +use serde::{Deserialize, Serialize}; /// Command-line arguments for compute-node. #[derive(Parser, Clone, Debug, OverrideConfig)] @@ -227,7 +228,6 @@ pub fn start( run_jvm(); }); - for join_handle in join_handle_vec { join_handle.await.unwrap(); } @@ -237,11 +237,15 @@ pub fn start( fn run_jvm() { let mut env = JVM.attach_current_thread_as_daemon().unwrap(); let string_class = env.find_class("java/lang/String").unwrap(); - let jarray = env.new_object_array(0, string_class, JObject::null()).unwrap(); + let jarray = env + .new_object_array(0, string_class, JObject::null()) + .unwrap(); run_this_func_to_get_valid_ptr_from_java_binding(); - let binding_class = env.find_class("com/risingwave/java/binding/Binding").unwrap(); + let binding_class = env + .find_class("com/risingwave/java/binding/Binding") + .unwrap(); env.register_native_methods(binding_class, &[ NativeMethod { name: JNIString::from("vnodeCount"), @@ -440,7 +444,14 @@ fn run_jvm() { ]).unwrap(); - let _ = env.call_static_method("com/risingwave/connector/ConnectorService", "main", "([Ljava/lang/String;)V", &[JValue::Object(&jarray)]).inspect_err(|e| eprintln!("{:?}", e)); + let _ = env + .call_static_method( + "com/risingwave/connector/ConnectorService", + "main", + "([Ljava/lang/String;)V", + &[JValue::Object(&jarray)], + ) + .inspect_err(|e| eprintln!("{:?}", e)); } fn default_total_memory_bytes() -> usize { diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 367fc04f3b32c..bbc344ed2374e 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -50,10 +50,10 @@ glob = "0.3" google-cloud-pubsub = "0.17" icelake = { workspace = true } itertools = "0.11" +jni = { version = "0.21.1", features = ["invocation"] } maplit = "1.0.2" moka = { version = "0.11", features = ["future"] } nexmark = { version = "0.2", features = ["serde"] } -jni = { version = "0.21.1", features = ["invocation"] } num-bigint = "0.4" opendal = "0.39" parking_lot = "0.12" diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index a887935e2af8c..27b8767d87297 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -19,10 +19,10 @@ use async_trait::async_trait; use futures::{pin_mut, StreamExt, TryStreamExt}; use futures_async_stream::try_stream; use jni::objects::{JObject, JValue}; -use tokio::sync::mpsc; -use risingwave_common::jvm_runtime::{JVM, MyJniSender}; +use risingwave_common::jvm_runtime::{MyJniSender, JVM}; use risingwave_common::util::addr::HostAddr; use risingwave_pb::connector_service::GetEventStreamResponse; +use tokio::sync::mpsc; use crate::impl_common_split_reader_logic; use crate::parser::ParserConfig; @@ -35,7 +35,6 @@ use crate::source::{ impl_common_split_reader_logic!(CdcSplitReader, CdcProperties); - pub struct CdcSplitReader { source_id: u64, start_offset: Option, @@ -184,13 +183,22 @@ impl CdcSplitReader { let source_type = self.conn_props.get_source_type_pb()?; - tokio::task::spawn_blocking(move || { let mut env = JVM.attach_current_thread_as_daemon().unwrap(); - env.find_class("com/risingwave/proto/ConnectorServiceProto$SourceType").inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + env.find_class("com/risingwave/proto/ConnectorServiceProto$SourceType") + .inspect_err(|e| eprintln!("{:?}", e)) + .unwrap(); let source_type_arg = JValue::from(source_type as i32); - let st = env.call_static_method("com/risingwave/proto/ConnectorServiceProto$SourceType", "forNumber", "(I)Lcom/risingwave/proto/ConnectorServiceProto$SourceType;", &[source_type_arg]).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + let st = env + .call_static_method( + "com/risingwave/proto/ConnectorServiceProto$SourceType", + "forNumber", + "(I)Lcom/risingwave/proto/ConnectorServiceProto$SourceType;", + &[source_type_arg], + ) + .inspect_err(|e| eprintln!("{:?}", e)) + .unwrap(); let st = env.call_static_method("com/risingwave/connector/api/source/SourceTypeE", "valueOf", "(Lcom/risingwave/proto/ConnectorServiceProto$SourceType;)Lcom/risingwave/connector/api/source/SourceTypeE;", &[(&st).into()]).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); let source_id_arg = JValue::from(self.source_id as i64); @@ -198,26 +206,25 @@ impl CdcSplitReader { let start_offset = match self.start_offset { Some(start_offset) => { let start_offset = env.new_string(start_offset).unwrap(); - env.call_method(start_offset, "toString", "()Ljava/lang/String;", &[]).unwrap() - }, - None => { - jni::objects::JValueGen::Object(JObject::null()) + env.call_method(start_offset, "toString", "()Ljava/lang/String;", &[]) + .unwrap() } + None => jni::objects::JValueGen::Object(JObject::null()), }; let hashmap_class = "java/util/HashMap"; let hashmap_constructor_signature = "()V"; let hashmap_put_signature = "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; - let java_map = env.new_object(hashmap_class, hashmap_constructor_signature, &[]).unwrap(); - for (key, value) in properties.iter() { + let java_map = env + .new_object(hashmap_class, hashmap_constructor_signature, &[]) + .unwrap(); + for (key, value) in &properties { let key = env.new_string(key.to_string()).unwrap(); let value = env.new_string(value.to_string()).unwrap(); - let args = [ - JValue::Object(&key), - JValue::Object(&value), - ]; - env.call_method(&java_map, "put", hashmap_put_signature, &args).unwrap(); + let args = [JValue::Object(&key), JValue::Object(&value)]; + env.call_method(&java_map, "put", hashmap_put_signature, &args) + .unwrap(); } let snapshot_done = JValue::from(self.snapshot_done); @@ -234,7 +241,7 @@ impl CdcSplitReader { }); while let Some(GetEventStreamResponse { events, .. }) = rx.recv().await { - println!("recieve events {:?}", events.len()); + println!("receive events {:?}", events.len()); if events.is_empty() { continue; } @@ -246,5 +253,3 @@ impl CdcSplitReader { } } } - - diff --git a/src/java_binding/Cargo.toml b/src/java_binding/Cargo.toml index 38ee45081d72e..89bcd90a0a047 100644 --- a/src/java_binding/Cargo.toml +++ b/src/java_binding/Cargo.toml @@ -37,7 +37,6 @@ tracing = "0.1" [dev-dependencies] risingwave_expr = { workspace = true } - [[bin]] name = "data-chunk-payload-generator" test = false diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index 4c77177702c1f..c59890a90a387 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -39,15 +39,15 @@ use jni::JNIEnv; use prost::{DecodeError, Message}; use risingwave_common::array::{ArrayError, StreamChunk}; use risingwave_common::hash::VirtualNode; +use risingwave_common::jvm_runtime::MyJniSender; use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::test_prelude::StreamChunkTestExt; use risingwave_common::types::ScalarRefImpl; use risingwave_common::util::panic::rw_catch_unwind; +use risingwave_pb::connector_service::{CdcMessage, GetEventStreamResponse}; use risingwave_storage::error::StorageError; use thiserror::Error; use tokio::runtime::Runtime; -use risingwave_common::jvm_runtime::MyJniSender; -use risingwave_pb::connector_service::{CdcMessage, GetEventStreamResponse}; use crate::stream_chunk_iterator::{StreamChunkIterator, StreamChunkRow}; @@ -823,53 +823,83 @@ pub fn run_this_func_to_get_valid_ptr_from_java_binding() { println!("run_this_func_to_get_valid_ptr_from_java_binding") } - #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( mut env: EnvParam<'a>, channel: Pointer<'a, MyJniSender>, mut msg: JObject<'a>, ) { - let source_id = env.env.call_method(&mut msg, "getSourceId", "()J", &[]).unwrap(); + let source_id = env + .env + .call_method(&mut msg, "getSourceId", "()J", &[]) + .unwrap(); let source_id = source_id.j().unwrap(); - let events_list = env.env.call_method(&mut msg, "getEventsList", "()Ljava/util/List;", &[]).unwrap(); + let events_list = env + .env + .call_method(&mut msg, "getEventsList", "()Ljava/util/List;", &[]) + .unwrap(); let mut events_list = match events_list { JValueGen::Object(obj) => obj, - _ => unreachable!() + _ => unreachable!(), }; - - let size = env.env.call_method(&mut events_list, "size", "()I", &[]).unwrap().i().unwrap(); + let size = env + .env + .call_method(&mut events_list, "size", "()I", &[]) + .unwrap() + .i() + .unwrap(); let mut events = Vec::with_capacity(size as usize); for i in 0..size { - let java_element = env.call_method(&mut events_list, "get", "(I)Ljava/lang/Object;", &[JValue::from(i as i32)]).unwrap(); + let java_element = env + .call_method( + &mut events_list, + "get", + "(I)Ljava/lang/Object;", + &[JValue::from(i)], + ) + .unwrap(); let mut java_element = match java_element { JValueGen::Object(obj) => obj, - _ => unreachable!() + _ => unreachable!(), }; - let payload = env.call_method(&mut java_element, "getPayload", "()Ljava/lang/String;", &[]).unwrap(); + let payload = env + .call_method(&mut java_element, "getPayload", "()Ljava/lang/String;", &[]) + .unwrap(); let payload = match payload { JValueGen::Object(obj) => obj, - _ => unreachable!() + _ => unreachable!(), }; let payload: String = env.get_string(&JString::from(payload)).unwrap().into(); - let partition = env.call_method(&mut java_element, "getPartition", "()Ljava/lang/String;", &[]).unwrap(); + let partition = env + .call_method( + &mut java_element, + "getPartition", + "()Ljava/lang/String;", + &[], + ) + .unwrap(); let partition = match partition { JValueGen::Object(obj) => obj, - _ => unreachable!() + _ => unreachable!(), }; let partition: String = env.get_string(&JString::from(partition)).unwrap().into(); - let offset = env.call_method(&mut java_element, "getOffset", "()Ljava/lang/String;", &[]).unwrap(); + let offset = env + .call_method(&mut java_element, "getOffset", "()Ljava/lang/String;", &[]) + .unwrap(); let offset = match offset { JValueGen::Object(obj) => obj, - _ => unreachable!() + _ => unreachable!(), }; let offset: String = env.get_string(&JString::from(offset)).unwrap().into(); - println!("source_id = {:?}, payload = {:?}, partition = {:?}, offset = {:?}", source_id, payload, partition, offset); + println!( + "source_id = {:?}, payload = {:?}, partition = {:?}, offset = {:?}", + source_id, payload, partition, offset + ); events.push(CdcMessage { payload, partition, @@ -881,7 +911,11 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel events, }; println!("before send"); - let _ = channel.as_ref().blocking_send(get_event_stream_response).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + channel + .as_ref() + .blocking_send(get_event_stream_response) + .inspect_err(|e| eprintln!("{:?}", e)) + .unwrap(); println!("send successfully"); } From 61f5463a39061b0a8a8d48f6b7d684299d156d02 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Fri, 25 Aug 2023 19:55:24 +0800 Subject: [PATCH 09/23] use never inline for run_this_func_to_get_valid_ptr_from_java_binding --- src/compute/src/lib.rs | 10 ++++++---- src/java_binding/src/lib.rs | 5 ++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index 9235d19c56ace..e63f47af0fa9b 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -236,11 +236,8 @@ pub fn start( fn run_jvm() { let mut env = JVM.attach_current_thread_as_daemon().unwrap(); - let string_class = env.find_class("java/lang/String").unwrap(); - let jarray = env - .new_object_array(0, string_class, JObject::null()) - .unwrap(); + // FIXME: remove this function would cause segment fault. run_this_func_to_get_valid_ptr_from_java_binding(); let binding_class = env @@ -444,6 +441,11 @@ fn run_jvm() { ]).unwrap(); + let string_class = env.find_class("java/lang/String").unwrap(); + let jarray = env + .new_object_array(0, string_class, JObject::null()) + .unwrap(); + let _ = env .call_static_method( "com/risingwave/connector/ConnectorService", diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index c59890a90a387..73ba4dcc1f847 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -819,9 +819,8 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>( pointer.drop() } -pub fn run_this_func_to_get_valid_ptr_from_java_binding() { - println!("run_this_func_to_get_valid_ptr_from_java_binding") -} +#[inline(never)] +pub fn run_this_func_to_get_valid_ptr_from_java_binding() {} #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( From 1f2d3ae1e6456d551f80b39cc3004f3b7b7ab25f Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Mon, 28 Aug 2023 14:00:53 +0800 Subject: [PATCH 10/23] refine sendMsgToChannel --- java/com_risingwave_java_binding_Binding.h | 4 ++-- .../source/core/JniSourceHandler.java | 9 +++++++- .../com/risingwave/java/binding/Binding.java | 2 +- src/compute/src/lib.rs | 2 +- src/java_binding/src/lib.rs | 22 ++++++++++++------- 5 files changed, 26 insertions(+), 13 deletions(-) diff --git a/java/com_risingwave_java_binding_Binding.h b/java/com_risingwave_java_binding_Binding.h index c2c235921756f..052eba8915b0a 100644 --- a/java/com_risingwave_java_binding_Binding.h +++ b/java/com_risingwave_java_binding_Binding.h @@ -226,9 +226,9 @@ JNIEXPORT jlong JNICALL Java_com_risingwave_java_binding_Binding_streamChunkIter /* * Class: com_risingwave_java_binding_Binding * Method: sendMsgToChannel - * Signature: (JLjava/lang/Object;)V + * Signature: (JLjava/lang/Object;)Z */ -JNIEXPORT void JNICALL Java_com_risingwave_java_binding_Binding_sendMsgToChannel +JNIEXPORT jboolean JNICALL Java_com_risingwave_java_binding_Binding_sendMsgToChannel (JNIEnv *, jclass, jlong, jobject); #ifdef __cplusplus diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java index c21487e7777a4..f26e8e8b20d30 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java @@ -59,7 +59,14 @@ public void run() { config.getSourceId(), resp.getEventsCount()); - Binding.sendMsgToChannel(channelPtr, resp); + boolean success = Binding.sendMsgToChannel(channelPtr, resp); + if (!success) { + LOG.info( + "Engine#{}: JNI sender broken detected, stop the engine", + config.getSourceId()); + runner.stop(); + return; + } } } catch (Throwable e) { LOG.error("Poll engine output channel fail. ", e); diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java index 93f10a1f829f8..18439184c7ebd 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java @@ -85,5 +85,5 @@ public class Binding { static native long streamChunkIteratorFromPretty(String str); - public static native void sendMsgToChannel(long channelPtr, Object msg); + public static native boolean sendMsgToChannel(long channelPtr, Object msg); } diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index e63f47af0fa9b..4370eb8a0c306 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -435,7 +435,7 @@ fn run_jvm() { NativeMethod { name: JNIString::from("sendMsgToChannel"), - sig: JNIString::from("(JLjava/lang/Object;)V"), + sig: JNIString::from("(JLjava/lang/Object;)Z"), fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_sendMsgToChannel as *mut c_void, }, diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index 73ba4dcc1f847..046183d4d0760 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -34,7 +34,9 @@ use jni::objects::{ JValue, JValueGen, JValueOwned, ReleaseMode, }; use jni::signature::ReturnType; -use jni::sys::{jboolean, jbyte, jdouble, jfloat, jint, jlong, jshort, jsize, jvalue}; +use jni::sys::{ + jboolean, jbyte, jdouble, jfloat, jint, jlong, jshort, jsize, jvalue, JNI_FALSE, JNI_TRUE, +}; use jni::JNIEnv; use prost::{DecodeError, Message}; use risingwave_common::array::{ArrayError, StreamChunk}; @@ -827,7 +829,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel mut env: EnvParam<'a>, channel: Pointer<'a, MyJniSender>, mut msg: JObject<'a>, -) { +) -> jboolean { let source_id = env .env .call_method(&mut msg, "getSourceId", "()J", &[]) @@ -910,12 +912,16 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel events, }; println!("before send"); - channel - .as_ref() - .blocking_send(get_event_stream_response) - .inspect_err(|e| eprintln!("{:?}", e)) - .unwrap(); - println!("send successfully"); + match channel.as_ref().blocking_send(get_event_stream_response) { + Ok(_) => { + println!("send successfully"); + JNI_TRUE + } + Err(e) => { + eprintln!("send error. {:?}", e); + JNI_FALSE + } + } } #[cfg(test)] From 2e3c93f18ab5ccc35e928116e6ef1f733cf7d094 Mon Sep 17 00:00:00 2001 From: Eric Fu Date: Mon, 28 Aug 2023 16:03:01 +0800 Subject: [PATCH 11/23] add copyright Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/common/src/jvm_runtime.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/common/src/jvm_runtime.rs b/src/common/src/jvm_runtime.rs index d61376e550186..c3c13aac17d53 100644 --- a/src/common/src/jvm_runtime.rs +++ b/src/common/src/jvm_runtime.rs @@ -1,3 +1,17 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use core::option::Option::Some; use core::result::Result::{Err, Ok}; use std::fs; From 41583db96b6ca246225f6b878e0a698647afcda5 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Mon, 28 Aug 2023 20:24:56 +0800 Subject: [PATCH 12/23] support resouce reclamation --- .../source/core/JniSourceHandler.java | 18 +++++++++--------- src/java_binding/src/lib.rs | 9 +++++++++ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java index f26e8e8b20d30..f65dc2a6b5672 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java @@ -58,15 +58,15 @@ public void run() { "Engine#{}: emit one chunk {} events to network ", config.getSourceId(), resp.getEventsCount()); - - boolean success = Binding.sendMsgToChannel(channelPtr, resp); - if (!success) { - LOG.info( - "Engine#{}: JNI sender broken detected, stop the engine", - config.getSourceId()); - runner.stop(); - return; - } + } + // If resp is null means just check whether channel is closed. + boolean success = Binding.sendMsgToChannel(channelPtr, resp); + if (!success) { + LOG.info( + "Engine#{}: JNI sender broken detected, stop the engine", + config.getSourceId()); + runner.stop(); + return; } } catch (Throwable e) { LOG.error("Poll engine output channel fail. ", e); diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index 046183d4d0760..c786f18187f12 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -830,6 +830,15 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel channel: Pointer<'a, MyJniSender>, mut msg: JObject<'a>, ) -> jboolean { + // If msg is null means just check whether channel is closed. + if msg.is_null() { + if channel.as_ref().is_closed() { + return JNI_FALSE; + } else { + return JNI_TRUE; + } + } + let source_id = env .env .call_method(&mut msg, "getSourceId", "()J", &[]) From 3041df25ffad5de91bfd44d4e3b4af364a079592 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Mon, 28 Aug 2023 20:30:39 +0800 Subject: [PATCH 13/23] drop channel pointer properly --- src/java_binding/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index c786f18187f12..3a7f982189ce5 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -833,6 +833,8 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel // If msg is null means just check whether channel is closed. if msg.is_null() { if channel.as_ref().is_closed() { + // Drop channel as well. + channel.drop(); return JNI_FALSE; } else { return JNI_TRUE; @@ -927,6 +929,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel JNI_TRUE } Err(e) => { + channel.drop(); eprintln!("send error. {:?}", e); JNI_FALSE } From ef0f0c892789833eab584321008da8db6aec0f7f Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Thu, 31 Aug 2023 17:14:35 +0800 Subject: [PATCH 14/23] serialize proto msg in jni --- java/com_risingwave_java_binding_Binding.h | 4 +- .../source/core/JniSourceHandler.java | 7 +- .../com/risingwave/java/binding/Binding.java | 2 +- src/compute/src/lib.rs | 2 +- src/java_binding/src/lib.rs | 129 ++++-------------- 5 files changed, 35 insertions(+), 109 deletions(-) diff --git a/java/com_risingwave_java_binding_Binding.h b/java/com_risingwave_java_binding_Binding.h index 052eba8915b0a..ba0e53c4afb05 100644 --- a/java/com_risingwave_java_binding_Binding.h +++ b/java/com_risingwave_java_binding_Binding.h @@ -226,10 +226,10 @@ JNIEXPORT jlong JNICALL Java_com_risingwave_java_binding_Binding_streamChunkIter /* * Class: com_risingwave_java_binding_Binding * Method: sendMsgToChannel - * Signature: (JLjava/lang/Object;)Z + * Signature: (J[B)Z */ JNIEXPORT jboolean JNICALL Java_com_risingwave_java_binding_Binding_sendMsgToChannel - (JNIEnv *, jclass, jlong, jobject); + (JNIEnv *, jclass, jlong, jbyteArray); #ifdef __cplusplus } diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java index f65dc2a6b5672..6ad323f87faa7 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java @@ -49,6 +49,7 @@ public void run() { // Thread will block on the channel to get output from engine var resp = runner.getEngine().getOutputChannel().poll(500, TimeUnit.MILLISECONDS); + boolean success; if (resp != null) { ConnectorNodeMetrics.incSourceRowsReceived( config.getSourceType().toString(), @@ -58,9 +59,11 @@ public void run() { "Engine#{}: emit one chunk {} events to network ", config.getSourceId(), resp.getEventsCount()); + success = Binding.sendMsgToChannel(channelPtr, resp.toByteArray()); + } else { + // If resp is null means just check whether channel is closed. + success = Binding.sendMsgToChannel(channelPtr, null); } - // If resp is null means just check whether channel is closed. - boolean success = Binding.sendMsgToChannel(channelPtr, resp); if (!success) { LOG.info( "Engine#{}: JNI sender broken detected, stop the engine", diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java index 18439184c7ebd..9829f6a4a4004 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java @@ -85,5 +85,5 @@ public class Binding { static native long streamChunkIteratorFromPretty(String str); - public static native boolean sendMsgToChannel(long channelPtr, Object msg); + public static native boolean sendMsgToChannel(long channelPtr, byte[] msg); } diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index c722e47183b55..0751618f3d21a 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -433,7 +433,7 @@ fn run_jvm() { NativeMethod { name: JNIString::from("sendMsgToChannel"), - sig: JNIString::from("(JLjava/lang/Object;)Z"), + sig: JNIString::from("(J[B)Z"), fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_sendMsgToChannel as *mut c_void, }, diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index c7b0bd11958e9..4af1d8377084f 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -46,7 +46,7 @@ use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::test_prelude::StreamChunkTestExt; use risingwave_common::types::ScalarRefImpl; use risingwave_common::util::panic::rw_catch_unwind; -use risingwave_pb::connector_service::{CdcMessage, GetEventStreamResponse}; +use risingwave_pb::connector_service::GetEventStreamResponse; use risingwave_storage::error::StorageError; use thiserror::Error; use tokio::runtime::Runtime; @@ -829,114 +829,37 @@ pub fn run_this_func_to_get_valid_ptr_from_java_binding() {} #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( - mut env: EnvParam<'a>, + env: EnvParam<'a>, channel: Pointer<'a, MyJniSender>, - mut msg: JObject<'a>, + msg: JByteArray<'a>, ) -> jboolean { - // If msg is null means just check whether channel is closed. - if msg.is_null() { - if channel.as_ref().is_closed() { - // Drop channel as well. - channel.drop(); - return JNI_FALSE; - } else { - return JNI_TRUE; + execute_and_catch(env, move |env| { + // If msg is null means just check whether channel is closed. + if msg.is_null() { + if channel.as_ref().is_closed() { + // Drop channel as well. + channel.drop(); + return Ok(JNI_FALSE); + } else { + return Ok(JNI_TRUE); + } } - } - let source_id = env - .env - .call_method(&mut msg, "getSourceId", "()J", &[]) - .unwrap(); - let source_id = source_id.j().unwrap(); - - let events_list = env - .env - .call_method(&mut msg, "getEventsList", "()Ljava/util/List;", &[]) - .unwrap(); - let mut events_list = match events_list { - JValueGen::Object(obj) => obj, - _ => unreachable!(), - }; - - let size = env - .env - .call_method(&mut events_list, "size", "()I", &[]) - .unwrap() - .i() - .unwrap(); - let mut events = Vec::with_capacity(size as usize); - for i in 0..size { - let java_element = env - .call_method( - &mut events_list, - "get", - "(I)Ljava/lang/Object;", - &[JValue::from(i)], - ) - .unwrap(); - let mut java_element = match java_element { - JValueGen::Object(obj) => obj, - _ => unreachable!(), - }; - let payload = env - .call_method(&mut java_element, "getPayload", "()Ljava/lang/String;", &[]) - .unwrap(); - let payload = match payload { - JValueGen::Object(obj) => obj, - _ => unreachable!(), - }; - let payload: String = env.get_string(&JString::from(payload)).unwrap().into(); - - let partition = env - .call_method( - &mut java_element, - "getPartition", - "()Ljava/lang/String;", - &[], - ) - .unwrap(); - let partition = match partition { - JValueGen::Object(obj) => obj, - _ => unreachable!(), - }; - let partition: String = env.get_string(&JString::from(partition)).unwrap().into(); - - let offset = env - .call_method(&mut java_element, "getOffset", "()Ljava/lang/String;", &[]) - .unwrap(); - let offset = match offset { - JValueGen::Object(obj) => obj, - _ => unreachable!(), - }; - let offset: String = env.get_string(&JString::from(offset)).unwrap().into(); + let get_event_stream_response: GetEventStreamResponse = Message::decode(to_guarded_slice(&msg, env)?.deref())?; - println!( - "source_id = {:?}, payload = {:?}, partition = {:?}, offset = {:?}", - source_id, payload, partition, offset - ); - events.push(CdcMessage { - payload, - partition, - offset, - }) - } - let get_event_stream_response = GetEventStreamResponse { - source_id: source_id as u64, - events, - }; - println!("before send"); - match channel.as_ref().blocking_send(get_event_stream_response) { - Ok(_) => { - println!("send successfully"); - JNI_TRUE - } - Err(e) => { - channel.drop(); - eprintln!("send error. {:?}", e); - JNI_FALSE + println!("before send"); + match channel.as_ref().blocking_send(get_event_stream_response) { + Ok(_) => { + println!("send successfully"); + Ok(JNI_TRUE) + } + Err(e) => { + channel.drop(); + eprintln!("send error. {:?}", e); + Ok(JNI_FALSE) + } } - } + }) } #[cfg(test)] From a940f8d9c42dca9e594b641af497d6dd493ab966 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Fri, 1 Sep 2023 11:58:25 +0800 Subject: [PATCH 15/23] refactor --- Cargo.lock | 3 +- src/common/Cargo.toml | 1 - src/common/src/lib.rs | 1 - src/compute/src/lib.rs | 8 ++-- src/connector/Cargo.toml | 3 +- src/connector/src/source/cdc/source/reader.rs | 48 +++++++++---------- .../src/jvm_runtime.rs | 16 +++---- src/java_binding/src/lib.rs | 19 +++++--- src/workspace-hack/Cargo.toml | 2 + 9 files changed, 54 insertions(+), 47 deletions(-) rename src/{common => java_binding}/src/jvm_runtime.rs (79%) diff --git a/Cargo.lock b/Cargo.lock index a4999b0b35426..caa5f993245a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6503,7 +6503,6 @@ dependencies = [ "hytra", "itertools 0.11.0", "itoa", - "jni", "libc", "lru 0.7.6", "mach2", @@ -6735,6 +6734,7 @@ dependencies = [ "rand", "reqwest", "risingwave_common", + "risingwave_java_binding", "risingwave_pb", "risingwave_rpc_client", "rust_decimal", @@ -9862,6 +9862,7 @@ dependencies = [ "hyper", "indexmap 1.9.3", "itertools 0.10.5", + "jni", "lexical-core", "lexical-parse-float", "lexical-parse-integer", diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index 43953865b8367..4a5903d7c4f43 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -47,7 +47,6 @@ hyper = "0.14" hytra = { workspace = true } itertools = "0.11" itoa = "1.0" -jni = { version = "0.21.1", features = ["invocation"] } lru = { git = "https://github.com/risingwavelabs/lru-rs.git", rev = "cb2d7c7" } memcomparable = { version = "0.2", features = ["decimal"] } num-integer = "0.1" diff --git a/src/common/src/lib.rs b/src/common/src/lib.rs index 4d89db380de5a..554815d43e753 100644 --- a/src/common/src/lib.rs +++ b/src/common/src/lib.rs @@ -69,7 +69,6 @@ pub mod telemetry; pub mod transaction; pub mod format; -pub mod jvm_runtime; pub mod metrics; pub mod test_utils; pub mod types; diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index 0751618f3d21a..49287cc13bd6e 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -41,9 +41,9 @@ use jni::objects::{JObject, JValue}; use jni::strings::JNIString; use jni::NativeMethod; use risingwave_common::config::{AsyncStackTraceOption, OverrideConfig}; -use risingwave_common::jvm_runtime::JVM; use risingwave_common::util::resource_util::cpu::total_cpu_available; use risingwave_common::util::resource_util::memory::total_memory_available_bytes; +use risingwave_java_binding::jvm_runtime::JVM; use risingwave_java_binding::run_this_func_to_get_valid_ptr_from_java_binding; use serde::{Deserialize, Serialize}; @@ -222,7 +222,7 @@ pub fn start(opts: ComputeNodeOpts) -> Pin + Send>> let (join_handle_vec, _shutdown_send) = compute_node_serve(listen_addr, advertise_addr, opts).await; - tokio::task::spawn_blocking(move || { + std::thread::spawn(move || { run_jvm(); }); @@ -444,6 +444,7 @@ fn run_jvm() { .new_object_array(0, string_class, JObject::null()) .unwrap(); + // FIXME: if we finish rewriting all RPCs to JNI calls, we don't need to run main anymore. let _ = env .call_static_method( "com/risingwave/connector/ConnectorService", @@ -451,7 +452,8 @@ fn run_jvm() { "([Ljava/lang/String;)V", &[JValue::Object(&jarray)], ) - .inspect_err(|e| eprintln!("{:?}", e)); + .inspect_err(|e| eprintln!("{:?}", e)) + .unwrap(); } fn default_total_memory_bytes() -> usize { diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 5636f91e8b884..394cb3292e0ef 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -50,8 +50,8 @@ glob = "0.3" google-cloud-pubsub = "0.19" icelake = { workspace = true } itertools = "0.11" -jsonschema-transpiler = "1.10.0" jni = { version = "0.21.1", features = ["invocation"] } +jsonschema-transpiler = "1.10.0" maplit = "1.0.2" moka = { version = "0.11", features = ["future"] } mysql_async = { version = "0.31", default-features = false, features = ["default"] } @@ -80,6 +80,7 @@ rdkafka = { workspace = true, features = [ ] } reqwest = { version = "0.11", features = ["json"] } risingwave_common = { workspace = true } +risingwave_java_binding = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } rust_decimal = "1" diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 27b8767d87297..58144d35fd5e1 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -19,8 +19,9 @@ use async_trait::async_trait; use futures::{pin_mut, StreamExt, TryStreamExt}; use futures_async_stream::try_stream; use jni::objects::{JObject, JValue}; -use risingwave_common::jvm_runtime::{MyJniSender, JVM}; use risingwave_common::util::addr::HostAddr; +use risingwave_java_binding::jvm_runtime::JVM; +use risingwave_java_binding::GetEventStreamJniSender; use risingwave_pb::connector_service::GetEventStreamResponse; use tokio::sync::mpsc; @@ -179,29 +180,29 @@ impl CdcSplitReader { let (tx, mut rx) = mpsc::channel(1024); - let tx: Box = Box::new(tx); + let tx: Box = Box::new(tx); let source_type = self.conn_props.get_source_type_pb()?; - tokio::task::spawn_blocking(move || { + std::thread::spawn(move || { let mut env = JVM.attach_current_thread_as_daemon().unwrap(); - env.find_class("com/risingwave/proto/ConnectorServiceProto$SourceType") - .inspect_err(|e| eprintln!("{:?}", e)) - .unwrap(); - let source_type_arg = JValue::from(source_type as i32); let st = env .call_static_method( "com/risingwave/proto/ConnectorServiceProto$SourceType", "forNumber", "(I)Lcom/risingwave/proto/ConnectorServiceProto$SourceType;", - &[source_type_arg], + &[JValue::from(source_type as i32)], ) - .inspect_err(|e| eprintln!("{:?}", e)) + .inspect_err(|e| tracing::error!("{:?}", e)) .unwrap(); - let st = env.call_static_method("com/risingwave/connector/api/source/SourceTypeE", "valueOf", "(Lcom/risingwave/proto/ConnectorServiceProto$SourceType;)Lcom/risingwave/connector/api/source/SourceTypeE;", &[(&st).into()]).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); - let source_id_arg = JValue::from(self.source_id as i64); + let st = env.call_static_method( + "com/risingwave/connector/api/source/SourceTypeE", + "valueOf", + "(Lcom/risingwave/proto/ConnectorServiceProto$SourceType;)Lcom/risingwave/connector/api/source/SourceTypeE;", + &[(&st).into()] + ).inspect_err(|e| tracing::error!("{:?}", e)).unwrap(); let start_offset = match self.start_offset { Some(start_offset) => { @@ -212,36 +213,35 @@ impl CdcSplitReader { None => jni::objects::JValueGen::Object(JObject::null()), }; - let hashmap_class = "java/util/HashMap"; - let hashmap_constructor_signature = "()V"; - let hashmap_put_signature = "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; + let java_map = env.new_object("java/util/HashMap", "()V", &[]).unwrap(); - let java_map = env - .new_object(hashmap_class, hashmap_constructor_signature, &[]) - .unwrap(); for (key, value) in &properties { let key = env.new_string(key.to_string()).unwrap(); let value = env.new_string(value.to_string()).unwrap(); let args = [JValue::Object(&key), JValue::Object(&value)]; - env.call_method(&java_map, "put", hashmap_put_signature, &args) - .unwrap(); + env.call_method( + &java_map, + "put", + "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", + &args, + ) + .inspect_err(|e| tracing::error!("{:?}", e)) + .unwrap(); } - let snapshot_done = JValue::from(self.snapshot_done); - let channel_ptr = Box::into_raw(tx) as i64; - println!("channel_ptr = {}", channel_ptr); let channel_ptr = JValue::from(channel_ptr); let _ = env.call_static_method( "com/risingwave/connector/source/core/SourceHandlerFactory", "startJniSourceHandler", "(Lcom/risingwave/connector/api/source/SourceTypeE;JLjava/lang/String;Ljava/util/Map;ZJ)V", - &[(&st).into(), source_id_arg, (&start_offset).into(), JValue::Object(&java_map), snapshot_done, channel_ptr]).inspect_err(|e| eprintln!("{:?}", e)).unwrap(); + &[(&st).into(), JValue::from(self.source_id as i64), (&start_offset).into(), JValue::Object(&java_map), JValue::from(self.snapshot_done), channel_ptr] + ).inspect_err(|e| tracing::error!("{:?}", e)).unwrap(); }); while let Some(GetEventStreamResponse { events, .. }) = rx.recv().await { - println!("receive events {:?}", events.len()); + tracing::debug!("receive events {:?}", events.len()); if events.is_empty() { continue; } diff --git a/src/common/src/jvm_runtime.rs b/src/java_binding/src/jvm_runtime.rs similarity index 79% rename from src/common/src/jvm_runtime.rs rename to src/java_binding/src/jvm_runtime.rs index c3c13aac17d53..455999ef1f453 100644 --- a/src/common/src/jvm_runtime.rs +++ b/src/java_binding/src/jvm_runtime.rs @@ -19,16 +19,14 @@ use std::path::Path; use std::sync::{Arc, LazyLock}; use jni::{InitArgsBuilder, JNIVersion, JavaVM}; -use risingwave_pb::connector_service::GetEventStreamResponse; -use tokio::sync::mpsc::Sender; pub static JVM: LazyLock> = LazyLock::new(|| { - let dir_path = ".risingwave/bin/connector-node/libs/"; + let libs_path = ".risingwave/bin/connector-node/libs/"; - let dir = Path::new(dir_path); + let dir = Path::new(libs_path); if !dir.is_dir() { - panic!("{} is not a directory", dir_path); + panic!("{} is not a directory", libs_path); } let mut class_vec = vec![]; @@ -36,12 +34,11 @@ pub static JVM: LazyLock> = LazyLock::new(|| { if let Ok(entries) = fs::read_dir(dir) { for entry in entries.flatten() { if let Some(name) = entry.path().file_name() { - println!("{:?}", name); - class_vec.push(dir_path.to_owned() + name.to_str().to_owned().unwrap()); + class_vec.push(libs_path.to_owned() + name.to_str().to_owned().unwrap()); } } } else { - println!("failed to read directory {}", dir_path); + panic!("failed to read directory {}", libs_path); } // Build the VM properties @@ -66,7 +63,6 @@ pub static JVM: LazyLock> = LazyLock::new(|| { Ok(jvm) => jvm, }; + tracing::info!("initialize JVM successfully"); Arc::new(jvm) }); - -pub type MyJniSender = Sender; diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index 4af1d8377084f..4f4afe88392b8 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -20,6 +20,7 @@ #![feature(result_option_inspect)] pub mod hummock_iterator; +pub mod jvm_runtime; pub mod stream_chunk_iterator; use std::backtrace::Backtrace; @@ -41,7 +42,6 @@ use jni::JNIEnv; use prost::{DecodeError, Message}; use risingwave_common::array::{ArrayError, StreamChunk}; use risingwave_common::hash::VirtualNode; -use risingwave_common::jvm_runtime::MyJniSender; use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::test_prelude::StreamChunkTestExt; use risingwave_common::types::ScalarRefImpl; @@ -50,8 +50,10 @@ use risingwave_pb::connector_service::GetEventStreamResponse; use risingwave_storage::error::StorageError; use thiserror::Error; use tokio::runtime::Runtime; +use tokio::sync::mpsc::Sender; use crate::stream_chunk_iterator::{StreamChunkIterator, StreamChunkRow}; +pub type GetEventStreamJniSender = Sender; static RUNTIME: LazyLock = LazyLock::new(|| tokio::runtime::Runtime::new().unwrap()); @@ -827,10 +829,14 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>( #[inline(never)] pub fn run_this_func_to_get_valid_ptr_from_java_binding() {} +/// Send messages to the channel received by `CdcSplitReader`. +/// If msg is null, just check whether the channel is closed. +/// Return true if sending is successful, otherwise, return false so that caller can stop +/// gracefully. #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( env: EnvParam<'a>, - channel: Pointer<'a, MyJniSender>, + channel: Pointer<'a, GetEventStreamJniSender>, msg: JByteArray<'a>, ) -> jboolean { execute_and_catch(env, move |env| { @@ -845,17 +851,18 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel } } - let get_event_stream_response: GetEventStreamResponse = Message::decode(to_guarded_slice(&msg, env)?.deref())?; + let get_event_stream_response: GetEventStreamResponse = + Message::decode(to_guarded_slice(&msg, env)?.deref())?; - println!("before send"); + tracing::debug!("before send"); match channel.as_ref().blocking_send(get_event_stream_response) { Ok(_) => { - println!("send successfully"); + tracing::debug!("send successfully"); Ok(JNI_TRUE) } Err(e) => { channel.drop(); - eprintln!("send error. {:?}", e); + tracing::debug!("send error. {:?}", e); Ok(JNI_FALSE) } } diff --git a/src/workspace-hack/Cargo.toml b/src/workspace-hack/Cargo.toml index c97eae2c7b92a..98bb62d27e6af 100644 --- a/src/workspace-hack/Cargo.toml +++ b/src/workspace-hack/Cargo.toml @@ -47,6 +47,7 @@ hashbrown-5ef9efb8ec2df382 = { package = "hashbrown", version = "0.12", features hyper = { version = "0.14", features = ["full"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } +jni = { version = "0.21", features = ["invocation"] } lexical-core = { version = "0.8", features = ["format"] } lexical-parse-float = { version = "0.8", default-features = false, features = ["format", "std"] } lexical-parse-integer = { version = "0.8", default-features = false, features = ["format", "std"] } @@ -140,6 +141,7 @@ hashbrown-5ef9efb8ec2df382 = { package = "hashbrown", version = "0.12", features hyper = { version = "0.14", features = ["full"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } +jni = { version = "0.21", features = ["invocation"] } lexical-core = { version = "0.8", features = ["format"] } lexical-parse-float = { version = "0.8", default-features = false, features = ["format", "std"] } lexical-parse-integer = { version = "0.8", default-features = false, features = ["format", "std"] } From d0ef67f19a75100ee9879703c5c3bc6036104593 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Mon, 4 Sep 2023 20:05:16 +0800 Subject: [PATCH 16/23] first version of embedded connector node --- Cargo.lock | 28 +- Cargo.toml | 2 + java/com_risingwave_java_binding_Binding.h | 4 +- .../source/core/JniSourceHandler.java | 4 +- .../com/risingwave/java/binding/Binding.java | 11 +- src/compute/Cargo.toml | 2 +- src/compute/src/lib.rs | 237 +---- src/connector/Cargo.toml | 2 +- src/connector/src/source/cdc/source/reader.rs | 4 +- src/java_binding/Cargo.toml | 4 + src/java_binding/src/jvm_runtime.rs | 68 -- src/java_binding/src/lib.rs | 867 +---------------- src/jni_core/Cargo.toml | 38 + .../data-chunk-payload-convert-generator.rs | 97 ++ .../src/bin/data-chunk-payload-generator.rs | 92 ++ .../src/hummock_iterator.rs | 0 src/jni_core/src/jvm_runtime.rs | 271 ++++++ src/jni_core/src/lib.rs | 887 ++++++++++++++++++ .../src/stream_chunk_iterator.rs | 0 src/meta/Cargo.toml | 1 + src/meta/src/lib.rs | 4 + 21 files changed, 1442 insertions(+), 1181 deletions(-) delete mode 100644 src/java_binding/src/jvm_runtime.rs create mode 100644 src/jni_core/Cargo.toml create mode 100644 src/jni_core/src/bin/data-chunk-payload-convert-generator.rs create mode 100644 src/jni_core/src/bin/data-chunk-payload-generator.rs rename src/{java_binding => jni_core}/src/hummock_iterator.rs (100%) create mode 100644 src/jni_core/src/jvm_runtime.rs create mode 100644 src/jni_core/src/lib.rs rename src/{java_binding => jni_core}/src/stream_chunk_iterator.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index caa5f993245a0..8d61306bdbd99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6664,7 +6664,7 @@ dependencies = [ "risingwave_common_service", "risingwave_connector", "risingwave_hummock_sdk", - "risingwave_java_binding", + "risingwave_jni_core", "risingwave_pb", "risingwave_rpc_client", "risingwave_source", @@ -6734,7 +6734,7 @@ dependencies = [ "rand", "reqwest", "risingwave_common", - "risingwave_java_binding", + "risingwave_jni_core", "risingwave_pb", "risingwave_rpc_client", "rust_decimal", @@ -6997,6 +6997,29 @@ dependencies = [ [[package]] name = "risingwave_java_binding" version = "0.1.0" +dependencies = [ + "bytes", + "futures", + "itertools 0.11.0", + "jni", + "madsim-tokio", + "prost", + "risingwave_common", + "risingwave_expr", + "risingwave_hummock_sdk", + "risingwave_jni_core", + "risingwave_object_store", + "risingwave_pb", + "risingwave_storage", + "serde", + "serde_json", + "thiserror", + "tracing", +] + +[[package]] +name = "risingwave_jni_core" +version = "0.1.0" dependencies = [ "bytes", "futures", @@ -7059,6 +7082,7 @@ dependencies = [ "risingwave_common_service", "risingwave_connector", "risingwave_hummock_sdk", + "risingwave_jni_core", "risingwave_object_store", "risingwave_pb", "risingwave_rpc_client", diff --git a/Cargo.toml b/Cargo.toml index d8713d66c3276..0c56df58c5e24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ members = [ "src/frontend", "src/frontend/planner_test", "src/java_binding", + "src/jni_core", "src/meta", "src/object_store", "src/prost", @@ -114,6 +115,7 @@ risingwave_test_runner = { path = "./src/test_runner" } risingwave_udf = { path = "./src/udf" } risingwave_variables = { path = "./src/utils/variables" } risingwave_java_binding = { path = "./src/java_binding" } +risingwave_jni_core = { path = "src/jni_core" } [profile.dev] lto = 'off' diff --git a/java/com_risingwave_java_binding_Binding.h b/java/com_risingwave_java_binding_Binding.h index ba0e53c4afb05..a3e9aa95ec84e 100644 --- a/java/com_risingwave_java_binding_Binding.h +++ b/java/com_risingwave_java_binding_Binding.h @@ -225,10 +225,10 @@ JNIEXPORT jlong JNICALL Java_com_risingwave_java_binding_Binding_streamChunkIter /* * Class: com_risingwave_java_binding_Binding - * Method: sendMsgToChannel + * Method: sendCdcSourceMsgToChannel * Signature: (J[B)Z */ -JNIEXPORT jboolean JNICALL Java_com_risingwave_java_binding_Binding_sendMsgToChannel +JNIEXPORT jboolean JNICALL Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel (JNIEnv *, jclass, jlong, jbyteArray); #ifdef __cplusplus diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java index 6ad323f87faa7..c4fe4be63fc89 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java @@ -59,10 +59,10 @@ public void run() { "Engine#{}: emit one chunk {} events to network ", config.getSourceId(), resp.getEventsCount()); - success = Binding.sendMsgToChannel(channelPtr, resp.toByteArray()); + success = Binding.sendCdcSourceMsgToChannel(channelPtr, resp.toByteArray()); } else { // If resp is null means just check whether channel is closed. - success = Binding.sendMsgToChannel(channelPtr, null); + success = Binding.sendCdcSourceMsgToChannel(channelPtr, null); } if (!success) { LOG.info( diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java index 9829f6a4a4004..4a79033b147a8 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java @@ -14,11 +14,16 @@ package com.risingwave.java.binding; -// import io.questdb.jar.jni.JarJniLoader; +import io.questdb.jar.jni.JarJniLoader; public class Binding { + private static final boolean IS_EMBEDDED_CONNECTOR = + Boolean.parseBoolean(System.getProperty("is_embedded_connector")); + static { - // JarJniLoader.loadLib(Binding.class, "/risingwave/jni", "risingwave_java_binding"); + if (!IS_EMBEDDED_CONNECTOR) { + JarJniLoader.loadLib(Binding.class, "/risingwave/jni", "risingwave_java_binding"); + } } public static native int vnodeCount(); @@ -85,5 +90,5 @@ public class Binding { static native long streamChunkIteratorFromPretty(String str); - public static native boolean sendMsgToChannel(long channelPtr, byte[] msg); + public static native boolean sendCdcSourceMsgToChannel(long channelPtr, byte[] msg); } diff --git a/src/compute/Cargo.toml b/src/compute/Cargo.toml index 5a3f57cb4df4f..8845dd0d80de2 100644 --- a/src/compute/Cargo.toml +++ b/src/compute/Cargo.toml @@ -33,7 +33,7 @@ risingwave_common = { workspace = true } risingwave_common_service = { workspace = true } risingwave_connector = { workspace = true } risingwave_hummock_sdk = { workspace = true } -risingwave_java_binding = { workspace = true } +risingwave_jni_core = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } risingwave_source = { workspace = true } diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index 49287cc13bd6e..39153653fe374 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -32,19 +32,14 @@ pub mod rpc; pub mod server; pub mod telemetry; -use std::ffi::c_void; use std::future::Future; use std::pin::Pin; use clap::{Parser, ValueEnum}; -use jni::objects::{JObject, JValue}; -use jni::strings::JNIString; -use jni::NativeMethod; use risingwave_common::config::{AsyncStackTraceOption, OverrideConfig}; use risingwave_common::util::resource_util::cpu::total_cpu_available; use risingwave_common::util::resource_util::memory::total_memory_available_bytes; -use risingwave_java_binding::jvm_runtime::JVM; -use risingwave_java_binding::run_this_func_to_get_valid_ptr_from_java_binding; +use risingwave_jni_core::jvm_runtime; use serde::{Deserialize, Serialize}; /// Command-line arguments for compute-node. @@ -219,243 +214,17 @@ pub fn start(opts: ComputeNodeOpts) -> Pin + Send>> .unwrap(); tracing::info!("advertise addr is {}", advertise_addr); + jvm_runtime::register_native_method_for_jvm(); + let (join_handle_vec, _shutdown_send) = compute_node_serve(listen_addr, advertise_addr, opts).await; - std::thread::spawn(move || { - run_jvm(); - }); - for join_handle in join_handle_vec { join_handle.await.unwrap(); } }) } -fn run_jvm() { - let mut env = JVM.attach_current_thread_as_daemon().unwrap(); - - // FIXME: remove this function would cause segment fault. - run_this_func_to_get_valid_ptr_from_java_binding(); - - let binding_class = env - .find_class("com/risingwave/java/binding/Binding") - .unwrap(); - env.register_native_methods(binding_class, &[ - NativeMethod { - name: JNIString::from("vnodeCount"), - sig: JNIString::from("()I"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_vnodeCount as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("hummockIteratorNew"), - sig: JNIString::from("([B)J"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_hummockIteratorNew as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("hummockIteratorNext"), - sig: JNIString::from("(J)J"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_hummockIteratorNext as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("hummockIteratorClose"), - sig: JNIString::from("(J)V"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_hummockIteratorClose as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetKey"), - sig: JNIString::from("(J)[B"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetKey as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetOp"), - sig: JNIString::from("(J)I"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetOp as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowIsNull"), - sig: JNIString::from("(JI)Z"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowIsNull as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetInt16Value"), - sig: JNIString::from("(JI)S"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetInt16Value as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetInt32Value"), - sig: JNIString::from("(JI)I"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetInt32Value as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetInt64Value"), - sig: JNIString::from("(JI)J"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetInt64Value as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetFloatValue"), - sig: JNIString::from("(JI)F"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetFloatValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetDoubleValue"), - sig: JNIString::from("(JI)D"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetDoubleValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetBooleanValue"), - sig: JNIString::from("(JI)Z"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetBooleanValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetStringValue"), - sig: JNIString::from("(JI)Ljava/lang/String;"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetStringValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetTimestampValue"), - sig: JNIString::from("(JI)Ljava/sql/Timestamp;"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetTimestampValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetDecimalValue"), - sig: JNIString::from("(JI)Ljava/math/BigDecimal;"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetDecimalValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetTimeValue"), - sig: JNIString::from("(JI)Ljava/sql/Time;"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetTimeValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetDateValue"), - sig: JNIString::from("(JI)Ljava/sql/Date;"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetDateValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetIntervalValue"), - sig: JNIString::from("(JI)Ljava/lang/String;"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetIntervalValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetJsonbValue"), - sig: JNIString::from("(JI)Ljava/lang/String;"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetJsonbValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetByteaValue"), - sig: JNIString::from("(JI)[B"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetByteaValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowGetArrayValue"), - sig: JNIString::from("(JILjava/lang/Class;)Ljava/lang/Object;"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowGetArrayValue as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("rowClose"), - sig: JNIString::from("(J)V"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_rowClose as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("streamChunkIteratorNew"), - sig: JNIString::from("([B)J"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_streamChunkIteratorNew as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("streamChunkIteratorNext"), - sig: JNIString::from("(J)J"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_streamChunkIteratorNext as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("streamChunkIteratorClose"), - sig: JNIString::from("(J)V"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_streamChunkIteratorClose as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("streamChunkIteratorFromPretty"), - sig: JNIString::from("(Ljava/lang/String;)J"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_streamChunkIteratorFromPretty as *mut c_void, - }, - - - NativeMethod { - name: JNIString::from("sendMsgToChannel"), - sig: JNIString::from("(J[B)Z"), - fn_ptr: risingwave_java_binding::Java_com_risingwave_java_binding_Binding_sendMsgToChannel as *mut c_void, - }, - - ]).unwrap(); - - let string_class = env.find_class("java/lang/String").unwrap(); - let jarray = env - .new_object_array(0, string_class, JObject::null()) - .unwrap(); - - // FIXME: if we finish rewriting all RPCs to JNI calls, we don't need to run main anymore. - let _ = env - .call_static_method( - "com/risingwave/connector/ConnectorService", - "main", - "([Ljava/lang/String;)V", - &[JValue::Object(&jarray)], - ) - .inspect_err(|e| eprintln!("{:?}", e)) - .unwrap(); -} - fn default_total_memory_bytes() -> usize { total_memory_available_bytes() } diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 394cb3292e0ef..ae5c5b038a81b 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -80,7 +80,7 @@ rdkafka = { workspace = true, features = [ ] } reqwest = { version = "0.11", features = ["json"] } risingwave_common = { workspace = true } -risingwave_java_binding = { workspace = true } +risingwave_jni_core = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } rust_decimal = "1" diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 58144d35fd5e1..76e7a1c675d97 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -20,8 +20,8 @@ use futures::{pin_mut, StreamExt, TryStreamExt}; use futures_async_stream::try_stream; use jni::objects::{JObject, JValue}; use risingwave_common::util::addr::HostAddr; -use risingwave_java_binding::jvm_runtime::JVM; -use risingwave_java_binding::GetEventStreamJniSender; +use risingwave_jni_core::jvm_runtime::JVM; +use risingwave_jni_core::GetEventStreamJniSender; use risingwave_pb::connector_service::GetEventStreamResponse; use tokio::sync::mpsc; diff --git a/src/java_binding/Cargo.toml b/src/java_binding/Cargo.toml index 89bcd90a0a047..3eafacc84a49c 100644 --- a/src/java_binding/Cargo.toml +++ b/src/java_binding/Cargo.toml @@ -17,6 +17,7 @@ jni = "0.21.1" prost = "0.11" risingwave_common = { workspace = true } risingwave_hummock_sdk = { workspace = true } +risingwave_jni_core = { workspace = true } risingwave_object_store = { workspace = true } risingwave_pb = { workspace = true } risingwave_storage = { workspace = true } @@ -37,6 +38,9 @@ tracing = "0.1" [dev-dependencies] risingwave_expr = { workspace = true } +[lib] +crate-type = ["cdylib"] + [[bin]] name = "data-chunk-payload-generator" test = false diff --git a/src/java_binding/src/jvm_runtime.rs b/src/java_binding/src/jvm_runtime.rs deleted file mode 100644 index 455999ef1f453..0000000000000 --- a/src/java_binding/src/jvm_runtime.rs +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use core::option::Option::Some; -use core::result::Result::{Err, Ok}; -use std::fs; -use std::path::Path; -use std::sync::{Arc, LazyLock}; - -use jni::{InitArgsBuilder, JNIVersion, JavaVM}; - -pub static JVM: LazyLock> = LazyLock::new(|| { - let libs_path = ".risingwave/bin/connector-node/libs/"; - - let dir = Path::new(libs_path); - - if !dir.is_dir() { - panic!("{} is not a directory", libs_path); - } - - let mut class_vec = vec![]; - - if let Ok(entries) = fs::read_dir(dir) { - for entry in entries.flatten() { - if let Some(name) = entry.path().file_name() { - class_vec.push(libs_path.to_owned() + name.to_str().to_owned().unwrap()); - } - } - } else { - panic!("failed to read directory {}", libs_path); - } - - // Build the VM properties - let jvm_args = InitArgsBuilder::new() - // Pass the JNI API version (default is 8) - .version(JNIVersion::V8) - // You can additionally pass any JVM options (standard, like a system property, - // or VM-specific). - // Here we enable some extra JNI checks useful during development - // .option("-Xcheck:jni") - .option("-ea") - .option(format!("-Djava.class.path={}", class_vec.join(":"))) - .option("-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=9111") - .build() - .unwrap(); - - // Create a new VM - let jvm = match JavaVM::new(jvm_args) { - Err(err) => { - panic!("{:?}", err) - } - Ok(jvm) => jvm, - }; - - tracing::info!("initialize JVM successfully"); - Arc::new(jvm) -}); diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index 4f4afe88392b8..28c8f0419aa86 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -19,869 +19,4 @@ #![feature(type_alias_impl_trait)] #![feature(result_option_inspect)] -pub mod hummock_iterator; -pub mod jvm_runtime; -pub mod stream_chunk_iterator; - -use std::backtrace::Backtrace; -use std::marker::PhantomData; -use std::ops::{Deref, DerefMut}; -use std::slice::from_raw_parts; -use std::sync::{Arc, LazyLock, OnceLock}; - -use hummock_iterator::{HummockJavaBindingIterator, KeyedRow}; -use jni::objects::{ - AutoElements, GlobalRef, JByteArray, JClass, JMethodID, JObject, JStaticMethodID, JString, - JValue, JValueGen, JValueOwned, ReleaseMode, -}; -use jni::signature::ReturnType; -use jni::sys::{ - jboolean, jbyte, jdouble, jfloat, jint, jlong, jshort, jsize, jvalue, JNI_FALSE, JNI_TRUE, -}; -use jni::JNIEnv; -use prost::{DecodeError, Message}; -use risingwave_common::array::{ArrayError, StreamChunk}; -use risingwave_common::hash::VirtualNode; -use risingwave_common::row::{OwnedRow, Row}; -use risingwave_common::test_prelude::StreamChunkTestExt; -use risingwave_common::types::ScalarRefImpl; -use risingwave_common::util::panic::rw_catch_unwind; -use risingwave_pb::connector_service::GetEventStreamResponse; -use risingwave_storage::error::StorageError; -use thiserror::Error; -use tokio::runtime::Runtime; -use tokio::sync::mpsc::Sender; - -use crate::stream_chunk_iterator::{StreamChunkIterator, StreamChunkRow}; -pub type GetEventStreamJniSender = Sender; - -static RUNTIME: LazyLock = LazyLock::new(|| tokio::runtime::Runtime::new().unwrap()); - -#[derive(Error, Debug)] -enum BindingError { - #[error("JniError {error}")] - Jni { - #[from] - error: jni::errors::Error, - backtrace: Backtrace, - }, - - #[error("StorageError {error}")] - Storage { - #[from] - error: StorageError, - backtrace: Backtrace, - }, - - #[error("DecodeError {error}")] - Decode { - #[from] - error: DecodeError, - backtrace: Backtrace, - }, - - #[error("StreamChunkArrayError {error}")] - StreamChunkArray { - #[from] - error: ArrayError, - backtrace: Backtrace, - }, -} - -type Result = std::result::Result; - -fn to_guarded_slice<'array, 'env>( - array: &'array JByteArray<'env>, - env: &'array mut JNIEnv<'env>, -) -> Result> { - unsafe { - let array = env.get_array_elements(array, ReleaseMode::NoCopyBack)?; - let slice = from_raw_parts(array.as_ptr() as *mut u8, array.len()); - - Ok(SliceGuard { - _array: array, - slice, - }) - } -} - -/// Wrapper around `&[u8]` derived from `jbyteArray` to prevent it from being auto-released. -pub struct SliceGuard<'env, 'array> { - _array: AutoElements<'env, 'env, 'array, jbyte>, - slice: &'array [u8], -} - -impl<'env, 'array> Deref for SliceGuard<'env, 'array> { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - self.slice - } -} - -#[repr(transparent)] -pub struct Pointer<'a, T> { - pointer: jlong, - _phantom: PhantomData<&'a T>, -} - -impl<'a, T> Default for Pointer<'a, T> { - fn default() -> Self { - Self { - pointer: 0, - _phantom: Default::default(), - } - } -} - -impl From for Pointer<'static, T> { - fn from(value: T) -> Self { - Pointer { - pointer: Box::into_raw(Box::new(value)) as jlong, - _phantom: PhantomData, - } - } -} - -impl Pointer<'static, T> { - fn null() -> Self { - Pointer { - pointer: 0, - _phantom: PhantomData, - } - } -} - -impl<'a, T> Pointer<'a, T> { - fn as_ref(&self) -> &'a T { - debug_assert!(self.pointer != 0); - unsafe { &*(self.pointer as *const T) } - } - - fn as_mut(&mut self) -> &'a mut T { - debug_assert!(self.pointer != 0); - unsafe { &mut *(self.pointer as *mut T) } - } - - fn drop(self) { - debug_assert!(self.pointer != 0); - unsafe { drop(Box::from_raw(self.pointer as *mut T)) } - } -} - -/// In most Jni interfaces, the first parameter is `JNIEnv`, and the second parameter is `JClass`. -/// This struct simply encapsulates the two common parameters into a single struct for simplicity. -#[repr(C)] -pub struct EnvParam<'a> { - env: JNIEnv<'a>, - class: JClass<'a>, -} - -impl<'a> Deref for EnvParam<'a> { - type Target = JNIEnv<'a>; - - fn deref(&self) -> &Self::Target { - &self.env - } -} - -impl<'a> DerefMut for EnvParam<'a> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.env - } -} - -impl<'a> EnvParam<'a> { - pub fn get_class(&self) -> &JClass<'a> { - &self.class - } -} - -fn execute_and_catch<'env, F, Ret>(mut env: EnvParam<'env>, inner: F) -> Ret -where - F: FnOnce(&mut EnvParam<'env>) -> Result, - Ret: Default + 'env, -{ - match rw_catch_unwind(std::panic::AssertUnwindSafe(|| inner(&mut env))) { - Ok(Ok(ret)) => ret, - Ok(Err(e)) => { - match e { - BindingError::Jni { - error: jni::errors::Error::JavaException, - backtrace, - } => { - tracing::error!("get JavaException thrown from: {:?}", backtrace); - // the exception is already thrown. No need to throw again - } - _ => { - env.throw(format!("get error while processing: {:?}", e)) - .expect("should be able to throw"); - } - } - Ret::default() - } - Err(e) => { - env.throw(format!("panic while processing: {:?}", e)) - .expect("should be able to throw"); - Ret::default() - } - } -} - -pub enum JavaBindingRowInner { - Keyed(KeyedRow), - StreamChunk(StreamChunkRow), -} -#[derive(Default)] -pub struct JavaClassMethodCache { - big_decimal_ctor: OnceLock<(GlobalRef, JMethodID)>, - timestamp_ctor: OnceLock<(GlobalRef, JMethodID)>, - - date_ctor: OnceLock<(GlobalRef, JStaticMethodID)>, - time_ctor: OnceLock<(GlobalRef, JStaticMethodID)>, -} - -pub struct JavaBindingRow { - inner: JavaBindingRowInner, - class_cache: Arc, -} - -impl JavaBindingRow { - fn with_stream_chunk( - underlying: StreamChunkRow, - class_cache: Arc, - ) -> Self { - Self { - inner: JavaBindingRowInner::StreamChunk(underlying), - class_cache, - } - } - - fn with_keyed(underlying: KeyedRow, class_cache: Arc) -> Self { - Self { - inner: JavaBindingRowInner::Keyed(underlying), - class_cache, - } - } - - fn as_keyed(&self) -> &KeyedRow { - match &self.inner { - JavaBindingRowInner::Keyed(r) => r, - _ => unreachable!("can only call as_keyed for KeyedRow"), - } - } - - fn as_stream_chunk(&self) -> &StreamChunkRow { - match &self.inner { - JavaBindingRowInner::StreamChunk(r) => r, - _ => unreachable!("can only call as_stream_chunk for StreamChunkRow"), - } - } -} - -impl Deref for JavaBindingRow { - type Target = OwnedRow; - - fn deref(&self) -> &Self::Target { - match &self.inner { - JavaBindingRowInner::Keyed(r) => r.row(), - JavaBindingRowInner::StreamChunk(r) => r.row(), - } - } -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_vnodeCount( - _env: EnvParam<'_>, -) -> jint { - VirtualNode::COUNT as jint -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorNew<'a>( - env: EnvParam<'a>, - read_plan: JByteArray<'a>, -) -> Pointer<'static, HummockJavaBindingIterator> { - execute_and_catch(env, move |env| { - let read_plan = Message::decode(to_guarded_slice(&read_plan, env)?.deref())?; - let iter = RUNTIME.block_on(HummockJavaBindingIterator::new(read_plan))?; - Ok(iter.into()) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorNext<'a>( - env: EnvParam<'a>, - mut pointer: Pointer<'a, HummockJavaBindingIterator>, -) -> Pointer<'static, JavaBindingRow> { - execute_and_catch(env, move |_env| { - let iter = pointer.as_mut(); - match RUNTIME.block_on(iter.next())? { - None => Ok(Pointer::null()), - Some(row) => Ok(JavaBindingRow::with_keyed(row, iter.class_cache.clone()).into()), - } - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorClose( - _env: EnvParam<'_>, - pointer: Pointer<'_, HummockJavaBindingIterator>, -) { - pointer.drop(); -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorNew<'a>( - env: EnvParam<'a>, - stream_chunk_payload: JByteArray<'a>, -) -> Pointer<'static, StreamChunkIterator> { - execute_and_catch(env, move |env| { - let prost_stream_chumk = - Message::decode(to_guarded_slice(&stream_chunk_payload, env)?.deref())?; - let iter = StreamChunkIterator::new(StreamChunk::from_protobuf(&prost_stream_chumk)?); - Ok(iter.into()) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorFromPretty< - 'a, ->( - env: EnvParam<'a>, - str: JString<'a>, -) -> Pointer<'static, StreamChunkIterator> { - execute_and_catch(env, move |env: &mut EnvParam<'_>| { - let iter = StreamChunkIterator::new(StreamChunk::from_pretty( - env.get_string(&str) - .expect("cannot get java string") - .to_str() - .unwrap(), - )); - Ok(iter.into()) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorNext<'a>( - env: EnvParam<'a>, - mut pointer: Pointer<'a, StreamChunkIterator>, -) -> Pointer<'static, JavaBindingRow> { - execute_and_catch(env, move |_env| { - let iter = pointer.as_mut(); - match iter.next() { - None => Ok(Pointer::null()), - Some(row) => { - Ok(JavaBindingRow::with_stream_chunk(row, iter.class_cache.clone()).into()) - } - } - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorClose( - _env: EnvParam<'_>, - pointer: Pointer<'_, StreamChunkIterator>, -) { - pointer.drop(); -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetKey<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, -) -> JByteArray<'a> { - execute_and_catch(env, move |env: &mut EnvParam<'_>| { - Ok(env.byte_array_from_slice(pointer.as_ref().as_keyed().key())?) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetOp<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, -) -> jint { - execute_and_catch(env, move |_env| { - Ok(pointer.as_ref().as_stream_chunk().op() as jint) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowIsNull<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> jboolean { - execute_and_catch(env, move |_env| { - Ok(pointer.as_ref().datum_at(idx as usize).is_none() as jboolean) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt16Value<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> jshort { - execute_and_catch(env, move |_env| { - Ok(pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_int16()) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt32Value<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> jint { - execute_and_catch(env, move |_env| { - Ok(pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_int32()) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt64Value<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> jlong { - execute_and_catch(env, move |_env| { - Ok(pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_int64()) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetFloatValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> jfloat { - execute_and_catch(env, move |_env| { - Ok(pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_float32() - .into()) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDoubleValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> jdouble { - execute_and_catch(env, move |_env| { - Ok(pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_float64() - .into()) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetBooleanValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> jboolean { - execute_and_catch(env, move |_env| { - Ok(pointer.as_ref().datum_at(idx as usize).unwrap().into_bool() as jboolean) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetStringValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> JString<'a> { - execute_and_catch(env, move |env: &mut EnvParam<'a>| { - Ok(env.new_string(pointer.as_ref().datum_at(idx as usize).unwrap().into_utf8())?) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetIntervalValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> JString<'a> { - execute_and_catch(env, move |env: &mut EnvParam<'a>| { - let interval = pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_interval() - .as_iso_8601(); - Ok(env.new_string(interval)?) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetJsonbValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> JString<'a> { - execute_and_catch(env, move |env: &mut EnvParam<'_>| { - let jsonb = pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_jsonb() - .to_string(); - Ok(env.new_string(jsonb)?) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetTimestampValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> JObject<'a> { - execute_and_catch(env, move |env: &mut EnvParam<'_>| { - let scalar_value = pointer.as_ref().datum_at(idx as usize).unwrap(); - let millis = match scalar_value { - // supports sinking rw timestamptz to mysql timestamp - ScalarRefImpl::Timestamptz(tz) => tz.timestamp_millis(), - ScalarRefImpl::Timestamp(ts) => ts.0.timestamp_millis(), - _ => panic!("expect timestamp or timestamptz"), - }; - let (ts_class_ref, constructor) = pointer - .as_ref() - .class_cache - .timestamp_ctor - .get_or_try_init(|| { - let cls = env.find_class("java/sql/Timestamp")?; - let init_method = env.get_method_id(&cls, "", "(J)V")?; - Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method)) - })?; - unsafe { - let ts_class = <&JClass<'_>>::from(ts_class_ref.as_obj()); - let date_obj = - env.new_object_unchecked(ts_class, *constructor, &[jvalue { j: millis }])?; - Ok(date_obj) - } - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDecimalValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> JObject<'a> { - execute_and_catch(env, move |env: &mut EnvParam<'_>| { - let value = pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_decimal() - .to_string(); - let string_value = env.new_string(value)?; - let (decimal_class_ref, constructor) = pointer - .as_ref() - .class_cache - .big_decimal_ctor - .get_or_try_init(|| { - let cls = env.find_class("java/math/BigDecimal")?; - let init_method = env.get_method_id(&cls, "", "(Ljava/lang/String;)V")?; - Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method)) - })?; - unsafe { - let decimal_class = <&JClass<'_>>::from(decimal_class_ref.as_obj()); - let date_obj = env.new_object_unchecked( - decimal_class, - *constructor, - &[jvalue { - l: string_value.into_raw(), - }], - )?; - Ok(date_obj) - } - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDateValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> JObject<'a> { - execute_and_catch(env, move |env: &mut EnvParam<'_>| { - let value = pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_date() - .0 - .to_string(); - - let string_value = env.new_string(value)?; - let (class_ref, constructor) = - pointer.as_ref().class_cache.date_ctor.get_or_try_init(|| { - let cls = env.find_class("java/sql/Date")?; - let init_method = env.get_static_method_id( - &cls, - "valueOf", - "(Ljava/lang/String;)Ljava/sql/Date;", - )?; - Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method)) - })?; - unsafe { - let JValueOwned::Object(date_obj) = env.call_static_method_unchecked( - <&JClass<'_>>::from(class_ref.as_obj()), - *constructor, - ReturnType::Object, - &[jvalue { - l: string_value.into_raw(), - }], - )? - else { - return Err(BindingError::from(jni::errors::Error::MethodNotFound { - name: "valueOf".to_string(), - sig: "(Ljava/lang/String;)Ljava/sql/Date;".into(), - })); - }; - Ok(date_obj) - } - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetTimeValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> JObject<'a> { - execute_and_catch(env, move |env: &mut EnvParam<'_>| { - let value = pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_time() - .0 - .to_string(); - - let string_value = env.new_string(value)?; - let (class_ref, constructor) = - pointer.as_ref().class_cache.time_ctor.get_or_try_init(|| { - let cls = env.find_class("java/sql/Time")?; - let init_method = env.get_static_method_id( - &cls, - "valueOf", - "(Ljava/lang/String;)Ljava/sql/Time;", - )?; - Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method)) - })?; - unsafe { - let class = <&JClass<'_>>::from(class_ref.as_obj()); - match env.call_static_method_unchecked( - class, - *constructor, - ReturnType::Object, - &[jvalue { - l: string_value.into_raw(), - }], - )? { - JValueGen::Object(obj) => Ok(obj), - _ => Err(BindingError::from(jni::errors::Error::MethodNotFound { - name: "valueOf".to_string(), - sig: "(Ljava/lang/String;)Ljava/sql/Time;".into(), - })), - } - } - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetByteaValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, -) -> JByteArray<'a> { - execute_and_catch(env, move |env: &mut EnvParam<'_>| { - let bytes = pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_bytea(); - Ok(env.byte_array_from_slice(bytes)?) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetArrayValue<'a>( - env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, - idx: jint, - class: JClass<'a>, -) -> JObject<'a> { - execute_and_catch(env, move |env: &mut EnvParam<'_>| { - let elems = pointer - .as_ref() - .datum_at(idx as usize) - .unwrap() - .into_list() - .iter(); - - // convert the Rust elements to a Java object array (Object[]) - let jarray = env.new_object_array(elems.len() as jsize, &class, JObject::null())?; - - for (i, ele) in elems.enumerate() { - let index = i as jsize; - match ele { - None => env.set_object_array_element(&jarray, i as jsize, JObject::null())?, - Some(val) => match val { - ScalarRefImpl::Int16(v) => { - let obj = env.call_static_method( - &class, - "valueOf", - "(S)Ljava.lang.Short;", - &[JValue::from(v as jshort)], - )?; - if let JValueOwned::Object(o) = obj { - env.set_object_array_element(&jarray, index, &o)? - } - } - ScalarRefImpl::Int32(v) => { - let obj = env.call_static_method( - &class, - "valueOf", - "(I)Ljava.lang.Integer;", - &[JValue::from(v as jint)], - )?; - if let JValueOwned::Object(o) = obj { - env.set_object_array_element(&jarray, index, &o)? - } - } - ScalarRefImpl::Int64(v) => { - let obj = env.call_static_method( - &class, - "valueOf", - "(J)Ljava.lang.Long;", - &[JValue::from(v as jlong)], - )?; - if let JValueOwned::Object(o) = obj { - env.set_object_array_element(&jarray, index, &o)? - } - } - ScalarRefImpl::Float32(v) => { - let obj = env.call_static_method( - &class, - "valueOf", - "(F)Ljava/lang/Float;", - &[JValue::from(v.into_inner() as jfloat)], - )?; - if let JValueOwned::Object(o) = obj { - env.set_object_array_element(&jarray, index, &o)? - } - } - ScalarRefImpl::Float64(v) => { - let obj = env.call_static_method( - &class, - "valueOf", - "(D)Ljava/lang/Double;", - &[JValue::from(v.into_inner() as jdouble)], - )?; - if let JValueOwned::Object(o) = obj { - env.set_object_array_element(&jarray, index, &o)? - } - } - ScalarRefImpl::Utf8(v) => { - let obj = env.new_string(v)?; - env.set_object_array_element(&jarray, index, obj)? - } - _ => env.set_object_array_element(&jarray, index, JObject::null())?, - }, - } - } - let output = unsafe { JObject::from_raw(jarray.into_raw()) }; - Ok(output) - }) -} - -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>( - _env: EnvParam<'a>, - pointer: Pointer<'a, JavaBindingRow>, -) { - pointer.drop() -} - -#[inline(never)] -pub fn run_this_func_to_get_valid_ptr_from_java_binding() {} - -/// Send messages to the channel received by `CdcSplitReader`. -/// If msg is null, just check whether the channel is closed. -/// Return true if sending is successful, otherwise, return false so that caller can stop -/// gracefully. -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendMsgToChannel<'a>( - env: EnvParam<'a>, - channel: Pointer<'a, GetEventStreamJniSender>, - msg: JByteArray<'a>, -) -> jboolean { - execute_and_catch(env, move |env| { - // If msg is null means just check whether channel is closed. - if msg.is_null() { - if channel.as_ref().is_closed() { - // Drop channel as well. - channel.drop(); - return Ok(JNI_FALSE); - } else { - return Ok(JNI_TRUE); - } - } - - let get_event_stream_response: GetEventStreamResponse = - Message::decode(to_guarded_slice(&msg, env)?.deref())?; - - tracing::debug!("before send"); - match channel.as_ref().blocking_send(get_event_stream_response) { - Ok(_) => { - tracing::debug!("send successfully"); - Ok(JNI_TRUE) - } - Err(e) => { - channel.drop(); - tracing::debug!("send error. {:?}", e); - Ok(JNI_FALSE) - } - } - }) -} - -#[cfg(test)] -mod tests { - use risingwave_common::types::{DataType, Timestamptz}; - use risingwave_expr::vector_op::cast::literal_parsing; - - /// make sure that the [`ScalarRefImpl::Int64`] received by - /// [`Java_com_risingwave_java_binding_Binding_rowGetTimestampValue`] - /// is of type [`DataType::Timestamptz`] stored in microseconds - #[test] - fn test_timestamptz_to_i64() { - assert_eq!( - literal_parsing(&DataType::Timestamptz, "2023-06-01 09:45:00+08:00").unwrap(), - Timestamptz::from_micros(1_685_583_900_000_000).into() - ); - } -} +pub use risingwave_jni_core::*; diff --git a/src/jni_core/Cargo.toml b/src/jni_core/Cargo.toml new file mode 100644 index 0000000000000..e0e4bd75022fd --- /dev/null +++ b/src/jni_core/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "risingwave_jni_core" +version = "0.1.0" +edition = "2021" + +[package.metadata.cargo-machete] +ignored = ["workspace-hack"] + +[package.metadata.cargo-udeps.ignore] +normal = ["workspace-hack"] + +[dependencies] +bytes = "1" +futures = { version = "0.3", default-features = false, features = ["alloc"] } +itertools = "0.11" +jni = "0.21.1" +prost = "0.11" +risingwave_common = { workspace = true } +risingwave_hummock_sdk = { workspace = true } +risingwave_object_store = { workspace = true } +risingwave_pb = { workspace = true } +risingwave_storage = { workspace = true } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "1" +tokio = { version = "0.2", package = "madsim-tokio", features = [ + "fs", + "rt", + "rt-multi-thread", + "sync", + "macros", + "time", + "signal", +] } +tracing = "0.1" + +[dev-dependencies] +risingwave_expr = { workspace = true } diff --git a/src/jni_core/src/bin/data-chunk-payload-convert-generator.rs b/src/jni_core/src/bin/data-chunk-payload-convert-generator.rs new file mode 100644 index 0000000000000..75d5afb8d27dc --- /dev/null +++ b/src/jni_core/src/bin/data-chunk-payload-convert-generator.rs @@ -0,0 +1,97 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +use std::env; +use std::fs::File; +use std::io::{Read, Write}; +use std::process::exit; + +use prost::Message; +use risingwave_common::array::{Op, StreamChunk}; +use risingwave_common::row::OwnedRow; +use risingwave_common::types::{DataType, ScalarImpl}; +use risingwave_common::util::chunk_coalesce::DataChunkBuilder; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize, Serialize)] +struct Line { + id: u32, + name: String, +} + +#[derive(Debug, Deserialize, Serialize)] +struct Operation { + op_type: u32, + line: Line, +} + +fn convert_to_op(value: u32) -> Option { + match value { + 1 => Some(Op::Insert), + 2 => Some(Op::Delete), + 3 => Some(Op::UpdateInsert), + 4 => Some(Op::UpdateDelete), + _ => None, + } +} + +fn main() { + let args: Vec = env::args().collect(); + if args.len() <= 1 { + println!("No input file name"); + exit(0); + } + // Read the JSON file + let mut file = File::open(&args[1]).expect("Failed to open file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Failed to read file"); + + // Parse the JSON data + let data: Vec> = serde_json::from_str(&contents).expect("Failed to parse JSON"); + + let data_types: Vec<_> = vec![DataType::Int32, DataType::Varchar]; + + // Access the data + let mut row_count = 0; + for operations in &data { + row_count += operations.len(); + } + let mut ops = Vec::with_capacity(row_count); + let mut builder = DataChunkBuilder::new(data_types, row_count * 1024); + + for operations in data { + for operation in operations { + let mut row_value = Vec::with_capacity(10); + row_value.push(Some(ScalarImpl::Int32(operation.line.id as i32))); + row_value.push(Some(ScalarImpl::Utf8(operation.line.name.into_boxed_str()))); + let _ = builder.append_one_row(OwnedRow::new(row_value)); + // let op: Op = unsafe { ::std::mem::transmute(operation.op_type as u8) }; + if let Some(op) = convert_to_op(operation.op_type) { + ops.push(op); + } else { + println!("Invalid value"); + } + } + } + + let data_chunk = builder.consume_all().expect("should not be empty"); + let stream_chunk = StreamChunk::from_parts(ops, data_chunk); + let prost_stream_chunk: risingwave_pb::data::StreamChunk = stream_chunk.to_protobuf(); + + let payload = Message::encode_to_vec(&prost_stream_chunk); + + std::io::stdout() + .write_all(&payload) + .expect("should success"); +} diff --git a/src/jni_core/src/bin/data-chunk-payload-generator.rs b/src/jni_core/src/bin/data-chunk-payload-generator.rs new file mode 100644 index 0000000000000..f4d0dd6ff16f9 --- /dev/null +++ b/src/jni_core/src/bin/data-chunk-payload-generator.rs @@ -0,0 +1,92 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::env; +use std::io::Write; + +use prost::Message; +use risingwave_common::array::{Op, StreamChunk}; +use risingwave_common::row::OwnedRow; +use risingwave_common::types::{DataType, ScalarImpl, Timestamp, F32, F64}; +use risingwave_common::util::chunk_coalesce::DataChunkBuilder; + +fn build_row(index: usize) -> OwnedRow { + let mut row_value = Vec::with_capacity(10); + row_value.push(Some(ScalarImpl::Int16(index as i16))); + row_value.push(Some(ScalarImpl::Int32(index as i32))); + row_value.push(Some(ScalarImpl::Int64(index as i64))); + row_value.push(Some(ScalarImpl::Float32(F32::from(index as f32)))); + row_value.push(Some(ScalarImpl::Float64(F64::from(index as f64)))); + row_value.push(Some(ScalarImpl::Bool(index % 3 == 0))); + row_value.push(Some(ScalarImpl::Utf8( + format!("{}", index).repeat((index % 10) + 1).into(), + ))); + row_value.push(Some(ScalarImpl::Timestamp( + Timestamp::from_timestamp_uncheck(index as _, 0), + ))); + row_value.push(Some(ScalarImpl::Decimal(index.into()))); + row_value.push(if index % 5 == 0 { + None + } else { + Some(ScalarImpl::Int64(index as i64)) + }); + + OwnedRow::new(row_value) +} + +fn main() { + let args: Vec = env::args().collect(); + let mut flag = false; + let mut row_count: usize = 30000; + if args.len() > 1 { + flag = true; + row_count = args[1].parse().unwrap(); + } + let data_types = vec![ + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + DataType::Boolean, + DataType::Varchar, + DataType::Timestamp, + DataType::Decimal, + DataType::Int64, + ]; + let mut ops = Vec::with_capacity(row_count); + let mut builder = DataChunkBuilder::new(data_types, row_count * 1024); + for i in 0..row_count { + assert!( + builder.append_one_row(build_row(i)).is_none(), + "should not finish" + ); + // In unit test, it does not support delete operation + if flag || i % 2 == 0 { + ops.push(Op::Insert); + } else { + ops.push(Op::Delete); + } + } + + let data_chunk = builder.consume_all().expect("should not be empty"); + let stream_chunk = StreamChunk::from_parts(ops, data_chunk); + let prost_stream_chunk = stream_chunk.to_protobuf(); + + let payload = Message::encode_to_vec(&prost_stream_chunk); + + std::io::stdout() + .write_all(&payload) + .expect("should success"); +} diff --git a/src/java_binding/src/hummock_iterator.rs b/src/jni_core/src/hummock_iterator.rs similarity index 100% rename from src/java_binding/src/hummock_iterator.rs rename to src/jni_core/src/hummock_iterator.rs diff --git a/src/jni_core/src/jvm_runtime.rs b/src/jni_core/src/jvm_runtime.rs new file mode 100644 index 0000000000000..93f56195646b1 --- /dev/null +++ b/src/jni_core/src/jvm_runtime.rs @@ -0,0 +1,271 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use core::option::Option::Some; +use core::result::Result::{Err, Ok}; +use std::ffi::c_void; +use std::fs; +use std::path::Path; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, LazyLock}; + +use jni::strings::JNIString; +use jni::{InitArgsBuilder, JNIVersion, JavaVM, NativeMethod}; + +use crate::run_this_func_to_get_valid_ptr_from_java_binding; + +pub static JVM: LazyLock> = LazyLock::new(|| { + let libs_path = ".risingwave/bin/connector-node/libs/"; + + let dir = Path::new(libs_path); + + if !dir.is_dir() { + panic!("{} is not a directory", libs_path); + } + + let mut class_vec = vec![]; + + if let Ok(entries) = fs::read_dir(dir) { + for entry in entries.flatten() { + if let Some(name) = entry.path().file_name() { + class_vec.push(libs_path.to_owned() + name.to_str().to_owned().unwrap()); + } + } + } else { + panic!("failed to read directory {}", libs_path); + } + + // Build the VM properties + let jvm_args = InitArgsBuilder::new() + // Pass the JNI API version (default is 8) + .version(JNIVersion::V8) + // You can additionally pass any JVM options (standard, like a system property, + // or VM-specific). + // Here we enable some extra JNI checks useful during development + // .option("-Xcheck:jni") + .option("-ea") + .option("-Dis_embedded_connector=true") + .option(format!("-Djava.class.path={}", class_vec.join(":"))) + // TODO: remove it + // .option("-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=9111") + .build() + .unwrap(); + + // Create a new VM + let jvm = match JavaVM::new(jvm_args) { + Err(err) => { + panic!("{:?}", err) + } + Ok(jvm) => jvm, + }; + + tracing::info!("initialize JVM successfully"); + Arc::new(jvm) +}); + +static REGISTERED: AtomicBool = AtomicBool::new(false); + +pub fn register_native_method_for_jvm() { + // Ensure registering only once. + if REGISTERED + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_err() + { + return; + } + + let mut env = JVM + .attach_current_thread() + .inspect_err(|e| tracing::error!("{:?}", e)) + .unwrap(); + + // FIXME: remove this function would cause segment fault. + run_this_func_to_get_valid_ptr_from_java_binding(); + + let binding_class = env + .find_class("com/risingwave/java/binding/Binding") + .unwrap(); + env.register_native_methods( + binding_class, + &[ + NativeMethod { + name: JNIString::from("vnodeCount"), + sig: JNIString::from("()I"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_vnodeCount as *mut c_void, + }, + NativeMethod { + name: JNIString::from("hummockIteratorNew"), + sig: JNIString::from("([B)J"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_hummockIteratorNew + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("hummockIteratorNext"), + sig: JNIString::from("(J)J"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_hummockIteratorNext + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("hummockIteratorClose"), + sig: JNIString::from("(J)V"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_hummockIteratorClose + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetKey"), + sig: JNIString::from("(J)[B"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetKey as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetOp"), + sig: JNIString::from("(J)I"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetOp as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowIsNull"), + sig: JNIString::from("(JI)Z"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowIsNull as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetInt16Value"), + sig: JNIString::from("(JI)S"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetInt16Value + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetInt32Value"), + sig: JNIString::from("(JI)I"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetInt32Value + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetInt64Value"), + sig: JNIString::from("(JI)J"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetInt64Value + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetFloatValue"), + sig: JNIString::from("(JI)F"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetFloatValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetDoubleValue"), + sig: JNIString::from("(JI)D"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetDoubleValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetBooleanValue"), + sig: JNIString::from("(JI)Z"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetBooleanValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetStringValue"), + sig: JNIString::from("(JI)Ljava/lang/String;"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetStringValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetTimestampValue"), + sig: JNIString::from("(JI)Ljava/sql/Timestamp;"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetTimestampValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetDecimalValue"), + sig: JNIString::from("(JI)Ljava/math/BigDecimal;"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetDecimalValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetTimeValue"), + sig: JNIString::from("(JI)Ljava/sql/Time;"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetTimeValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetDateValue"), + sig: JNIString::from("(JI)Ljava/sql/Date;"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetDateValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetIntervalValue"), + sig: JNIString::from("(JI)Ljava/lang/String;"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetIntervalValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetJsonbValue"), + sig: JNIString::from("(JI)Ljava/lang/String;"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetJsonbValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetByteaValue"), + sig: JNIString::from("(JI)[B"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetByteaValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowGetArrayValue"), + sig: JNIString::from("(JILjava/lang/Class;)Ljava/lang/Object;"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetArrayValue + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("rowClose"), + sig: JNIString::from("(J)V"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowClose as *mut c_void, + }, + NativeMethod { + name: JNIString::from("streamChunkIteratorNew"), + sig: JNIString::from("([B)J"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorNew + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("streamChunkIteratorNext"), + sig: JNIString::from("(J)J"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorNext + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("streamChunkIteratorClose"), + sig: JNIString::from("(J)V"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorClose + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("streamChunkIteratorFromPretty"), + sig: JNIString::from("(Ljava/lang/String;)J"), + fn_ptr: + crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorFromPretty + as *mut c_void, + }, + NativeMethod { + name: JNIString::from("sendCdcSourceMsgToChannel"), + sig: JNIString::from("(J[B)Z"), + fn_ptr: crate::Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel + as *mut c_void, + }, + ], + ) + .unwrap(); + + tracing::info!("register native methods for jvm successfully"); +} diff --git a/src/jni_core/src/lib.rs b/src/jni_core/src/lib.rs new file mode 100644 index 0000000000000..ce4da4da0534f --- /dev/null +++ b/src/jni_core/src/lib.rs @@ -0,0 +1,887 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(error_generic_member_access)] +#![feature(provide_any)] +#![feature(lazy_cell)] +#![feature(once_cell_try)] +#![feature(type_alias_impl_trait)] +#![feature(result_option_inspect)] + +pub mod hummock_iterator; +pub mod jvm_runtime; +pub mod stream_chunk_iterator; + +use std::backtrace::Backtrace; +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; +use std::slice::from_raw_parts; +use std::sync::{Arc, LazyLock, OnceLock}; + +use hummock_iterator::{HummockJavaBindingIterator, KeyedRow}; +use jni::objects::{ + AutoElements, GlobalRef, JByteArray, JClass, JMethodID, JObject, JStaticMethodID, JString, + JValue, JValueGen, JValueOwned, ReleaseMode, +}; +use jni::signature::ReturnType; +use jni::sys::{ + jboolean, jbyte, jdouble, jfloat, jint, jlong, jshort, jsize, jvalue, JNI_FALSE, JNI_TRUE, +}; +use jni::JNIEnv; +use prost::{DecodeError, Message}; +use risingwave_common::array::{ArrayError, StreamChunk}; +use risingwave_common::hash::VirtualNode; +use risingwave_common::row::{OwnedRow, Row}; +use risingwave_common::test_prelude::StreamChunkTestExt; +use risingwave_common::types::ScalarRefImpl; +use risingwave_common::util::panic::rw_catch_unwind; +use risingwave_pb::connector_service::GetEventStreamResponse; +use risingwave_storage::error::StorageError; +use thiserror::Error; +use tokio::runtime::Runtime; +use tokio::sync::mpsc::Sender; + +use crate::stream_chunk_iterator::{StreamChunkIterator, StreamChunkRow}; +pub type GetEventStreamJniSender = Sender; + +static RUNTIME: LazyLock = LazyLock::new(|| tokio::runtime::Runtime::new().unwrap()); + +#[derive(Error, Debug)] +enum BindingError { + #[error("JniError {error}")] + Jni { + #[from] + error: jni::errors::Error, + backtrace: Backtrace, + }, + + #[error("StorageError {error}")] + Storage { + #[from] + error: StorageError, + backtrace: Backtrace, + }, + + #[error("DecodeError {error}")] + Decode { + #[from] + error: DecodeError, + backtrace: Backtrace, + }, + + #[error("StreamChunkArrayError {error}")] + StreamChunkArray { + #[from] + error: ArrayError, + backtrace: Backtrace, + }, +} + +type Result = std::result::Result; + +fn to_guarded_slice<'array, 'env>( + array: &'array JByteArray<'env>, + env: &'array mut JNIEnv<'env>, +) -> Result> { + unsafe { + let array = env.get_array_elements(array, ReleaseMode::NoCopyBack)?; + let slice = from_raw_parts(array.as_ptr() as *mut u8, array.len()); + + Ok(SliceGuard { + _array: array, + slice, + }) + } +} + +/// Wrapper around `&[u8]` derived from `jbyteArray` to prevent it from being auto-released. +pub struct SliceGuard<'env, 'array> { + _array: AutoElements<'env, 'env, 'array, jbyte>, + slice: &'array [u8], +} + +impl<'env, 'array> Deref for SliceGuard<'env, 'array> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.slice + } +} + +#[repr(transparent)] +pub struct Pointer<'a, T> { + pointer: jlong, + _phantom: PhantomData<&'a T>, +} + +impl<'a, T> Default for Pointer<'a, T> { + fn default() -> Self { + Self { + pointer: 0, + _phantom: Default::default(), + } + } +} + +impl From for Pointer<'static, T> { + fn from(value: T) -> Self { + Pointer { + pointer: Box::into_raw(Box::new(value)) as jlong, + _phantom: PhantomData, + } + } +} + +impl Pointer<'static, T> { + fn null() -> Self { + Pointer { + pointer: 0, + _phantom: PhantomData, + } + } +} + +impl<'a, T> Pointer<'a, T> { + fn as_ref(&self) -> &'a T { + debug_assert!(self.pointer != 0); + unsafe { &*(self.pointer as *const T) } + } + + fn as_mut(&mut self) -> &'a mut T { + debug_assert!(self.pointer != 0); + unsafe { &mut *(self.pointer as *mut T) } + } + + fn drop(self) { + debug_assert!(self.pointer != 0); + unsafe { drop(Box::from_raw(self.pointer as *mut T)) } + } +} + +/// In most Jni interfaces, the first parameter is `JNIEnv`, and the second parameter is `JClass`. +/// This struct simply encapsulates the two common parameters into a single struct for simplicity. +#[repr(C)] +pub struct EnvParam<'a> { + env: JNIEnv<'a>, + class: JClass<'a>, +} + +impl<'a> Deref for EnvParam<'a> { + type Target = JNIEnv<'a>; + + fn deref(&self) -> &Self::Target { + &self.env + } +} + +impl<'a> DerefMut for EnvParam<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.env + } +} + +impl<'a> EnvParam<'a> { + pub fn get_class(&self) -> &JClass<'a> { + &self.class + } +} + +fn execute_and_catch<'env, F, Ret>(mut env: EnvParam<'env>, inner: F) -> Ret +where + F: FnOnce(&mut EnvParam<'env>) -> Result, + Ret: Default + 'env, +{ + match rw_catch_unwind(std::panic::AssertUnwindSafe(|| inner(&mut env))) { + Ok(Ok(ret)) => ret, + Ok(Err(e)) => { + match e { + BindingError::Jni { + error: jni::errors::Error::JavaException, + backtrace, + } => { + tracing::error!("get JavaException thrown from: {:?}", backtrace); + // the exception is already thrown. No need to throw again + } + _ => { + env.throw(format!("get error while processing: {:?}", e)) + .expect("should be able to throw"); + } + } + Ret::default() + } + Err(e) => { + env.throw(format!("panic while processing: {:?}", e)) + .expect("should be able to throw"); + Ret::default() + } + } +} + +pub enum JavaBindingRowInner { + Keyed(KeyedRow), + StreamChunk(StreamChunkRow), +} +#[derive(Default)] +pub struct JavaClassMethodCache { + big_decimal_ctor: OnceLock<(GlobalRef, JMethodID)>, + timestamp_ctor: OnceLock<(GlobalRef, JMethodID)>, + + date_ctor: OnceLock<(GlobalRef, JStaticMethodID)>, + time_ctor: OnceLock<(GlobalRef, JStaticMethodID)>, +} + +pub struct JavaBindingRow { + inner: JavaBindingRowInner, + class_cache: Arc, +} + +impl JavaBindingRow { + fn with_stream_chunk( + underlying: StreamChunkRow, + class_cache: Arc, + ) -> Self { + Self { + inner: JavaBindingRowInner::StreamChunk(underlying), + class_cache, + } + } + + fn with_keyed(underlying: KeyedRow, class_cache: Arc) -> Self { + Self { + inner: JavaBindingRowInner::Keyed(underlying), + class_cache, + } + } + + fn as_keyed(&self) -> &KeyedRow { + match &self.inner { + JavaBindingRowInner::Keyed(r) => r, + _ => unreachable!("can only call as_keyed for KeyedRow"), + } + } + + fn as_stream_chunk(&self) -> &StreamChunkRow { + match &self.inner { + JavaBindingRowInner::StreamChunk(r) => r, + _ => unreachable!("can only call as_stream_chunk for StreamChunkRow"), + } + } +} + +impl Deref for JavaBindingRow { + type Target = OwnedRow; + + fn deref(&self) -> &Self::Target { + match &self.inner { + JavaBindingRowInner::Keyed(r) => r.row(), + JavaBindingRowInner::StreamChunk(r) => r.row(), + } + } +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_vnodeCount( + _env: EnvParam<'_>, +) -> jint { + VirtualNode::COUNT as jint +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorNew<'a>( + env: EnvParam<'a>, + read_plan: JByteArray<'a>, +) -> Pointer<'static, HummockJavaBindingIterator> { + execute_and_catch(env, move |env| { + let read_plan = Message::decode(to_guarded_slice(&read_plan, env)?.deref())?; + let iter = RUNTIME.block_on(HummockJavaBindingIterator::new(read_plan))?; + Ok(iter.into()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorNext<'a>( + env: EnvParam<'a>, + mut pointer: Pointer<'a, HummockJavaBindingIterator>, +) -> Pointer<'static, JavaBindingRow> { + execute_and_catch(env, move |_env| { + let iter = pointer.as_mut(); + match RUNTIME.block_on(iter.next())? { + None => Ok(Pointer::null()), + Some(row) => Ok(JavaBindingRow::with_keyed(row, iter.class_cache.clone()).into()), + } + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorClose( + _env: EnvParam<'_>, + pointer: Pointer<'_, HummockJavaBindingIterator>, +) { + pointer.drop(); +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorNew<'a>( + env: EnvParam<'a>, + stream_chunk_payload: JByteArray<'a>, +) -> Pointer<'static, StreamChunkIterator> { + execute_and_catch(env, move |env| { + let prost_stream_chumk = + Message::decode(to_guarded_slice(&stream_chunk_payload, env)?.deref())?; + let iter = StreamChunkIterator::new(StreamChunk::from_protobuf(&prost_stream_chumk)?); + Ok(iter.into()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorFromPretty< + 'a, +>( + env: EnvParam<'a>, + str: JString<'a>, +) -> Pointer<'static, StreamChunkIterator> { + execute_and_catch(env, move |env: &mut EnvParam<'_>| { + let iter = StreamChunkIterator::new(StreamChunk::from_pretty( + env.get_string(&str) + .expect("cannot get java string") + .to_str() + .unwrap(), + )); + Ok(iter.into()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorNext<'a>( + env: EnvParam<'a>, + mut pointer: Pointer<'a, StreamChunkIterator>, +) -> Pointer<'static, JavaBindingRow> { + execute_and_catch(env, move |_env| { + let iter = pointer.as_mut(); + match iter.next() { + None => Ok(Pointer::null()), + Some(row) => { + Ok(JavaBindingRow::with_stream_chunk(row, iter.class_cache.clone()).into()) + } + } + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkIteratorClose( + _env: EnvParam<'_>, + pointer: Pointer<'_, StreamChunkIterator>, +) { + pointer.drop(); +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetKey<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, +) -> JByteArray<'a> { + execute_and_catch(env, move |env: &mut EnvParam<'_>| { + Ok(env.byte_array_from_slice(pointer.as_ref().as_keyed().key())?) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetOp<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, +) -> jint { + execute_and_catch(env, move |_env| { + Ok(pointer.as_ref().as_stream_chunk().op() as jint) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowIsNull<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> jboolean { + execute_and_catch(env, move |_env| { + Ok(pointer.as_ref().datum_at(idx as usize).is_none() as jboolean) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt16Value<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> jshort { + execute_and_catch(env, move |_env| { + Ok(pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_int16()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt32Value<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> jint { + execute_and_catch(env, move |_env| { + Ok(pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_int32()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetInt64Value<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> jlong { + execute_and_catch(env, move |_env| { + Ok(pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_int64()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetFloatValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> jfloat { + execute_and_catch(env, move |_env| { + Ok(pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_float32() + .into()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDoubleValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> jdouble { + execute_and_catch(env, move |_env| { + Ok(pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_float64() + .into()) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetBooleanValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> jboolean { + execute_and_catch(env, move |_env| { + Ok(pointer.as_ref().datum_at(idx as usize).unwrap().into_bool() as jboolean) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetStringValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> JString<'a> { + execute_and_catch(env, move |env: &mut EnvParam<'a>| { + Ok(env.new_string(pointer.as_ref().datum_at(idx as usize).unwrap().into_utf8())?) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetIntervalValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> JString<'a> { + execute_and_catch(env, move |env: &mut EnvParam<'a>| { + let interval = pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_interval() + .as_iso_8601(); + Ok(env.new_string(interval)?) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetJsonbValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> JString<'a> { + execute_and_catch(env, move |env: &mut EnvParam<'_>| { + let jsonb = pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_jsonb() + .to_string(); + Ok(env.new_string(jsonb)?) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetTimestampValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> JObject<'a> { + execute_and_catch(env, move |env: &mut EnvParam<'_>| { + let scalar_value = pointer.as_ref().datum_at(idx as usize).unwrap(); + let millis = match scalar_value { + // supports sinking rw timestamptz to mysql timestamp + ScalarRefImpl::Timestamptz(tz) => tz.timestamp_millis(), + ScalarRefImpl::Timestamp(ts) => ts.0.timestamp_millis(), + _ => panic!("expect timestamp or timestamptz"), + }; + let (ts_class_ref, constructor) = pointer + .as_ref() + .class_cache + .timestamp_ctor + .get_or_try_init(|| { + let cls = env.find_class("java/sql/Timestamp")?; + let init_method = env.get_method_id(&cls, "", "(J)V")?; + Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method)) + })?; + unsafe { + let ts_class = <&JClass<'_>>::from(ts_class_ref.as_obj()); + let date_obj = + env.new_object_unchecked(ts_class, *constructor, &[jvalue { j: millis }])?; + Ok(date_obj) + } + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDecimalValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> JObject<'a> { + execute_and_catch(env, move |env: &mut EnvParam<'_>| { + let value = pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_decimal() + .to_string(); + let string_value = env.new_string(value)?; + let (decimal_class_ref, constructor) = pointer + .as_ref() + .class_cache + .big_decimal_ctor + .get_or_try_init(|| { + let cls = env.find_class("java/math/BigDecimal")?; + let init_method = env.get_method_id(&cls, "", "(Ljava/lang/String;)V")?; + Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method)) + })?; + unsafe { + let decimal_class = <&JClass<'_>>::from(decimal_class_ref.as_obj()); + let date_obj = env.new_object_unchecked( + decimal_class, + *constructor, + &[jvalue { + l: string_value.into_raw(), + }], + )?; + Ok(date_obj) + } + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDateValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> JObject<'a> { + execute_and_catch(env, move |env: &mut EnvParam<'_>| { + let value = pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_date() + .0 + .to_string(); + + let string_value = env.new_string(value)?; + let (class_ref, constructor) = + pointer.as_ref().class_cache.date_ctor.get_or_try_init(|| { + let cls = env.find_class("java/sql/Date")?; + let init_method = env.get_static_method_id( + &cls, + "valueOf", + "(Ljava/lang/String;)Ljava/sql/Date;", + )?; + Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method)) + })?; + unsafe { + let JValueOwned::Object(date_obj) = env.call_static_method_unchecked( + <&JClass<'_>>::from(class_ref.as_obj()), + *constructor, + ReturnType::Object, + &[jvalue { + l: string_value.into_raw(), + }], + )? + else { + return Err(BindingError::from(jni::errors::Error::MethodNotFound { + name: "valueOf".to_string(), + sig: "(Ljava/lang/String;)Ljava/sql/Date;".into(), + })); + }; + Ok(date_obj) + } + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetTimeValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> JObject<'a> { + execute_and_catch(env, move |env: &mut EnvParam<'_>| { + let value = pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_time() + .0 + .to_string(); + + let string_value = env.new_string(value)?; + let (class_ref, constructor) = + pointer.as_ref().class_cache.time_ctor.get_or_try_init(|| { + let cls = env.find_class("java/sql/Time")?; + let init_method = env.get_static_method_id( + &cls, + "valueOf", + "(Ljava/lang/String;)Ljava/sql/Time;", + )?; + Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method)) + })?; + unsafe { + let class = <&JClass<'_>>::from(class_ref.as_obj()); + match env.call_static_method_unchecked( + class, + *constructor, + ReturnType::Object, + &[jvalue { + l: string_value.into_raw(), + }], + )? { + JValueGen::Object(obj) => Ok(obj), + _ => Err(BindingError::from(jni::errors::Error::MethodNotFound { + name: "valueOf".to_string(), + sig: "(Ljava/lang/String;)Ljava/sql/Time;".into(), + })), + } + } + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetByteaValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> JByteArray<'a> { + execute_and_catch(env, move |env: &mut EnvParam<'_>| { + let bytes = pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_bytea(); + Ok(env.byte_array_from_slice(bytes)?) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetArrayValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, + class: JClass<'a>, +) -> JObject<'a> { + execute_and_catch(env, move |env: &mut EnvParam<'_>| { + let elems = pointer + .as_ref() + .datum_at(idx as usize) + .unwrap() + .into_list() + .iter(); + + // convert the Rust elements to a Java object array (Object[]) + let jarray = env.new_object_array(elems.len() as jsize, &class, JObject::null())?; + + for (i, ele) in elems.enumerate() { + let index = i as jsize; + match ele { + None => env.set_object_array_element(&jarray, i as jsize, JObject::null())?, + Some(val) => match val { + ScalarRefImpl::Int16(v) => { + let obj = env.call_static_method( + &class, + "valueOf", + "(S)Ljava.lang.Short;", + &[JValue::from(v as jshort)], + )?; + if let JValueOwned::Object(o) = obj { + env.set_object_array_element(&jarray, index, &o)? + } + } + ScalarRefImpl::Int32(v) => { + let obj = env.call_static_method( + &class, + "valueOf", + "(I)Ljava.lang.Integer;", + &[JValue::from(v as jint)], + )?; + if let JValueOwned::Object(o) = obj { + env.set_object_array_element(&jarray, index, &o)? + } + } + ScalarRefImpl::Int64(v) => { + let obj = env.call_static_method( + &class, + "valueOf", + "(J)Ljava.lang.Long;", + &[JValue::from(v as jlong)], + )?; + if let JValueOwned::Object(o) = obj { + env.set_object_array_element(&jarray, index, &o)? + } + } + ScalarRefImpl::Float32(v) => { + let obj = env.call_static_method( + &class, + "valueOf", + "(F)Ljava/lang/Float;", + &[JValue::from(v.into_inner() as jfloat)], + )?; + if let JValueOwned::Object(o) = obj { + env.set_object_array_element(&jarray, index, &o)? + } + } + ScalarRefImpl::Float64(v) => { + let obj = env.call_static_method( + &class, + "valueOf", + "(D)Ljava/lang/Double;", + &[JValue::from(v.into_inner() as jdouble)], + )?; + if let JValueOwned::Object(o) = obj { + env.set_object_array_element(&jarray, index, &o)? + } + } + ScalarRefImpl::Utf8(v) => { + let obj = env.new_string(v)?; + env.set_object_array_element(&jarray, index, obj)? + } + _ => env.set_object_array_element(&jarray, index, JObject::null())?, + }, + } + } + let output = unsafe { JObject::from_raw(jarray.into_raw()) }; + Ok(output) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>( + _env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, +) { + pointer.drop() +} + +#[inline(never)] +pub fn run_this_func_to_get_valid_ptr_from_java_binding() {} + +/// Send messages to the channel received by `CdcSplitReader`. +/// If msg is null, just check whether the channel is closed. +/// Return true if sending is successful, otherwise, return false so that caller can stop +/// gracefully. +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel<'a>( + env: EnvParam<'a>, + channel: Pointer<'a, GetEventStreamJniSender>, + msg: JByteArray<'a>, +) -> jboolean { + execute_and_catch(env, move |env| { + // If msg is null means just check whether channel is closed. + if msg.is_null() { + if channel.as_ref().is_closed() { + // Drop channel as well. + channel.drop(); + return Ok(JNI_FALSE); + } else { + return Ok(JNI_TRUE); + } + } + + let get_event_stream_response: GetEventStreamResponse = + Message::decode(to_guarded_slice(&msg, env)?.deref())?; + + tracing::debug!("before send"); + match channel.as_ref().blocking_send(get_event_stream_response) { + Ok(_) => { + tracing::debug!("send successfully"); + Ok(JNI_TRUE) + } + Err(e) => { + channel.drop(); + tracing::debug!("send error. {:?}", e); + Ok(JNI_FALSE) + } + } + }) +} + +#[cfg(test)] +mod tests { + use risingwave_common::types::{DataType, Timestamptz}; + use risingwave_expr::vector_op::cast::literal_parsing; + + /// make sure that the [`ScalarRefImpl::Int64`] received by + /// [`Java_com_risingwave_java_binding_Binding_rowGetTimestampValue`] + /// is of type [`DataType::Timestamptz`] stored in microseconds + #[test] + fn test_timestamptz_to_i64() { + assert_eq!( + literal_parsing(&DataType::Timestamptz, "2023-06-01 09:45:00+08:00").unwrap(), + Timestamptz::from_micros(1_685_583_900_000_000).into() + ); + } +} diff --git a/src/java_binding/src/stream_chunk_iterator.rs b/src/jni_core/src/stream_chunk_iterator.rs similarity index 100% rename from src/java_binding/src/stream_chunk_iterator.rs rename to src/jni_core/src/stream_chunk_iterator.rs diff --git a/src/meta/Cargo.toml b/src/meta/Cargo.toml index 61fb6118fdb3b..b17013decb5e9 100644 --- a/src/meta/Cargo.toml +++ b/src/meta/Cargo.toml @@ -49,6 +49,7 @@ risingwave_common = { workspace = true } risingwave_common_service = { workspace = true } risingwave_connector = { workspace = true } risingwave_hummock_sdk = { workspace = true } +risingwave_jni_core = { workspace = true } risingwave_object_store = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } diff --git a/src/meta/src/lib.rs b/src/meta/src/lib.rs index 53a962f86c5b5..b54beb27d3475 100644 --- a/src/meta/src/lib.rs +++ b/src/meta/src/lib.rs @@ -192,6 +192,7 @@ use std::future::Future; use std::pin::Pin; use risingwave_common::config::{load_config, MetaBackend, RwConfig}; +use risingwave_jni_core::jvm_runtime; use tracing::info; /// Start meta node @@ -246,6 +247,9 @@ pub fn start(opts: MetaNodeOpts) -> Pin + Send>> { dashboard_addr, ui_path: opts.dashboard_ui_path, }; + + jvm_runtime::register_native_method_for_jvm(); + let (mut join_handle, leader_lost_handle, shutdown_send) = rpc_serve( add_info, backend, From df9fc72d447b1ab094e551d1b3fc5e228ef06017 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Mon, 4 Sep 2023 20:10:13 +0800 Subject: [PATCH 17/23] refine --- src/connector/src/source/cdc/source/reader.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 76e7a1c675d97..b1e995af3f641 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -99,8 +99,9 @@ impl SplitReader for CdcSplitReader { } impl CdcSplitReader { + /// RPC version which is deprecated #[try_stream(boxed, ok = Vec, error = anyhow::Error)] - async fn ____into_data_stream(self) { + async fn into_data_stream_rpc_version(self) { let cdc_client = self.source_ctx.connector_client.clone().ok_or_else(|| { anyhow!("connector node endpoint not specified or unable to connect to connector node") })?; @@ -158,6 +159,7 @@ impl CdcSplitReader { } } + /// JNI version #[try_stream(boxed, ok = Vec, error = anyhow::Error)] async fn into_data_stream(self) { // rewrite the hostname and port for the split From 071b96d4dddfcbca0000997b9a7659d65bdb9376 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Tue, 5 Sep 2023 23:48:11 +0800 Subject: [PATCH 18/23] better naming --- .../connector/ConnectorService.java | 3 -- .../source/core/DbzCdcEngineRunner.java | 2 +- ...eHandler.java => JniDbzSourceHandler.java} | 31 ++++++++++++++++--- .../source/core/SourceHandlerFactory.java | 20 ------------ src/connector/src/source/cdc/source/reader.rs | 16 ++++++---- src/jni_core/src/jvm_runtime.rs | 4 +-- 6 files changed, 40 insertions(+), 36 deletions(-) rename java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/{JniSourceHandler.java => JniDbzSourceHandler.java} (75%) diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/ConnectorService.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/ConnectorService.java index 4c68ff45b88d0..810fd9d0f26f4 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/ConnectorService.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/ConnectorService.java @@ -37,9 +37,6 @@ public static void main(String[] args) throws Exception { CommandLineParser parser = new DefaultParser(); CommandLine cmd = parser.parse(options, args); - java.lang.Thread.currentThread() - .setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader()); - // Quoted from the debezium document: // > Your application should always properly stop the engine to ensure graceful and complete // > shutdown and that each source record is sent to the application exactly one time. diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java index f69f5774f3d84..ba9511b02303b 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java @@ -70,7 +70,7 @@ public static CdcEngineRunner newCdcEngineRunner( return runner; } - public static CdcEngineRunner newCdcEngineRunnerV2(DbzConnectorConfig config) { + public static CdcEngineRunner newCdcEngineRunner(DbzConnectorConfig config) { DbzCdcEngineRunner runner = null; try { var sourceId = config.getSourceId(); diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniDbzSourceHandler.java similarity index 75% rename from java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java rename to java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniDbzSourceHandler.java index c4fe4be63fc89..7b800f690bd2e 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniSourceHandler.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniDbzSourceHandler.java @@ -15,23 +15,46 @@ package com.risingwave.connector.source.core; import com.risingwave.connector.api.source.CdcEngineRunner; +import com.risingwave.connector.api.source.SourceTypeE; import com.risingwave.connector.source.common.DbzConnectorConfig; import com.risingwave.java.binding.Binding; import com.risingwave.metrics.ConnectorNodeMetrics; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.TimeUnit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** handler for starting a debezium source connectors */ -public class JniSourceHandler { +/** handler for starting a debezium source connectors for jni */ +public class JniDbzSourceHandler { static final Logger LOG = LoggerFactory.getLogger(DbzSourceHandler.class); private final DbzConnectorConfig config; - public JniSourceHandler(DbzConnectorConfig config) { + public JniDbzSourceHandler(DbzConnectorConfig config) { this.config = config; } + public static void runJniDbzSourceThread( + SourceTypeE source, + long sourceId, + String startOffset, + Map userProps, + boolean snapshotDone, + long channelPtr) { + // For jni.rs + java.lang.Thread.currentThread() + .setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader()); + // userProps extracted from grpc request, underlying implementation is UnmodifiableMap + Map mutableUserProps = new HashMap<>(userProps); + mutableUserProps.put("source.id", Long.toString(sourceId)); + var config = + new DbzConnectorConfig( + source, sourceId, startOffset, mutableUserProps, snapshotDone); + JniDbzSourceHandler handler = new JniDbzSourceHandler(config); + handler.start(channelPtr); + } + class OnReadyHandler implements Runnable { private final CdcEngineRunner runner; private final long channelPtr; @@ -79,7 +102,7 @@ public void run() { } public void start(long channelPtr) { - var runner = DbzCdcEngineRunner.newCdcEngineRunnerV2(config); + var runner = DbzCdcEngineRunner.newCdcEngineRunner(config); if (runner == null) { return; } diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java index ec51d3123e448..b60bcb4f7da5a 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/SourceHandlerFactory.java @@ -39,24 +39,4 @@ public static SourceHandler createSourceHandler( source, sourceId, startOffset, mutableUserProps, snapshotDone); return new DbzSourceHandler(config); } - - public static void startJniSourceHandler( - SourceTypeE source, - long sourceId, - String startOffset, - Map userProps, - boolean snapshotDone, - long channelPtr) { - // For jni.rs - java.lang.Thread.currentThread() - .setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader()); - // userProps extracted from grpc request, underlying implementation is UnmodifiableMap - Map mutableUserProps = new HashMap<>(userProps); - mutableUserProps.put("source.id", Long.toString(sourceId)); - var config = - new DbzConnectorConfig( - source, sourceId, startOffset, mutableUserProps, snapshotDone); - JniSourceHandler handler = new JniSourceHandler(config); - handler.start(channelPtr); - } } diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index b1e995af3f641..55cd39214eba8 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -196,7 +196,7 @@ impl CdcSplitReader { "(I)Lcom/risingwave/proto/ConnectorServiceProto$SourceType;", &[JValue::from(source_type as i32)], ) - .inspect_err(|e| tracing::error!("{:?}", e)) + .inspect_err(|e| tracing::error!("jni call error: {:?}", e)) .unwrap(); let st = env.call_static_method( @@ -204,7 +204,9 @@ impl CdcSplitReader { "valueOf", "(Lcom/risingwave/proto/ConnectorServiceProto$SourceType;)Lcom/risingwave/connector/api/source/SourceTypeE;", &[(&st).into()] - ).inspect_err(|e| tracing::error!("{:?}", e)).unwrap(); + ) + .inspect_err(|e| tracing::error!("jni call error: {:?}", e)) + .unwrap(); let start_offset = match self.start_offset { Some(start_offset) => { @@ -227,7 +229,7 @@ impl CdcSplitReader { "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", &args, ) - .inspect_err(|e| tracing::error!("{:?}", e)) + .inspect_err(|e| tracing::error!("jni call error: {:?}", e)) .unwrap(); } @@ -235,11 +237,13 @@ impl CdcSplitReader { let channel_ptr = JValue::from(channel_ptr); let _ = env.call_static_method( - "com/risingwave/connector/source/core/SourceHandlerFactory", - "startJniSourceHandler", + "com/risingwave/connector/source/core/JniDbzSourceHandler", + "runJniDbzSourceThread", "(Lcom/risingwave/connector/api/source/SourceTypeE;JLjava/lang/String;Ljava/util/Map;ZJ)V", &[(&st).into(), JValue::from(self.source_id as i64), (&start_offset).into(), JValue::Object(&java_map), JValue::from(self.snapshot_done), channel_ptr] - ).inspect_err(|e| tracing::error!("{:?}", e)).unwrap(); + ) + .inspect_err(|e| tracing::error!("jni call error: {:?}", e)) + .unwrap(); }); while let Some(GetEventStreamResponse { events, .. }) = rx.recv().await { diff --git a/src/jni_core/src/jvm_runtime.rs b/src/jni_core/src/jvm_runtime.rs index 93f56195646b1..8e6734b2171cf 100644 --- a/src/jni_core/src/jvm_runtime.rs +++ b/src/jni_core/src/jvm_runtime.rs @@ -87,10 +87,10 @@ pub fn register_native_method_for_jvm() { let mut env = JVM .attach_current_thread() - .inspect_err(|e| tracing::error!("{:?}", e)) + .inspect_err(|e| tracing::error!("jni call error: {:?}", e)) .unwrap(); - // FIXME: remove this function would cause segment fault. + // FIXME: remove this function might cause segment fault. run_this_func_to_get_valid_ptr_from_java_binding(); let binding_class = env From d9562ff5d7b74c6a5459cdbbc673ff6eeea63ac5 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Wed, 6 Sep 2023 10:52:19 +0800 Subject: [PATCH 19/23] remove jvm related codes --- Cargo.lock | 15 - java/com_risingwave_java_binding_Binding.h | 8 - .../source/core/DbzCdcEngineRunner.java | 27 -- .../source/core/JniDbzSourceHandler.java | 128 --------- .../com/risingwave/java/binding/Binding.java | 9 +- src/compute/Cargo.toml | 1 - src/compute/src/lib.rs | 3 - src/connector/Cargo.toml | 1 - src/connector/src/source/cdc/source/reader.rs | 115 +++----- src/jni_core/src/jvm_runtime.rs | 271 ------------------ src/jni_core/src/lib.rs | 48 +--- src/meta/src/lib.rs | 3 - src/workspace-hack/Cargo.toml | 2 - 13 files changed, 38 insertions(+), 593 deletions(-) delete mode 100644 java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniDbzSourceHandler.java delete mode 100644 src/jni_core/src/jvm_runtime.rs diff --git a/Cargo.lock b/Cargo.lock index 38949830591f9..c5b427716937d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3852,16 +3852,6 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" -[[package]] -name = "java-locator" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90003f2fd9c52f212c21d8520f1128da0080bad6fff16b68fe6e7f2f0c3780c2" -dependencies = [ - "glob", - "lazy_static", -] - [[package]] name = "jni" version = "0.21.1" @@ -3871,9 +3861,7 @@ dependencies = [ "cesu8", "cfg-if", "combine", - "java-locator", "jni-sys", - "libloading", "log", "thiserror", "walkdir", @@ -6827,7 +6815,6 @@ dependencies = [ "futures-async-stream", "hyper", "itertools 0.11.0", - "jni", "madsim-tokio", "madsim-tonic", "maplit", @@ -6891,7 +6878,6 @@ dependencies = [ "hyper-tls", "icelake", "itertools 0.11.0", - "jni", "jsonschema-transpiler", "madsim-rdkafka", "madsim-tokio", @@ -10128,7 +10114,6 @@ dependencies = [ "hyper", "indexmap 1.9.3", "itertools 0.10.5", - "jni", "lexical-core", "lexical-parse-float", "lexical-parse-integer", diff --git a/java/com_risingwave_java_binding_Binding.h b/java/com_risingwave_java_binding_Binding.h index a3e9aa95ec84e..c2c027ed22b58 100644 --- a/java/com_risingwave_java_binding_Binding.h +++ b/java/com_risingwave_java_binding_Binding.h @@ -223,14 +223,6 @@ JNIEXPORT void JNICALL Java_com_risingwave_java_binding_Binding_streamChunkItera JNIEXPORT jlong JNICALL Java_com_risingwave_java_binding_Binding_streamChunkIteratorFromPretty (JNIEnv *, jclass, jstring); -/* - * Class: com_risingwave_java_binding_Binding - * Method: sendCdcSourceMsgToChannel - * Signature: (J[B)Z - */ -JNIEXPORT jboolean JNICALL Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel - (JNIEnv *, jclass, jlong, jbyteArray); - #ifdef __cplusplus } #endif diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java index ba9511b02303b..e9fef6e869c04 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java @@ -70,33 +70,6 @@ public static CdcEngineRunner newCdcEngineRunner( return runner; } - public static CdcEngineRunner newCdcEngineRunner(DbzConnectorConfig config) { - DbzCdcEngineRunner runner = null; - try { - var sourceId = config.getSourceId(); - var engine = - new DbzCdcEngine( - config.getSourceId(), - config.getResolvedDebeziumProps(), - (success, message, error) -> { - if (!success) { - LOG.error( - "engine#{} terminated with error. message: {}", - sourceId, - message, - error); - } else { - LOG.info("engine#{} stopped normally. {}", sourceId, message); - } - }); - - runner = new DbzCdcEngineRunner(engine); - } catch (Exception e) { - LOG.error("failed to create the CDC engine", e); - } - return runner; - } - /** Start to run the cdc engine */ public void start() { if (isRunning()) { diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniDbzSourceHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniDbzSourceHandler.java deleted file mode 100644 index 7b800f690bd2e..0000000000000 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniDbzSourceHandler.java +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.risingwave.connector.source.core; - -import com.risingwave.connector.api.source.CdcEngineRunner; -import com.risingwave.connector.api.source.SourceTypeE; -import com.risingwave.connector.source.common.DbzConnectorConfig; -import com.risingwave.java.binding.Binding; -import com.risingwave.metrics.ConnectorNodeMetrics; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** handler for starting a debezium source connectors for jni */ -public class JniDbzSourceHandler { - static final Logger LOG = LoggerFactory.getLogger(DbzSourceHandler.class); - - private final DbzConnectorConfig config; - - public JniDbzSourceHandler(DbzConnectorConfig config) { - this.config = config; - } - - public static void runJniDbzSourceThread( - SourceTypeE source, - long sourceId, - String startOffset, - Map userProps, - boolean snapshotDone, - long channelPtr) { - // For jni.rs - java.lang.Thread.currentThread() - .setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader()); - // userProps extracted from grpc request, underlying implementation is UnmodifiableMap - Map mutableUserProps = new HashMap<>(userProps); - mutableUserProps.put("source.id", Long.toString(sourceId)); - var config = - new DbzConnectorConfig( - source, sourceId, startOffset, mutableUserProps, snapshotDone); - JniDbzSourceHandler handler = new JniDbzSourceHandler(config); - handler.start(channelPtr); - } - - class OnReadyHandler implements Runnable { - private final CdcEngineRunner runner; - private final long channelPtr; - - public OnReadyHandler(CdcEngineRunner runner, long channelPtr) { - this.runner = runner; - this.channelPtr = channelPtr; - } - - @Override - public void run() { - while (runner.isRunning()) { - try { - // check whether the send queue has room for new messages - // Thread will block on the channel to get output from engine - var resp = - runner.getEngine().getOutputChannel().poll(500, TimeUnit.MILLISECONDS); - boolean success; - if (resp != null) { - ConnectorNodeMetrics.incSourceRowsReceived( - config.getSourceType().toString(), - String.valueOf(config.getSourceId()), - resp.getEventsCount()); - LOG.info( - "Engine#{}: emit one chunk {} events to network ", - config.getSourceId(), - resp.getEventsCount()); - success = Binding.sendCdcSourceMsgToChannel(channelPtr, resp.toByteArray()); - } else { - // If resp is null means just check whether channel is closed. - success = Binding.sendCdcSourceMsgToChannel(channelPtr, null); - } - if (!success) { - LOG.info( - "Engine#{}: JNI sender broken detected, stop the engine", - config.getSourceId()); - runner.stop(); - return; - } - } catch (Throwable e) { - LOG.error("Poll engine output channel fail. ", e); - } - } - } - } - - public void start(long channelPtr) { - var runner = DbzCdcEngineRunner.newCdcEngineRunner(config); - if (runner == null) { - return; - } - - try { - // Start the engine - runner.start(); - LOG.info("Start consuming events of table {}", config.getSourceId()); - - final OnReadyHandler onReadyHandler = new OnReadyHandler(runner, channelPtr); - - onReadyHandler.run(); - - } catch (Throwable t) { - LOG.error("Cdc engine failed.", t); - try { - runner.stop(); - } catch (Exception e) { - LOG.warn("Failed to stop Engine#{}", config.getSourceId(), e); - } - } - } -} diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java index 4a79033b147a8..3f05768ec74b8 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java @@ -17,13 +17,8 @@ import io.questdb.jar.jni.JarJniLoader; public class Binding { - private static final boolean IS_EMBEDDED_CONNECTOR = - Boolean.parseBoolean(System.getProperty("is_embedded_connector")); - static { - if (!IS_EMBEDDED_CONNECTOR) { - JarJniLoader.loadLib(Binding.class, "/risingwave/jni", "risingwave_java_binding"); - } + JarJniLoader.loadLib(Binding.class, "/risingwave/jni", "risingwave_java_binding"); } public static native int vnodeCount(); @@ -89,6 +84,4 @@ public class Binding { static native void streamChunkIteratorClose(long pointer); static native long streamChunkIteratorFromPretty(String str); - - public static native boolean sendCdcSourceMsgToChannel(long channelPtr, byte[] msg); } diff --git a/src/compute/Cargo.toml b/src/compute/Cargo.toml index 8845dd0d80de2..8f59d2c21f6cb 100644 --- a/src/compute/Cargo.toml +++ b/src/compute/Cargo.toml @@ -24,7 +24,6 @@ futures = { version = "0.3", default-features = false, features = ["alloc"] } futures-async-stream = { workspace = true } hyper = "0.14" itertools = "0.11" -jni = { version = "0.21.1", features = ["invocation"] } maplit = "1.0.2" pprof = { version = "0.12", features = ["flamegraph"] } prometheus = { version = "0.13" } diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index 39153653fe374..6ca06f3253815 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -39,7 +39,6 @@ use clap::{Parser, ValueEnum}; use risingwave_common::config::{AsyncStackTraceOption, OverrideConfig}; use risingwave_common::util::resource_util::cpu::total_cpu_available; use risingwave_common::util::resource_util::memory::total_memory_available_bytes; -use risingwave_jni_core::jvm_runtime; use serde::{Deserialize, Serialize}; /// Command-line arguments for compute-node. @@ -214,8 +213,6 @@ pub fn start(opts: ComputeNodeOpts) -> Pin + Send>> .unwrap(); tracing::info!("advertise addr is {}", advertise_addr); - jvm_runtime::register_native_method_for_jvm(); - let (join_handle_vec, _shutdown_send) = compute_node_serve(listen_addr, advertise_addr, opts).await; diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index c1a748e34f0d2..c7945bf63b239 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -53,7 +53,6 @@ hyper = "0.14" hyper-tls = "0.5" icelake = { workspace = true } itertools = "0.11" -jni = { version = "0.21.1", features = ["invocation"] } jsonschema-transpiler = "1.10.0" maplit = "1.0.2" moka = { version = "0.11", features = ["future"] } diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 370b478318372..32e0c27ca2856 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -16,13 +16,10 @@ use std::str::FromStr; use anyhow::{anyhow, Result}; use async_trait::async_trait; +use futures::pin_mut; use futures_async_stream::try_stream; -use jni::objects::{JObject, JValue}; use risingwave_common::util::addr::HostAddr; -use risingwave_jni_core::jvm_runtime::JVM; -use risingwave_jni_core::GetEventStreamJniSender; use risingwave_pb::connector_service::GetEventStreamResponse; -use tokio::sync::mpsc; use crate::parser::ParserConfig; use crate::source::base::SourceMessage; @@ -98,8 +95,12 @@ impl SplitReader for CdcSplitReader { } impl CommonSplitReader for CdcSplitReader { - #[try_stream(boxed, ok = Vec, error = anyhow::Error)] + #[try_stream(ok = Vec, error = anyhow::Error)] async fn into_data_stream(self) { + let cdc_client = self.source_ctx.connector_client.clone().ok_or_else(|| { + anyhow!("connector node endpoint not specified or unable to connect to connector node") + })?; + // rewrite the hostname and port for the split let mut properties = self.conn_props.props.clone(); @@ -118,82 +119,38 @@ impl CommonSplitReader for CdcSplitReader { properties.insert("table.name".into(), table_name); } - let (tx, mut rx) = mpsc::channel(1024); - - let tx: Box = Box::new(tx); - - let source_type = self.conn_props.get_source_type_pb()?; - - std::thread::spawn(move || { - let mut env = JVM.attach_current_thread_as_daemon().unwrap(); - - let st = env - .call_static_method( - "com/risingwave/proto/ConnectorServiceProto$SourceType", - "forNumber", - "(I)Lcom/risingwave/proto/ConnectorServiceProto$SourceType;", - &[JValue::from(source_type as i32)], - ) - .inspect_err(|e| tracing::error!("jni call error: {:?}", e)) - .unwrap(); - - let st = env.call_static_method( - "com/risingwave/connector/api/source/SourceTypeE", - "valueOf", - "(Lcom/risingwave/proto/ConnectorServiceProto$SourceType;)Lcom/risingwave/connector/api/source/SourceTypeE;", - &[(&st).into()] - ) - .inspect_err(|e| tracing::error!("jni call error: {:?}", e)) - .unwrap(); - - let start_offset = match self.start_offset { - Some(start_offset) => { - let start_offset = env.new_string(start_offset).unwrap(); - env.call_method(start_offset, "toString", "()Ljava/lang/String;", &[]) - .unwrap() + let cdc_stream = cdc_client + .start_source_stream( + self.source_id, + self.conn_props.get_source_type_pb()?, + self.start_offset, + properties, + self.snapshot_done, + ) + .await + .inspect_err(|err| tracing::error!("connector node start stream error: {}", err))?; + pin_mut!(cdc_stream); + #[for_await] + for event_res in cdc_stream { + match event_res { + Ok(GetEventStreamResponse { events, .. }) => { + if events.is_empty() { + continue; + } + let mut msgs = Vec::with_capacity(events.len()); + for event in events { + msgs.push(SourceMessage::from(event)); + } + yield msgs; + } + Err(e) => { + return Err(anyhow!( + "Cdc service error: code {}, msg {}", + e.code(), + e.message() + )) } - None => jni::objects::JValueGen::Object(JObject::null()), - }; - - let java_map = env.new_object("java/util/HashMap", "()V", &[]).unwrap(); - - for (key, value) in &properties { - let key = env.new_string(key.to_string()).unwrap(); - let value = env.new_string(value.to_string()).unwrap(); - let args = [JValue::Object(&key), JValue::Object(&value)]; - env.call_method( - &java_map, - "put", - "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", - &args, - ) - .inspect_err(|e| tracing::error!("jni call error: {:?}", e)) - .unwrap(); - } - - let channel_ptr = Box::into_raw(tx) as i64; - let channel_ptr = JValue::from(channel_ptr); - - let _ = env.call_static_method( - "com/risingwave/connector/source/core/JniDbzSourceHandler", - "runJniDbzSourceThread", - "(Lcom/risingwave/connector/api/source/SourceTypeE;JLjava/lang/String;Ljava/util/Map;ZJ)V", - &[(&st).into(), JValue::from(self.source_id as i64), (&start_offset).into(), JValue::Object(&java_map), JValue::from(self.snapshot_done), channel_ptr] - ) - .inspect_err(|e| tracing::error!("jni call error: {:?}", e)) - .unwrap(); - }); - - while let Some(GetEventStreamResponse { events, .. }) = rx.recv().await { - tracing::debug!("receive events {:?}", events.len()); - if events.is_empty() { - continue; - } - let mut msgs = Vec::with_capacity(events.len()); - for event in events { - msgs.push(SourceMessage::from(event)); } - yield msgs; } } } diff --git a/src/jni_core/src/jvm_runtime.rs b/src/jni_core/src/jvm_runtime.rs deleted file mode 100644 index 8e6734b2171cf..0000000000000 --- a/src/jni_core/src/jvm_runtime.rs +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use core::option::Option::Some; -use core::result::Result::{Err, Ok}; -use std::ffi::c_void; -use std::fs; -use std::path::Path; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, LazyLock}; - -use jni::strings::JNIString; -use jni::{InitArgsBuilder, JNIVersion, JavaVM, NativeMethod}; - -use crate::run_this_func_to_get_valid_ptr_from_java_binding; - -pub static JVM: LazyLock> = LazyLock::new(|| { - let libs_path = ".risingwave/bin/connector-node/libs/"; - - let dir = Path::new(libs_path); - - if !dir.is_dir() { - panic!("{} is not a directory", libs_path); - } - - let mut class_vec = vec![]; - - if let Ok(entries) = fs::read_dir(dir) { - for entry in entries.flatten() { - if let Some(name) = entry.path().file_name() { - class_vec.push(libs_path.to_owned() + name.to_str().to_owned().unwrap()); - } - } - } else { - panic!("failed to read directory {}", libs_path); - } - - // Build the VM properties - let jvm_args = InitArgsBuilder::new() - // Pass the JNI API version (default is 8) - .version(JNIVersion::V8) - // You can additionally pass any JVM options (standard, like a system property, - // or VM-specific). - // Here we enable some extra JNI checks useful during development - // .option("-Xcheck:jni") - .option("-ea") - .option("-Dis_embedded_connector=true") - .option(format!("-Djava.class.path={}", class_vec.join(":"))) - // TODO: remove it - // .option("-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=9111") - .build() - .unwrap(); - - // Create a new VM - let jvm = match JavaVM::new(jvm_args) { - Err(err) => { - panic!("{:?}", err) - } - Ok(jvm) => jvm, - }; - - tracing::info!("initialize JVM successfully"); - Arc::new(jvm) -}); - -static REGISTERED: AtomicBool = AtomicBool::new(false); - -pub fn register_native_method_for_jvm() { - // Ensure registering only once. - if REGISTERED - .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) - .is_err() - { - return; - } - - let mut env = JVM - .attach_current_thread() - .inspect_err(|e| tracing::error!("jni call error: {:?}", e)) - .unwrap(); - - // FIXME: remove this function might cause segment fault. - run_this_func_to_get_valid_ptr_from_java_binding(); - - let binding_class = env - .find_class("com/risingwave/java/binding/Binding") - .unwrap(); - env.register_native_methods( - binding_class, - &[ - NativeMethod { - name: JNIString::from("vnodeCount"), - sig: JNIString::from("()I"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_vnodeCount as *mut c_void, - }, - NativeMethod { - name: JNIString::from("hummockIteratorNew"), - sig: JNIString::from("([B)J"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_hummockIteratorNew - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("hummockIteratorNext"), - sig: JNIString::from("(J)J"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_hummockIteratorNext - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("hummockIteratorClose"), - sig: JNIString::from("(J)V"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_hummockIteratorClose - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetKey"), - sig: JNIString::from("(J)[B"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetKey as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetOp"), - sig: JNIString::from("(J)I"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetOp as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowIsNull"), - sig: JNIString::from("(JI)Z"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowIsNull as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetInt16Value"), - sig: JNIString::from("(JI)S"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetInt16Value - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetInt32Value"), - sig: JNIString::from("(JI)I"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetInt32Value - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetInt64Value"), - sig: JNIString::from("(JI)J"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetInt64Value - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetFloatValue"), - sig: JNIString::from("(JI)F"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetFloatValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetDoubleValue"), - sig: JNIString::from("(JI)D"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetDoubleValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetBooleanValue"), - sig: JNIString::from("(JI)Z"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetBooleanValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetStringValue"), - sig: JNIString::from("(JI)Ljava/lang/String;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetStringValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetTimestampValue"), - sig: JNIString::from("(JI)Ljava/sql/Timestamp;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetTimestampValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetDecimalValue"), - sig: JNIString::from("(JI)Ljava/math/BigDecimal;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetDecimalValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetTimeValue"), - sig: JNIString::from("(JI)Ljava/sql/Time;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetTimeValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetDateValue"), - sig: JNIString::from("(JI)Ljava/sql/Date;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetDateValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetIntervalValue"), - sig: JNIString::from("(JI)Ljava/lang/String;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetIntervalValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetJsonbValue"), - sig: JNIString::from("(JI)Ljava/lang/String;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetJsonbValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetByteaValue"), - sig: JNIString::from("(JI)[B"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetByteaValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowGetArrayValue"), - sig: JNIString::from("(JILjava/lang/Class;)Ljava/lang/Object;"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowGetArrayValue - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("rowClose"), - sig: JNIString::from("(J)V"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_rowClose as *mut c_void, - }, - NativeMethod { - name: JNIString::from("streamChunkIteratorNew"), - sig: JNIString::from("([B)J"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorNew - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("streamChunkIteratorNext"), - sig: JNIString::from("(J)J"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorNext - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("streamChunkIteratorClose"), - sig: JNIString::from("(J)V"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorClose - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("streamChunkIteratorFromPretty"), - sig: JNIString::from("(Ljava/lang/String;)J"), - fn_ptr: - crate::Java_com_risingwave_java_binding_Binding_streamChunkIteratorFromPretty - as *mut c_void, - }, - NativeMethod { - name: JNIString::from("sendCdcSourceMsgToChannel"), - sig: JNIString::from("(J[B)Z"), - fn_ptr: crate::Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel - as *mut c_void, - }, - ], - ) - .unwrap(); - - tracing::info!("register native methods for jvm successfully"); -} diff --git a/src/jni_core/src/lib.rs b/src/jni_core/src/lib.rs index ce4da4da0534f..62625e14d21fe 100644 --- a/src/jni_core/src/lib.rs +++ b/src/jni_core/src/lib.rs @@ -20,7 +20,6 @@ #![feature(result_option_inspect)] pub mod hummock_iterator; -pub mod jvm_runtime; pub mod stream_chunk_iterator; use std::backtrace::Backtrace; @@ -35,9 +34,7 @@ use jni::objects::{ JValue, JValueGen, JValueOwned, ReleaseMode, }; use jni::signature::ReturnType; -use jni::sys::{ - jboolean, jbyte, jdouble, jfloat, jint, jlong, jshort, jsize, jvalue, JNI_FALSE, JNI_TRUE, -}; +use jni::sys::{jboolean, jbyte, jdouble, jfloat, jint, jlong, jshort, jsize, jvalue}; use jni::JNIEnv; use prost::{DecodeError, Message}; use risingwave_common::array::{ArrayError, StreamChunk}; @@ -826,49 +823,6 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>( pointer.drop() } -#[inline(never)] -pub fn run_this_func_to_get_valid_ptr_from_java_binding() {} - -/// Send messages to the channel received by `CdcSplitReader`. -/// If msg is null, just check whether the channel is closed. -/// Return true if sending is successful, otherwise, return false so that caller can stop -/// gracefully. -#[no_mangle] -pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel<'a>( - env: EnvParam<'a>, - channel: Pointer<'a, GetEventStreamJniSender>, - msg: JByteArray<'a>, -) -> jboolean { - execute_and_catch(env, move |env| { - // If msg is null means just check whether channel is closed. - if msg.is_null() { - if channel.as_ref().is_closed() { - // Drop channel as well. - channel.drop(); - return Ok(JNI_FALSE); - } else { - return Ok(JNI_TRUE); - } - } - - let get_event_stream_response: GetEventStreamResponse = - Message::decode(to_guarded_slice(&msg, env)?.deref())?; - - tracing::debug!("before send"); - match channel.as_ref().blocking_send(get_event_stream_response) { - Ok(_) => { - tracing::debug!("send successfully"); - Ok(JNI_TRUE) - } - Err(e) => { - channel.drop(); - tracing::debug!("send error. {:?}", e); - Ok(JNI_FALSE) - } - } - }) -} - #[cfg(test)] mod tests { use risingwave_common::types::{DataType, Timestamptz}; diff --git a/src/meta/src/lib.rs b/src/meta/src/lib.rs index b54beb27d3475..e812ee5217296 100644 --- a/src/meta/src/lib.rs +++ b/src/meta/src/lib.rs @@ -192,7 +192,6 @@ use std::future::Future; use std::pin::Pin; use risingwave_common::config::{load_config, MetaBackend, RwConfig}; -use risingwave_jni_core::jvm_runtime; use tracing::info; /// Start meta node @@ -248,8 +247,6 @@ pub fn start(opts: MetaNodeOpts) -> Pin + Send>> { ui_path: opts.dashboard_ui_path, }; - jvm_runtime::register_native_method_for_jvm(); - let (mut join_handle, leader_lost_handle, shutdown_send) = rpc_serve( add_info, backend, diff --git a/src/workspace-hack/Cargo.toml b/src/workspace-hack/Cargo.toml index ee1214540c502..a63eaa9abf35c 100644 --- a/src/workspace-hack/Cargo.toml +++ b/src/workspace-hack/Cargo.toml @@ -49,7 +49,6 @@ hashbrown-5ef9efb8ec2df382 = { package = "hashbrown", version = "0.12", features hyper = { version = "0.14", features = ["full"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } -jni = { version = "0.21", features = ["invocation"] } lexical-core = { version = "0.8", features = ["format"] } lexical-parse-float = { version = "0.8", default-features = false, features = ["format", "std"] } lexical-parse-integer = { version = "0.8", default-features = false, features = ["format", "std"] } @@ -146,7 +145,6 @@ hashbrown-5ef9efb8ec2df382 = { package = "hashbrown", version = "0.12", features hyper = { version = "0.14", features = ["full"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } -jni = { version = "0.21", features = ["invocation"] } lexical-core = { version = "0.8", features = ["format"] } lexical-parse-float = { version = "0.8", default-features = false, features = ["format", "std"] } lexical-parse-integer = { version = "0.8", default-features = false, features = ["format", "std"] } From 4f4bb7f3765ceff669fcd22ad293870a1fa7de83 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Wed, 6 Sep 2023 11:13:11 +0800 Subject: [PATCH 20/23] remove bin in jni_core --- .../data-chunk-payload-convert-generator.rs | 97 ------------------- .../src/bin/data-chunk-payload-generator.rs | 92 ------------------ 2 files changed, 189 deletions(-) delete mode 100644 src/jni_core/src/bin/data-chunk-payload-convert-generator.rs delete mode 100644 src/jni_core/src/bin/data-chunk-payload-generator.rs diff --git a/src/jni_core/src/bin/data-chunk-payload-convert-generator.rs b/src/jni_core/src/bin/data-chunk-payload-convert-generator.rs deleted file mode 100644 index 75d5afb8d27dc..0000000000000 --- a/src/jni_core/src/bin/data-chunk-payload-convert-generator.rs +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -use std::env; -use std::fs::File; -use std::io::{Read, Write}; -use std::process::exit; - -use prost::Message; -use risingwave_common::array::{Op, StreamChunk}; -use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, ScalarImpl}; -use risingwave_common::util::chunk_coalesce::DataChunkBuilder; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Deserialize, Serialize)] -struct Line { - id: u32, - name: String, -} - -#[derive(Debug, Deserialize, Serialize)] -struct Operation { - op_type: u32, - line: Line, -} - -fn convert_to_op(value: u32) -> Option { - match value { - 1 => Some(Op::Insert), - 2 => Some(Op::Delete), - 3 => Some(Op::UpdateInsert), - 4 => Some(Op::UpdateDelete), - _ => None, - } -} - -fn main() { - let args: Vec = env::args().collect(); - if args.len() <= 1 { - println!("No input file name"); - exit(0); - } - // Read the JSON file - let mut file = File::open(&args[1]).expect("Failed to open file"); - let mut contents = String::new(); - file.read_to_string(&mut contents) - .expect("Failed to read file"); - - // Parse the JSON data - let data: Vec> = serde_json::from_str(&contents).expect("Failed to parse JSON"); - - let data_types: Vec<_> = vec![DataType::Int32, DataType::Varchar]; - - // Access the data - let mut row_count = 0; - for operations in &data { - row_count += operations.len(); - } - let mut ops = Vec::with_capacity(row_count); - let mut builder = DataChunkBuilder::new(data_types, row_count * 1024); - - for operations in data { - for operation in operations { - let mut row_value = Vec::with_capacity(10); - row_value.push(Some(ScalarImpl::Int32(operation.line.id as i32))); - row_value.push(Some(ScalarImpl::Utf8(operation.line.name.into_boxed_str()))); - let _ = builder.append_one_row(OwnedRow::new(row_value)); - // let op: Op = unsafe { ::std::mem::transmute(operation.op_type as u8) }; - if let Some(op) = convert_to_op(operation.op_type) { - ops.push(op); - } else { - println!("Invalid value"); - } - } - } - - let data_chunk = builder.consume_all().expect("should not be empty"); - let stream_chunk = StreamChunk::from_parts(ops, data_chunk); - let prost_stream_chunk: risingwave_pb::data::StreamChunk = stream_chunk.to_protobuf(); - - let payload = Message::encode_to_vec(&prost_stream_chunk); - - std::io::stdout() - .write_all(&payload) - .expect("should success"); -} diff --git a/src/jni_core/src/bin/data-chunk-payload-generator.rs b/src/jni_core/src/bin/data-chunk-payload-generator.rs deleted file mode 100644 index f4d0dd6ff16f9..0000000000000 --- a/src/jni_core/src/bin/data-chunk-payload-generator.rs +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::env; -use std::io::Write; - -use prost::Message; -use risingwave_common::array::{Op, StreamChunk}; -use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, ScalarImpl, Timestamp, F32, F64}; -use risingwave_common::util::chunk_coalesce::DataChunkBuilder; - -fn build_row(index: usize) -> OwnedRow { - let mut row_value = Vec::with_capacity(10); - row_value.push(Some(ScalarImpl::Int16(index as i16))); - row_value.push(Some(ScalarImpl::Int32(index as i32))); - row_value.push(Some(ScalarImpl::Int64(index as i64))); - row_value.push(Some(ScalarImpl::Float32(F32::from(index as f32)))); - row_value.push(Some(ScalarImpl::Float64(F64::from(index as f64)))); - row_value.push(Some(ScalarImpl::Bool(index % 3 == 0))); - row_value.push(Some(ScalarImpl::Utf8( - format!("{}", index).repeat((index % 10) + 1).into(), - ))); - row_value.push(Some(ScalarImpl::Timestamp( - Timestamp::from_timestamp_uncheck(index as _, 0), - ))); - row_value.push(Some(ScalarImpl::Decimal(index.into()))); - row_value.push(if index % 5 == 0 { - None - } else { - Some(ScalarImpl::Int64(index as i64)) - }); - - OwnedRow::new(row_value) -} - -fn main() { - let args: Vec = env::args().collect(); - let mut flag = false; - let mut row_count: usize = 30000; - if args.len() > 1 { - flag = true; - row_count = args[1].parse().unwrap(); - } - let data_types = vec![ - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - DataType::Boolean, - DataType::Varchar, - DataType::Timestamp, - DataType::Decimal, - DataType::Int64, - ]; - let mut ops = Vec::with_capacity(row_count); - let mut builder = DataChunkBuilder::new(data_types, row_count * 1024); - for i in 0..row_count { - assert!( - builder.append_one_row(build_row(i)).is_none(), - "should not finish" - ); - // In unit test, it does not support delete operation - if flag || i % 2 == 0 { - ops.push(Op::Insert); - } else { - ops.push(Op::Delete); - } - } - - let data_chunk = builder.consume_all().expect("should not be empty"); - let stream_chunk = StreamChunk::from_parts(ops, data_chunk); - let prost_stream_chunk = stream_chunk.to_protobuf(); - - let payload = Message::encode_to_vec(&prost_stream_chunk); - - std::io::stdout() - .write_all(&payload) - .expect("should success"); -} From 018b33239208b64359da1d88c712a1855643fb88 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Wed, 6 Sep 2023 11:52:40 +0800 Subject: [PATCH 21/23] refine cargo toml --- Cargo.lock | 18 ------------------ src/compute/Cargo.toml | 1 - src/connector/Cargo.toml | 1 - src/java_binding/Cargo.toml | 23 ----------------------- src/java_binding/src/lib.rs | 1 - src/meta/Cargo.toml | 1 - 6 files changed, 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c5b427716937d..dfaa858f59b15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6826,7 +6826,6 @@ dependencies = [ "risingwave_common_service", "risingwave_connector", "risingwave_hummock_sdk", - "risingwave_jni_core", "risingwave_pb", "risingwave_rpc_client", "risingwave_source", @@ -6898,7 +6897,6 @@ dependencies = [ "rand", "reqwest", "risingwave_common", - "risingwave_jni_core", "risingwave_pb", "risingwave_rpc_client", "rust_decimal", @@ -7162,23 +7160,8 @@ dependencies = [ name = "risingwave_java_binding" version = "0.1.0" dependencies = [ - "bytes", - "futures", - "itertools 0.11.0", - "jni", - "madsim-tokio", - "prost", - "risingwave_common", "risingwave_expr", - "risingwave_hummock_sdk", "risingwave_jni_core", - "risingwave_object_store", - "risingwave_pb", - "risingwave_storage", - "serde", - "serde_json", - "thiserror", - "tracing", ] [[package]] @@ -7246,7 +7229,6 @@ dependencies = [ "risingwave_common_service", "risingwave_connector", "risingwave_hummock_sdk", - "risingwave_jni_core", "risingwave_object_store", "risingwave_pb", "risingwave_rpc_client", diff --git a/src/compute/Cargo.toml b/src/compute/Cargo.toml index 8f59d2c21f6cb..70aaf895e7b73 100644 --- a/src/compute/Cargo.toml +++ b/src/compute/Cargo.toml @@ -32,7 +32,6 @@ risingwave_common = { workspace = true } risingwave_common_service = { workspace = true } risingwave_connector = { workspace = true } risingwave_hummock_sdk = { workspace = true } -risingwave_jni_core = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } risingwave_source = { workspace = true } diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index c7945bf63b239..4188291614311 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -82,7 +82,6 @@ rdkafka = { workspace = true, features = [ ] } reqwest = { version = "0.11", features = ["json"] } risingwave_common = { workspace = true } -risingwave_jni_core = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } rust_decimal = "1" diff --git a/src/java_binding/Cargo.toml b/src/java_binding/Cargo.toml index 3eafacc84a49c..996aa8683a8cd 100644 --- a/src/java_binding/Cargo.toml +++ b/src/java_binding/Cargo.toml @@ -10,30 +10,7 @@ ignored = ["workspace-hack"] normal = ["workspace-hack"] [dependencies] -bytes = "1" -futures = { version = "0.3", default-features = false, features = ["alloc"] } -itertools = "0.11" -jni = "0.21.1" -prost = "0.11" -risingwave_common = { workspace = true } -risingwave_hummock_sdk = { workspace = true } risingwave_jni_core = { workspace = true } -risingwave_object_store = { workspace = true } -risingwave_pb = { workspace = true } -risingwave_storage = { workspace = true } -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -thiserror = "1" -tokio = { version = "0.2", package = "madsim-tokio", features = [ - "fs", - "rt", - "rt-multi-thread", - "sync", - "macros", - "time", - "signal", -] } -tracing = "0.1" [dev-dependencies] risingwave_expr = { workspace = true } diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index 28c8f0419aa86..6ccc450c09d5a 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -17,6 +17,5 @@ #![feature(lazy_cell)] #![feature(once_cell_try)] #![feature(type_alias_impl_trait)] -#![feature(result_option_inspect)] pub use risingwave_jni_core::*; diff --git a/src/meta/Cargo.toml b/src/meta/Cargo.toml index b17013decb5e9..61fb6118fdb3b 100644 --- a/src/meta/Cargo.toml +++ b/src/meta/Cargo.toml @@ -49,7 +49,6 @@ risingwave_common = { workspace = true } risingwave_common_service = { workspace = true } risingwave_connector = { workspace = true } risingwave_hummock_sdk = { workspace = true } -risingwave_jni_core = { workspace = true } risingwave_object_store = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } From 25e1a8a7a165cd352aa7d0fb63184f66eb55b48e Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Wed, 6 Sep 2023 12:04:26 +0800 Subject: [PATCH 22/23] add necessary dependencies to java_binding --- Cargo.lock | 5 +++++ src/java_binding/Cargo.toml | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index dfaa858f59b15..61894661accde 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7160,8 +7160,13 @@ dependencies = [ name = "risingwave_java_binding" version = "0.1.0" dependencies = [ + "prost", + "risingwave_common", "risingwave_expr", "risingwave_jni_core", + "risingwave_pb", + "serde", + "serde_json", ] [[package]] diff --git a/src/java_binding/Cargo.toml b/src/java_binding/Cargo.toml index 996aa8683a8cd..a0177b0dd2536 100644 --- a/src/java_binding/Cargo.toml +++ b/src/java_binding/Cargo.toml @@ -10,7 +10,13 @@ ignored = ["workspace-hack"] normal = ["workspace-hack"] [dependencies] +prost = "0.11" +risingwave_common = { workspace = true } +risingwave_pb = { workspace = true } risingwave_jni_core = { workspace = true } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + [dev-dependencies] risingwave_expr = { workspace = true } From 8b5b94a3905bd504aefcb9d943e4719c9c161d58 Mon Sep 17 00:00:00 2001 From: Dylan Chen Date: Wed, 6 Sep 2023 12:07:02 +0800 Subject: [PATCH 23/23] fmt --- src/java_binding/Cargo.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/java_binding/Cargo.toml b/src/java_binding/Cargo.toml index a0177b0dd2536..3280125f3ac49 100644 --- a/src/java_binding/Cargo.toml +++ b/src/java_binding/Cargo.toml @@ -12,12 +12,11 @@ normal = ["workspace-hack"] [dependencies] prost = "0.11" risingwave_common = { workspace = true } -risingwave_pb = { workspace = true } risingwave_jni_core = { workspace = true } +risingwave_pb = { workspace = true } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" - [dev-dependencies] risingwave_expr = { workspace = true }