Skip to content

Commit

Permalink
GH-44363: [C#] Handle Flight data with zero batches (#45315)
Browse files Browse the repository at this point in the history
### Rationale for this change

See #44363. This improves compatibility with other Flight implementations and means user code works with empty data without needing to treat it as a special case to work around this limitation.

### What changes are included in this PR?

* Adds new async overloads of `FlightClient.StartPut` that immediately send the schema, before any data batches are sent.
* Updates the test server to send the schema on `DoGet` even when there are no data batches.
* Enables the `primitive_no_batches` test case for C# Flight.

### Are these changes tested?

Yes, using a new unit test and with the integration tests.

### Are there any user-facing changes?

Yes. New overloads of the `FlightClient.StartPut` method have been added that are async and accept a `Schema` parameter, and ensure the schema is sent when no data batches are sent.

* GitHub Issue: #44363

Authored-by: Adam Reeve <[email protected]>
Signed-off-by: Curt Hagenlocher <[email protected]>
  • Loading branch information
adamreeve authored Jan 21, 2025
1 parent ba79a48 commit ead8d6f
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 20 deletions.
55 changes: 55 additions & 0 deletions csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,39 @@ public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Met
flightInfoResult.Dispose);
}

/// <summary>
/// Start a Flight Put request.
/// </summary>
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
/// <param name="headers">gRPC headers to send with the request</param>
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers = null)
{
return StartPut(flightDescriptor, headers, null, CancellationToken.None);
}

/// <summary>
/// Start a Flight Put request.
/// </summary>
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
/// <param name="schema">The schema of the data</param>
/// <param name="headers">gRPC headers to send with the request</param>
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
/// <remarks>Using this method rather than a StartPut overload that doesn't accept a schema
/// means that the schema is sent even if no data batches are sent</remarks>
public Task<FlightRecordBatchDuplexStreamingCall> StartPut(FlightDescriptor flightDescriptor, Schema schema, Metadata headers = null)
{
return StartPut(flightDescriptor, schema, headers, null, CancellationToken.None);
}

/// <summary>
/// Start a Flight Put request.
/// </summary>
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
/// <param name="headers">gRPC headers to send with the request</param>
/// <param name="deadline">Optional deadline. The request will be cancelled if this deadline is reached.</param>
/// <param name="cancellationToken">Optional token for cancelling the request</param>
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channels = _client.DoPut(headers, deadline, cancellationToken);
Expand All @@ -117,6 +145,33 @@ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDesc
channels.Dispose);
}

/// <summary>
/// Start a Flight Put request.
/// </summary>
/// <param name="flightDescriptor">Descriptor for the data to be put</param>
/// <param name="schema">The schema of the data</param>
/// <param name="headers">gRPC headers to send with the request</param>
/// <param name="deadline">Optional deadline. The request will be cancelled if this deadline is reached.</param>
/// <param name="cancellationToken">Optional token for cancelling the request</param>
/// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> object used to write data batches and receive responses</returns>
/// <remarks>Using this method rather than a StartPut overload that doesn't accept a schema
/// means that the schema is sent even if no data batches are sent</remarks>
public async Task<FlightRecordBatchDuplexStreamingCall> StartPut(FlightDescriptor flightDescriptor, Schema schema, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channels = _client.DoPut(headers, deadline, cancellationToken);
var requestStream = new FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor);
var readStream = new StreamReader<Protocol.PutResult, FlightPutResult>(channels.ResponseStream, putResult => new FlightPutResult(putResult));
var streamingCall = new FlightRecordBatchDuplexStreamingCall(
requestStream,
readStream,
channels.ResponseHeadersAsync,
channels.GetStatus,
channels.GetTrailers,
channels.Dispose);
await streamingCall.RequestStream.SetupStream(schema).ConfigureAwait(false);
return streamingCall;
}

public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse> Handshake(Metadata headers = null)
{
return Handshake(headers, null, CancellationToken.None);
Expand Down
21 changes: 17 additions & 4 deletions csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,22 @@ private protected FlightRecordBatchStreamWriter(IAsyncStreamWriter<Protocol.Flig
_flightDescriptor = flightDescriptor;
}

private void SetupStream(Schema schema)
/// <summary>
/// Configure the data stream to write to.
/// </summary>
/// <remarks>
/// The stream will be set up automatically when writing a RecordBatch if required,
/// but calling this method before writing any data allows handling empty streams.
/// </remarks>
/// <param name="schema">The schema of data to be written to this stream</param>
public async Task SetupStream(Schema schema)
{
if (_flightDataStream != null)
{
throw new InvalidOperationException("Flight data stream is already set");
}
_flightDataStream = new FlightDataStream(_clientStreamWriter, _flightDescriptor, schema);
await _flightDataStream.SendSchema().ConfigureAwait(false);
}

public WriteOptions WriteOptions { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
Expand All @@ -50,14 +63,14 @@ public Task WriteAsync(RecordBatch message)
return WriteAsync(message, default);
}

public Task WriteAsync(RecordBatch message, ByteString applicationMetadata)
public async Task WriteAsync(RecordBatch message, ByteString applicationMetadata)
{
if (_flightDataStream == null)
{
SetupStream(message.Schema);
await SetupStream(message.Schema).ConfigureAwait(false);
}

return _flightDataStream.Write(message, applicationMetadata);
await _flightDataStream.Write(message, applicationMetadata);
}

protected virtual void Dispose(bool disposing)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public FlightDataStream(IAsyncStreamWriter<Protocol.FlightData> clientStreamWrit
_flightDescriptor = flightDescriptor;
}

private async Task SendSchema()
public async Task SendSchema()
{
_currentFlightData = new Protocol.FlightData();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public async Task RunClient(int serverPort)
var batches = jsonFile.Batches.Select(batch => batch.ToArrow(schema, dictionaries)).ToArray();

// 1. Put the data to the server.
await UploadBatches(client, descriptor, batches).ConfigureAwait(false);
await UploadBatches(client, descriptor, schema, batches).ConfigureAwait(false);

// 2. Get the ticket for the data.
var info = await client.GetInfo(descriptor).ConfigureAwait(false);
Expand Down Expand Up @@ -112,9 +112,10 @@ public async Task RunClient(int serverPort)
}
}

private static async Task UploadBatches(FlightClient client, FlightDescriptor descriptor, RecordBatch[] batches)
private static async Task UploadBatches(
FlightClient client, FlightDescriptor descriptor, Schema schema, RecordBatch[] batches)
{
using var putCall = client.StartPut(descriptor);
using var putCall = await client.StartPut(descriptor, schema);
using var writer = putCall.RequestStream;

try
Expand Down
3 changes: 2 additions & 1 deletion csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStr

if(_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder))
{
await responseStream.SetupStream(flightHolder.GetFlightInfo().Schema);

var batches = flightHolder.GetRecordBatches();


foreach(var batch in batches)
{
await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata);
Expand Down
36 changes: 29 additions & 7 deletions csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ private RecordBatch CreateTestBatch(int startValue, int length)
return batchBuilder.Build();
}

private Schema GetStoreSchema(FlightDescriptor flightDescriptor)
{
Assert.Contains(flightDescriptor, (IReadOnlyDictionary<FlightDescriptor, FlightHolder>)_flightStore.Flights);

var flightHolder = _flightStore.Flights[flightDescriptor];
return flightHolder.GetFlightInfo().Schema;
}

private IEnumerable<RecordBatchWithMetadata> GetStoreBatch(FlightDescriptor flightDescriptor)
{
Expand Down Expand Up @@ -88,7 +95,7 @@ public async Task TestPutSingleRecordBatch()
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
var expectedBatch = CreateTestBatch(0, 100);

var putStream = _flightClient.StartPut(flightDescriptor);
var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch.Schema);
await putStream.RequestStream.WriteAsync(expectedBatch);
await putStream.RequestStream.CompleteAsync();
var putResults = await putStream.ResponseStream.ToListAsync();
Expand All @@ -108,7 +115,7 @@ public async Task TestPutTwoRecordBatches()
var expectedBatch1 = CreateTestBatch(0, 100);
var expectedBatch2 = CreateTestBatch(0, 100);

var putStream = _flightClient.StartPut(flightDescriptor);
var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch1.Schema);
await putStream.RequestStream.WriteAsync(expectedBatch1);
await putStream.RequestStream.WriteAsync(expectedBatch2);
await putStream.RequestStream.CompleteAsync();
Expand All @@ -123,6 +130,23 @@ public async Task TestPutTwoRecordBatches()
ArrowReaderVerifier.CompareBatches(expectedBatch2, actualBatches[1].RecordBatch);
}

[Fact]
public async Task TestPutZeroRecordBatches()
{
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
var schema = CreateTestBatch(0, 1).Schema;

var putStream = await _flightClient.StartPut(flightDescriptor, schema);
await putStream.RequestStream.CompleteAsync();
var putResults = await putStream.ResponseStream.ToListAsync();

Assert.Empty(putResults);

var actualSchema = GetStoreSchema(flightDescriptor);

SchemaComparer.Compare(schema, actualSchema);
}

[Fact]
public async Task TestGetRecordBatchWithDelayedSchema()
{
Expand Down Expand Up @@ -230,7 +254,7 @@ public async Task TestPutWithMetadata()
var expectedBatch = CreateTestBatch(0, 100);
var expectedMetadata = ByteString.CopyFromUtf8("test metadata");

var putStream = _flightClient.StartPut(flightDescriptor);
var putStream = await _flightClient.StartPut(flightDescriptor, expectedBatch.Schema);
await putStream.RequestStream.WriteAsync(expectedBatch, expectedMetadata);
await putStream.RequestStream.CompleteAsync();
var putResults = await putStream.ResponseStream.ToListAsync();
Expand Down Expand Up @@ -471,8 +495,7 @@ public async Task EnsureCallRaisesDeadlineExceeded()
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

var putStream = _flightClient.StartPut(flightDescriptor, null, deadline);
exception = await Assert.ThrowsAsync<RpcException>(async () => await putStream.RequestStream.WriteAsync(batch));
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.StartPut(flightDescriptor, batch.Schema, null, deadline));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetSchema(flightDescriptor, null, deadline));
Expand Down Expand Up @@ -514,8 +537,7 @@ public async Task EnsureCallRaisesRequestCancelled()
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

var putStream = _flightClient.StartPut(flightDescriptor, null, null, cts.Token);
exception = await Assert.ThrowsAsync<RpcException>(async () => await putStream.RequestStream.WriteAsync(batch));
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.StartPut(flightDescriptor, batch.Schema, null, null, cts.Token));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetSchema(flightDescriptor, null, null, cts.Token));
Expand Down
5 changes: 1 addition & 4 deletions dev/archery/archery/integration/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,10 +1890,7 @@ def _temp_path():
return

file_objs = [
generate_primitive_case([], name='primitive_no_batches')
# TODO(https://github.com/apache/arrow/issues/44363)
.skip_format(SKIP_FLIGHT, 'C#'),

generate_primitive_case([], name='primitive_no_batches'),
generate_primitive_case([17, 20], name='primitive'),
generate_primitive_case([0, 0, 0], name='primitive_zerolength'),

Expand Down

0 comments on commit ead8d6f

Please sign in to comment.