Skip to content

Commit

Permalink
Backport improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Dandandan committed Nov 28, 2023
1 parent c388634 commit cc2cb39
Showing 1 changed file with 43 additions and 75 deletions.
118 changes: 43 additions & 75 deletions ballista/executor/src/flight_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ use std::fs::File;
use std::pin::Pin;

use arrow::ipc::CompressionType;
use arrow_flight::SchemaAsIpc;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use ballista_core::error::BallistaError;
use ballista_core::serde::decode_protobuf;
use ballista_core::serde::scheduler::Action as BallistaAction;

use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
use arrow::ipc::writer::IpcWriteOptions;
use arrow_flight::{
flight_service_server::FlightService, Action, ActionType, Criteria, Empty,
FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
Expand All @@ -36,20 +37,16 @@ use arrow_flight::{
use datafusion::arrow::{
error::ArrowError, ipc::reader::FileReader, record_batch::RecordBatch,
};
use futures::{Stream, StreamExt};
use log::{debug, info, warn};
use futures::{Stream, StreamExt, TryStreamExt};
use log::{debug, info};
use std::io::{Read, Seek};
use tokio::sync::mpsc::channel;
use tokio::{
sync::mpsc::{Receiver, Sender},
task,
};
use tokio::sync::mpsc::error::SendError;
use tokio::{sync::mpsc::Sender, task};
use tokio_stream::wrappers::ReceiverStream;
use tonic::metadata::MetadataValue;
use tonic::{Request, Response, Status, Streaming};

type FlightDataSender = Sender<Result<FlightData, Status>>;
type FlightDataReceiver = Receiver<Result<FlightData, Status>>;
use tracing::warn;

/// Service implementing the Apache Arrow Flight Protocol
#[derive(Clone)]
Expand All @@ -68,7 +65,7 @@ impl Default for BallistaFlightService {
}

type BoxedFlightStream<T> =
Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + Sync + 'static>>;
Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;

#[tonic::async_trait]
impl FlightService for BallistaFlightService {
Expand Down Expand Up @@ -102,19 +99,25 @@ impl FlightService for BallistaFlightService {
let reader =
FileReader::try_new(file, None).map_err(|e| from_arrow_err(&e))?;

let (tx, rx): (FlightDataSender, FlightDataReceiver) = channel(2);

let file_path = path.to_owned();
// Arrow IPC reader does not implement Sync + Send so we need to use a channel
// to communicate
task::spawn(async move {
if let Err(e) = stream_flight_data(file_path, reader, tx).await {
warn!("Error streaming results: {:?}", e);
let (tx, rx) = channel(2);
let schema = reader.schema();
task::spawn_blocking(move || {
if let Err(e) = read_partition(reader, tx) {
warn!(error = %e, "error streaming shuffle partition");
}
});

let write_options: IpcWriteOptions = IpcWriteOptions::default()
.try_with_compression(Some(CompressionType::LZ4_FRAME))
.map_err(|e| from_arrow_err(&e))?;
let flight_data_stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.with_options(write_options)
.build(ReceiverStream::new(rx))
.map_err(|err| Status::from_error(Box::new(err)));

Ok(Response::new(
Box::pin(ReceiverStream::new(rx)) as Self::DoGetStream
Box::pin(flight_data_stream) as Self::DoGetStream
))
}
}
Expand Down Expand Up @@ -149,7 +152,7 @@ impl FlightService for BallistaFlightService {
let output = futures::stream::iter(vec![result]);
let str = format!("Bearer {token}");
let mut resp: Response<
Pin<Box<dyn Stream<Item = Result<_, Status>> + Sync + Send>>,
Pin<Box<dyn Stream<Item = Result<_, Status>> + Send + 'static>>,
> = Response::new(Box::pin(output));
let md = MetadataValue::try_from(str)
.map_err(|_| Status::invalid_argument("authorization not parsable"))?;
Expand Down Expand Up @@ -203,69 +206,34 @@ impl FlightService for BallistaFlightService {
}
}

/// Convert a single RecordBatch into an iterator of FlightData (containing
/// dictionaries and batches)
fn create_flight_iter(
batch: &RecordBatch,
options: &IpcWriteOptions,
) -> Box<dyn Iterator<Item = Result<FlightData, Status>>> {
let data_gen = IpcDataGenerator::default();
let mut dictionary_tracker = DictionaryTracker::new(false);
let res = data_gen.encoded_batch(batch, &mut dictionary_tracker, options);
match res {
Ok((dicts, batch)) => {
let flights = dicts
.into_iter()
.chain(std::iter::once(batch))
.map(|x| x.into());
Box::new(flights.map(Ok))
}
Err(e) => Box::new(std::iter::once(Err(from_arrow_err(&e)))),
}
}

async fn stream_flight_data<T>(
file_path: String,
fn read_partition<T>(
reader: FileReader<T>,
tx: FlightDataSender,
) -> Result<(), Status>
tx: Sender<Result<RecordBatch, FlightError>>,
) -> Result<(), FlightError>
where
T: Read + Seek,
{
let options = arrow::ipc::writer::IpcWriteOptions::default()
.try_with_compression(Some(CompressionType::LZ4_FRAME))
.map_err(|x| from_arrow_err(&x))?;
let schema_flight_data = SchemaAsIpc::new(reader.schema().as_ref(), &options).into();
send_response(&tx, Ok(schema_flight_data)).await?;
if tx.is_closed() {
return Err(FlightError::Tonic(Status::internal(
"Can't send a batch, channel is closed",
)));
}

let mut row_count = 0;
for batch in reader {
if let Ok(x) = &batch {
row_count += x.num_rows();
}
let batch_flight_data: Vec<_> = batch
.map(|b| create_flight_iter(&b, &options).collect())
.map_err(|e| from_arrow_err(&e))?;
for batch in batch_flight_data.into_iter() {
send_response(&tx, batch).await?;
}
tx.blocking_send(batch.map_err(|err| err.into()))
.map_err(|err| {
if let SendError(Err(err)) = err {
err
} else {
FlightError::Tonic(Status::internal(
"Can't send a batch, something went wrong",
))
}
})?
}
debug!(
"FetchPartition streamed {} rows for file {}",
row_count, file_path
);
Ok(())
}

async fn send_response(
tx: &FlightDataSender,
data: Result<FlightData, Status>,
) -> Result<(), Status> {
tx.send(data)
.await
.map_err(|e| Status::internal(format!("{e:?}")))
}

fn from_arrow_err(e: &ArrowError) -> Status {
Status::internal(format!("ArrowError: {e:?}"))
}
Expand Down

0 comments on commit cc2cb39

Please sign in to comment.