Skip to content

Commit

Permalink
refactor(connector): replace sink coordinator rpc with jni (#12724)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzl25 authored Oct 11, 2023
1 parent e0059df commit ac2a58c
Show file tree
Hide file tree
Showing 15 changed files with 333 additions and 77 deletions.
16 changes: 16 additions & 0 deletions java/com_risingwave_java_binding_Binding.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.risingwave.connector;

import com.risingwave.java.binding.Binding;
import com.risingwave.proto.ConnectorServiceProto;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JniSinkCoordinatorHandler {
private static final Logger LOG = LoggerFactory.getLogger(JniSinkCoordinatorHandler.class);

public static void runJniSinkCoordinatorThread(long requestRxPtr, long responseTxPtr) {
// For jni.rs
java.lang.Thread.currentThread()
.setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader());
JniSinkCoordinatorResponseObserver responseObserver =
new JniSinkCoordinatorResponseObserver(responseTxPtr);
SinkCoordinatorStreamObserver sinkCoordinatorStreamObserver =
new SinkCoordinatorStreamObserver(responseObserver);
try {
byte[] requestBytes;
while ((requestBytes = Binding.recvSinkCoordinatorRequestFromChannel(requestRxPtr))
!= null) {
var request =
ConnectorServiceProto.SinkCoordinatorStreamRequest.parseFrom(requestBytes);
sinkCoordinatorStreamObserver.onNext(request);
if (!responseObserver.isSuccess()) {
throw new RuntimeException("fail to sendSinkCoordinatorResponseToChannel");
}
}
sinkCoordinatorStreamObserver.onCompleted();
} catch (Throwable t) {
sinkCoordinatorStreamObserver.onError(t);
}
LOG.info("end of runJniSinkCoordinatorThread");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.risingwave.connector;

import com.risingwave.java.binding.Binding;
import com.risingwave.proto.ConnectorServiceProto;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JniSinkCoordinatorResponseObserver
implements StreamObserver<ConnectorServiceProto.SinkCoordinatorStreamResponse> {
private static final Logger LOG =
LoggerFactory.getLogger(JniSinkCoordinatorResponseObserver.class);
private long responseTxPtr;

private boolean success;

public JniSinkCoordinatorResponseObserver(long responseTxPtr) {
this.responseTxPtr = responseTxPtr;
}

@Override
public void onNext(ConnectorServiceProto.SinkCoordinatorStreamResponse response) {
this.success =
Binding.sendSinkCoordinatorResponseToChannel(
this.responseTxPtr, response.toByteArray());
}

@Override
public void onError(Throwable throwable) {
LOG.error("JniSinkCoordinatorHandler onError: ", throwable);
}

@Override
public void onCompleted() {
LOG.info("JniSinkCoordinatorHandler onCompleted");
}

public boolean isSuccess() {
return success;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ public static void runJniSinkWriterThread(long requestRxPtr, long responseTxPtr)
// For jni.rs
java.lang.Thread.currentThread()
.setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader());
JniSinkResponseObserver responseObserver = new JniSinkResponseObserver(responseTxPtr);
JniSinkWriterResponseObserver responseObserver =
new JniSinkWriterResponseObserver(responseTxPtr);
SinkWriterStreamObserver sinkWriterStreamObserver =
new SinkWriterStreamObserver(responseObserver);
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JniSinkResponseObserver
public class JniSinkWriterResponseObserver
implements StreamObserver<ConnectorServiceProto.SinkWriterStreamResponse> {
private static final Logger LOG = LoggerFactory.getLogger(JniSinkResponseObserver.class);
private static final Logger LOG = LoggerFactory.getLogger(JniSinkWriterResponseObserver.class);
private long responseTxPtr;

private boolean success;

public JniSinkResponseObserver(long responseTxPtr) {
public JniSinkWriterResponseObserver(long responseTxPtr) {
this.responseTxPtr = responseTxPtr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,8 @@ public class Binding {
public static native byte[] recvSinkWriterRequestFromChannel(long channelPtr);

public static native boolean sendSinkWriterResponseToChannel(long channelPtr, byte[] msg);

public static native byte[] recvSinkCoordinatorRequestFromChannel(long channelPtr);

public static native boolean sendSinkCoordinatorResponseToChannel(long channelPtr, byte[] msg);
}
6 changes: 1 addition & 5 deletions src/connector/src/sink/iceberg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ use risingwave_common::error::anyhow_error;
use risingwave_pb::connector_service::sink_metadata::Metadata::Serialized;
use risingwave_pb::connector_service::sink_metadata::SerializedMetadata;
use risingwave_pb::connector_service::SinkMetadata;
use risingwave_rpc_client::ConnectorClient;
use serde::de;
use serde_derive::Deserialize;
use url::Url;
Expand Down Expand Up @@ -363,10 +362,7 @@ impl Sink for IcebergSink {
.into_log_sinker(writer_param.sink_metrics))
}

async fn new_coordinator(
&self,
_connector_client: Option<ConnectorClient>,
) -> Result<Self::Coordinator> {
async fn new_coordinator(&self) -> Result<Self::Coordinator> {
let table = self.create_table().await?;
let partition_type = table.current_partition_type()?;

Expand Down
7 changes: 2 additions & 5 deletions src/connector/src/sink/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use risingwave_common::metrics::{
use risingwave_pb::catalog::PbSinkType;
use risingwave_pb::connector_service::{PbSinkParam, SinkMetadata, TableSchema};
use risingwave_rpc_client::error::RpcError;
use risingwave_rpc_client::{ConnectorClient, MetaClient};
use risingwave_rpc_client::MetaClient;
use thiserror::Error;
pub use tracing;

Expand Down Expand Up @@ -272,10 +272,7 @@ pub trait Sink: TryFrom<SinkParam, Error = SinkError> {
async fn validate(&self) -> Result<()>;
async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker>;
#[expect(clippy::unused_async)]
async fn new_coordinator(
&self,
_connector_client: Option<ConnectorClient>,
) -> Result<Self::Coordinator> {
async fn new_coordinator(&self) -> Result<Self::Coordinator> {
Err(SinkError::Coordinator(anyhow!("no coordinator")))
}
}
Expand Down
149 changes: 130 additions & 19 deletions src/connector/src/sink/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ use risingwave_common::array::StreamChunk;
use risingwave_common::error::anyhow_error;
use risingwave_common::types::DataType;
use risingwave_jni_core::jvm_runtime::JVM;
use risingwave_pb::connector_service::sink_coordinator_stream_request::{
CommitMetadata, StartCoordinator,
};
use risingwave_pb::connector_service::sink_writer_stream_request::write_batch::json_payload::RowOp;
use risingwave_pb::connector_service::sink_writer_stream_request::write_batch::{
JsonPayload, Payload, StreamChunkPayload,
Expand All @@ -34,10 +37,10 @@ use risingwave_pb::connector_service::sink_writer_stream_request::{
};
use risingwave_pb::connector_service::sink_writer_stream_response::CommitResponse;
use risingwave_pb::connector_service::{
sink_writer_stream_response, SinkMetadata, SinkPayloadFormat, SinkWriterStreamRequest,
SinkWriterStreamResponse, ValidateSinkRequest, ValidateSinkResponse,
sink_coordinator_stream_request, sink_coordinator_stream_response, sink_writer_stream_response,
SinkCoordinatorStreamRequest, SinkCoordinatorStreamResponse, SinkMetadata, SinkPayloadFormat,
SinkWriterStreamRequest, SinkWriterStreamResponse, ValidateSinkRequest, ValidateSinkResponse,
};
use risingwave_rpc_client::{ConnectorClient, SinkCoordinatorStreamHandle};
use tokio::sync::mpsc;
use tokio::sync::mpsc::{Receiver, Sender};
use tracing::warn;
Expand All @@ -46,7 +49,6 @@ use super::encoder::{JsonEncoder, RowEncoder};
use crate::sink::coordinate::CoordinatedSinkWriter;
use crate::sink::encoder::TimestampHandlingMode;
use crate::sink::writer::{LogSinkerOf, SinkWriter, SinkWriterExt};
use crate::sink::SinkError::Remote;
use crate::sink::{
DummySinkCommitCoordinator, Result, Sink, SinkCommitCoordinator, SinkError, SinkParam,
SinkWriterParam,
Expand Down Expand Up @@ -229,16 +231,52 @@ impl<R: RemoteSinkTrait> Sink for CoordinatedRemoteSink<R> {
.into_log_sinker(writer_param.sink_metrics))
}

async fn new_coordinator(
&self,
connector_client: Option<ConnectorClient>,
) -> Result<Self::Coordinator> {
RemoteCoordinator::new(
connector_client
.ok_or_else(|| Remote(anyhow_error!("no connector client specified")))?,
self.0.param.clone(),
)
.await
async fn new_coordinator(&self) -> Result<Self::Coordinator> {
RemoteCoordinator::new(self.0.param.clone()).await
}
}

#[derive(Debug)]
pub struct SinkCoordinatorStreamJniHandle {
request_tx: Sender<SinkCoordinatorStreamRequest>,
response_rx: Receiver<SinkCoordinatorStreamResponse>,
}

impl SinkCoordinatorStreamJniHandle {
pub async fn commit(&mut self, epoch: u64, metadata: Vec<SinkMetadata>) -> Result<()> {
self.request_tx
.send(SinkCoordinatorStreamRequest {
request: Some(sink_coordinator_stream_request::Request::Commit(
CommitMetadata { epoch, metadata },
)),
})
.await
.map_err(|err| SinkError::Internal(err.into()))?;

match self.response_rx.recv().await {
Some(SinkCoordinatorStreamResponse {
response:
Some(sink_coordinator_stream_response::Response::Commit(
sink_coordinator_stream_response::CommitResponse {
epoch: response_epoch,
},
)),
}) => {
if epoch == response_epoch {
Ok(())
} else {
Err(SinkError::Internal(anyhow!(
"get different response epoch to commit epoch: {} {}",
epoch,
response_epoch
)))
}
}
msg => Err(SinkError::Internal(anyhow!(
"should get Commit response but get {:?}",
msg
))),
}
}
}

Expand Down Expand Up @@ -555,15 +593,88 @@ where
}

pub struct RemoteCoordinator<R: RemoteSinkTrait> {
stream_handle: SinkCoordinatorStreamHandle,
stream_handle: SinkCoordinatorStreamJniHandle,
_phantom: PhantomData<R>,
}

impl<R: RemoteSinkTrait> RemoteCoordinator<R> {
pub async fn new(client: ConnectorClient, param: SinkParam) -> Result<Self> {
let stream_handle = client
.start_sink_coordinator_stream(param.to_proto())
.await?;
pub async fn new(param: SinkParam) -> Result<Self> {
let (request_tx, request_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
let (response_tx, response_rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);

let mut stream_handle = SinkCoordinatorStreamJniHandle {
request_tx,
response_rx,
};

std::thread::spawn(move || {
let mut env = JVM.as_ref().unwrap().attach_current_thread().unwrap();

let result = env.call_static_method(
"com/risingwave/connector/JniSinkCoordinatorHandler",
"runJniSinkCoordinatorThread",
"(JJ)V",
&[
JValue::from(
&request_rx as *const Receiver<SinkCoordinatorStreamRequest> as i64,
),
JValue::from(
&response_tx as *const Sender<SinkCoordinatorStreamResponse> as i64,
),
],
);

match result {
Ok(_) => {
tracing::info!("end of jni call runJniSinkCoordinatorThread");
}
Err(e) => {
tracing::error!("jni call error: {:?}", e);
}
};
});

let sink_coordinator_stream_request = SinkCoordinatorStreamRequest {
request: Some(sink_coordinator_stream_request::Request::Start(
StartCoordinator {
param: Some(param.to_proto()),
},
)),
};

// First request
stream_handle
.request_tx
.send(sink_coordinator_stream_request)
.await
.map_err(|err| {
SinkError::Internal(anyhow!(
"fail to send start request for connector `{}`: {:?}",
R::SINK_NAME,
err
))
})?;

// First response
match stream_handle.response_rx.recv().await {
Some(SinkCoordinatorStreamResponse {
response: Some(sink_coordinator_stream_response::Response::Start(_)),
}) => {}
msg => {
return Err(SinkError::Internal(anyhow!(
"should get start response for connector `{}` but get {:?}",
R::SINK_NAME,
msg
)));
}
};

tracing::trace!(
"{:?} RemoteCoordinator started with properties: {:?}",
R::SINK_NAME,
&param.properties
);

Ok(RemoteCoordinator {
stream_handle,
_phantom: PhantomData,
Expand Down
Loading

0 comments on commit ac2a58c

Please sign in to comment.