Skip to content

Commit

Permalink
feat(jni): pass stream chunk directly without serde (#13430)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenym1 authored Dec 2, 2023
1 parent b149c67 commit ab011eb
Show file tree
Hide file tree
Showing 12 changed files with 230 additions and 77 deletions.
6 changes: 3 additions & 3 deletions e2e_test/iceberg/start_spark_connect_server.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
set -ex

ICEBERG_VERSION=1.3.1
SPARK_VERSION=3.4.1
SPARK_VERSION=3.4.2

PACKAGES="org.apache.iceberg:iceberg-spark-runtime-3.4_2.12:$ICEBERG_VERSION,org.apache.hadoop:hadoop-aws:3.3.2"
PACKAGES="$PACKAGES,org.apache.spark:spark-connect_2.12:$SPARK_VERSION"

SPARK_FILE="spark-${SPARK_VERSION}-bin-hadoop3.tgz"


wget https://dlcdn.apache.org/spark/spark-3.4.1/$SPARK_FILE
wget https://dlcdn.apache.org/spark/spark-${SPARK_VERSION}/$SPARK_FILE
tar -xzf $SPARK_FILE --no-same-owner

./spark-3.4.1-bin-hadoop3/sbin/start-connect-server.sh --packages $PACKAGES \
./spark-${SPARK_VERSION}-bin-hadoop3/sbin/start-connect-server.sh --packages $PACKAGES \
--master local[3] \
--conf spark.driver.bindAddress=0.0.0.0 \
--conf spark.sql.catalog.demo=org.apache.iceberg.spark.SparkCatalog \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
package com.risingwave.connector;

import com.risingwave.java.binding.Binding;
import com.risingwave.proto.ConnectorServiceProto;
import com.risingwave.java.binding.JniSinkWriterStreamRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -31,11 +31,14 @@ public static void runJniSinkWriterThread(long requestRxPtr, long responseTxPtr)
SinkWriterStreamObserver sinkWriterStreamObserver =
new SinkWriterStreamObserver(responseObserver);
try {
byte[] requestBytes;
while ((requestBytes = Binding.recvSinkWriterRequestFromChannel(requestRxPtr))
!= null) {
var request = ConnectorServiceProto.SinkWriterStreamRequest.parseFrom(requestBytes);
sinkWriterStreamObserver.onNext(request);
while (true) {
try (JniSinkWriterStreamRequest request =
Binding.recvSinkWriterRequestFromChannel(requestRxPtr)) {
if (request == null) {
break;
}
sinkWriterStreamObserver.onNext(request.asPbRequest());
}
if (!responseObserver.isSuccess()) {
throw new RuntimeException("fail to sendSinkWriterResponseToChannel");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,21 @@ static ValueGetter[] buildValueGetter(TableSchema tableSchema) {
@Override
public CloseableIterable<SinkRow> deserialize(
ConnectorServiceProto.SinkWriterStreamRequest.WriteBatch writeBatch) {
if (!writeBatch.hasStreamChunkPayload()) {
if (writeBatch.hasStreamChunkPayload()) {
StreamChunkPayload streamChunkPayload = writeBatch.getStreamChunkPayload();
return new StreamChunkIterable(
StreamChunk.fromPayload(streamChunkPayload.getBinaryData().toByteArray()),
valueGetters);
} else if (writeBatch.hasStreamChunkRefPointer()) {
return new StreamChunkIterable(
StreamChunk.fromRefPointer(writeBatch.getStreamChunkRefPointer()),
valueGetters);
} else {
throw INVALID_ARGUMENT
.withDescription(
"expected StreamChunkPayload, got " + writeBatch.getPayloadCase())
.asRuntimeException();
}
StreamChunkPayload streamChunkPayload = writeBatch.getStreamChunkPayload();
return new StreamChunkIterable(
StreamChunk.fromPayload(streamChunkPayload.getBinaryData().toByteArray()),
valueGetters);
}

static class StreamChunkRowWrapper implements SinkRow {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ public class Binding {

public static native boolean sendCdcSourceMsgToChannel(long channelPtr, byte[] msg);

public static native byte[] recvSinkWriterRequestFromChannel(long channelPtr);
public static native com.risingwave.java.binding.JniSinkWriterStreamRequest
recvSinkWriterRequestFromChannel(long channelPtr);

public static native boolean sendSinkWriterResponseToChannel(long channelPtr, byte[] msg);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.risingwave.java.binding;

import com.google.protobuf.InvalidProtocolBufferException;
import com.risingwave.proto.ConnectorServiceProto;

public class JniSinkWriterStreamRequest implements AutoCloseable {
private final ConnectorServiceProto.SinkWriterStreamRequest pbRequest;
private final StreamChunk chunk;
private final long epoch;
private final long batchId;
private final boolean isPb;

JniSinkWriterStreamRequest(ConnectorServiceProto.SinkWriterStreamRequest pbRequest) {
this.pbRequest = pbRequest;
this.chunk = null;
this.epoch = 0;
this.batchId = 0;
this.isPb = true;
}

JniSinkWriterStreamRequest(StreamChunk chunk, long epoch, long batchId) {
this.pbRequest = null;
this.chunk = chunk;
this.epoch = epoch;
this.batchId = batchId;
this.isPb = false;
}

public static JniSinkWriterStreamRequest fromSerializedPayload(byte[] payload) {
try {
return new JniSinkWriterStreamRequest(
ConnectorServiceProto.SinkWriterStreamRequest.parseFrom(payload));
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}

public static JniSinkWriterStreamRequest fromStreamChunkOwnedPointer(
long pointer, long epoch, long batchId) {
return new JniSinkWriterStreamRequest(
StreamChunk.fromOwnedPointer(pointer), epoch, batchId);
}

public ConnectorServiceProto.SinkWriterStreamRequest asPbRequest() {
if (isPb) {
return pbRequest;
} else {
return ConnectorServiceProto.SinkWriterStreamRequest.newBuilder()
.setWriteBatch(
ConnectorServiceProto.SinkWriterStreamRequest.WriteBatch.newBuilder()
.setEpoch(epoch)
.setBatchId(batchId)
.setStreamChunkRefPointer(chunk.getPointer())
.build())
.build();
}
}

@Override
public void close() throws Exception {
if (!isPb && chunk != null) {
this.chunk.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,25 @@

public class StreamChunk implements AutoCloseable {
private final long pointer;
private final boolean isOwnedChunk;
private boolean isClosed;

StreamChunk(long pointer) {
StreamChunk(long pointer, boolean isOwnedChunk) {
this.pointer = pointer;
this.isOwnedChunk = isOwnedChunk;
this.isClosed = false;
}

public static StreamChunk fromPayload(byte[] streamChunkPayload) {
return new StreamChunk(Binding.newStreamChunkFromPayload(streamChunkPayload));
return new StreamChunk(Binding.newStreamChunkFromPayload(streamChunkPayload), true);
}

public static StreamChunk fromRefPointer(long pointer) {
return new StreamChunk(pointer, false);
}

public static StreamChunk fromOwnedPointer(long pointer) {
return new StreamChunk(pointer, true);
}

/**
Expand All @@ -34,13 +44,15 @@ public static StreamChunk fromPayload(byte[] streamChunkPayload) {
* 40"
*/
public static StreamChunk fromPretty(String str) {
return new StreamChunk(Binding.newStreamChunkFromPretty(str));
return new StreamChunk(Binding.newStreamChunkFromPretty(str), true);
}

@Override
public void close() {
if (!isClosed) {
Binding.streamChunkClose(pointer);
if (this.isOwnedChunk) {
Binding.streamChunkClose(pointer);
}
this.isClosed = true;
}
}
Expand Down
4 changes: 4 additions & 0 deletions proto/connector_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ message SinkWriterStreamRequest {
oneof payload {
JsonPayload json_payload = 1;
StreamChunkPayload stream_chunk_payload = 2;
// This is a reference pointer to a StreamChunk. The StreamChunk is owned
// by the JniSinkWriterStreamRequest, which should handle the release of StreamChunk.
// Index set to 5 because 3 and 4 have been occupied by `batch_id` and `epoch`
int64 stream_chunk_ref_pointer = 5;
}

uint64 batch_id = 3;
Expand Down
Loading

0 comments on commit ab011eb

Please sign in to comment.