Skip to content

Commit

Permalink
refactor(connector): replace sink writer rpc with jni (#12480)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzl25 authored Sep 27, 2023
1 parent 0726b59 commit 85248b7
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 46 deletions.
16 changes: 16 additions & 0 deletions java/com_risingwave_java_binding_Binding.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.risingwave.connector;

import com.risingwave.java.binding.Binding;
import com.risingwave.proto.ConnectorServiceProto;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JniSinkResponseObserver
implements StreamObserver<ConnectorServiceProto.SinkWriterStreamResponse> {
private static final Logger LOG = LoggerFactory.getLogger(JniSinkResponseObserver.class);
private long responseTxPtr;

private boolean success;

public JniSinkResponseObserver(long responseTxPtr) {
this.responseTxPtr = responseTxPtr;
}

@Override
public void onNext(ConnectorServiceProto.SinkWriterStreamResponse response) {
this.success =
Binding.sendSinkWriterResponseToChannel(this.responseTxPtr, response.toByteArray());
}

@Override
public void onError(Throwable throwable) {
LOG.error("JniSinkWriterHandler onError: ", throwable);
}

@Override
public void onCompleted() {
LOG.info("JniSinkWriterHandler onCompleted");
}

public boolean isSuccess() {
return success;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.risingwave.connector;

import com.risingwave.java.binding.Binding;
import com.risingwave.proto.ConnectorServiceProto;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JniSinkWriterHandler {
private static final Logger LOG = LoggerFactory.getLogger(JniSinkWriterHandler.class);

public static void runJniSinkWriterThread(long requestRxPtr, long responseTxPtr) {
// For jni.rs
java.lang.Thread.currentThread()
.setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader());
JniSinkResponseObserver responseObserver = new JniSinkResponseObserver(responseTxPtr);
SinkWriterStreamObserver sinkWriterStreamObserver =
new SinkWriterStreamObserver(responseObserver);
try {
byte[] requestBytes;
while ((requestBytes = Binding.recvSinkWriterRequestFromChannel(requestRxPtr))
!= null) {
var request = ConnectorServiceProto.SinkWriterStreamRequest.parseFrom(requestBytes);
sinkWriterStreamObserver.onNext(request);
if (!responseObserver.isSuccess()) {
throw new RuntimeException("fail to sendSinkWriterResponseToChannel");
}
}
sinkWriterStreamObserver.onCompleted();
} catch (Throwable t) {
sinkWriterStreamObserver.onError(t);
}
LOG.info("end of runJniSinkWriterThread");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,8 @@ public class Binding {
static native long streamChunkIteratorFromPretty(String str);

public static native boolean sendCdcSourceMsgToChannel(long channelPtr, byte[] msg);

public static native byte[] recvSinkWriterRequestFromChannel(long channelPtr);

public static native boolean sendSinkWriterResponseToChannel(long channelPtr, byte[] msg);
}
181 changes: 147 additions & 34 deletions src/connector/src/sink/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@ use risingwave_pb::connector_service::sink_writer_stream_request::write_batch::j
use risingwave_pb::connector_service::sink_writer_stream_request::write_batch::{
JsonPayload, Payload, StreamChunkPayload,
};
use risingwave_pb::connector_service::sink_writer_stream_request::{
Barrier, BeginEpoch, Request as SinkRequest, StartSink, WriteBatch,
};
use risingwave_pb::connector_service::sink_writer_stream_response::CommitResponse;
use risingwave_pb::connector_service::{
SinkMetadata, SinkPayloadFormat, ValidateSinkRequest, ValidateSinkResponse,
sink_writer_stream_response, SinkMetadata, SinkPayloadFormat, SinkWriterStreamRequest,
SinkWriterStreamResponse, ValidateSinkRequest, ValidateSinkResponse,
};
#[cfg(test)]
use risingwave_pb::connector_service::{SinkWriterStreamRequest, SinkWriterStreamResponse};
use risingwave_rpc_client::{ConnectorClient, SinkCoordinatorStreamHandle, SinkWriterStreamHandle};
#[cfg(test)]
use tokio::sync::mpsc::{Sender, UnboundedReceiver};
#[cfg(test)]
use tonic::Status;
use tracing::{error, warn};
use risingwave_rpc_client::{ConnectorClient, SinkCoordinatorStreamHandle};
use tokio::sync::mpsc;
use tokio::sync::mpsc::{Receiver, Sender};
use tracing::warn;

use super::encoder::{JsonEncoder, RowEncoder};
use crate::sink::coordinate::CoordinatedSinkWriter;
Expand Down Expand Up @@ -242,6 +242,71 @@ impl<R: RemoteSinkTrait> Sink for CoordinatedRemoteSink<R> {
}
}

const DEFAULT_CHANNEL_SIZE: usize = 16;
#[derive(Debug)]
pub struct SinkWriterStreamJniHandle {
request_tx: Sender<SinkWriterStreamRequest>,
response_rx: Receiver<SinkWriterStreamResponse>,
}

impl SinkWriterStreamJniHandle {
pub async fn start_epoch(&mut self, epoch: u64) -> Result<()> {
self.request_tx
.send(SinkWriterStreamRequest {
request: Some(SinkRequest::BeginEpoch(BeginEpoch { epoch })),
})
.await
.map_err(|err| SinkError::Internal(err.into()))
}

pub async fn write_batch(&mut self, epoch: u64, batch_id: u64, payload: Payload) -> Result<()> {
self.request_tx
.send(SinkWriterStreamRequest {
request: Some(SinkRequest::WriteBatch(WriteBatch {
epoch,
batch_id,
payload: Some(payload),
})),
})
.await
.map_err(|err| SinkError::Internal(err.into()))
}

pub async fn barrier(&mut self, epoch: u64) -> Result<()> {
self.request_tx
.send(SinkWriterStreamRequest {
request: Some(SinkRequest::Barrier(Barrier {
epoch,
is_checkpoint: false,
})),
})
.await
.map_err(|err| SinkError::Internal(err.into()))
}

pub async fn commit(&mut self, epoch: u64) -> Result<CommitResponse> {
self.request_tx
.send(SinkWriterStreamRequest {
request: Some(SinkRequest::Barrier(Barrier {
epoch,
is_checkpoint: true,
})),
})
.await
.map_err(|err| SinkError::Internal(err.into()))?;

match self.response_rx.recv().await {
Some(SinkWriterStreamResponse {
response: Some(sink_writer_stream_response::Response::Commit(rsp)),
}) => Ok(rsp),
msg => Err(SinkError::Internal(anyhow!(
"should get Sync response but get {:?}",
msg
))),
}
}
}

pub type RemoteSinkWriter<R> = RemoteSinkWriterInner<(), R>;
pub type CoordinatedRemoteSinkWriter<R> = RemoteSinkWriterInner<Option<SinkMetadata>, R>;

Expand All @@ -250,28 +315,78 @@ pub struct RemoteSinkWriterInner<SM, R: RemoteSinkTrait> {
epoch: Option<u64>,
batch_id: u64,
payload_format: SinkPayloadFormat,
stream_handle: SinkWriterStreamHandle,
stream_handle: SinkWriterStreamJniHandle,
json_encoder: JsonEncoder,
_phantom: PhantomData<(SM, R)>,
}

impl<SM, R: RemoteSinkTrait> RemoteSinkWriterInner<SM, R> {
pub async fn new(param: SinkParam, connector_params: ConnectorParams) -> Result<Self> {
let client = connector_params.connector_client.ok_or_else(|| {
SinkError::Remote(anyhow_error!(
"connector node endpoint not specified or unable to connect to connector node"
))
})?;
let stream_handle = client
.start_sink_writer_stream(param.to_proto(), connector_params.sink_payload_format)
let (request_tx, request_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
let (response_tx, response_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);

let mut stream_handle = SinkWriterStreamJniHandle {
request_tx,
response_rx,
};

std::thread::spawn(move || {
let mut env = JVM.as_ref().unwrap().attach_current_thread().unwrap();

let result = env.call_static_method(
"com/risingwave/connector/JniSinkWriterHandler",
"runJniSinkWriterThread",
"(JJ)V",
&[
JValue::from(&request_rx as *const Receiver<SinkWriterStreamRequest> as i64),
JValue::from(&response_tx as *const Sender<SinkWriterStreamResponse> as i64),
],
);

match result {
Ok(_) => {
tracing::info!("end of jni call runJniSinkWriterThread");
}
Err(e) => {
tracing::error!("jni call error: {:?}", e);
}
};
});

let sink_writer_stream_request = SinkWriterStreamRequest {
request: Some(SinkRequest::Start(StartSink {
sink_param: Some(param.to_proto()),
format: connector_params.sink_payload_format as i32,
})),
};

// First request
stream_handle
.request_tx
.send(sink_writer_stream_request)
.await
.inspect_err(|e| {
error!(
"failed to start sink stream for connector `{}`: {:?}",
.map_err(|err| {
SinkError::Internal(anyhow!(
"fail to send start request for connector `{}`: {:?}",
R::SINK_NAME,
e
)
err
))
})?;

// First response
match stream_handle.response_rx.recv().await {
Some(SinkWriterStreamResponse {
response: Some(sink_writer_stream_response::Response::Start(_)),
}) => {}
msg => {
return Err(SinkError::Internal(anyhow!(
"should get start response for connector `{}` but get {:?}",
R::SINK_NAME,
msg
)));
}
};

tracing::trace!(
"{:?} sink stream started with properties: {:?}",
R::SINK_NAME,
Expand All @@ -293,7 +408,7 @@ impl<SM, R: RemoteSinkTrait> RemoteSinkWriterInner<SM, R> {

#[cfg(test)]
fn for_test(
response_receiver: UnboundedReceiver<std::result::Result<SinkWriterStreamResponse, Status>>,
response_receiver: Receiver<SinkWriterStreamResponse>,
request_sender: Sender<SinkWriterStreamRequest>,
) -> RemoteSinkWriter<R> {
use risingwave_common::catalog::{Field, Schema};
Expand All @@ -314,13 +429,10 @@ impl<SM, R: RemoteSinkTrait> RemoteSinkWriterInner<SM, R> {
},
]);

use futures::StreamExt;
use tokio_stream::wrappers::UnboundedReceiverStream;

let stream_handle = SinkWriterStreamHandle::for_test(
request_sender,
UnboundedReceiverStream::new(response_receiver).boxed(),
);
let stream_handle = SinkWriterStreamJniHandle {
request_tx: request_sender,
response_rx: response_receiver,
};

RemoteSinkWriter {
properties,
Expand Down Expand Up @@ -494,7 +606,7 @@ mod test {
#[tokio::test]
async fn test_epoch_check() {
let (request_sender, mut request_recv) = mpsc::channel(16);
let (_, resp_recv) = mpsc::unbounded_channel();
let (_, resp_recv) = mpsc::channel(16);

let mut sink = <RemoteSinkWriter<TestRemote>>::for_test(resp_recv, request_sender);
let chunk = StreamChunk::from_pretty(
Expand Down Expand Up @@ -532,7 +644,7 @@ mod test {
#[tokio::test]
async fn test_remote_sink() {
let (request_sender, mut request_receiver) = mpsc::channel(16);
let (response_sender, response_receiver) = mpsc::unbounded_channel();
let (response_sender, response_receiver) = mpsc::channel(16);
let mut sink = <RemoteSinkWriter<TestRemote>>::for_test(response_receiver, request_sender);

let chunk_a = StreamChunk::from_pretty(
Expand Down Expand Up @@ -588,12 +700,13 @@ mod test {

// test commit
response_sender
.send(Ok(SinkWriterStreamResponse {
.send(SinkWriterStreamResponse {
response: Some(Response::Commit(CommitResponse {
epoch: 2022,
metadata: None,
})),
}))
})
.await
.expect("test failed: failed to sync epoch");
sink.barrier(true).await.unwrap();
let commit_request = request_receiver.recv().await.unwrap();
Expand Down
Loading

0 comments on commit 85248b7

Please sign in to comment.