Skip to content

Commit

Permalink
better error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
wenym1 committed Oct 30, 2023
1 parent 501acad commit a7eb77c
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 43 deletions.
1 change: 1 addition & 0 deletions src/connector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#![feature(iter_from_coroutine)]
#![feature(if_let_guard)]
#![feature(iterator_try_collect)]
#![feature(try_blocks)]

use std::time::Duration;

Expand Down
17 changes: 12 additions & 5 deletions src/connector/src/sink/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,22 +689,29 @@ impl EmbeddedConnectorClient {
&self,
class_name: &'static str,
method_name: &'static str,
request_rx: JniReceiverType<REQ>,
mut request_rx: JniReceiverType<REQ>,
) -> Receiver<std::result::Result<RSP, anyhow::Error>> {
let (response_tx, response_rx): (JniSenderType<RSP>, _) =
let (mut response_tx, response_rx): (JniSenderType<RSP>, _) =
mpsc::channel(DEFAULT_BUFFER_SIZE);

let jvm = self.jvm;
std::thread::spawn(move || {
let mut env = jvm.attach_current_thread().unwrap();
let mut env = match jvm.attach_current_thread() {
Ok(env) => env,
Err(e) => {
let _ = response_tx
.blocking_send(Err(anyhow!("failed to attach current thread: {:?}", e)));
return;
}
};

let result = env.call_static_method(
class_name,
method_name,
gen_jni_sig!(void f(long, long)),
&[
JValue::from(&request_rx as *const JniReceiverType<REQ> as i64),
JValue::from(&response_tx as *const JniSenderType<RSP> as i64),
JValue::from(&mut request_rx as *mut JniReceiverType<REQ> as i64),
JValue::from(&mut response_tx as *mut JniSenderType<RSP> as i64),
],
);

Expand Down
32 changes: 24 additions & 8 deletions src/connector/src/source/cdc/source/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use jni::objects::JValue;
use prost::Message;
use risingwave_common::util::addr::HostAddr;
use risingwave_jni_core::jvm_runtime::JVM;
use risingwave_jni_core::GetEventStreamJniSender;
use risingwave_jni_core::JniSenderType;
use risingwave_pb::connector_service::{GetEventStreamRequest, GetEventStreamResponse};
use tokio::sync::mpsc;

Expand Down Expand Up @@ -121,7 +121,7 @@ impl<T: CdcSourceTypeTrait> CommonSplitReader for CdcSplitReader<T> {
properties.insert("table.name".into(), table_name);
}

let (tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
let (mut tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);

let jvm = JVM
.get()
Expand All @@ -139,18 +139,33 @@ impl<T: CdcSourceTypeTrait> CommonSplitReader for CdcSplitReader<T> {
let source_type = get_event_stream_request.source_type.to_string();

std::thread::spawn(move || {
let mut env = jvm.attach_current_thread().unwrap();
let result: anyhow::Result<_> = try {
let env = jvm.attach_current_thread()?;

let get_event_stream_request_bytes =
env.byte_array_from_slice(&Message::encode_to_vec(&get_event_stream_request))?;

(env, get_event_stream_request_bytes)
};

let (mut env, get_event_stream_request_bytes) = match result {
Ok(inner) => inner,
Err(e) => {
let _ = tx.blocking_send(Err(anyhow!(
"err before calling runJniDbzSourceThread: {:?}",
e
)));
return;
}
};

let get_event_stream_request_bytes = env
.byte_array_from_slice(&Message::encode_to_vec(&get_event_stream_request))
.unwrap();
let result = env.call_static_method(
"com/risingwave/connector/source/core/JniDbzSourceHandler",
"runJniDbzSourceThread",
"([BJ)V",
&[
JValue::Object(&get_event_stream_request_bytes),
JValue::from(&tx as *const GetEventStreamJniSender as i64),
JValue::from(&mut tx as *mut JniSenderType<GetEventStreamResponse> as i64),
],
);

Expand All @@ -164,7 +179,8 @@ impl<T: CdcSourceTypeTrait> CommonSplitReader for CdcSplitReader<T> {
}
});

while let Some(GetEventStreamResponse { events, .. }) = rx.recv().await {
while let Some(result) = rx.recv().await {
let GetEventStreamResponse { events, .. } = result?;
tracing::trace!("receive events {:?}", events.len());
self.source_ctx
.metrics
Expand Down
52 changes: 28 additions & 24 deletions src/jni_core/src/jvm_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,11 @@ impl JavaVmWrapper {
pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::Error> {
let mut env = jvm
.attach_current_thread()
.inspect_err(|e| tracing::error!("jvm attach thread error: {:?}", e))
.unwrap();
.inspect_err(|e| tracing::error!("jvm attach thread error: {:?}", e))?;

let binding_class = env
.find_class("com/risingwave/java/binding/Binding")
.inspect_err(|e| tracing::error!("jvm find class error: {:?}", e))
.unwrap();
.inspect_err(|e| tracing::error!("jvm find class error: {:?}", e))?;
use crate::*;
macro_rules! gen_native_method_array {
() => {{
Expand Down Expand Up @@ -163,34 +161,40 @@ pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::E
/// Load JVM memory statistics from the runtime. If JVM is not initialized or fail to initialize, return zero.
pub fn load_jvm_memory_stats() -> (usize, usize) {
if let Ok(jvm) = JVM.get() {
let mut env = jvm.attach_current_thread().unwrap();
let runtime_instance = env
.call_static_method(
let result: Result<(usize, usize), jni::errors::Error> = try {
let mut env = jvm.attach_current_thread()?;

let runtime_instance = env.call_static_method(
"java/lang/Runtime",
"getRuntime",
"()Ljava/lang/Runtime;",
&[],
)
.unwrap();
)?;

let runtime_instance = match runtime_instance {
JValueOwned::Object(o) => o,
_ => unreachable!(),
};
let runtime_instance = match runtime_instance {
JValueOwned::Object(o) => o,
_ => unreachable!(),
};

let total_memory = env
.call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[])
.unwrap()
.j()
.unwrap();
let total_memory = env
.call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[])?
.j()
.expect("should be long");

let free_memory = env
.call_method(runtime_instance, "freeMemory", "()J", &[])
.unwrap()
.j()
.unwrap();
let free_memory = env
.call_method(runtime_instance, "freeMemory", "()J", &[])?
.j()
.expect("should be long");

(total_memory as usize, (total_memory - free_memory) as usize)
(total_memory as usize, (total_memory - free_memory) as usize)
};
match result {
Ok(ret) => ret,
Err(e) => {
error!("failed to collect jvm stats: {:?}", e);
(0, 0)
}
}
} else {
(0, 0)
}
Expand Down
15 changes: 9 additions & 6 deletions src/jni_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#![feature(once_cell_try)]
#![feature(type_alias_impl_trait)]
#![feature(result_option_inspect)]
#![feature(try_blocks)]

pub mod hummock_iterator;
pub mod jvm_runtime;
Expand Down Expand Up @@ -60,7 +61,6 @@ use tokio::sync::mpsc::{Receiver, Sender};
use crate::hummock_iterator::HummockJavaBindingIterator;
pub use crate::jvm_runtime::register_native_method_for_jvm;
use crate::stream_chunk_iterator::{into_iter, StreamChunkRowIterator};
pub type GetEventStreamJniSender = Sender<GetEventStreamResponse>;

static RUNTIME: LazyLock<Runtime> = LazyLock::new(|| tokio::runtime::Runtime::new().unwrap());

Expand Down Expand Up @@ -846,14 +846,17 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetArrayValu
})
}

pub type JniSenderType<T> = Sender<anyhow::Result<T>>;
pub type JniReceiverType<T> = Receiver<T>;

/// Send messages to the channel received by `CdcSplitReader`.
/// If msg is null, just check whether the channel is closed.
/// Return true if sending is successful, otherwise, return false so that caller can stop
/// gracefully.
#[no_mangle]
extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel<'a>(
env: EnvParam<'a>,
channel: Pointer<'a, GetEventStreamJniSender>,
channel: Pointer<'a, JniSenderType<GetEventStreamResponse>>,
msg: JByteArray<'a>,
) -> jboolean {
execute_and_catch(env, move |env| {
Expand All @@ -869,7 +872,10 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToCh
let get_event_stream_response: GetEventStreamResponse =
Message::decode(to_guarded_slice(&msg, env)?.deref())?;

match channel.as_ref().blocking_send(get_event_stream_response) {
match channel
.as_ref()
.blocking_send(Ok(get_event_stream_response))
{
Ok(_) => Ok(JNI_TRUE),
Err(e) => {
tracing::info!("send error. {:?}", e);
Expand All @@ -879,9 +885,6 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToCh
})
}

pub type JniSenderType<T> = Sender<anyhow::Result<T>>;
pub type JniReceiverType<T> = Receiver<T>;

#[no_mangle]
pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkWriterRequestFromChannel<
'a,
Expand Down

0 comments on commit a7eb77c

Please sign in to comment.