From 501acadededce2b3876c9a875087b70c08e39cfe Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 30 Oct 2023 12:52:13 +0800 Subject: [PATCH] refactor(connector-node): jni reuse bidi stream handle --- src/connector/src/sink/remote.rs | 460 ++++++------------ .../src/source/cdc/enumerator/mod.rs | 2 +- src/connector/src/source/cdc/source/reader.rs | 7 +- src/jni_core/src/jvm_runtime.rs | 117 +++-- src/jni_core/src/lib.rs | 13 +- src/jni_core/src/macros.rs | 24 +- .../src/manager/sink_coordination/manager.rs | 93 ++-- src/rpc_client/src/connector_client.rs | 18 +- src/rpc_client/src/lib.rs | 89 ++-- src/rpc_client/src/sink_coordinate_client.rs | 25 +- 10 files changed, 351 insertions(+), 497 deletions(-) diff --git a/src/connector/src/sink/remote.rs b/src/connector/src/sink/remote.rs index 3c52cb720dbd4..a3e06b8a22cc3 100644 --- a/src/connector/src/sink/remote.rs +++ b/src/connector/src/sink/remote.rs @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::any::Any; use std::collections::HashMap; -use std::fmt::Formatter; use std::future::Future; use std::marker::PhantomData; use std::ops::Deref; @@ -21,25 +21,24 @@ use std::time::Instant; use anyhow::anyhow; use async_trait::async_trait; -use futures::stream::Peekable; use futures::{StreamExt, TryFutureExt, TryStreamExt}; use itertools::Itertools; use jni::objects::{JByteArray, JValue, JValueOwned}; +use jni::JavaVM; use prost::Message; use risingwave_common::array::StreamChunk; use risingwave_common::error::anyhow_error; use risingwave_common::types::DataType; use risingwave_common::util::await_future_with_monitor_error_stream; use risingwave_jni_core::jvm_runtime::JVM; -use risingwave_pb::connector_service::sink_coordinator_stream_request::{ - CommitMetadata, StartCoordinator, -}; +use risingwave_jni_core::{gen_class_name, gen_jni_sig, JniReceiverType, JniSenderType}; +use risingwave_pb::connector_service::sink_coordinator_stream_request::StartCoordinator; use risingwave_pb::connector_service::sink_writer_stream_request::write_batch::json_payload::RowOp; use risingwave_pb::connector_service::sink_writer_stream_request::write_batch::{ JsonPayload, Payload, StreamChunkPayload, }; use risingwave_pb::connector_service::sink_writer_stream_request::{ - Barrier, BeginEpoch, Request as SinkRequest, StartSink, WriteBatch, + Request as SinkRequest, StartSink, }; use risingwave_pb::connector_service::sink_writer_stream_response::CommitResponse; use risingwave_pb::connector_service::{ @@ -47,6 +46,10 @@ use risingwave_pb::connector_service::{ SinkCoordinatorStreamRequest, SinkCoordinatorStreamResponse, SinkMetadata, SinkPayloadFormat, SinkWriterStreamRequest, SinkWriterStreamResponse, ValidateSinkRequest, ValidateSinkResponse, }; +use risingwave_rpc_client::error::RpcError; +use risingwave_rpc_client::{ + BidiStreamReceiver, SinkCoordinatorStreamHandle, SinkWriterStreamHandle, DEFAULT_BUFFER_SIZE, +}; use tokio::sync::mpsc; use tokio::sync::mpsc::{Receiver, Sender}; use tokio_stream::wrappers::ReceiverStream; @@ -150,8 +153,7 @@ impl Sink for RemoteSink { }).try_collect()?; let mut env = JVM - .get_or_init() - .map_err(|err| SinkError::Internal(err.into()))? + .get_with_err()? .attach_current_thread() .map_err(|err| SinkError::Internal(err.into()))?; let validate_sink_request = ValidateSinkRequest { @@ -217,13 +219,13 @@ impl RemoteLogSinker { /// Await the given future while monitoring on error of the receiver stream. async fn await_future_with_monitor_receiver_err>>( - receiver: &mut SinkWriterStreamJniReceiver, + receiver: &mut BidiStreamReceiver, future: F, ) -> Result { - match await_future_with_monitor_error_stream(&mut receiver.response_stream, future).await { - Ok(result) => result, + match await_future_with_monitor_error_stream(&mut receiver.stream, future).await { + Ok(result) => Ok(result?), Err(None) => Err(SinkError::Remote(anyhow!("end of remote receiver stream"))), - Err(Some(err)) => Err(SinkError::Internal(err)), + Err(Some(err)) => Err(SinkError::Remote(err.into())), } } @@ -254,7 +256,7 @@ impl LogSinker for RemoteLogSinker { loop { let (epoch, item): (u64, LogStoreReadItem) = await_future_with_monitor_receiver_err( - &mut sink_writer.stream_handle.response_rx, + &mut sink_writer.stream_handle.response_stream, log_reader.next_item().map_err(SinkError::Internal), ) .await?; @@ -378,155 +380,6 @@ impl Sink for CoordinatedRemoteSink { } } -#[derive(Debug)] -pub struct SinkCoordinatorStreamJniHandle { - request_tx: Sender, - response_rx: Receiver, -} - -impl SinkCoordinatorStreamJniHandle { - pub async fn commit(&mut self, epoch: u64, metadata: Vec) -> Result<()> { - self.request_tx - .send(SinkCoordinatorStreamRequest { - request: Some(sink_coordinator_stream_request::Request::Commit( - CommitMetadata { epoch, metadata }, - )), - }) - .await - .map_err(|err| SinkError::Internal(err.into()))?; - - match self.response_rx.recv().await { - Some(SinkCoordinatorStreamResponse { - response: - Some(sink_coordinator_stream_response::Response::Commit( - sink_coordinator_stream_response::CommitResponse { - epoch: response_epoch, - }, - )), - }) => { - if epoch == response_epoch { - Ok(()) - } else { - Err(SinkError::Internal(anyhow!( - "get different response epoch to commit epoch: {} {}", - epoch, - response_epoch - ))) - } - } - msg => Err(SinkError::Internal(anyhow!( - "should get Commit response but get {:?}", - msg - ))), - } - } -} - -struct SinkWriterStreamJniSender { - request_tx: Sender, -} - -impl SinkWriterStreamJniSender { - pub async fn start_epoch(&mut self, epoch: u64) -> Result<()> { - self.request_tx - .send(SinkWriterStreamRequest { - request: Some(SinkRequest::BeginEpoch(BeginEpoch { epoch })), - }) - .await - .map_err(|err| SinkError::Internal(err.into())) - } - - pub async fn write_batch(&mut self, epoch: u64, batch_id: u64, payload: Payload) -> Result<()> { - self.request_tx - .send(SinkWriterStreamRequest { - request: Some(SinkRequest::WriteBatch(WriteBatch { - epoch, - batch_id, - payload: Some(payload), - })), - }) - .await - .map_err(|err| SinkError::Internal(err.into())) - } - - pub async fn barrier(&mut self, epoch: u64, is_checkpoint: bool) -> Result<()> { - self.request_tx - .send(SinkWriterStreamRequest { - request: Some(SinkRequest::Barrier(Barrier { - epoch, - is_checkpoint, - })), - }) - .await - .map_err(|err| SinkError::Internal(err.into())) - } -} - -struct SinkWriterStreamJniReceiver { - response_stream: Peekable>>, -} - -impl SinkWriterStreamJniReceiver { - async fn next_commit_response(&mut self) -> Result { - match self.response_stream.try_next().await { - Ok(Some(SinkWriterStreamResponse { - response: Some(sink_writer_stream_response::Response::Commit(rsp)), - })) => Ok(rsp), - msg => Err(SinkError::Internal(anyhow!( - "should get Sync response but get {:?}", - msg - ))), - } - } -} - -const DEFAULT_CHANNEL_SIZE: usize = 16; -struct SinkWriterStreamJniHandle { - request_tx: SinkWriterStreamJniSender, - response_rx: SinkWriterStreamJniReceiver, -} - -impl std::fmt::Debug for SinkWriterStreamJniHandle { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SinkWriterStreamJniHandle").finish() - } -} - -impl SinkWriterStreamJniHandle { - async fn start_epoch(&mut self, epoch: u64) -> Result<()> { - await_future_with_monitor_receiver_err( - &mut self.response_rx, - self.request_tx.start_epoch(epoch), - ) - .await - } - - async fn write_batch(&mut self, epoch: u64, batch_id: u64, payload: Payload) -> Result<()> { - await_future_with_monitor_receiver_err( - &mut self.response_rx, - self.request_tx.write_batch(epoch, batch_id, payload), - ) - .await - } - - async fn barrier(&mut self, epoch: u64) -> Result<()> { - await_future_with_monitor_receiver_err( - &mut self.response_rx, - self.request_tx.barrier(epoch, false), - ) - .await - } - - async fn commit(&mut self, epoch: u64) -> Result { - await_future_with_monitor_receiver_err( - &mut self.response_rx, - self.request_tx.barrier(epoch, true), - ) - .await?; - self.response_rx.next_commit_response().await - } -} - pub type RemoteSinkWriter = RemoteSinkWriterInner<(), R>; pub type CoordinatedRemoteSinkWriter = RemoteSinkWriterInner, R>; @@ -535,7 +388,7 @@ pub struct RemoteSinkWriterInner { epoch: Option, batch_id: u64, payload_format: SinkPayloadFormat, - stream_handle: SinkWriterStreamJniHandle, + stream_handle: SinkWriterStreamHandle, json_encoder: JsonEncoder, sink_metrics: SinkMetrics, _phantom: PhantomData<(SM, R)>, @@ -547,84 +400,12 @@ impl RemoteSinkWriterInner { connector_params: ConnectorParams, sink_metrics: SinkMetrics, ) -> Result { - let (request_tx, request_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - let (response_tx, response_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - - let mut response_stream = ReceiverStream::new(response_rx).peekable(); - - std::thread::spawn(move || { - let mut env = JVM.get_or_init().unwrap().attach_current_thread().unwrap(); - - let result = env.call_static_method( - "com/risingwave/connector/JniSinkWriterHandler", - "runJniSinkWriterThread", - "(JJ)V", - &[ - JValue::from(&request_rx as *const Receiver as i64), - JValue::from( - &response_tx as *const Sender> - as i64, - ), - ], - ); - - match result { - Ok(_) => { - tracing::info!("end of jni call runJniSinkWriterThread"); - } - Err(e) => { - tracing::error!("jni call error: {:?}", e); - } - }; - }); - - let sink_writer_stream_request = SinkWriterStreamRequest { - request: Some(SinkRequest::Start(StartSink { - sink_param: Some(param.to_proto()), - format: connector_params.sink_payload_format as i32, - })), - }; - - // First request - request_tx - .send(sink_writer_stream_request) - .await - .map_err(|err| { - SinkError::Internal(anyhow!( - "fail to send start request for connector `{}`: {:?}", - R::SINK_NAME, - err - )) - })?; - - // First response - match response_stream.try_next().await { - Ok(Some(SinkWriterStreamResponse { - response: Some(sink_writer_stream_response::Response::Start(_)), - })) => {} - Ok(msg) => { - return Err(SinkError::Internal(anyhow!( - "should get start response for connector `{}` but get {:?}", - R::SINK_NAME, - msg - ))); - } - Err(e) => return Err(SinkError::Internal(e)), - }; - - tracing::trace!( - "{:?} sink stream started with properties: {:?}", - R::SINK_NAME, - ¶m.properties - ); + let stream_handle = EmbeddedConnectorClient::new()? + .start_sink_writer_stream(param.clone(), connector_params.sink_payload_format) + .await?; let schema = param.schema(); - let stream_handle = SinkWriterStreamJniHandle { - request_tx: SinkWriterStreamJniSender { request_tx }, - response_rx: SinkWriterStreamJniReceiver { response_stream }, - }; - Ok(Self { properties: param.properties, epoch: None, @@ -637,7 +418,6 @@ impl RemoteSinkWriterInner { }) } - #[cfg(test)] fn for_test( response_receiver: Receiver>, request_sender: Sender, @@ -660,14 +440,12 @@ impl RemoteSinkWriterInner { }, ]); - let stream_handle = SinkWriterStreamJniHandle { - request_tx: SinkWriterStreamJniSender { - request_tx: request_sender, - }, - response_rx: SinkWriterStreamJniReceiver { - response_stream: ReceiverStream::new(response_receiver).peekable(), - }, - }; + let stream_handle = SinkWriterStreamHandle::for_test( + request_sender, + ReceiverStream::new(response_receiver) + .map_err(RpcError::from) + .boxed(), + ); RemoteSinkWriter { properties, @@ -795,81 +573,15 @@ where } pub struct RemoteCoordinator { - stream_handle: SinkCoordinatorStreamJniHandle, + stream_handle: SinkCoordinatorStreamHandle, _phantom: PhantomData, } impl RemoteCoordinator { pub async fn new(param: SinkParam) -> Result { - let (request_tx, request_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - let (response_tx, response_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - - let mut stream_handle = SinkCoordinatorStreamJniHandle { - request_tx, - response_rx, - }; - - std::thread::spawn(move || { - let mut env = JVM.get_or_init().unwrap().attach_current_thread().unwrap(); - - let result = env.call_static_method( - "com/risingwave/connector/JniSinkCoordinatorHandler", - "runJniSinkCoordinatorThread", - "(JJ)V", - &[ - JValue::from( - &request_rx as *const Receiver as i64, - ), - JValue::from( - &response_tx as *const Sender as i64, - ), - ], - ); - - match result { - Ok(_) => { - tracing::info!("end of jni call runJniSinkCoordinatorThread"); - } - Err(e) => { - tracing::error!("jni call error: {:?}", e); - } - }; - }); - - let sink_coordinator_stream_request = SinkCoordinatorStreamRequest { - request: Some(sink_coordinator_stream_request::Request::Start( - StartCoordinator { - param: Some(param.to_proto()), - }, - )), - }; - - // First request - stream_handle - .request_tx - .send(sink_coordinator_stream_request) - .await - .map_err(|err| { - SinkError::Internal(anyhow!( - "fail to send start request for connector `{}`: {:?}", - R::SINK_NAME, - err - )) - })?; - - // First response - match stream_handle.response_rx.recv().await { - Some(SinkCoordinatorStreamResponse { - response: Some(sink_coordinator_stream_response::Response::Start(_)), - }) => {} - msg => { - return Err(SinkError::Internal(anyhow!( - "should get start response for connector `{}` but get {:?}", - R::SINK_NAME, - msg - ))); - } - }; + let stream_handle = EmbeddedConnectorClient::new()? + .start_sink_coordinator_stream(param.clone()) + .await?; tracing::trace!( "{:?} RemoteCoordinator started with properties: {:?}", @@ -895,6 +607,120 @@ impl SinkCommitCoordinator for RemoteCoordinator { } } +struct EmbeddedConnectorClient { + jvm: &'static JavaVM, +} + +impl EmbeddedConnectorClient { + fn new() -> Result { + let jvm = JVM.get().map_err(|e| anyhow!("cannot get jvm: {:?}", e))?; + Ok(EmbeddedConnectorClient { jvm }) + } + + async fn start_sink_writer_stream( + &self, + sink_param: SinkParam, + sink_payload_format: SinkPayloadFormat, + ) -> Result { + let (handle, first_rsp) = SinkWriterStreamHandle::initialize( + SinkWriterStreamRequest { + request: Some(SinkRequest::Start(StartSink { + sink_param: Some(sink_param.to_proto()), + format: sink_payload_format as i32, + })), + }, + |rx| async move { + let rx = self.start_jvm_worker_thread( + gen_class_name!(com.risingwave.connector.JniSinkWriterHandler), + "runJniSinkWriterThread", + rx, + ); + Ok(ReceiverStream::new(rx).map_err(RpcError::from)) + }, + ) + .await?; + + match first_rsp { + SinkWriterStreamResponse { + response: Some(sink_writer_stream_response::Response::Start(_)), + } => Ok(handle), + msg => Err(SinkError::Internal(anyhow!( + "should get start response but get {:?}", + msg + ))), + } + } + + pub async fn start_sink_coordinator_stream( + &self, + param: SinkParam, + ) -> Result { + let (handle, first_rsp) = SinkCoordinatorStreamHandle::initialize( + SinkCoordinatorStreamRequest { + request: Some(sink_coordinator_stream_request::Request::Start( + StartCoordinator { + param: Some(param.to_proto()), + }, + )), + }, + |rx| async move { + let rx = self.start_jvm_worker_thread( + gen_class_name!(com.risingwave.connector.JniSinkCoordinatorHandler), + "runJniSinkCoordinatorThread", + rx, + ); + Ok(ReceiverStream::new(rx).map_err(RpcError::from)) + }, + ) + .await?; + + match first_rsp { + SinkCoordinatorStreamResponse { + response: Some(sink_coordinator_stream_response::Response::Start(_)), + } => Ok(handle), + msg => Err(SinkError::Internal(anyhow!( + "should get start response but get {:?}", + msg + ))), + } + } + + fn start_jvm_worker_thread( + &self, + class_name: &'static str, + method_name: &'static str, + request_rx: JniReceiverType, + ) -> Receiver> { + let (response_tx, response_rx): (JniSenderType, _) = + mpsc::channel(DEFAULT_BUFFER_SIZE); + + let jvm = self.jvm; + std::thread::spawn(move || { + let mut env = jvm.attach_current_thread().unwrap(); + + let result = env.call_static_method( + class_name, + method_name, + gen_jni_sig!(void f(long, long)), + &[ + JValue::from(&request_rx as *const JniReceiverType as i64), + JValue::from(&response_tx as *const JniSenderType as i64), + ], + ); + + match result { + Ok(_) => { + tracing::info!("end of jni call {}::{}", class_name, method_name); + } + Err(e) => { + tracing::error!("jni call error: {:?}", e); + } + }; + }); + response_rx + } +} + #[cfg(test)] mod test { use std::time::Duration; diff --git a/src/connector/src/source/cdc/enumerator/mod.rs b/src/connector/src/source/cdc/enumerator/mod.rs index e88440bc876e1..803e092ab90a3 100644 --- a/src/connector/src/source/cdc/enumerator/mod.rs +++ b/src/connector/src/source/cdc/enumerator/mod.rs @@ -69,7 +69,7 @@ where SourceType::from(T::source_type()) ); - let mut env = JVM.get_or_init()?.attach_current_thread()?; + let mut env = JVM.get_with_err()?.attach_current_thread()?; let validate_source_request = ValidateSourceRequest { source_id: context.info.source_id as u64, diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 4d25d82c106c3..d4382b67739e0 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -123,8 +123,9 @@ impl CommonSplitReader for CdcSplitReader { let (tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - // Force init, because we don't want to see initialization failure in the following thread. - JVM.get_or_init()?; + let jvm = JVM + .get() + .map_err(|e| anyhow!("jvm not initialized properly: {:?}", e))?; let get_event_stream_request = GetEventStreamRequest { source_id: self.source_id, @@ -138,7 +139,7 @@ impl CommonSplitReader for CdcSplitReader { let source_type = get_event_stream_request.source_type.to_string(); std::thread::spawn(move || { - let mut env = JVM.get_or_init().unwrap().attach_current_thread().unwrap(); + let mut env = jvm.attach_current_thread().unwrap(); let get_event_stream_request_bytes = env .byte_array_from_slice(&Message::encode_to_vec(&get_event_stream_request)) diff --git a/src/jni_core/src/jvm_runtime.rs b/src/jni_core/src/jvm_runtime.rs index bd1f068b6eaee..8a51c862506ed 100644 --- a/src/jni_core/src/jvm_runtime.rs +++ b/src/jni_core/src/jvm_runtime.rs @@ -18,50 +18,49 @@ use std::fs; use std::path::Path; use std::sync::OnceLock; +use anyhow::anyhow; use jni::objects::JValueOwned; use jni::strings::JNIString; use jni::{InitArgsBuilder, JNIVersion, JavaVM, NativeMethod}; -use risingwave_common::error::{ErrorCode, RwError}; use risingwave_common::util::resource_util::memory::system_memory_available_bytes; +use tracing::error; /// Use 10% of compute total memory by default. Compute node uses 0.7 * system memory by default. const DEFAULT_MEMORY_PROPORTION: f64 = 0.07; -pub static JVM: JavaVmWrapper = JavaVmWrapper::new(); +pub static JVM: JavaVmWrapper = JavaVmWrapper; -pub struct JavaVmWrapper(OnceLock>); +pub struct JavaVmWrapper; impl JavaVmWrapper { - const fn new() -> Self { - Self(OnceLock::new()) + pub fn get(&self) -> Result<&'static JavaVM, &String> { + static JVM_RESULT: OnceLock> = OnceLock::new(); + JVM_RESULT + .get_or_init(|| { + Self::inner_new().inspect_err(|e| error!("failed to init jvm: {:?}", e)) + }) + .as_ref() } - pub fn get(&self) -> Option<&Result> { - self.0.get() + pub fn get_with_err(&self) -> anyhow::Result<&'static JavaVM> { + self.get() + .map_err(|e| anyhow!("jvm not initialized properly: {:?}", e)) } - pub fn get_or_init(&self) -> Result<&JavaVM, &RwError> { - self.0.get_or_init(Self::inner_new).as_ref() - } - - fn inner_new() -> Result { + fn inner_new() -> Result { let libs_path = if let Ok(libs_path) = std::env::var("CONNECTOR_LIBS_PATH") { libs_path } else { - return Err(ErrorCode::InternalError( - "environment variable CONNECTOR_LIBS_PATH is not specified".to_string(), - ) - .into()); + return Err("environment variable CONNECTOR_LIBS_PATH is not specified".to_string()); }; let dir = Path::new(&libs_path); if !dir.is_dir() { - return Err(ErrorCode::InternalError(format!( + return Err(format!( "CONNECTOR_LIBS_PATH \"{}\" is not a directory", libs_path - )) - .into()); + )); } let mut class_vec = vec![]; @@ -70,16 +69,16 @@ impl JavaVmWrapper { for entry in entries.flatten() { let entry_path = entry.path(); if entry_path.file_name().is_some() { - let path = std::fs::canonicalize(entry_path)?; + let path = std::fs::canonicalize(entry_path) + .expect("valid entry_path obtained from fs::read_dir"); class_vec.push(path.to_str().unwrap().to_string()); } } } else { - return Err(ErrorCode::InternalError(format!( + return Err(format!( "failed to read CONNECTOR_LIBS_PATH \"{}\"", libs_path - )) - .into()); + )); } let jvm_heap_size = if let Ok(heap_size) = std::env::var("JVM_HEAP_SIZE") { @@ -101,20 +100,23 @@ impl JavaVmWrapper { .option(format!("-Xmx{}", jvm_heap_size)); tracing::info!("JVM args: {:?}", args_builder); - let jvm_args = args_builder.build().unwrap(); + let jvm_args = args_builder + .build() + .map_err(|e| format!("invalid jvm args: {:?}", e))?; // Create a new VM let jvm = match JavaVM::new(jvm_args) { Err(err) => { tracing::error!("fail to new JVM {:?}", err); - return Err(ErrorCode::InternalError("fail to new JVM".to_string()).into()); + return Err("fail to new JVM".to_string()); } Ok(jvm) => jvm, }; tracing::info!("initialize JVM successfully"); - register_native_method_for_jvm(&jvm).unwrap(); + register_native_method_for_jvm(&jvm) + .map_err(|e| format!("failed to register native method: {:?}", e))?; Ok(jvm) } @@ -160,40 +162,35 @@ pub fn register_native_method_for_jvm(jvm: &JavaVM) -> Result<(), jni::errors::E /// Load JVM memory statistics from the runtime. If JVM is not initialized or fail to initialize, return zero. pub fn load_jvm_memory_stats() -> (usize, usize) { - if let Some(jvm) = JVM.get() { - match jvm { - Ok(jvm) => { - let mut env = jvm.attach_current_thread().unwrap(); - let runtime_instance = env - .call_static_method( - "java/lang/Runtime", - "getRuntime", - "()Ljava/lang/Runtime;", - &[], - ) - .unwrap(); - - let runtime_instance = match runtime_instance { - JValueOwned::Object(o) => o, - _ => unreachable!(), - }; - - let total_memory = env - .call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[]) - .unwrap() - .j() - .unwrap(); - - let free_memory = env - .call_method(runtime_instance, "freeMemory", "()J", &[]) - .unwrap() - .j() - .unwrap(); - - (total_memory as usize, (total_memory - free_memory) as usize) - } - Err(_) => (0, 0), - } + if let Ok(jvm) = JVM.get() { + let mut env = jvm.attach_current_thread().unwrap(); + let runtime_instance = env + .call_static_method( + "java/lang/Runtime", + "getRuntime", + "()Ljava/lang/Runtime;", + &[], + ) + .unwrap(); + + let runtime_instance = match runtime_instance { + JValueOwned::Object(o) => o, + _ => unreachable!(), + }; + + let total_memory = env + .call_method(runtime_instance.as_ref(), "totalMemory", "()J", &[]) + .unwrap() + .j() + .unwrap(); + + let free_memory = env + .call_method(runtime_instance, "freeMemory", "()J", &[]) + .unwrap() + .j() + .unwrap(); + + (total_memory as usize, (total_memory - free_memory) as usize) } else { (0, 0) } diff --git a/src/jni_core/src/lib.rs b/src/jni_core/src/lib.rs index c92ad2f146e6c..414d8348760c1 100644 --- a/src/jni_core/src/lib.rs +++ b/src/jni_core/src/lib.rs @@ -879,12 +879,15 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToCh }) } +pub type JniSenderType = Sender>; +pub type JniReceiverType = Receiver; + #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkWriterRequestFromChannel< 'a, >( env: EnvParam<'a>, - mut channel: Pointer<'a, Receiver>, + mut channel: Pointer<'a, JniReceiverType>, ) -> JByteArray<'a> { execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() { Some(msg) => { @@ -902,7 +905,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkWriterRe 'a, >( env: EnvParam<'a>, - channel: Pointer<'a, Sender>>, + channel: Pointer<'a, JniSenderType>, msg: JByteArray<'a>, ) -> jboolean { execute_and_catch(env, move |env| { @@ -927,7 +930,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkCoordina 'a, >( env: EnvParam<'a>, - mut channel: Pointer<'a, Receiver>, + mut channel: Pointer<'a, JniReceiverType>, ) -> JByteArray<'a> { execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() { Some(msg) => { @@ -945,7 +948,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkCoordina 'a, >( env: EnvParam<'a>, - channel: Pointer<'a, Sender>, + channel: Pointer<'a, JniSenderType>, msg: JByteArray<'a>, ) -> jboolean { execute_and_catch(env, move |env| { @@ -954,7 +957,7 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkCoordina match channel .as_ref() - .blocking_send(sink_coordinator_stream_response) + .blocking_send(Ok(sink_coordinator_stream_response)) { Ok(_) => Ok(JNI_TRUE), Err(e) => { diff --git a/src/jni_core/src/macros.rs b/src/jni_core/src/macros.rs index 0ca0748fb0206..2484fbba9488c 100644 --- a/src/jni_core/src/macros.rs +++ b/src/jni_core/src/macros.rs @@ -18,20 +18,20 @@ macro_rules! gen_class_name { stringify! {$last} }; ($first:ident . $($rest:ident).+) => { - concat! {stringify! {$first}, "/", gen_class_name! {$($rest).+} } + concat! {stringify! {$first}, "/", $crate::gen_class_name! {$($rest).+} } } } #[macro_export] macro_rules! gen_jni_sig_inner { ($(public)? static native $($rest:tt)*) => { - gen_jni_sig_inner! { $($rest)* } + $crate::gen_jni_sig_inner! { $($rest)* } }; ($($ret:ident).+ $($func_name:ident)? ($($args:tt)*)) => { - concat! {"(", gen_jni_sig_inner!{$($args)*}, ")", gen_jni_sig_inner! {$($ret).+} } + concat! {"(", $crate::gen_jni_sig_inner!{$($args)*}, ")", $crate::gen_jni_sig_inner! {$($ret).+} } }; ($($ret:ident).+ [] $($func_name:ident)? ($($args:tt)*)) => { - concat! {"(", gen_jni_sig_inner!{$($args)*}, ")", gen_jni_sig_inner! {$($ret).+ []} } + concat! {"(", $crate::gen_jni_sig_inner!{$($args)*}, ")", $crate::gen_jni_sig_inner! {$($ret).+ []} } }; (boolean) => { "Z" @@ -61,25 +61,25 @@ macro_rules! gen_jni_sig_inner { "V" }; (String) => { - gen_jni_sig_inner! { java.lang.String } + $crate::gen_jni_sig_inner! { java.lang.String } }; (Object) => { - gen_jni_sig_inner! { java.lang.Object } + $crate::gen_jni_sig_inner! { java.lang.Object } }; (Class) => { - gen_jni_sig_inner! { java.lang.Class } + $crate::gen_jni_sig_inner! { java.lang.Class } }; ($($class_part:ident).+) => { - concat! {"L", gen_class_name! {$($class_part).+}, ";"} + concat! {"L", $crate::gen_class_name! {$($class_part).+}, ";"} }; ($($class_part:ident).+ $(.)? [] $($param_name:ident)? $(,$($rest:tt)*)?) => { - concat! { "[", gen_jni_sig_inner! {$($class_part).+}, gen_jni_sig_inner! {$($($rest)*)?}} + concat! { "[", $crate::gen_jni_sig_inner! {$($class_part).+}, $crate::gen_jni_sig_inner! {$($($rest)*)?}} }; (Class $(< ? >)? $($param_name:ident)? $(,$($rest:tt)*)?) => { - concat! { gen_jni_sig_inner! { Class }, gen_jni_sig_inner! {$($($rest)*)?}} + concat! { $crate::gen_jni_sig_inner! { Class }, $crate::gen_jni_sig_inner! {$($($rest)*)?}} }; ($($class_part:ident).+ $($param_name:ident)? $(,$($rest:tt)*)?) => { - concat! { gen_jni_sig_inner! {$($class_part).+}, gen_jni_sig_inner! {$($($rest)*)?}} + concat! { $crate::gen_jni_sig_inner! {$($class_part).+}, $crate::gen_jni_sig_inner! {$($($rest)*)?}} }; () => { "" @@ -93,7 +93,7 @@ macro_rules! gen_jni_sig_inner { macro_rules! gen_jni_sig { ($($input:tt)*) => {{ // this macro only provide with a expression context - gen_jni_sig_inner! {$($input)*} + $crate::gen_jni_sig_inner! {$($input)*} }} } diff --git a/src/meta/src/manager/sink_coordination/manager.rs b/src/meta/src/manager/sink_coordination/manager.rs index 720a698fa8e72..4c2b2a37295b1 100644 --- a/src/meta/src/manager/sink_coordination/manager.rs +++ b/src/meta/src/manager/sink_coordination/manager.rs @@ -362,6 +362,7 @@ mod tests { use risingwave_pb::connector_service::sink_metadata::{Metadata, SerializedMetadata}; use risingwave_pb::connector_service::SinkMetadata; use risingwave_rpc_client::CoordinatorStreamHandle; + use tokio_stream::wrappers::ReceiverStream; use crate::manager::sink_coordination::coordinator_worker::CoordinatorWorker; use crate::manager::sink_coordination::{NewSinkWriterRequest, SinkCoordinatorManager}; @@ -481,19 +482,15 @@ mod tests { }); let build_client = |vnode| async { - CoordinatorStreamHandle::new_with_init_stream( - param.to_proto(), - vnode, - |stream_req| async { - Ok(tonic::Response::new( - manager - .handle_new_request(stream_req.into_inner().map(Ok).boxed()) - .await - .unwrap() - .boxed(), - )) - }, - ) + CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async { + Ok(tonic::Response::new( + manager + .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed()) + .await + .unwrap() + .boxed(), + )) + }) .await .unwrap() }; @@ -647,19 +644,15 @@ mod tests { }); let build_client = |vnode| async { - CoordinatorStreamHandle::new_with_init_stream( - param.to_proto(), - vnode, - |stream_req| async { - Ok(tonic::Response::new( - manager - .handle_new_request(stream_req.into_inner().map(Ok).boxed()) - .await - .unwrap() - .boxed(), - )) - }, - ) + CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async { + Ok(tonic::Response::new( + manager + .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed()) + .await + .unwrap() + .boxed(), + )) + }) .await .unwrap() }; @@ -710,10 +703,10 @@ mod tests { let mut build_client_future1 = pin!(CoordinatorStreamHandle::new_with_init_stream( param.to_proto(), Bitmap::zeros(VirtualNode::COUNT), - |stream_req| async { + |rx| async { Ok(tonic::Response::new( manager - .handle_new_request(stream_req.into_inner().map(Ok).boxed()) + .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed()) .await .unwrap() .boxed(), @@ -778,19 +771,15 @@ mod tests { }); let build_client = |vnode| async { - CoordinatorStreamHandle::new_with_init_stream( - param.to_proto(), - vnode, - |stream_req| async { - Ok(tonic::Response::new( - manager - .handle_new_request(stream_req.into_inner().map(Ok).boxed()) - .await - .unwrap() - .boxed(), - )) - }, - ) + CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async { + Ok(tonic::Response::new( + manager + .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed()) + .await + .unwrap() + .boxed(), + )) + }) .await .unwrap() }; @@ -863,19 +852,15 @@ mod tests { }); let build_client = |vnode| async { - CoordinatorStreamHandle::new_with_init_stream( - param.to_proto(), - vnode, - |stream_req| async { - Ok(tonic::Response::new( - manager - .handle_new_request(stream_req.into_inner().map(Ok).boxed()) - .await - .unwrap() - .boxed(), - )) - }, - ) + CoordinatorStreamHandle::new_with_init_stream(param.to_proto(), vnode, |rx| async { + Ok(tonic::Response::new( + manager + .handle_new_request(ReceiverStream::new(rx).map(Ok).boxed()) + .await + .unwrap() + .boxed(), + )) + }) .await .unwrap() }; diff --git a/src/rpc_client/src/connector_client.rs b/src/rpc_client/src/connector_client.rs index 12386628e5516..4a4036ff4c874 100644 --- a/src/rpc_client/src/connector_client.rs +++ b/src/rpc_client/src/connector_client.rs @@ -17,6 +17,7 @@ use std::fmt::Debug; use std::time::Duration; use anyhow::anyhow; +use futures::TryStreamExt; use risingwave_common::config::{MAX_CONNECTION_WINDOW_SIZE, STREAM_WINDOW_SIZE}; use risingwave_common::monitor::connection::{EndpointExt, TcpConfig}; use risingwave_pb::connector_service::connector_service_client::ConnectorServiceClient; @@ -29,6 +30,7 @@ use risingwave_pb::connector_service::sink_writer_stream_request::{ }; use risingwave_pb::connector_service::sink_writer_stream_response::CommitResponse; use risingwave_pb::connector_service::*; +use tokio_stream::wrappers::ReceiverStream; use tonic::transport::{Channel, Endpoint}; use tonic::Streaming; use tracing::error; @@ -262,7 +264,13 @@ impl ConnectorClient { format: sink_payload_format as i32, })), }, - |req_stream| async move { rpc_client.sink_writer_stream(req_stream).await }, + |rx| async move { + rpc_client + .sink_writer_stream(ReceiverStream::new(rx)) + .await + .map(|response| response.into_inner().map_err(RpcError::from)) + .map_err(RpcError::from) + }, ) .await?; @@ -288,7 +296,13 @@ impl ConnectorClient { StartCoordinator { param: Some(param) }, )), }, - |req_stream| async move { rpc_client.sink_coordinator_stream(req_stream).await }, + |rx| async move { + rpc_client + .sink_coordinator_stream(ReceiverStream::new(rx)) + .await + .map(|response| response.into_inner().map_err(RpcError::from)) + .map_err(RpcError::from) + }, ) .await?; diff --git a/src/rpc_client/src/lib.rs b/src/rpc_client/src/lib.rs index 6afa67ef88efe..f7a90a36e125c 100644 --- a/src/rpc_client/src/lib.rs +++ b/src/rpc_client/src/lib.rs @@ -42,9 +42,7 @@ use rand::prelude::SliceRandom; use risingwave_common::util::addr::HostAddr; use risingwave_pb::common::WorkerNode; use risingwave_pb::meta::heartbeat_request::extra_info; -use tokio::sync::mpsc::{channel, Sender}; -use tokio_stream::wrappers::ReceiverStream; -use tonic::{Request, Response, Status}; +use tokio::sync::mpsc::{channel, Receiver, Sender}; pub mod error; use error::{Result, RpcError}; @@ -171,83 +169,106 @@ macro_rules! meta_rpc_client_method_impl { } } -pub struct BidiStreamHandle { - request_sender: Sender, - response_stream: Peekable>>, +pub const DEFAULT_BUFFER_SIZE: usize = 16; + +pub struct BidiStreamSender { + tx: Sender, +} + +impl BidiStreamSender { + pub async fn send_request(&mut self, request: REQ) -> Result<()> { + self.tx + .send(request) + .await + .map_err(|_| anyhow!("unable to send request {}", type_name::()).into()) + } +} + +pub struct BidiStreamReceiver { + pub stream: Peekable>>, +} + +impl BidiStreamReceiver { + pub async fn next_response(&mut self) -> Result { + self.stream + .next() + .await + .ok_or_else(|| anyhow!("end of response stream"))? + } +} + +pub struct BidiStreamHandle { + pub request_sender: BidiStreamSender, + pub response_stream: BidiStreamReceiver, } -impl Debug for BidiStreamHandle { +impl Debug for BidiStreamHandle { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str(type_name::()) } } -impl BidiStreamHandle { +impl BidiStreamHandle { pub fn for_test( request_sender: Sender, - response_stream: BoxStream<'static, std::result::Result>, + response_stream: BoxStream<'static, Result>, ) -> Self { Self { - request_sender, - response_stream: response_stream.peekable(), + request_sender: BidiStreamSender { tx: request_sender }, + response_stream: BidiStreamReceiver { + stream: response_stream.peekable(), + }, } } pub async fn initialize< - F: FnOnce(Request>) -> Fut, - St: Stream> + Send + Unpin + 'static, - Fut: Future, Status>> + Send, + F: FnOnce(Receiver) -> Fut, + St: Stream> + Send + Unpin + 'static, + Fut: Future> + Send, >( first_request: REQ, init_stream_fn: F, ) -> Result<(Self, RSP)> { - const SINK_WRITER_REQUEST_BUFFER_SIZE: usize = 16; - let (request_sender, request_receiver) = channel(SINK_WRITER_REQUEST_BUFFER_SIZE); + let (request_sender, request_receiver) = channel(DEFAULT_BUFFER_SIZE); // Send initial request in case of the blocking receive call from creating streaming request request_sender .send(first_request) .await - .map_err(|err| anyhow!(err.to_string()))?; + .map_err(|_err| anyhow!("unable to send first request of {}", type_name::()))?; - let mut response_stream = - init_stream_fn(Request::new(ReceiverStream::new(request_receiver))) - .await? - .into_inner(); + let mut response_stream = init_stream_fn(request_receiver).await?; let first_response = response_stream .next() .await - .ok_or_else(|| anyhow!("get empty response from start sink request"))??; + .ok_or_else(|| anyhow!("get empty response from firstrequest"))??; Ok(( Self { - request_sender, - response_stream: response_stream.boxed().peekable(), + request_sender: BidiStreamSender { tx: request_sender }, + response_stream: BidiStreamReceiver { + stream: response_stream.boxed().peekable(), + }, }, first_response, )) } pub async fn next_response(&mut self) -> Result { - Ok(self - .response_stream - .next() - .await - .ok_or_else(|| anyhow!("end of response stream"))??) + self.response_stream.next_response().await } pub async fn send_request(&mut self, request: REQ) -> Result<()> { match await_future_with_monitor_error_stream( - &mut self.response_stream, - self.request_sender.send(request), + &mut self.response_stream.stream, + self.request_sender.send_request(request), ) .await { - Ok(send_result) => send_result - .map_err(|_| anyhow!("unable to send request {}", type_name::()).into()), + Ok(send_result) => send_result, Err(None) => Err(anyhow!("end of response stream").into()), - Err(Some(e)) => Err(e.into()), + Err(Some(e)) => Err(e), } } } diff --git a/src/rpc_client/src/sink_coordinate_client.rs b/src/rpc_client/src/sink_coordinate_client.rs index 2afa878e2cd34..0eb8d97e5aa9f 100644 --- a/src/rpc_client/src/sink_coordinate_client.rs +++ b/src/rpc_client/src/sink_coordinate_client.rs @@ -15,7 +15,7 @@ use std::future::Future; use anyhow::anyhow; -use futures::Stream; +use futures::{Stream, TryStreamExt}; use risingwave_common::buffer::Bitmap; use risingwave_pb::connector_service::coordinate_request::{ CommitRequest, StartCoordinationRequest, @@ -24,9 +24,11 @@ use risingwave_pb::connector_service::{ coordinate_request, coordinate_response, CoordinateRequest, CoordinateResponse, PbSinkParam, SinkMetadata, }; +use tokio::sync::mpsc::Receiver; use tokio_stream::wrappers::ReceiverStream; -use tonic::{Request, Response, Status}; +use tonic::{Response, Status}; +use crate::error::RpcError; use crate::{BidiStreamHandle, SinkCoordinationRpcClient}; pub type CoordinatorStreamHandle = BidiStreamHandle; @@ -36,9 +38,9 @@ impl CoordinatorStreamHandle { mut client: SinkCoordinationRpcClient, param: PbSinkParam, vnode_bitmap: Bitmap, - ) -> anyhow::Result { - Self::new_with_init_stream(param, vnode_bitmap, |req_stream| async move { - client.coordinate(req_stream).await + ) -> Result { + Self::new_with_init_stream(param, vnode_bitmap, |rx| async move { + client.coordinate(ReceiverStream::new(rx)).await }) .await } @@ -47,9 +49,9 @@ impl CoordinatorStreamHandle { param: PbSinkParam, vnode_bitmap: Bitmap, init_stream: F, - ) -> anyhow::Result + ) -> Result where - F: FnOnce(Request>) -> Fut, + F: FnOnce(Receiver) -> Fut + Send, St: Stream> + Send + Unpin + 'static, Fut: Future, Status>> + Send, { @@ -62,14 +64,19 @@ impl CoordinatorStreamHandle { }, )), }, - init_stream, + move |rx| async move { + init_stream(rx) + .await + .map(|response| response.into_inner().map_err(RpcError::from)) + .map_err(RpcError::from) + }, ) .await?; match first_response { CoordinateResponse { msg: Some(coordinate_response::Msg::StartResponse(_)), } => Ok(stream_handle), - msg => Err(anyhow!("should get start response but get {:?}", msg)), + msg => Err(anyhow!("should get start response but get {:?}", msg).into()), } }