Skip to content

Commit

Permalink
refactor(connector-node): jni reuse bidi stream handle
Browse files Browse the repository at this point in the history
  • Loading branch information
wenym1 committed Oct 30, 2023
1 parent b0f266b commit 501acad
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 497 deletions.
460 changes: 143 additions & 317 deletions src/connector/src/sink/remote.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/connector/src/source/cdc/enumerator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ where
SourceType::from(T::source_type())
);

let mut env = JVM.get_or_init()?.attach_current_thread()?;
let mut env = JVM.get_with_err()?.attach_current_thread()?;

let validate_source_request = ValidateSourceRequest {
source_id: context.info.source_id as u64,
Expand Down
7 changes: 4 additions & 3 deletions src/connector/src/source/cdc/source/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ impl<T: CdcSourceTypeTrait> CommonSplitReader for CdcSplitReader<T> {

let (tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);

// Force init, because we don't want to see initialization failure in the following thread.
JVM.get_or_init()?;
let jvm = JVM
.get()
.map_err(|e| anyhow!("jvm not initialized properly: {:?}", e))?;

let get_event_stream_request = GetEventStreamRequest {
source_id: self.source_id,
Expand All @@ -138,7 +139,7 @@ impl<T: CdcSourceTypeTrait> CommonSplitReader for CdcSplitReader<T> {
let source_type = get_event_stream_request.source_type.to_string();

std::thread::spawn(move || {
let mut env = JVM.get_or_init().unwrap().attach_current_thread().unwrap();
let mut env = jvm.attach_current_thread().unwrap();

let get_event_stream_request_bytes = env
.byte_array_from_slice(&Message::encode_to_vec(&get_event_stream_request))
Expand Down
117 changes: 57 additions & 60 deletions src/jni_core/src/jvm_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,50 +18,49 @@ use std::fs;
use std::path::Path;
use std::sync::OnceLock;

use anyhow::anyhow;
use jni::objects::JValueOwned;
use jni::strings::JNIString;
use jni::{InitArgsBuilder, JNIVersion, JavaVM, NativeMethod};
use risingwave_common::error::{ErrorCode, RwError};
use risingwave_common::util::resource_util::memory::system_memory_available_bytes;
use tracing::error;

/// Use 10% of compute total memory by default. Compute node uses 0.7 * system memory by default.
const DEFAULT_MEMORY_PROPORTION: f64 = 0.07;

pub static JVM: JavaVmWrapper = JavaVmWrapper::new();
pub static JVM: JavaVmWrapper = JavaVmWrapper;

pub struct JavaVmWrapper(OnceLock<Result<JavaVM, RwError>>);
pub struct JavaVmWrapper;

impl JavaVmWrapper {
const fn new() -> Self {
Self(OnceLock::new())
pub fn get(&self) -> Result<&'static JavaVM, &String> {
static JVM_RESULT: OnceLock<Result<JavaVM, String>> = OnceLock::new();
JVM_RESULT
.get_or_init(|| {
Self::inner_new().inspect_err(|e| error!("failed to init jvm: {:?}", e))
})
.as_ref()
}

pub fn get(&self) -> Option<&Result<JavaVM, RwError>> {
self.0.get()
pub fn get_with_err(&self) -> anyhow::Result<&'static JavaVM> {
self.get()
.map_err(|e| anyhow!("jvm not initialized properly: {:?}", e))
}

pub fn get_or_init(&self) -> Result<&JavaVM, &RwError> {
self.0.get_or_init(Self::inner_new).as_ref()
}

fn inner_new() -> Result<JavaVM, RwError> {
fn inner_new() -> Result<JavaVM, String> {
let libs_path = if let Ok(libs_path) = std::env::var("CONNECTOR_LIBS_PATH") {
libs_path
} else {
return Err(ErrorCode::InternalError(
"environment variable CONNECTOR_LIBS_PATH is not specified".to_string(),
)
.into());
return Err("environment variable CONNECTOR_LIBS_PATH is not specified".to_string());
};

let dir = Path::new(&libs_path);

if !dir.is_dir() {
return Err(ErrorCode::InternalError(format!(
return Err(format!(
"CONNECTOR_LIBS_PATH \"{}\" is not a directory",
libs_path
))
.into());
));
}

let mut class_vec = vec![];
Expand All @@ -70,16 +69,16 @@ impl JavaVmWrapper {
for entry in entries.flatten() {
let entry_path = entry.path();
if entry_path.file_name().is_some() {
let path = std::fs::canonicalize(entry_path)?;
let path = std::fs::canonicalize(entry_path)
.expect("valid entry_path obtained from fs::read_dir");
class_vec.push(path.to_str().unwrap().to_string());
}
}
} else {
return Err(ErrorCode::InternalError(format!(
return Err(format!(
"failed to read CONNECTOR_LIBS_PATH \"{}\"",
libs_path
))
.into());
));
}

let jvm_heap_size = if let Ok(heap_size) = std::env::var("JVM_HEAP_SIZE") {
Expand All @@ -101,20 +100,23 @@ impl JavaVmWrapper {
.option(format!("-Xmx{}", jvm_heap_size));

tracing::info!("JVM args: {:?}", args_builder);
let jvm_args = args_builder.build().unwrap();
let jvm_args = args_builder
.build()
.map_err(|e| format!("invalid jvm args: {:?}", e))?;

// Create a new VM
let jvm = match JavaVM::new(jvm_args) {
Err(err) => {
tracing::error!("fail to new JVM {:?}", err);
return Err(ErrorCode::InternalError("fail to new JVM".to_string()).into());
return Err("fail to new JVM".to_string());
}
Ok(jvm) => jvm,
};

tracing::info!("initialize JVM successfully");

register_native_method_for_jvm(&jvm).unwrap();
register_native_method_for_jvm(&jvm)
.map_err(|e| format!("failed to register native method: {:?}", e))?;

Ok(jvm)
}
Expand Down Expand Up @@ -160,40 +162,35 @@ pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::E

/// Load JVM memory statistics from the runtime. If JVM is not initialized or fail to initialize, return zero.
pub fn load_jvm_memory_stats() -> (usize, usize) {
if let Some(jvm) = JVM.get() {
match jvm {
Ok(jvm) => {
let mut env = jvm.attach_current_thread().unwrap();
let runtime_instance = env
.call_static_method(
"java/lang/Runtime",
"getRuntime",
"()Ljava/lang/Runtime;",
&[],
)
.unwrap();

let runtime_instance = match runtime_instance {
JValueOwned::Object(o) => o,
_ => unreachable!(),
};

let total_memory = env
.call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[])
.unwrap()
.j()
.unwrap();

let free_memory = env
.call_method(runtime_instance, "freeMemory", "()J", &[])
.unwrap()
.j()
.unwrap();

(total_memory as usize, (total_memory - free_memory) as usize)
}
Err(_) => (0, 0),
}
if let Ok(jvm) = JVM.get() {
let mut env = jvm.attach_current_thread().unwrap();
let runtime_instance = env
.call_static_method(
"java/lang/Runtime",
"getRuntime",
"()Ljava/lang/Runtime;",
&[],
)
.unwrap();

let runtime_instance = match runtime_instance {
JValueOwned::Object(o) => o,
_ => unreachable!(),
};

let total_memory = env
.call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[])
.unwrap()
.j()
.unwrap();

let free_memory = env
.call_method(runtime_instance, "freeMemory", "()J", &[])
.unwrap()
.j()
.unwrap();

(total_memory as usize, (total_memory - free_memory) as usize)
} else {
(0, 0)
}
Expand Down
13 changes: 8 additions & 5 deletions src/jni_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -879,12 +879,15 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToCh
})
}

pub type JniSenderType<T> = Sender<anyhow::Result<T>>;
pub type JniReceiverType<T> = Receiver<T>;

#[no_mangle]
pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkWriterRequestFromChannel<
'a,
>(
env: EnvParam<'a>,
mut channel: Pointer<'a, Receiver<SinkWriterStreamRequest>>,
mut channel: Pointer<'a, JniReceiverType<SinkWriterStreamRequest>>,
) -> JByteArray<'a> {
execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() {
Some(msg) => {
Expand All @@ -902,7 +905,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkWriterRe
'a,
>(
env: EnvParam<'a>,
channel: Pointer<'a, Sender<anyhow::Result<SinkWriterStreamResponse>>>,
channel: Pointer<'a, JniSenderType<SinkWriterStreamResponse>>,
msg: JByteArray<'a>,
) -> jboolean {
execute_and_catch(env, move |env| {
Expand All @@ -927,7 +930,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkCoordina
'a,
>(
env: EnvParam<'a>,
mut channel: Pointer<'a, Receiver<SinkCoordinatorStreamRequest>>,
mut channel: Pointer<'a, JniReceiverType<SinkCoordinatorStreamRequest>>,
) -> JByteArray<'a> {
execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() {
Some(msg) => {
Expand All @@ -945,7 +948,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkCoordina
'a,
>(
env: EnvParam<'a>,
channel: Pointer<'a, Sender<SinkCoordinatorStreamResponse>>,
channel: Pointer<'a, JniSenderType<SinkCoordinatorStreamResponse>>,
msg: JByteArray<'a>,
) -> jboolean {
execute_and_catch(env, move |env| {
Expand All @@ -954,7 +957,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkCoordina

match channel
.as_ref()
.blocking_send(sink_coordinator_stream_response)
.blocking_send(Ok(sink_coordinator_stream_response))
{
Ok(_) => Ok(JNI_TRUE),
Err(e) => {
Expand Down
24 changes: 12 additions & 12 deletions src/jni_core/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@ macro_rules! gen_class_name {
stringify! {$last}
};
($first:ident . $($rest:ident).+) => {
concat! {stringify! {$first}, "/", gen_class_name! {$($rest).+} }
concat! {stringify! {$first}, "/", $crate::gen_class_name! {$($rest).+} }
}
}

#[macro_export]
macro_rules! gen_jni_sig_inner {
($(public)? static native $($rest:tt)*) => {
gen_jni_sig_inner! { $($rest)* }
$crate::gen_jni_sig_inner! { $($rest)* }
};
($($ret:ident).+ $($func_name:ident)? ($($args:tt)*)) => {
concat! {"(", gen_jni_sig_inner!{$($args)*}, ")", gen_jni_sig_inner! {$($ret).+} }
concat! {"(", $crate::gen_jni_sig_inner!{$($args)*}, ")", $crate::gen_jni_sig_inner! {$($ret).+} }
};
($($ret:ident).+ [] $($func_name:ident)? ($($args:tt)*)) => {
concat! {"(", gen_jni_sig_inner!{$($args)*}, ")", gen_jni_sig_inner! {$($ret).+ []} }
concat! {"(", $crate::gen_jni_sig_inner!{$($args)*}, ")", $crate::gen_jni_sig_inner! {$($ret).+ []} }
};
(boolean) => {
"Z"
Expand Down Expand Up @@ -61,25 +61,25 @@ macro_rules! gen_jni_sig_inner {
"V"
};
(String) => {
gen_jni_sig_inner! { java.lang.String }
$crate::gen_jni_sig_inner! { java.lang.String }
};
(Object) => {
gen_jni_sig_inner! { java.lang.Object }
$crate::gen_jni_sig_inner! { java.lang.Object }
};
(Class) => {
gen_jni_sig_inner! { java.lang.Class }
$crate::gen_jni_sig_inner! { java.lang.Class }
};
($($class_part:ident).+) => {
concat! {"L", gen_class_name! {$($class_part).+}, ";"}
concat! {"L", $crate::gen_class_name! {$($class_part).+}, ";"}
};
($($class_part:ident).+ $(.)? [] $($param_name:ident)? $(,$($rest:tt)*)?) => {
concat! { "[", gen_jni_sig_inner! {$($class_part).+}, gen_jni_sig_inner! {$($($rest)*)?}}
concat! { "[", $crate::gen_jni_sig_inner! {$($class_part).+}, $crate::gen_jni_sig_inner! {$($($rest)*)?}}
};
(Class $(< ? >)? $($param_name:ident)? $(,$($rest:tt)*)?) => {
concat! { gen_jni_sig_inner! { Class }, gen_jni_sig_inner! {$($($rest)*)?}}
concat! { $crate::gen_jni_sig_inner! { Class }, $crate::gen_jni_sig_inner! {$($($rest)*)?}}
};
($($class_part:ident).+ $($param_name:ident)? $(,$($rest:tt)*)?) => {
concat! { gen_jni_sig_inner! {$($class_part).+}, gen_jni_sig_inner! {$($($rest)*)?}}
concat! { $crate::gen_jni_sig_inner! {$($class_part).+}, $crate::gen_jni_sig_inner! {$($($rest)*)?}}
};
() => {
""
Expand All @@ -93,7 +93,7 @@ macro_rules! gen_jni_sig_inner {
macro_rules! gen_jni_sig {
($($input:tt)*) => {{
// this macro only provide with a expression context
gen_jni_sig_inner! {$($input)*}
$crate::gen_jni_sig_inner! {$($input)*}
}}
}

Expand Down
Loading

0 comments on commit 501acad

Please sign in to comment.