diff --git a/java/com_risingwave_java_binding_Binding.h b/java/com_risingwave_java_binding_Binding.h index a3e9aa95ec84..b28989669956 100644 --- a/java/com_risingwave_java_binding_Binding.h +++ b/java/com_risingwave_java_binding_Binding.h @@ -231,6 +231,22 @@ JNIEXPORT jlong JNICALL Java_com_risingwave_java_binding_Binding_streamChunkIter JNIEXPORT jboolean JNICALL Java_com_risingwave_java_binding_Binding_sendCdcSourceMsgToChannel (JNIEnv *, jclass, jlong, jbyteArray); +/* + * Class: com_risingwave_java_binding_Binding + * Method: recvSinkWriterRequestFromChannel + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_com_risingwave_java_binding_Binding_recvSinkWriterRequestFromChannel + (JNIEnv *, jclass, jlong); + +/* + * Class: com_risingwave_java_binding_Binding + * Method: sendSinkWriterResponseToChannel + * Signature: (J[B)Z + */ +JNIEXPORT jboolean JNICALL Java_com_risingwave_java_binding_Binding_sendSinkWriterResponseToChannel + (JNIEnv *, jclass, jlong, jbyteArray); + #ifdef __cplusplus } #endif diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/JniSinkResponseObserver.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/JniSinkResponseObserver.java new file mode 100644 index 000000000000..493a10433c53 --- /dev/null +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/JniSinkResponseObserver.java @@ -0,0 +1,53 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.connector; + +import com.risingwave.java.binding.Binding; +import com.risingwave.proto.ConnectorServiceProto; +import io.grpc.stub.StreamObserver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class JniSinkResponseObserver + implements StreamObserver { + private static final Logger LOG = LoggerFactory.getLogger(JniSinkResponseObserver.class); + private long responseTxPtr; + + private boolean success; + + public JniSinkResponseObserver(long responseTxPtr) { + this.responseTxPtr = responseTxPtr; + } + + @Override + public void onNext(ConnectorServiceProto.SinkWriterStreamResponse response) { + this.success = + Binding.sendSinkWriterResponseToChannel(this.responseTxPtr, response.toByteArray()); + } + + @Override + public void onError(Throwable throwable) { + LOG.error("JniSinkWriterHandler onError: ", throwable); + } + + @Override + public void onCompleted() { + LOG.info("JniSinkWriterHandler onCompleted"); + } + + public boolean isSuccess() { + return success; + } +} diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/JniSinkWriterHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/JniSinkWriterHandler.java new file mode 100644 index 000000000000..fbafb0161847 --- /dev/null +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/JniSinkWriterHandler.java @@ -0,0 +1,48 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.connector; + +import com.risingwave.java.binding.Binding; +import com.risingwave.proto.ConnectorServiceProto; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class JniSinkWriterHandler { + private static final Logger LOG = LoggerFactory.getLogger(JniSinkWriterHandler.class); + + public static void runJniSinkWriterThread(long requestRxPtr, long responseTxPtr) { + // For jni.rs + java.lang.Thread.currentThread() + .setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader()); + JniSinkResponseObserver responseObserver = new JniSinkResponseObserver(responseTxPtr); + SinkWriterStreamObserver sinkWriterStreamObserver = + new SinkWriterStreamObserver(responseObserver); + try { + byte[] requestBytes; + while ((requestBytes = Binding.recvSinkWriterRequestFromChannel(requestRxPtr)) + != null) { + var request = ConnectorServiceProto.SinkWriterStreamRequest.parseFrom(requestBytes); + sinkWriterStreamObserver.onNext(request); + if (!responseObserver.isSuccess()) { + throw new RuntimeException("fail to sendSinkWriterResponseToChannel"); + } + } + sinkWriterStreamObserver.onCompleted(); + } catch (Throwable t) { + sinkWriterStreamObserver.onError(t); + } + LOG.info("end of runJniSinkWriterThread"); + } +} diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java index 4a79033b147a..f72c63ae6d3e 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java @@ -91,4 +91,8 @@ public class Binding { static native long streamChunkIteratorFromPretty(String str); public static native boolean sendCdcSourceMsgToChannel(long channelPtr, byte[] msg); + + public static native byte[] recvSinkWriterRequestFromChannel(long channelPtr); + + public static native boolean sendSinkWriterResponseToChannel(long channelPtr, byte[] msg); } diff --git a/src/connector/src/sink/remote.rs b/src/connector/src/sink/remote.rs index e7d20d69fbbd..9c0fc6e84e97 100644 --- a/src/connector/src/sink/remote.rs +++ b/src/connector/src/sink/remote.rs @@ -29,18 +29,18 @@ use risingwave_pb::connector_service::sink_writer_stream_request::write_batch::j use risingwave_pb::connector_service::sink_writer_stream_request::write_batch::{ JsonPayload, Payload, StreamChunkPayload, }; +use risingwave_pb::connector_service::sink_writer_stream_request::{ + Barrier, BeginEpoch, Request as SinkRequest, StartSink, WriteBatch, +}; use risingwave_pb::connector_service::sink_writer_stream_response::CommitResponse; use risingwave_pb::connector_service::{ - SinkMetadata, SinkPayloadFormat, ValidateSinkRequest, ValidateSinkResponse, + sink_writer_stream_response, SinkMetadata, SinkPayloadFormat, SinkWriterStreamRequest, + SinkWriterStreamResponse, ValidateSinkRequest, ValidateSinkResponse, }; -#[cfg(test)] -use risingwave_pb::connector_service::{SinkWriterStreamRequest, SinkWriterStreamResponse}; -use risingwave_rpc_client::{ConnectorClient, SinkCoordinatorStreamHandle, SinkWriterStreamHandle}; -#[cfg(test)] -use tokio::sync::mpsc::{Sender, UnboundedReceiver}; -#[cfg(test)] -use tonic::Status; -use tracing::{error, warn}; +use risingwave_rpc_client::{ConnectorClient, SinkCoordinatorStreamHandle}; +use tokio::sync::mpsc; +use tokio::sync::mpsc::{Receiver, Sender}; +use tracing::warn; use super::encoder::{JsonEncoder, RowEncoder}; use crate::sink::coordinate::CoordinatedSinkWriter; @@ -242,6 +242,71 @@ impl Sink for CoordinatedRemoteSink { } } +const DEFAULT_CHANNEL_SIZE: usize = 16; +#[derive(Debug)] +pub struct SinkWriterStreamJniHandle { + request_tx: Sender, + response_rx: Receiver, +} + +impl SinkWriterStreamJniHandle { + pub async fn start_epoch(&mut self, epoch: u64) -> Result<()> { + self.request_tx + .send(SinkWriterStreamRequest { + request: Some(SinkRequest::BeginEpoch(BeginEpoch { epoch })), + }) + .await + .map_err(|err| SinkError::Internal(err.into())) + } + + pub async fn write_batch(&mut self, epoch: u64, batch_id: u64, payload: Payload) -> Result<()> { + self.request_tx + .send(SinkWriterStreamRequest { + request: Some(SinkRequest::WriteBatch(WriteBatch { + epoch, + batch_id, + payload: Some(payload), + })), + }) + .await + .map_err(|err| SinkError::Internal(err.into())) + } + + pub async fn barrier(&mut self, epoch: u64) -> Result<()> { + self.request_tx + .send(SinkWriterStreamRequest { + request: Some(SinkRequest::Barrier(Barrier { + epoch, + is_checkpoint: false, + })), + }) + .await + .map_err(|err| SinkError::Internal(err.into())) + } + + pub async fn commit(&mut self, epoch: u64) -> Result { + self.request_tx + .send(SinkWriterStreamRequest { + request: Some(SinkRequest::Barrier(Barrier { + epoch, + is_checkpoint: true, + })), + }) + .await + .map_err(|err| SinkError::Internal(err.into()))?; + + match self.response_rx.recv().await { + Some(SinkWriterStreamResponse { + response: Some(sink_writer_stream_response::Response::Commit(rsp)), + }) => Ok(rsp), + msg => Err(SinkError::Internal(anyhow!( + "should get Sync response but get {:?}", + msg + ))), + } + } +} + pub type RemoteSinkWriter = RemoteSinkWriterInner<(), R>; pub type CoordinatedRemoteSinkWriter = RemoteSinkWriterInner, R>; @@ -250,28 +315,78 @@ pub struct RemoteSinkWriterInner { epoch: Option, batch_id: u64, payload_format: SinkPayloadFormat, - stream_handle: SinkWriterStreamHandle, + stream_handle: SinkWriterStreamJniHandle, json_encoder: JsonEncoder, _phantom: PhantomData<(SM, R)>, } impl RemoteSinkWriterInner { pub async fn new(param: SinkParam, connector_params: ConnectorParams) -> Result { - let client = connector_params.connector_client.ok_or_else(|| { - SinkError::Remote(anyhow_error!( - "connector node endpoint not specified or unable to connect to connector node" - )) - })?; - let stream_handle = client - .start_sink_writer_stream(param.to_proto(), connector_params.sink_payload_format) + let (request_tx, request_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); + let (response_tx, response_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); + + let mut stream_handle = SinkWriterStreamJniHandle { + request_tx, + response_rx, + }; + + std::thread::spawn(move || { + let mut env = JVM.as_ref().unwrap().attach_current_thread().unwrap(); + + let result = env.call_static_method( + "com/risingwave/connector/JniSinkWriterHandler", + "runJniSinkWriterThread", + "(JJ)V", + &[ + JValue::from(&request_rx as *const Receiver as i64), + JValue::from(&response_tx as *const Sender as i64), + ], + ); + + match result { + Ok(_) => { + tracing::info!("end of jni call runJniSinkWriterThread"); + } + Err(e) => { + tracing::error!("jni call error: {:?}", e); + } + }; + }); + + let sink_writer_stream_request = SinkWriterStreamRequest { + request: Some(SinkRequest::Start(StartSink { + sink_param: Some(param.to_proto()), + format: connector_params.sink_payload_format as i32, + })), + }; + + // First request + stream_handle + .request_tx + .send(sink_writer_stream_request) .await - .inspect_err(|e| { - error!( - "failed to start sink stream for connector `{}`: {:?}", + .map_err(|err| { + SinkError::Internal(anyhow!( + "fail to send start request for connector `{}`: {:?}", R::SINK_NAME, - e - ) + err + )) })?; + + // First response + match stream_handle.response_rx.recv().await { + Some(SinkWriterStreamResponse { + response: Some(sink_writer_stream_response::Response::Start(_)), + }) => {} + msg => { + return Err(SinkError::Internal(anyhow!( + "should get start response for connector `{}` but get {:?}", + R::SINK_NAME, + msg + ))); + } + }; + tracing::trace!( "{:?} sink stream started with properties: {:?}", R::SINK_NAME, @@ -293,7 +408,7 @@ impl RemoteSinkWriterInner { #[cfg(test)] fn for_test( - response_receiver: UnboundedReceiver>, + response_receiver: Receiver, request_sender: Sender, ) -> RemoteSinkWriter { use risingwave_common::catalog::{Field, Schema}; @@ -314,13 +429,10 @@ impl RemoteSinkWriterInner { }, ]); - use futures::StreamExt; - use tokio_stream::wrappers::UnboundedReceiverStream; - - let stream_handle = SinkWriterStreamHandle::for_test( - request_sender, - UnboundedReceiverStream::new(response_receiver).boxed(), - ); + let stream_handle = SinkWriterStreamJniHandle { + request_tx: request_sender, + response_rx: response_receiver, + }; RemoteSinkWriter { properties, @@ -494,7 +606,7 @@ mod test { #[tokio::test] async fn test_epoch_check() { let (request_sender, mut request_recv) = mpsc::channel(16); - let (_, resp_recv) = mpsc::unbounded_channel(); + let (_, resp_recv) = mpsc::channel(16); let mut sink = >::for_test(resp_recv, request_sender); let chunk = StreamChunk::from_pretty( @@ -532,7 +644,7 @@ mod test { #[tokio::test] async fn test_remote_sink() { let (request_sender, mut request_receiver) = mpsc::channel(16); - let (response_sender, response_receiver) = mpsc::unbounded_channel(); + let (response_sender, response_receiver) = mpsc::channel(16); let mut sink = >::for_test(response_receiver, request_sender); let chunk_a = StreamChunk::from_pretty( @@ -588,12 +700,13 @@ mod test { // test commit response_sender - .send(Ok(SinkWriterStreamResponse { + .send(SinkWriterStreamResponse { response: Some(Response::Commit(CommitResponse { epoch: 2022, metadata: None, })), - })) + }) + .await .expect("test failed: failed to sync epoch"); sink.barrier(true).await.unwrap(); let commit_request = request_receiver.recv().await.unwrap(); diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index d4b20a86d7a2..d6210b4146af 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -135,11 +135,7 @@ impl CommonSplitReader for CdcSplitReader { let source_type = get_event_stream_request.source_type.to_string(); std::thread::spawn(move || { - let mut env = JVM - .as_ref() - .unwrap() - .attach_current_thread_as_daemon() - .unwrap(); + let mut env = JVM.as_ref().unwrap().attach_current_thread().unwrap(); let get_event_stream_request_bytes = env .byte_array_from_slice(&Message::encode_to_vec(&get_event_stream_request)) diff --git a/src/jni_core/src/lib.rs b/src/jni_core/src/lib.rs index 17e42d9fc8f5..3e6a12943b49 100644 --- a/src/jni_core/src/lib.rs +++ b/src/jni_core/src/lib.rs @@ -47,11 +47,13 @@ use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::test_prelude::StreamChunkTestExt; use risingwave_common::types::ScalarRefImpl; use risingwave_common::util::panic::rw_catch_unwind; -use risingwave_pb::connector_service::GetEventStreamResponse; +use risingwave_pb::connector_service::{ + GetEventStreamResponse, SinkWriterStreamRequest, SinkWriterStreamResponse, +}; use risingwave_storage::error::StorageError; use thiserror::Error; use tokio::runtime::Runtime; -use tokio::sync::mpsc::Sender; +use tokio::sync::mpsc::{Receiver, Sender}; pub use crate::jvm_runtime::register_native_method_for_jvm; use crate::stream_chunk_iterator::{StreamChunkIterator, StreamChunkRow}; @@ -853,14 +855,50 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceMsg let get_event_stream_response: GetEventStreamResponse = Message::decode(to_guarded_slice(&msg, env)?.deref())?; - tracing::debug!("before send"); match channel.as_ref().blocking_send(get_event_stream_response) { - Ok(_) => { - tracing::debug!("send successfully"); - Ok(JNI_TRUE) + Ok(_) => Ok(JNI_TRUE), + Err(e) => { + tracing::info!("send error. {:?}", e); + Ok(JNI_FALSE) } + } + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_recvSinkWriterRequestFromChannel< + 'a, +>( + env: EnvParam<'a>, + mut channel: Pointer<'a, Receiver>, +) -> JByteArray<'a> { + execute_and_catch(env, move |env| match channel.as_mut().blocking_recv() { + Some(msg) => { + let bytes = env + .byte_array_from_slice(&Message::encode_to_vec(&msg)) + .unwrap(); + Ok(bytes) + } + None => Ok(JObject::null().into()), + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_sendSinkWriterResponseToChannel< + 'a, +>( + env: EnvParam<'a>, + channel: Pointer<'a, Sender>, + msg: JByteArray<'a>, +) -> jboolean { + execute_and_catch(env, move |env| { + let sink_writer_stream_response: SinkWriterStreamResponse = + Message::decode(to_guarded_slice(&msg, env)?.deref())?; + + match channel.as_ref().blocking_send(sink_writer_stream_response) { + Ok(_) => Ok(JNI_TRUE), Err(e) => { - tracing::debug!("send error. {:?}", e); + tracing::info!("send error. {:?}", e); Ok(JNI_FALSE) } } diff --git a/src/jni_core/src/macros.rs b/src/jni_core/src/macros.rs index bdb5c60ec3f8..a4fe12f66919 100644 --- a/src/jni_core/src/macros.rs +++ b/src/jni_core/src/macros.rs @@ -167,6 +167,10 @@ macro_rules! for_all_plain_native_methods { static native long streamChunkIteratorFromPretty(String str); public static native boolean sendCdcSourceMsgToChannel(long channelPtr, byte[] msg); + + public static native byte[] recvSinkWriterRequestFromChannel(long channelPtr); + + public static native boolean sendSinkWriterResponseToChannel(long channelPtr, byte[] msg); } $(,$args)* }