Skip to content

Commit

Permalink
refactor JniSinkWriterHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzl25 committed Sep 25, 2023
1 parent 67ad5a4 commit c616498
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 199 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,224 +14,67 @@

package com.risingwave.connector;

import static com.risingwave.connector.SinkUtils.getConnectorName;
import static io.grpc.Status.*;
import static io.grpc.Status.INVALID_ARGUMENT;

import com.risingwave.connector.api.TableSchema;
import com.risingwave.connector.api.sink.*;
import com.risingwave.connector.deserializer.StreamChunkDeserializer;
import com.risingwave.java.binding.Binding;
import com.risingwave.metrics.ConnectorNodeMetrics;
import com.risingwave.metrics.MonitoredRowIterator;
import com.risingwave.proto.ConnectorServiceProto;
import java.util.Optional;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JniSinkWriterHandler {
public class JniSinkWriterHandler
implements StreamObserver<ConnectorServiceProto.SinkWriterStreamResponse> {
private static final Logger LOG = LoggerFactory.getLogger(JniSinkWriterHandler.class);
private SinkWriter sink;

private String connectorName;

private long sinkId;

private TableSchema tableSchema;

private boolean epochStarted;
private long currentEpoch;
private Long currentBatchId;

private Deserializer deserializer;

private long requestRxPtr;

private long responseTxPtr;

public boolean isInitialized() {
return sink != null;
}
private boolean success;

public JniSinkWriterHandler(long requestRxPtr, long responseTxPtr) {
this.requestRxPtr = requestRxPtr;
this.responseTxPtr = responseTxPtr;
}

public static void runJniSinkWriterThread(long requestRxPtr, long responseTxPtr)
throws com.google.protobuf.InvalidProtocolBufferException {
public static void runJniSinkWriterThread(long requestRxPtr, long responseTxPtr) {
// For jni.rs
java.lang.Thread.currentThread()
.setContextClassLoader(java.lang.ClassLoader.getSystemClassLoader());

JniSinkWriterHandler handler = new JniSinkWriterHandler(requestRxPtr, responseTxPtr);

byte[] requestBytes;
while ((requestBytes = Binding.recvSinkWriterRequestFromChannel(handler.requestRxPtr))
!= null) {
var request = ConnectorServiceProto.SinkWriterStreamRequest.parseFrom(requestBytes);
if (!handler.onNext(request)) {
handler.cleanup();
return;
}
}
LOG.info("end of runJniSinkWriterThread");
}
SinkWriterStreamObserver observer = new SinkWriterStreamObserver(handler);

public boolean onNext(ConnectorServiceProto.SinkWriterStreamRequest sinkTask) {
try {
if (sinkTask.hasStart()) {
if (isInitialized()) {
throw ALREADY_EXISTS
.withDescription("Sink is already initialized")
.asRuntimeException();
}
LOG.debug("sinkTask received");
sinkId = sinkTask.getStart().getSinkParam().getSinkId();
bindSink(sinkTask.getStart().getSinkParam(), sinkTask.getStart().getFormat());
return Binding.sendSinkWriterResponseToChannel(
this.responseTxPtr,
ConnectorServiceProto.SinkWriterStreamResponse.newBuilder()
.setStart(
ConnectorServiceProto.SinkWriterStreamResponse.StartResponse
.newBuilder())
.build()
.toByteArray());
} else if (sinkTask.hasBeginEpoch()) {
if (!isInitialized()) {
throw FAILED_PRECONDITION
.withDescription("sink is not initialized, please call start first")
.asRuntimeException();
}
if (epochStarted && sinkTask.getBeginEpoch().getEpoch() <= currentEpoch) {
throw INVALID_ARGUMENT
.withDescription(
"invalid epoch: new epoch ID should be larger than current epoch")
.asRuntimeException();
}
epochStarted = true;
currentEpoch = sinkTask.getBeginEpoch().getEpoch();
LOG.debug("Epoch {} started", currentEpoch);
return Binding.sendSinkWriterResponseToChannel(this.responseTxPtr, null);
} else if (sinkTask.hasWriteBatch()) {
if (!isInitialized()) {
throw FAILED_PRECONDITION
.withDescription("Sink is not initialized. Invoke `CreateSink` first.")
.asRuntimeException();
}
if (!epochStarted) {
throw FAILED_PRECONDITION
.withDescription("Epoch is not started. Invoke `StartEpoch` first.")
.asRuntimeException();
byte[] requestBytes;
while ((requestBytes = Binding.recvSinkWriterRequestFromChannel(handler.requestRxPtr))
!= null) {
var request = ConnectorServiceProto.SinkWriterStreamRequest.parseFrom(requestBytes);

observer.onNext(request);
if (!handler.success) {
throw new RuntimeException("fail to sendSinkWriterResponseToChannel");
}
ConnectorServiceProto.SinkWriterStreamRequest.WriteBatch batch =
sinkTask.getWriteBatch();
if (batch.getEpoch() != currentEpoch) {
throw INVALID_ARGUMENT
.withDescription(
"invalid epoch: expected write to epoch "
+ currentEpoch
+ ", got "
+ sinkTask.getWriteBatch().getEpoch())
.asRuntimeException();
}
if (currentBatchId != null && batch.getBatchId() <= currentBatchId) {
throw INVALID_ARGUMENT
.withDescription(
"invalid batch ID: expected batch ID to be larger than "
+ currentBatchId
+ ", got "
+ batch.getBatchId())
.asRuntimeException();
}

try (CloseableIterator<SinkRow> rowIter = deserializer.deserialize(batch)) {
sink.write(
new MonitoredRowIterator(
rowIter, connectorName, String.valueOf(sinkId)));
}

currentBatchId = batch.getBatchId();
LOG.debug("Batch {} written to epoch {}", currentBatchId, batch.getEpoch());
return Binding.sendSinkWriterResponseToChannel(this.responseTxPtr, null);
} else if (sinkTask.hasBarrier()) {
if (!isInitialized()) {
throw FAILED_PRECONDITION
.withDescription("Sink is not initialized. Invoke `Start` first.")
.asRuntimeException();
}
if (!epochStarted) {
throw FAILED_PRECONDITION
.withDescription("Epoch is not started. Invoke `StartEpoch` first.")
.asRuntimeException();
}
if (sinkTask.getBarrier().getEpoch() != currentEpoch) {
throw INVALID_ARGUMENT
.withDescription(
"invalid epoch: expected sync to epoch "
+ currentEpoch
+ ", got "
+ sinkTask.getBarrier().getEpoch())
.asRuntimeException();
}
boolean isCheckpoint = sinkTask.getBarrier().getIsCheckpoint();
Optional<ConnectorServiceProto.SinkMetadata> metadata = sink.barrier(isCheckpoint);
currentEpoch = sinkTask.getBarrier().getEpoch();
LOG.debug("Epoch {} barrier {}", currentEpoch, isCheckpoint);
if (isCheckpoint) {
ConnectorServiceProto.SinkWriterStreamResponse.CommitResponse.Builder builder =
ConnectorServiceProto.SinkWriterStreamResponse.CommitResponse
.newBuilder()
.setEpoch(currentEpoch);
if (metadata.isPresent()) {
builder.setMetadata(metadata.get());
}
return Binding.sendSinkWriterResponseToChannel(
this.responseTxPtr,
ConnectorServiceProto.SinkWriterStreamResponse.newBuilder()
.setCommit(builder)
.build()
.toByteArray());
} else {
return Binding.sendSinkWriterResponseToChannel(this.responseTxPtr, null);
}
} else {
throw INVALID_ARGUMENT.withDescription("invalid sink task").asRuntimeException();
}
} catch (Exception e) {
LOG.error("sink task error: ", e);
return false;
observer.onCompleted();
} catch (Throwable t) {
observer.onError(t);
}
LOG.info("end of runJniSinkWriterThread");
}

private void cleanup() {
if (sink != null) {
sink.drop();
ConnectorNodeMetrics.decActiveSinkConnections(connectorName, "node1");
}
@Override
public void onNext(ConnectorServiceProto.SinkWriterStreamResponse response) {
this.success =
Binding.sendSinkWriterResponseToChannel(this.responseTxPtr, response.toByteArray());
}

private void bindSink(
ConnectorServiceProto.SinkParam sinkParam,
ConnectorServiceProto.SinkPayloadFormat format) {
tableSchema = TableSchema.fromProto(sinkParam.getTableSchema());
String connectorName = getConnectorName(sinkParam);
SinkFactory sinkFactory = SinkUtils.getSinkFactory(connectorName);
sink = sinkFactory.createWriter(tableSchema, sinkParam.getPropertiesMap());
switch (format) {
case FORMAT_UNSPECIFIED:
case UNRECOGNIZED:
throw INVALID_ARGUMENT
.withDescription("should specify payload format in request")
.asRuntimeException();
case JSON:
deserializer = new JsonDeserializer(tableSchema);
break;
case STREAM_CHUNK:
deserializer = new StreamChunkDeserializer(tableSchema);
break;
}
this.connectorName = connectorName.toUpperCase();
ConnectorNodeMetrics.incActiveSinkConnections(connectorName, "node1");
@Override
public void onError(Throwable throwable) {
LOG.error("JniSinkWriterHandler onError: ", throwable);
}

@Override
public void onCompleted() {
LOG.info("JniSinkWriterHandler onCompleted");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public void onError(Throwable throwable) {

@Override
public void onCompleted() {
LOG.debug("sink task completed");
LOG.info("sink task completed");
cleanup();
responseObserver.onCompleted();
}
Expand Down
8 changes: 2 additions & 6 deletions src/connector/src/sink/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,7 @@ impl<SM> RemoteSinkWriterInner<SM> {
};

std::thread::spawn(move || {
let mut env = JVM
.as_ref()
.unwrap()
.attach_current_thread_as_daemon()
.unwrap();
let mut env = JVM.as_ref().unwrap().attach_current_thread().unwrap();

let result = env.call_static_method(
"com/risingwave/connector/JniSinkWriterHandler",
Expand All @@ -360,7 +356,7 @@ impl<SM> RemoteSinkWriterInner<SM> {
Err(e) => {
tracing::error!("jni call error: {:?}", e);
}
}
};
});

let sink_writer_stream_request = SinkWriterStreamRequest {
Expand Down
6 changes: 1 addition & 5 deletions src/connector/src/source/cdc/source/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,7 @@ impl<T: CdcSourceTypeTrait> CommonSplitReader for CdcSplitReader<T> {
let source_type = get_event_stream_request.source_type.to_string();

std::thread::spawn(move || {
let mut env = JVM
.as_ref()
.unwrap()
.attach_current_thread_as_daemon()
.unwrap();
let mut env = JVM.as_ref().unwrap().attach_current_thread().unwrap();

let get_event_stream_request_bytes = env
.byte_array_from_slice(&Message::encode_to_vec(&get_event_stream_request))
Expand Down

0 comments on commit c616498

Please sign in to comment.