From 501acadededce2b3876c9a875087b70c08e39cfe Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 30 Oct 2023 12:52:13 +0800 Subject: [PATCH 1/4] refactor(connector-node): jni reuse bidi stream handle --- src/connector/src/sink/remote.rs | 460 ++++++------------ .../src/source/cdc/enumerator/mod.rs | 2 +- src/connector/src/source/cdc/source/reader.rs | 7 +- src/jni_core/src/jvm_runtime.rs | 117 +++-- src/jni_core/src/lib.rs | 13 +- src/jni_core/src/macros.rs | 24 +- .../src/manager/sink_coordination/manager.rs | 93 ++-- src/rpc_client/src/connector_client.rs | 18 +- src/rpc_client/src/lib.rs | 89 ++-- src/rpc_client/src/sink_coordinate_client.rs | 25 +- 10 files changed, 351 insertions(+), 497 deletions(-) diff --git a/src/connector/src/sink/remote.rs b/src/connector/src/sink/remote.rs index 3c52cb720dbd4..a3e06b8a22cc3 100644 --- a/src/connector/src/sink/remote.rs +++ b/src/connector/src/sink/remote.rs @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::any::Any; use std::collections::HashMap; -use std::fmt::Formatter; use std::future::Future; use std::marker::PhantomData; use std::ops::Deref; @@ -21,25 +21,24 @@ use std::time::Instant; use anyhow::anyhow; use async_trait::async_trait; -use futures::stream::Peekable; use futures::{StreamExt, TryFutureExt, TryStreamExt}; use itertools::Itertools; use jni::objects::{JByteArray, JValue, JValueOwned}; +use jni::JavaVM; use prost::Message; use risingwave_common::array::StreamChunk; use risingwave_common::error::anyhow_error; use risingwave_common::types::DataType; use risingwave_common::util::await_future_with_monitor_error_stream; use risingwave_jni_core::jvm_runtime::JVM; -use risingwave_pb::connector_service::sink_coordinator_stream_request::{ - CommitMetadata, StartCoordinator, -}; +use risingwave_jni_core::{gen_class_name, gen_jni_sig, JniReceiverType, JniSenderType}; +use risingwave_pb::connector_service::sink_coordinator_stream_request::StartCoordinator; use risingwave_pb::connector_service::sink_writer_stream_request::write_batch::json_payload::RowOp; use risingwave_pb::connector_service::sink_writer_stream_request::write_batch::{ JsonPayload, Payload, StreamChunkPayload, }; use risingwave_pb::connector_service::sink_writer_stream_request::{ - Barrier, BeginEpoch, Request as SinkRequest, StartSink, WriteBatch, + Request as SinkRequest, StartSink, }; use risingwave_pb::connector_service::sink_writer_stream_response::CommitResponse; use risingwave_pb::connector_service::{ @@ -47,6 +46,10 @@ use risingwave_pb::connector_service::{ SinkCoordinatorStreamRequest, SinkCoordinatorStreamResponse, SinkMetadata, SinkPayloadFormat, SinkWriterStreamRequest, SinkWriterStreamResponse, ValidateSinkRequest, ValidateSinkResponse, }; +use risingwave_rpc_client::error::RpcError; +use risingwave_rpc_client::{ + BidiStreamReceiver, SinkCoordinatorStreamHandle, SinkWriterStreamHandle, DEFAULT_BUFFER_SIZE, +}; use tokio::sync::mpsc; use tokio::sync::mpsc::{Receiver, Sender}; use tokio_stream::wrappers::ReceiverStream; @@ -150,8 +153,7 @@ impl Sink for RemoteSink { }).try_collect()?; let mut env = JVM - .get_or_init() - .map_err(|err| SinkError::Internal(err.into()))? + .get_with_err()? .attach_current_thread() .map_err(|err| SinkError::Internal(err.into()))?; let validate_sink_request = ValidateSinkRequest { @@ -217,13 +219,13 @@ impl RemoteLogSinker { /// Await the given future while monitoring on error of the receiver stream. async fn await_future_with_monitor_receiver_err>>( - receiver: &mut SinkWriterStreamJniReceiver, + receiver: &mut BidiStreamReceiver, future: F, ) -> Result { - match await_future_with_monitor_error_stream(&mut receiver.response_stream, future).await { - Ok(result) => result, + match await_future_with_monitor_error_stream(&mut receiver.stream, future).await { + Ok(result) => Ok(result?), Err(None) => Err(SinkError::Remote(anyhow!("end of remote receiver stream"))), - Err(Some(err)) => Err(SinkError::Internal(err)), + Err(Some(err)) => Err(SinkError::Remote(err.into())), } } @@ -254,7 +256,7 @@ impl LogSinker for RemoteLogSinker { loop { let (epoch, item): (u64, LogStoreReadItem) = await_future_with_monitor_receiver_err( - &mut sink_writer.stream_handle.response_rx, + &mut sink_writer.stream_handle.response_stream, log_reader.next_item().map_err(SinkError::Internal), ) .await?; @@ -378,155 +380,6 @@ impl Sink for CoordinatedRemoteSink { } } -#[derive(Debug)] -pub struct SinkCoordinatorStreamJniHandle { - request_tx: Sender, - response_rx: Receiver, -} - -impl SinkCoordinatorStreamJniHandle { - pub async fn commit(&mut self, epoch: u64, metadata: Vec) -> Result<()> { - self.request_tx - .send(SinkCoordinatorStreamRequest { - request: Some(sink_coordinator_stream_request::Request::Commit( - CommitMetadata { epoch, metadata }, - )), - }) - .await - .map_err(|err| SinkError::Internal(err.into()))?; - - match self.response_rx.recv().await { - Some(SinkCoordinatorStreamResponse { - response: - Some(sink_coordinator_stream_response::Response::Commit( - sink_coordinator_stream_response::CommitResponse { - epoch: response_epoch, - }, - )), - }) => { - if epoch == response_epoch { - Ok(()) - } else { - Err(SinkError::Internal(anyhow!( - "get different response epoch to commit epoch: {} {}", - epoch, - response_epoch - ))) - } - } - msg => Err(SinkError::Internal(anyhow!( - "should get Commit response but get {:?}", - msg - ))), - } - } -} - -struct SinkWriterStreamJniSender { - request_tx: Sender, -} - -impl SinkWriterStreamJniSender { - pub async fn start_epoch(&mut self, epoch: u64) -> Result<()> { - self.request_tx - .send(SinkWriterStreamRequest { - request: Some(SinkRequest::BeginEpoch(BeginEpoch { epoch })), - }) - .await - .map_err(|err| SinkError::Internal(err.into())) - } - - pub async fn write_batch(&mut self, epoch: u64, batch_id: u64, payload: Payload) -> Result<()> { - self.request_tx - .send(SinkWriterStreamRequest { - request: Some(SinkRequest::WriteBatch(WriteBatch { - epoch, - batch_id, - payload: Some(payload), - })), - }) - .await - .map_err(|err| SinkError::Internal(err.into())) - } - - pub async fn barrier(&mut self, epoch: u64, is_checkpoint: bool) -> Result<()> { - self.request_tx - .send(SinkWriterStreamRequest { - request: Some(SinkRequest::Barrier(Barrier { - epoch, - is_checkpoint, - })), - }) - .await - .map_err(|err| SinkError::Internal(err.into())) - } -} - -struct SinkWriterStreamJniReceiver { - response_stream: Peekable>>, -} - -impl SinkWriterStreamJniReceiver { - async fn next_commit_response(&mut self) -> Result { - match self.response_stream.try_next().await { - Ok(Some(SinkWriterStreamResponse { - response: Some(sink_writer_stream_response::Response::Commit(rsp)), - })) => Ok(rsp), - msg => Err(SinkError::Internal(anyhow!( - "should get Sync response but get {:?}", - msg - ))), - } - } -} - -const DEFAULT_CHANNEL_SIZE: usize = 16; -struct SinkWriterStreamJniHandle { - request_tx: SinkWriterStreamJniSender, - response_rx: SinkWriterStreamJniReceiver, -} - -impl std::fmt::Debug for SinkWriterStreamJniHandle { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SinkWriterStreamJniHandle").finish() - } -} - -impl SinkWriterStreamJniHandle { - async fn start_epoch(&mut self, epoch: u64) -> Result<()> { - await_future_with_monitor_receiver_err( - &mut self.response_rx, - self.request_tx.start_epoch(epoch), - ) - .await - } - - async fn write_batch(&mut self, epoch: u64, batch_id: u64, payload: Payload) -> Result<()> { - await_future_with_monitor_receiver_err( - &mut self.response_rx, - self.request_tx.write_batch(epoch, batch_id, payload), - ) - .await - } - - async fn barrier(&mut self, epoch: u64) -> Result<()> { - await_future_with_monitor_receiver_err( - &mut self.response_rx, - self.request_tx.barrier(epoch, false), - ) - .await - } - - async fn commit(&mut self, epoch: u64) -> Result { - await_future_with_monitor_receiver_err( - &mut self.response_rx, - self.request_tx.barrier(epoch, true), - ) - .await?; - self.response_rx.next_commit_response().await - } -} - pub type RemoteSinkWriter = RemoteSinkWriterInner<(), R>; pub type CoordinatedRemoteSinkWriter = RemoteSinkWriterInner, R>; @@ -535,7 +388,7 @@ pub struct RemoteSinkWriterInner { epoch: Option, batch_id: u64, payload_format: SinkPayloadFormat, - stream_handle: SinkWriterStreamJniHandle, + stream_handle: SinkWriterStreamHandle, json_encoder: JsonEncoder, sink_metrics: SinkMetrics, _phantom: PhantomData<(SM, R)>, @@ -547,84 +400,12 @@ impl RemoteSinkWriterInner { connector_params: ConnectorParams, sink_metrics: SinkMetrics, ) -> Result { - let (request_tx, request_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - let (response_tx, response_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - - let mut response_stream = ReceiverStream::new(response_rx).peekable(); - - std::thread::spawn(move || { - let mut env = JVM.get_or_init().unwrap().attach_current_thread().unwrap(); - - let result = env.call_static_method( - "com/risingwave/connector/JniSinkWriterHandler", - "runJniSinkWriterThread", - "(JJ)V", - &[ - JValue::from(&request_rx as *const Receiver as i64), - JValue::from( - &response_tx as *const Sender> - as i64, - ), - ], - ); - - match result { - Ok(_) => { - tracing::info!("end of jni call runJniSinkWriterThread"); - } - Err(e) => { - tracing::error!("jni call error: {:?}", e); - } - }; - }); - - let sink_writer_stream_request = SinkWriterStreamRequest { - request: Some(SinkRequest::Start(StartSink { - sink_param: Some(param.to_proto()), - format: connector_params.sink_payload_format as i32, - })), - }; - - // First request - request_tx - .send(sink_writer_stream_request) - .await - .map_err(|err| { - SinkError::Internal(anyhow!( - "fail to send start request for connector `{}`: {:?}", - R::SINK_NAME, - err - )) - })?; - - // First response - match response_stream.try_next().await { - Ok(Some(SinkWriterStreamResponse { - response: Some(sink_writer_stream_response::Response::Start(_)), - })) => {} - Ok(msg) => { - return Err(SinkError::Internal(anyhow!( - "should get start response for connector `{}` but get {:?}", - R::SINK_NAME, - msg - ))); - } - Err(e) => return Err(SinkError::Internal(e)), - }; - - tracing::trace!( - "{:?} sink stream started with properties: {:?}", - R::SINK_NAME, - ¶m.properties - ); + let stream_handle = EmbeddedConnectorClient::new()? + .start_sink_writer_stream(param.clone(), connector_params.sink_payload_format) + .await?; let schema = param.schema(); - let stream_handle = SinkWriterStreamJniHandle { - request_tx: SinkWriterStreamJniSender { request_tx }, - response_rx: SinkWriterStreamJniReceiver { response_stream }, - }; - Ok(Self { properties: param.properties, epoch: None, @@ -637,7 +418,6 @@ impl RemoteSinkWriterInner { }) } - #[cfg(test)] fn for_test( response_receiver: Receiver>, request_sender: Sender, @@ -660,14 +440,12 @@ impl RemoteSinkWriterInner { }, ]); - let stream_handle = SinkWriterStreamJniHandle { - request_tx: SinkWriterStreamJniSender { - request_tx: request_sender, - }, - response_rx: SinkWriterStreamJniReceiver { - response_stream: ReceiverStream::new(response_receiver).peekable(), - }, - }; + let stream_handle = SinkWriterStreamHandle::for_test( + request_sender, + ReceiverStream::new(response_receiver) + .map_err(RpcError::from) + .boxed(), + ); RemoteSinkWriter { properties, @@ -795,81 +573,15 @@ where } pub struct RemoteCoordinator { - stream_handle: SinkCoordinatorStreamJniHandle, + stream_handle: SinkCoordinatorStreamHandle, _phantom: PhantomData, } impl RemoteCoordinator { pub async fn new(param: SinkParam) -> Result { - let (request_tx, request_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - let (response_tx, response_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - - let mut stream_handle = SinkCoordinatorStreamJniHandle { - request_tx, - response_rx, - }; - - std::thread::spawn(move || { - let mut env = JVM.get_or_init().unwrap().attach_current_thread().unwrap(); - - let result = env.call_static_method( - "com/risingwave/connector/JniSinkCoordinatorHandler", - "runJniSinkCoordinatorThread", - "(JJ)V", - &[ - JValue::from( - &request_rx as *const Receiver as i64, - ), - JValue::from( - &response_tx as *const Sender as i64, - ), - ], - ); - - match result { - Ok(_) => { - tracing::info!("end of jni call runJniSinkCoordinatorThread"); - } - Err(e) => { - tracing::error!("jni call error: {:?}", e); - } - }; - }); - - let sink_coordinator_stream_request = SinkCoordinatorStreamRequest { - request: Some(sink_coordinator_stream_request::Request::Start( - StartCoordinator { - param: Some(param.to_proto()), - }, - )), - }; - - // First request - stream_handle - .request_tx - .send(sink_coordinator_stream_request) - .await - .map_err(|err| { - SinkError::Internal(anyhow!( - "fail to send start request for connector `{}`: {:?}", - R::SINK_NAME, - err - )) - })?; - - // First response - match stream_handle.response_rx.recv().await { - Some(SinkCoordinatorStreamResponse { - response: Some(sink_coordinator_stream_response::Response::Start(_)), - }) => {} - msg => { - return Err(SinkError::Internal(anyhow!( - "should get start response for connector `{}` but get {:?}", - R::SINK_NAME, - msg - ))); - } - }; + let stream_handle = EmbeddedConnectorClient::new()? + .start_sink_coordinator_stream(param.clone()) + .await?; tracing::trace!( "{:?} RemoteCoordinator started with properties: {:?}", @@ -895,6 +607,120 @@ impl SinkCommitCoordinator for RemoteCoordinator { } } +struct EmbeddedConnectorClient { + jvm: &'static JavaVM, +} + +impl EmbeddedConnectorClient { + fn new() -> Result { + let jvm = JVM.get().map_err(|e| anyhow!("cannot get jvm: {:?}", e))?; + Ok(EmbeddedConnectorClient { jvm }) + } + + async fn start_sink_writer_stream( + &self, + sink_param: SinkParam, + sink_payload_format: SinkPayloadFormat, + ) -> Result { + let (handle, first_rsp) = SinkWriterStreamHandle::initialize( + SinkWriterStreamRequest { + request: Some(SinkRequest::Start(StartSink { + sink_param: Some(sink_param.to_proto()), + format: sink_payload_format as i32, + })), + }, + |rx| async move { + let rx = self.start_jvm_worker_thread( + gen_class_name!(com.risingwave.connector.JniSinkWriterHandler), + "runJniSinkWriterThread", + rx, + ); + Ok(ReceiverStream::new(rx).map_err(RpcError::from)) + }, + ) + .await?; + + match first_rsp { + SinkWriterStreamResponse { + response: Some(sink_writer_stream_response::Response::Start(_)), + } => Ok(handle), + msg => Err(SinkError::Internal(anyhow!( + "should get start response but get {:?}", + msg + ))), + } + } + + pub async fn start_sink_coordinator_stream( + &self, + param: SinkParam, + ) -> Result { + let (handle, first_rsp) = SinkCoordinatorStreamHandle::initialize( + SinkCoordinatorStreamRequest { + request: Some(sink_coordinator_stream_request::Request::Start( + StartCoordinator { + param: Some(param.to_proto()), + }, + )), + }, + |rx| async move { + let rx = self.start_jvm_worker_thread( + gen_class_name!(com.risingwave.connector.JniSinkCoordinatorHandler), + "runJniSinkCoordinatorThread", + rx, + ); + Ok(ReceiverStream::new(rx).map_err(RpcError::from)) + }, + ) + .await?; + + match first_rsp { + SinkCoordinatorStreamResponse { + response: Some(sink_coordinator_stream_response::Response::Start(_)), + } => Ok(handle), + msg => Err(SinkError::Internal(anyhow!( + "should get start response but get {:?}", + msg + ))), + } + } + + fn start_jvm_worker_thread( + &self, + class_name: &'static str, + method_name: &'static str, + request_rx: JniReceiverType, + ) -> Receiver> { + let (response_tx, response_rx): (JniSenderType, _) = + mpsc::channel(DEFAULT_BUFFER_SIZE); + + let jvm = self.jvm; + std::thread::spawn(move || { + let mut env = jvm.attach_current_thread().unwrap(); + + let result = env.call_static_method( + class_name, + method_name, + gen_jni_sig!(void f(long, long)), + &[ + JValue::from(&request_rx as *const JniReceiverType as i64), + JValue::from(&response_tx as *const JniSenderType as i64), + ], + ); + + match result { + Ok(_) => { + tracing::info!("end of jni call {}::{}", class_name, method_name); + } + Err(e) => { + tracing::error!("jni call error: {:?}", e); + } + }; + }); + response_rx + } +} + #[cfg(test)] mod test { use std::time::Duration; diff --git a/src/connector/src/source/cdc/enumerator/mod.rs b/src/connector/src/source/cdc/enumerator/mod.rs index e88440bc876e1..803e092ab90a3 100644 --- a/src/connector/src/source/cdc/enumerator/mod.rs +++ b/src/connector/src/source/cdc/enumerator/mod.rs @@ -69,7 +69,7 @@ where SourceType::from(T::source_type()) ); - let mut env = JVM.get_or_init()?.attach_current_thread()?; + let mut env = JVM.get_with_err()?.attach_current_thread()?; let validate_source_request = ValidateSourceRequest { source_id: context.info.source_id as u64, diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 4d25d82c106c3..d4382b67739e0 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -123,8 +123,9 @@ impl CommonSplitReader for CdcSplitReader { let (tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - // Force init, because we don't want to see initialization failure in the following thread. - JVM.get_or_init()?; + let jvm = JVM + .get() + .map_err(|e| anyhow!("jvm not initialized properly: {:?}", e))?; let get_event_stream_request = GetEventStreamRequest { source_id: self.source_id, @@ -138,7 +139,7 @@ impl CommonSplitReader for CdcSplitReader { let source_type = get_event_stream_request.source_type.to_string(); std::thread::spawn(move || { - let mut env = JVM.get_or_init().unwrap().attach_current_thread().unwrap(); + let mut env = jvm.attach_current_thread().unwrap(); let get_event_stream_request_bytes = env .byte_array_from_slice(&Message::encode_to_vec(&get_event_stream_request)) diff --git a/src/jni_core/src/jvm_runtime.rs b/src/jni_core/src/jvm_runtime.rs index bd1f068b6eaee..8a51c862506ed 100644 --- a/src/jni_core/src/jvm_runtime.rs +++ b/src/jni_core/src/jvm_runtime.rs @@ -18,50 +18,49 @@ use std::fs; use std::path::Path; use std::sync::OnceLock; +use anyhow::anyhow; use jni::objects::JValueOwned; use jni::strings::JNIString; use jni::{InitArgsBuilder, JNIVersion, JavaVM, NativeMethod}; -use risingwave_common::error::{ErrorCode, RwError}; use risingwave_common::util::resource_util::memory::system_memory_available_bytes; +use tracing::error; /// Use 10% of compute total memory by default. Compute node uses 0.7 * system memory by default. const DEFAULT_MEMORY_PROPORTION: f64 = 0.07; -pub static JVM: JavaVmWrapper = JavaVmWrapper::new(); +pub static JVM: JavaVmWrapper = JavaVmWrapper; -pub struct JavaVmWrapper(OnceLock>); +pub struct JavaVmWrapper; impl JavaVmWrapper { - const fn new() -> Self { - Self(OnceLock::new()) + pub fn get(&self) -> Result<&'static JavaVM, &String> { + static JVM_RESULT: OnceLock> = OnceLock::new(); + JVM_RESULT + .get_or_init(|| { + Self::inner_new().inspect_err(|e| error!("failed to init jvm: {:?}", e)) + }) + .as_ref() } - pub fn get(&self) -> Option<&Result> { - self.0.get() + pub fn get_with_err(&self) -> anyhow::Result<&'static JavaVM> { + self.get() + .map_err(|e| anyhow!("jvm not initialized properly: {:?}", e)) } - pub fn get_or_init(&self) -> Result<&JavaVM, &RwError> { - self.0.get_or_init(Self::inner_new).as_ref() - } - - fn inner_new() -> Result { + fn inner_new() -> Result { let libs_path = if let Ok(libs_path) = std::env::var("CONNECTOR_LIBS_PATH") { libs_path } else { - return Err(ErrorCode::InternalError( - "environment variable CONNECTOR_LIBS_PATH is not specified".to_string(), - ) - .into()); + return Err("environment variable CONNECTOR_LIBS_PATH is not specified".to_string()); }; let dir = Path::new(&libs_path); if !dir.is_dir() { - return Err(ErrorCode::InternalError(format!( + return Err(format!( "CONNECTOR_LIBS_PATH \"{}\" is not a directory", libs_path - )) - .into()); + )); } let mut class_vec = vec![]; @@ -70,16 +69,16 @@ impl JavaVmWrapper { for entry in entries.flatten() { let entry_path = entry.path(); if entry_path.file_name().is_some() { - let path = std::fs::canonicalize(entry_path)?; + let path = std::fs::canonicalize(entry_path) + .expect("valid entry_path obtained from fs::read_dir"); class_vec.push(path.to_str().unwrap().to_string()); } } } else { - return Err(ErrorCode::InternalError(format!( + return Err(format!( "failed to read CONNECTOR_LIBS_PATH \"{}\"", libs_path - )) - .into()); + )); } let jvm_heap_size = if let Ok(heap_size) = std::env::var("JVM_HEAP_SIZE") { @@ -101,20 +100,23 @@ impl JavaVmWrapper { .option(format!("-Xmx{}", jvm_heap_size)); tracing::info!("JVM args: {:?}", args_builder); - let jvm_args = args_builder.build().unwrap(); + let jvm_args = args_builder + .build() + .map_err(|e| format!("invalid jvm args: {:?}", e))?; // Create a new VM let jvm = match JavaVM::new(jvm_args) { Err(err) => { tracing::error!("fail to new JVM {:?}", err); - return Err(ErrorCode::InternalError("fail to new JVM".to_string()).into()); + return Err("fail to new JVM".to_string()); } Ok(jvm) => jvm, }; tracing::info!("initialize JVM successfully"); - register_native_method_for_jvm(&jvm).unwrap(); + register_native_method_for_jvm(&jvm) + .map_err(|e| format!("failed to register native method: {:?}", e))?; Ok(jvm) } @@ -160,40 +162,35 @@ pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::E /// Load JVM memory statistics from the runtime. If JVM is not initialized or fail to initialize, return zero. pub fn load_jvm_memory_stats() -> (usize, usize) { - if let Some(jvm) = JVM.get() { - match jvm { - Ok(jvm) => { - let mut env = jvm.attach_current_thread().unwrap(); - let runtime_instance = env - .call_static_method( - "java/lang/Runtime", - "getRuntime", - "()Ljava/lang/Runtime;", - &[], - ) - .unwrap(); - - let runtime_instance = match runtime_instance { - JValueOwned::Object(o) => o, - _ => unreachable!(), - }; - - let total_memory = env - .call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[]) - .unwrap() - .j() - .unwrap(); - - let free_memory = env - .call_method(runtime_instance, "freeMemory", "()J", &[]) - .unwrap() - .j() - .unwrap(); - - (total_memory as usize, (total_memory - free_memory) as usize) - } - Err(_) => (0, 0), - } + if let Ok(jvm) = JVM.get() { + let mut env = jvm.attach_current_thread().unwrap(); + let runtime_instance = env + .call_static_method( + "java/lang/Runtime", + "getRuntime", + "()Ljava/lang/Runtime;", + &[], + ) + .unwrap(); + + let runtime_instance = match runtime_instance { + JValueOwned::Object(o) => o, + _ => unreachable!(), + }; + + let total_memory = env + .call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[]) + .unwrap() + .j() + .unwrap(); + + let free_memory = env + .call_method(runtime_instance, "freeMemory", "()J", &[]) + .unwrap() + .j() + .unwrap(); + + (total_memory as usize, (total_memory - free_memory) as usize) } else { (0, 0) } diff --git a/src/jni_core/src/lib.rs b/src/jni_core/src/lib.rs index c92ad2f146e6c..414d8348760c1 100644 --- a/src/jni_core/src/lib.rs +++ b/src/jni_core/src/lib.rs @@ -879,12 +879,15 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToCh }) } +pub type JniSenderType = Sender>; +pub type JniReceiverType = Receiver; + #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkWriterRequestFromChannel< 'a, >( env: EnvParam<'a>, - mut channel: Pointer<'a, Receiver>, + mut channel: Pointer<'a, JniReceiverType>, ) -> JByteArray<'a> { execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() { Some(msg) => { @@ -902,7 +905,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkWriterRe 'a, >( env: EnvParam<'a>, - channel: Pointer<'a, Sender>>, + channel: Pointer<'a, JniSenderType>, msg: JByteArray<'a>, ) -> jboolean { execute_and_catch(env, move |env| { @@ -927,7 +930,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkCoordina 'a, >( env: EnvParam<'a>, - mut channel: Pointer<'a, Receiver>, + mut channel: Pointer<'a, JniReceiverType>, ) -> JByteArray<'a> { execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() { Some(msg) => { @@ -945,7 +948,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkCoordina 'a, >( env: EnvParam<'a>, - channel: Pointer<'a, Sender>, + channel: Pointer<'a, JniSenderType>, msg: JByteArray<'a>, ) -> jboolean { execute_and_catch(env, move |env| { @@ -954,7 +957,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkCoordina match channel .as_ref() - .blocking_send(sink_coordinator_stream_response) + .blocking_send(Ok(sink_coordinator_stream_response)) { Ok(_) => Ok(JNI_TRUE), Err(e) => { diff --git a/src/jni_core/src/macros.rs b/src/jni_core/src/macros.rs index 0ca0748fb0206..2484fbba9488c 100644 --- a/src/jni_core/src/macros.rs +++ b/src/jni_core/src/macros.rs @@ -18,20 +18,20 @@ macro_rules! gen_class_name { stringify! {$last} }; ($first:ident . $($rest:ident).+) => { - concat! {stringify! {$first}, "/", gen_class_name! {$($rest).+} } + concat! {stringify! {$first}, "/", $crate::gen_class_name! {$($rest).+} } } } #[macro_export] macro_rules! gen_jni_sig_inner { ($(public)? static native $($rest:tt)*) => { - gen_jni_sig_inner! { $($rest)* } + $crate::gen_jni_sig_inner! { $($rest)* } }; ($($ret:ident).+ $($func_name:ident)? ($($args:tt)*)) => { - concat! {"(", gen_jni_sig_inner!{$($args)*}, ")", gen_jni_sig_inner! {$($ret).+} } + concat! {"(", $crate::gen_jni_sig_inner!{$($args)*}, ")", $crate::gen_jni_sig_inner! {$($ret).+} } }; ($($ret:ident).+ [] $($func_name:ident)? ($($args:tt)*)) => { - concat! {"(", gen_jni_sig_inner!{$($args)*}, ")", gen_jni_sig_inner! {$($ret).+ []} } + concat! {"(", $crate::gen_jni_sig_inner!{$($args)*}, ")", $crate::gen_jni_sig_inner! {$($ret).+ []} } }; (boolean) => { "Z" @@ -61,25 +61,25 @@ macro_rules! gen_jni_sig_inner { "V" }; (String) => { - gen_jni_sig_inner! { java.lang.String } + $crate::gen_jni_sig_inner! { java.lang.String } }; (Object) => { - gen_jni_sig_inner! { java.lang.Object } + $crate::gen_jni_sig_inner! { java.lang.Object } }; (Class) => { - gen_jni_sig_inner! { java.lang.Class } + $crate::gen_jni_sig_inner! { java.lang.Class } }; ($($class_part:ident).+) => { - concat! {"L", gen_class_name! {$($class_part).+}, ";"} + concat! {"L", $crate::gen_class_name! {$($class_part).+}, ";"} }; ($($class_part:ident).+ $(.)? [] $($param_name:ident)? $(,$($rest:tt)*)?) => { - concat! { "[", gen_jni_sig_inner! {$($class_part).+}, gen_jni_sig_inner! {$($($rest)*)?}} + concat! { "[", $crate::gen_jni_sig_inner! {$($class_part).+}, $crate::gen_jni_sig_inner! {$($($rest)*)?}} }; (Class $(< ? >)? $($param_name:ident)? $(,$($rest:tt)*)?) => { - concat! { gen_jni_sig_inner! { Class }, gen_jni_sig_inner! {$($($rest)*)?}} + concat! { $crate::gen_jni_sig_inner! { Class }, $crate::gen_jni_sig_inner! {$($($rest)*)?}} }; ($($class_part:ident).+ $($param_name:ident)? $(,$($rest:tt)*)?) => { - concat! { gen_jni_sig_inner! {$($class_part).+}, gen_jni_sig_inner! {$($($rest)*)?}} + concat! { $crate::gen_jni_sig_inner! {$($class_part).+}, $crate::gen_jni_sig_inner! {$($($rest)*)?}} }; () => { "" @@ -93,7 +93,7 @@ macro_rules! gen_jni_sig_inner { macro_rules! gen_jni_sig { ($($input:tt)*) => {{ // this macro only provide with a expression context - gen_jni_sig_inner! {$($input)*} + $crate::gen_jni_sig_inner! {$($input)*} }} } diff --git a/src/meta/src/manager/sink_coordination/manager.rs b/src/meta/src/manager/sink_coordination/manager.rs index 720a698fa8e72..4c2b2a37295b1 100644 --- a/src/meta/src/manager/sink_coordination/manager.rs +++ b/src/meta/src/manager/sink_coordination/manager.rs @@ -362,6 +362,7 @@ mod tests { use risingwave_pb::connector_service::sink_metadata::{Metadata, SerializedMetadata}; use risingwave_pb::connector_service::SinkMetadata; use risingwave_rpc_client::CoordinatorStreamHandle; + use tokio_stream::wrappers::ReceiverStream; use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker; use crate::manager::sink_coordination::{NewSinkWriterRequest, SinkCoordinatorManager}; @@ -481,19 +482,15 @@ mod tests { }); let build_client = |vnode| async { - CoordinatorStreamHandle::new_with_init_stream( - param.to_proto(), - vnode, - |stream_req| async { - Ok(tonic::Response::new( - manager - .handle_new_request(stream_req.into_inner().map(Ok).boxed()) - .await - .unwrap() - .boxed(), - )) - }, - ) + CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async { + Ok(tonic::Response::new( + manager + .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed()) + .await + .unwrap() + .boxed(), + )) + }) .await .unwrap() }; @@ -647,19 +644,15 @@ mod tests { }); let build_client = |vnode| async { - CoordinatorStreamHandle::new_with_init_stream( - param.to_proto(), - vnode, - |stream_req| async { - Ok(tonic::Response::new( - manager - .handle_new_request(stream_req.into_inner().map(Ok).boxed()) - .await - .unwrap() - .boxed(), - )) - }, - ) + CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async { + Ok(tonic::Response::new( + manager + .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed()) + .await + .unwrap() + .boxed(), + )) + }) .await .unwrap() }; @@ -710,10 +703,10 @@ mod tests { let mut build_client_future1 = pin!(CoordinatorStreamHandle::new_with_init_stream( param.to_proto(), Bitmap::zeros(VirtualNode::COUNT), - |stream_req| async { + |rx| async { Ok(tonic::Response::new( manager - .handle_new_request(stream_req.into_inner().map(Ok).boxed()) + .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed()) .await .unwrap() .boxed(), @@ -778,19 +771,15 @@ mod tests { }); let build_client = |vnode| async { - CoordinatorStreamHandle::new_with_init_stream( - param.to_proto(), - vnode, - |stream_req| async { - Ok(tonic::Response::new( - manager - .handle_new_request(stream_req.into_inner().map(Ok).boxed()) - .await - .unwrap() - .boxed(), - )) - }, - ) + CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async { + Ok(tonic::Response::new( + manager + .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed()) + .await + .unwrap() + .boxed(), + )) + }) .await .unwrap() }; @@ -863,19 +852,15 @@ mod tests { }); let build_client = |vnode| async { - CoordinatorStreamHandle::new_with_init_stream( - param.to_proto(), - vnode, - |stream_req| async { - Ok(tonic::Response::new( - manager - .handle_new_request(stream_req.into_inner().map(Ok).boxed()) - .await - .unwrap() - .boxed(), - )) - }, - ) + CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async { + Ok(tonic::Response::new( + manager + .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed()) + .await + .unwrap() + .boxed(), + )) + }) .await .unwrap() }; diff --git a/src/rpc_client/src/connector_client.rs b/src/rpc_client/src/connector_client.rs index 12386628e5516..4a4036ff4c874 100644 --- a/src/rpc_client/src/connector_client.rs +++ b/src/rpc_client/src/connector_client.rs @@ -17,6 +17,7 @@ use std::fmt::Debug; use std::time::Duration; use anyhow::anyhow; +use futures::TryStreamExt; use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, STREAM_WINDOW_SIZE}; use risingwave_common::monitor::connection::{EndpointExt, TcpConfig}; use risingwave_pb::connector_service::connector_service_client::ConnectorServiceClient; @@ -29,6 +30,7 @@ use risingwave_pb::connector_service::sink_writer_stream_request::{ }; use risingwave_pb::connector_service::sink_writer_stream_response::CommitResponse; use risingwave_pb::connector_service::*; +use tokio_stream::wrappers::ReceiverStream; use tonic::transport::{Channel, Endpoint}; use tonic::Streaming; use tracing::error; @@ -262,7 +264,13 @@ impl ConnectorClient { format: sink_payload_format as i32, })), }, - |req_stream| async move { rpc_client.sink_writer_stream(req_stream).await }, + |rx| async move { + rpc_client + .sink_writer_stream(ReceiverStream::new(rx)) + .await + .map(|response| response.into_inner().map_err(RpcError::from)) + .map_err(RpcError::from) + }, ) .await?; @@ -288,7 +296,13 @@ impl ConnectorClient { StartCoordinator { param: Some(param) }, )), }, - |req_stream| async move { rpc_client.sink_coordinator_stream(req_stream).await }, + |rx| async move { + rpc_client + .sink_coordinator_stream(ReceiverStream::new(rx)) + .await + .map(|response| response.into_inner().map_err(RpcError::from)) + .map_err(RpcError::from) + }, ) .await?; diff --git a/src/rpc_client/src/lib.rs b/src/rpc_client/src/lib.rs index 6afa67ef88efe..f7a90a36e125c 100644 --- a/src/rpc_client/src/lib.rs +++ b/src/rpc_client/src/lib.rs @@ -42,9 +42,7 @@ use rand::prelude::SliceRandom; use risingwave_common::util::addr::HostAddr; use risingwave_pb::common::WorkerNode; use risingwave_pb::meta::heartbeat_request::extra_info; -use tokio::sync::mpsc::{channel, Sender}; -use tokio_stream::wrappers::ReceiverStream; -use tonic::{Request, Response, Status}; +use tokio::sync::mpsc::{channel, Receiver, Sender}; pub mod error; use error::{Result, RpcError}; @@ -171,83 +169,106 @@ macro_rules! meta_rpc_client_method_impl { } } -pub struct BidiStreamHandle { - request_sender: Sender, - response_stream: Peekable>>, +pub const DEFAULT_BUFFER_SIZE: usize = 16; + +pub struct BidiStreamSender { + tx: Sender, +} + +impl BidiStreamSender { + pub async fn send_request(&mut self, request: REQ) -> Result<()> { + self.tx + .send(request) + .await + .map_err(|_| anyhow!("unable to send request {}", type_name::()).into()) + } +} + +pub struct BidiStreamReceiver { + pub stream: Peekable>>, +} + +impl BidiStreamReceiver { + pub async fn next_response(&mut self) -> Result { + self.stream + .next() + .await + .ok_or_else(|| anyhow!("end of response stream"))? + } +} + +pub struct BidiStreamHandle { + pub request_sender: BidiStreamSender, + pub response_stream: BidiStreamReceiver, } -impl Debug for BidiStreamHandle { +impl Debug for BidiStreamHandle { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str(type_name::()) } } -impl BidiStreamHandle { +impl BidiStreamHandle { pub fn for_test( request_sender: Sender, - response_stream: BoxStream<'static, std::result::Result>, + response_stream: BoxStream<'static, Result>, ) -> Self { Self { - request_sender, - response_stream: response_stream.peekable(), + request_sender: BidiStreamSender { tx: request_sender }, + response_stream: BidiStreamReceiver { + stream: response_stream.peekable(), + }, } } pub async fn initialize< - F: FnOnce(Request>) -> Fut, - St: Stream> + Send + Unpin + 'static, - Fut: Future, Status>> + Send, + F: FnOnce(Receiver) -> Fut, + St: Stream> + Send + Unpin + 'static, + Fut: Future> + Send, >( first_request: REQ, init_stream_fn: F, ) -> Result<(Self, RSP)> { - const SINK_WRITER_REQUEST_BUFFER_SIZE: usize = 16; - let (request_sender, request_receiver) = channel(SINK_WRITER_REQUEST_BUFFER_SIZE); + let (request_sender, request_receiver) = channel(DEFAULT_BUFFER_SIZE); // Send initial request in case of the blocking receive call from creating streaming request request_sender .send(first_request) .await - .map_err(|err| anyhow!(err.to_string()))?; + .map_err(|_err| anyhow!("unable to send first request of {}", type_name::()))?; - let mut response_stream = - init_stream_fn(Request::new(ReceiverStream::new(request_receiver))) - .await? - .into_inner(); + let mut response_stream = init_stream_fn(request_receiver).await?; let first_response = response_stream .next() .await - .ok_or_else(|| anyhow!("get empty response from start sink request"))??; + .ok_or_else(|| anyhow!("get empty response from firstrequest"))??; Ok(( Self { - request_sender, - response_stream: response_stream.boxed().peekable(), + request_sender: BidiStreamSender { tx: request_sender }, + response_stream: BidiStreamReceiver { + stream: response_stream.boxed().peekable(), + }, }, first_response, )) } pub async fn next_response(&mut self) -> Result { - Ok(self - .response_stream - .next() - .await - .ok_or_else(|| anyhow!("end of response stream"))??) + self.response_stream.next_response().await } pub async fn send_request(&mut self, request: REQ) -> Result<()> { match await_future_with_monitor_error_stream( - &mut self.response_stream, - self.request_sender.send(request), + &mut self.response_stream.stream, + self.request_sender.send_request(request), ) .await { - Ok(send_result) => send_result - .map_err(|_| anyhow!("unable to send request {}", type_name::()).into()), + Ok(send_result) => send_result, Err(None) => Err(anyhow!("end of response stream").into()), - Err(Some(e)) => Err(e.into()), + Err(Some(e)) => Err(e), } } } diff --git a/src/rpc_client/src/sink_coordinate_client.rs b/src/rpc_client/src/sink_coordinate_client.rs index 2afa878e2cd34..0eb8d97e5aa9f 100644 --- a/src/rpc_client/src/sink_coordinate_client.rs +++ b/src/rpc_client/src/sink_coordinate_client.rs @@ -15,7 +15,7 @@ use std::future::Future; use anyhow::anyhow; -use futures::Stream; +use futures::{Stream, TryStreamExt}; use risingwave_common::buffer::Bitmap; use risingwave_pb::connector_service::coordinate_request::{ CommitRequest, StartCoordinationRequest, @@ -24,9 +24,11 @@ use risingwave_pb::connector_service::{ coordinate_request, coordinate_response, CoordinateRequest, CoordinateResponse, PbSinkParam, SinkMetadata, }; +use tokio::sync::mpsc::Receiver; use tokio_stream::wrappers::ReceiverStream; -use tonic::{Request, Response, Status}; +use tonic::{Response, Status}; +use crate::error::RpcError; use crate::{BidiStreamHandle, SinkCoordinationRpcClient}; pub type CoordinatorStreamHandle = BidiStreamHandle; @@ -36,9 +38,9 @@ impl CoordinatorStreamHandle { mut client: SinkCoordinationRpcClient, param: PbSinkParam, vnode_bitmap: Bitmap, - ) -> anyhow::Result { - Self::new_with_init_stream(param, vnode_bitmap, |req_stream| async move { - client.coordinate(req_stream).await + ) -> Result { + Self::new_with_init_stream(param, vnode_bitmap, |rx| async move { + client.coordinate(ReceiverStream::new(rx)).await }) .await } @@ -47,9 +49,9 @@ impl CoordinatorStreamHandle { param: PbSinkParam, vnode_bitmap: Bitmap, init_stream: F, - ) -> anyhow::Result + ) -> Result where - F: FnOnce(Request>) -> Fut, + F: FnOnce(Receiver) -> Fut + Send, St: Stream> + Send + Unpin + 'static, Fut: Future, Status>> + Send, { @@ -62,14 +64,19 @@ impl CoordinatorStreamHandle { }, )), }, - init_stream, + move |rx| async move { + init_stream(rx) + .await + .map(|response| response.into_inner().map_err(RpcError::from)) + .map_err(RpcError::from) + }, ) .await?; match first_response { CoordinateResponse { msg: Some(coordinate_response::Msg::StartResponse(_)), } => Ok(stream_handle), - msg => Err(anyhow!("should get start response but get {:?}", msg)), + msg => Err(anyhow!("should get start response but get {:?}", msg).into()), } } From a7eb77cde98a360da22967ad82454830c029d495 Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 30 Oct 2023 13:35:31 +0800 Subject: [PATCH 2/4] better error handling --- src/connector/src/lib.rs | 1 + src/connector/src/sink/remote.rs | 17 ++++-- src/connector/src/source/cdc/source/reader.rs | 32 +++++++++--- src/jni_core/src/jvm_runtime.rs | 52 ++++++++++--------- src/jni_core/src/lib.rs | 15 +++--- 5 files changed, 74 insertions(+), 43 deletions(-) diff --git a/src/connector/src/lib.rs b/src/connector/src/lib.rs index aa613b4043c23..8aa465b6e29c9 100644 --- a/src/connector/src/lib.rs +++ b/src/connector/src/lib.rs @@ -30,6 +30,7 @@ #![feature(iter_from_coroutine)] #![feature(if_let_guard)] #![feature(iterator_try_collect)] +#![feature(try_blocks)] use std::time::Duration; diff --git a/src/connector/src/sink/remote.rs b/src/connector/src/sink/remote.rs index a3e06b8a22cc3..20b4c6005d64b 100644 --- a/src/connector/src/sink/remote.rs +++ b/src/connector/src/sink/remote.rs @@ -689,22 +689,29 @@ impl EmbeddedConnectorClient { &self, class_name: &'static str, method_name: &'static str, - request_rx: JniReceiverType, + mut request_rx: JniReceiverType, ) -> Receiver> { - let (response_tx, response_rx): (JniSenderType, _) = + let (mut response_tx, response_rx): (JniSenderType, _) = mpsc::channel(DEFAULT_BUFFER_SIZE); let jvm = self.jvm; std::thread::spawn(move || { - let mut env = jvm.attach_current_thread().unwrap(); + let mut env = match jvm.attach_current_thread() { + Ok(env) => env, + Err(e) => { + let _ = response_tx + .blocking_send(Err(anyhow!("failed to attach current thread: {:?}", e))); + return; + } + }; let result = env.call_static_method( class_name, method_name, gen_jni_sig!(void f(long, long)), &[ - JValue::from(&request_rx as *const JniReceiverType as i64), - JValue::from(&response_tx as *const JniSenderType as i64), + JValue::from(&mut request_rx as *mut JniReceiverType as i64), + JValue::from(&mut response_tx as *mut JniSenderType as i64), ], ); diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index d4382b67739e0..1f10f71afc249 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -22,7 +22,7 @@ use jni::objects::JValue; use prost::Message; use risingwave_common::util::addr::HostAddr; use risingwave_jni_core::jvm_runtime::JVM; -use risingwave_jni_core::GetEventStreamJniSender; +use risingwave_jni_core::JniSenderType; use risingwave_pb::connector_service::{GetEventStreamRequest, GetEventStreamResponse}; use tokio::sync::mpsc; @@ -121,7 +121,7 @@ impl CommonSplitReader for CdcSplitReader { properties.insert("table.name".into(), table_name); } - let (tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); + let (mut tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); let jvm = JVM .get() @@ -139,18 +139,33 @@ impl CommonSplitReader for CdcSplitReader { let source_type = get_event_stream_request.source_type.to_string(); std::thread::spawn(move || { - let mut env = jvm.attach_current_thread().unwrap(); + let result: anyhow::Result<_> = try { + let env = jvm.attach_current_thread()?; + + let get_event_stream_request_bytes = + env.byte_array_from_slice(&Message::encode_to_vec(&get_event_stream_request))?; + + (env, get_event_stream_request_bytes) + }; + + let (mut env, get_event_stream_request_bytes) = match result { + Ok(inner) => inner, + Err(e) => { + let _ = tx.blocking_send(Err(anyhow!( + "err before calling runJniDbzSourceThread: {:?}", + e + ))); + return; + } + }; - let get_event_stream_request_bytes = env - .byte_array_from_slice(&Message::encode_to_vec(&get_event_stream_request)) - .unwrap(); let result = env.call_static_method( "com/risingwave/connector/source/core/JniDbzSourceHandler", "runJniDbzSourceThread", "([BJ)V", &[ JValue::Object(&get_event_stream_request_bytes), - JValue::from(&tx as *const GetEventStreamJniSender as i64), + JValue::from(&mut tx as *mut JniSenderType as i64), ], ); @@ -164,7 +179,8 @@ impl CommonSplitReader for CdcSplitReader { } }); - while let Some(GetEventStreamResponse { events, .. }) = rx.recv().await { + while let Some(result) = rx.recv().await { + let GetEventStreamResponse { events, .. } = result?; tracing::trace!("receive events {:?}", events.len()); self.source_ctx .metrics diff --git a/src/jni_core/src/jvm_runtime.rs b/src/jni_core/src/jvm_runtime.rs index 8a51c862506ed..1ed33f181aab6 100644 --- a/src/jni_core/src/jvm_runtime.rs +++ b/src/jni_core/src/jvm_runtime.rs @@ -125,13 +125,11 @@ impl JavaVmWrapper { pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::Error> { let mut env = jvm .attach_current_thread() - .inspect_err(|e| tracing::error!("jvm attach thread error: {:?}", e)) - .unwrap(); + .inspect_err(|e| tracing::error!("jvm attach thread error: {:?}", e))?; let binding_class = env .find_class("com/risingwave/java/binding/Binding") - .inspect_err(|e| tracing::error!("jvm find class error: {:?}", e)) - .unwrap(); + .inspect_err(|e| tracing::error!("jvm find class error: {:?}", e))?; use crate::*; macro_rules! gen_native_method_array { () => {{ @@ -163,34 +161,40 @@ pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::E /// Load JVM memory statistics from the runtime. If JVM is not initialized or fail to initialize, return zero. pub fn load_jvm_memory_stats() -> (usize, usize) { if let Ok(jvm) = JVM.get() { - let mut env = jvm.attach_current_thread().unwrap(); - let runtime_instance = env - .call_static_method( + let result: Result<(usize, usize), jni::errors::Error> = try { + let mut env = jvm.attach_current_thread()?; + + let runtime_instance = env.call_static_method( "java/lang/Runtime", "getRuntime", "()Ljava/lang/Runtime;", &[], - ) - .unwrap(); + )?; - let runtime_instance = match runtime_instance { - JValueOwned::Object(o) => o, - _ => unreachable!(), - }; + let runtime_instance = match runtime_instance { + JValueOwned::Object(o) => o, + _ => unreachable!(), + }; - let total_memory = env - .call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[]) - .unwrap() - .j() - .unwrap(); + let total_memory = env + .call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[])? + .j() + .expect("should be long"); - let free_memory = env - .call_method(runtime_instance, "freeMemory", "()J", &[]) - .unwrap() - .j() - .unwrap(); + let free_memory = env + .call_method(runtime_instance, "freeMemory", "()J", &[])? + .j() + .expect("should be long"); - (total_memory as usize, (total_memory - free_memory) as usize) + (total_memory as usize, (total_memory - free_memory) as usize) + }; + match result { + Ok(ret) => ret, + Err(e) => { + error!("failed to collect jvm stats: {:?}", e); + (0, 0) + } + } } else { (0, 0) } diff --git a/src/jni_core/src/lib.rs b/src/jni_core/src/lib.rs index 414d8348760c1..962eb121baed2 100644 --- a/src/jni_core/src/lib.rs +++ b/src/jni_core/src/lib.rs @@ -17,6 +17,7 @@ #![feature(once_cell_try)] #![feature(type_alias_impl_trait)] #![feature(result_option_inspect)] +#![feature(try_blocks)] pub mod hummock_iterator; pub mod jvm_runtime; @@ -60,7 +61,6 @@ use tokio::sync::mpsc::{Receiver, Sender}; use crate::hummock_iterator::HummockJavaBindingIterator; pub use crate::jvm_runtime::register_native_method_for_jvm; use crate::stream_chunk_iterator::{into_iter, StreamChunkRowIterator}; -pub type GetEventStreamJniSender = Sender; static RUNTIME: LazyLock = LazyLock::new(|| tokio::runtime::Runtime::new().unwrap()); @@ -846,6 +846,9 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetArrayValu }) } +pub type JniSenderType = Sender>; +pub type JniReceiverType = Receiver; + /// Send messages to the channel received by `CdcSplitReader`. /// If msg is null, just check whether the channel is closed. /// Return true if sending is successful, otherwise, return false so that caller can stop @@ -853,7 +856,7 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorGetArrayValu #[no_mangle] extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel<'a>( env: EnvParam<'a>, - channel: Pointer<'a, GetEventStreamJniSender>, + channel: Pointer<'a, JniSenderType>, msg: JByteArray<'a>, ) -> jboolean { execute_and_catch(env, move |env| { @@ -869,7 +872,10 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToCh let get_event_stream_response: GetEventStreamResponse = Message::decode(to_guarded_slice(&msg, env)?.deref())?; - match channel.as_ref().blocking_send(get_event_stream_response) { + match channel + .as_ref() + .blocking_send(Ok(get_event_stream_response)) + { Ok(_) => Ok(JNI_TRUE), Err(e) => { tracing::info!("send error. {:?}", e); @@ -879,9 +885,6 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToCh }) } -pub type JniSenderType = Sender>; -pub type JniReceiverType = Receiver; - #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkWriterRequestFromChannel< 'a, From aa7af6b32768b24c9bb38261b678f4162ec4fb60 Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 30 Oct 2023 17:45:21 +0800 Subject: [PATCH 3/4] do not initialize jvm when collecting metrics --- src/connector/src/sink/remote.rs | 2 +- .../src/source/cdc/enumerator/mod.rs | 2 +- src/jni_core/src/jvm_runtime.rs | 33 +++++++++++-------- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/connector/src/sink/remote.rs b/src/connector/src/sink/remote.rs index 20b4c6005d64b..dcba09007a677 100644 --- a/src/connector/src/sink/remote.rs +++ b/src/connector/src/sink/remote.rs @@ -153,7 +153,7 @@ impl Sink for RemoteSink { }).try_collect()?; let mut env = JVM - .get_with_err()? + .get()? .attach_current_thread() .map_err(|err| SinkError::Internal(err.into()))?; let validate_sink_request = ValidateSinkRequest { diff --git a/src/connector/src/source/cdc/enumerator/mod.rs b/src/connector/src/source/cdc/enumerator/mod.rs index 803e092ab90a3..10213c5e714bc 100644 --- a/src/connector/src/source/cdc/enumerator/mod.rs +++ b/src/connector/src/source/cdc/enumerator/mod.rs @@ -69,7 +69,7 @@ where SourceType::from(T::source_type()) ); - let mut env = JVM.get_with_err()?.attach_current_thread()?; + let mut env = JVM.get()?.attach_current_thread()?; let validate_source_request = ValidateSourceRequest { source_id: context.info.source_id as u64, diff --git a/src/jni_core/src/jvm_runtime.rs b/src/jni_core/src/jvm_runtime.rs index 1ed33f181aab6..ec833b36577cb 100644 --- a/src/jni_core/src/jvm_runtime.rs +++ b/src/jni_core/src/jvm_runtime.rs @@ -29,22 +29,28 @@ use tracing::error; const DEFAULT_MEMORY_PROPORTION: f64 = 0.07; pub static JVM: JavaVmWrapper = JavaVmWrapper; +static JVM_RESULT: OnceLock> = OnceLock::new(); pub struct JavaVmWrapper; impl JavaVmWrapper { - pub fn get(&self) -> Result<&'static JavaVM, &String> { - static JVM_RESULT: OnceLock> = OnceLock::new(); - JVM_RESULT - .get_or_init(|| { - Self::inner_new().inspect_err(|e| error!("failed to init jvm: {:?}", e)) - }) - .as_ref() + pub fn for_initialized_jvm O>(&self, f: F) -> Option { + JVM_RESULT.get().and_then(|result| { + if let Ok(jvm) = result { + Some(f(jvm)) + } else { + None + } + }) } - pub fn get_with_err(&self) -> anyhow::Result<&'static JavaVM> { - self.get() - .map_err(|e| anyhow!("jvm not initialized properly: {:?}", e)) + pub fn get(&self) -> anyhow::Result<&'static JavaVM> { + match JVM_RESULT.get_or_init(|| { + Self::inner_new().inspect_err(|e| error!("failed to init jvm: {:?}", e)) + }) { + Ok(jvm) => Ok(jvm), + Err(e) => Err(anyhow!("jvm not initialized properly: {:?}", e)), + } } fn inner_new() -> Result { @@ -160,7 +166,7 @@ pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::E /// Load JVM memory statistics from the runtime. If JVM is not initialized or fail to initialize, return zero. pub fn load_jvm_memory_stats() -> (usize, usize) { - if let Ok(jvm) = JVM.get() { + JVM.for_initialized_jvm(|jvm| { let result: Result<(usize, usize), jni::errors::Error> = try { let mut env = jvm.attach_current_thread()?; @@ -195,7 +201,6 @@ pub fn load_jvm_memory_stats() -> (usize, usize) { (0, 0) } } - } else { - (0, 0) - } + }) + .unwrap_or((0, 0)) } From 7042a94e8471bb460819d0b2146ab53657986422 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 31 Oct 2023 13:28:15 +0800 Subject: [PATCH 4/4] change name to get_or_init --- src/connector/src/sink/remote.rs | 4 +- .../src/source/cdc/enumerator/mod.rs | 2 +- src/connector/src/source/cdc/source/reader.rs | 2 +- src/jni_core/src/jvm_runtime.rs | 81 ++++++++----------- 4 files changed, 39 insertions(+), 50 deletions(-) diff --git a/src/connector/src/sink/remote.rs b/src/connector/src/sink/remote.rs index dcba09007a677..0b5948646637b 100644 --- a/src/connector/src/sink/remote.rs +++ b/src/connector/src/sink/remote.rs @@ -153,7 +153,7 @@ impl Sink for RemoteSink { }).try_collect()?; let mut env = JVM - .get()? + .get_or_init()? .attach_current_thread() .map_err(|err| SinkError::Internal(err.into()))?; let validate_sink_request = ValidateSinkRequest { @@ -613,7 +613,7 @@ struct EmbeddedConnectorClient { impl EmbeddedConnectorClient { fn new() -> Result { - let jvm = JVM.get().map_err(|e| anyhow!("cannot get jvm: {:?}", e))?; + let jvm = JVM.get_or_init()?; Ok(EmbeddedConnectorClient { jvm }) } diff --git a/src/connector/src/source/cdc/enumerator/mod.rs b/src/connector/src/source/cdc/enumerator/mod.rs index 10213c5e714bc..e88440bc876e1 100644 --- a/src/connector/src/source/cdc/enumerator/mod.rs +++ b/src/connector/src/source/cdc/enumerator/mod.rs @@ -69,7 +69,7 @@ where SourceType::from(T::source_type()) ); - let mut env = JVM.get()?.attach_current_thread()?; + let mut env = JVM.get_or_init()?.attach_current_thread()?; let validate_source_request = ValidateSourceRequest { source_id: context.info.source_id as u64, diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 1f10f71afc249..491ce18cfd493 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -124,7 +124,7 @@ impl CommonSplitReader for CdcSplitReader { let (mut tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); let jvm = JVM - .get() + .get_or_init() .map_err(|e| anyhow!("jvm not initialized properly: {:?}", e))?; let get_event_stream_request = GetEventStreamRequest { diff --git a/src/jni_core/src/jvm_runtime.rs b/src/jni_core/src/jvm_runtime.rs index ec833b36577cb..d3a6077e1079b 100644 --- a/src/jni_core/src/jvm_runtime.rs +++ b/src/jni_core/src/jvm_runtime.rs @@ -19,7 +19,6 @@ use std::path::Path; use std::sync::OnceLock; use anyhow::anyhow; -use jni::objects::JValueOwned; use jni::strings::JNIString; use jni::{InitArgsBuilder, JNIVersion, JavaVM, NativeMethod}; use risingwave_common::util::resource_util::memory::system_memory_available_bytes; @@ -34,17 +33,7 @@ static JVM_RESULT: OnceLock> = OnceLock::new(); pub struct JavaVmWrapper; impl JavaVmWrapper { - pub fn for_initialized_jvm O>(&self, f: F) -> Option { - JVM_RESULT.get().and_then(|result| { - if let Ok(jvm) = result { - Some(f(jvm)) - } else { - None - } - }) - } - - pub fn get(&self) -> anyhow::Result<&'static JavaVM> { + pub fn get_or_init(&self) -> anyhow::Result<&'static JavaVM> { match JVM_RESULT.get_or_init(|| { Self::inner_new().inspect_err(|e| error!("failed to init jvm: {:?}", e)) }) { @@ -166,41 +155,41 @@ pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::E /// Load JVM memory statistics from the runtime. If JVM is not initialized or fail to initialize, return zero. pub fn load_jvm_memory_stats() -> (usize, usize) { - JVM.for_initialized_jvm(|jvm| { - let result: Result<(usize, usize), jni::errors::Error> = try { - let mut env = jvm.attach_current_thread()?; - - let runtime_instance = env.call_static_method( - "java/lang/Runtime", - "getRuntime", - "()Ljava/lang/Runtime;", - &[], - )?; - - let runtime_instance = match runtime_instance { - JValueOwned::Object(o) => o, - _ => unreachable!(), + match JVM_RESULT.get() { + Some(Ok(jvm)) => { + let result: Result<(usize, usize), jni::errors::Error> = try { + let mut env = jvm.attach_current_thread()?; + + let runtime_instance = env + .call_static_method( + "java/lang/Runtime", + "getRuntime", + "()Ljava/lang/Runtime;", + &[], + )? + .l() + .expect("should be object"); + + let total_memory = env + .call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[])? + .j() + .expect("should be long"); + + let free_memory = env + .call_method(runtime_instance, "freeMemory", "()J", &[])? + .j() + .expect("should be long"); + + (total_memory as usize, (total_memory - free_memory) as usize) }; - - let total_memory = env - .call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[])? - .j() - .expect("should be long"); - - let free_memory = env - .call_method(runtime_instance, "freeMemory", "()J", &[])? - .j() - .expect("should be long"); - - (total_memory as usize, (total_memory - free_memory) as usize) - }; - match result { - Ok(ret) => ret, - Err(e) => { - error!("failed to collect jvm stats: {:?}", e); - (0, 0) + match result { + Ok(ret) => ret, + Err(e) => { + error!("failed to collect jvm stats: {:?}", e); + (0, 0) + } } } - }) - .unwrap_or((0, 0)) + _ => (0, 0), + } }