Skip to content

Commit

Permalink
Use lz4 compression for shuffle files & flight stream, refactoring / …
Browse files Browse the repository at this point in the history
…improvements (#920)

* Use lz4 compression for shuffle files and streams

* Add feature

* Backport improvements

* More compression
  • Loading branch information
Dandandan authored Nov 30, 2023
1 parent e474e34 commit 1345646
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 76 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ members = [
resolver = "2"

[workspace.dependencies]
arrow = { version = "48.0.0" }
arrow = { version = "48.0.0", features=["ipc_compression"] }
arrow-flight = { version = "48.0.0", features = ["flight-sql-experimental"] }
arrow-schema = { version = "48.0.0", default-features = false }
configure_me = { version = "0.4.0" }
Expand Down
9 changes: 8 additions & 1 deletion ballista/core/src/execution_plans/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
//! partition is re-partitioned and streamed to disk in Arrow IPC format. Future stages of the query
//! will use the ShuffleReaderExec to read these results.
use datafusion::arrow::ipc::writer::IpcWriteOptions;
use datafusion::arrow::ipc::CompressionType;
use datafusion::physical_plan::expressions::PhysicalSortExpr;

use std::any::Any;
Expand Down Expand Up @@ -242,9 +244,14 @@ impl ShuffleWriterExec {
));
debug!("Writing results to {:?}", path);

let mut writer = IPCWriter::new(
let options = IpcWriteOptions::default()
.try_with_compression(Some(
CompressionType::LZ4_FRAME,
))?;
let mut writer = IPCWriter::new_with_options(
&path,
stream.schema().as_ref(),
options,
)?;

writer.write(&output_batch)?;
Expand Down
9 changes: 8 additions & 1 deletion ballista/core/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use crate::serde::scheduler::PartitionStats;

use async_trait::async_trait;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::ipc::writer::IpcWriteOptions;
use datafusion::arrow::ipc::CompressionType;
use datafusion::arrow::{ipc::writer::FileWriter, record_batch::RecordBatch};
use datafusion::datasource::physical_plan::{CsvExec, ParquetExec};
use datafusion::error::DataFusionError;
Expand Down Expand Up @@ -82,7 +84,12 @@ pub async fn write_stream_to_disk(
let mut num_rows = 0;
let mut num_batches = 0;
let mut num_bytes = 0;
let mut writer = FileWriter::try_new(file, stream.schema().as_ref())?;

let options = IpcWriteOptions::default()
.try_with_compression(Some(CompressionType::LZ4_FRAME))?;

let mut writer =
FileWriter::try_new_with_options(file, stream.schema().as_ref(), options)?;

while let Some(result) = stream.next().await {
let batch = result?;
Expand Down
117 changes: 44 additions & 73 deletions ballista/executor/src/flight_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ use std::convert::TryFrom;
use std::fs::File;
use std::pin::Pin;

use arrow_flight::SchemaAsIpc;
use arrow::ipc::CompressionType;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use ballista_core::error::BallistaError;
use ballista_core::serde::decode_protobuf;
use ballista_core::serde::scheduler::Action as BallistaAction;

use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
use arrow::ipc::writer::IpcWriteOptions;
use arrow_flight::{
flight_service_server::FlightService, Action, ActionType, Criteria, Empty,
FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
Expand All @@ -35,20 +37,16 @@ use arrow_flight::{
use datafusion::arrow::{
error::ArrowError, ipc::reader::FileReader, record_batch::RecordBatch,
};
use futures::{Stream, StreamExt};
use log::{debug, info, warn};
use futures::{Stream, StreamExt, TryStreamExt};
use log::{debug, info};
use std::io::{Read, Seek};
use tokio::sync::mpsc::channel;
use tokio::{
sync::mpsc::{Receiver, Sender},
task,
};
use tokio::sync::mpsc::error::SendError;
use tokio::{sync::mpsc::Sender, task};
use tokio_stream::wrappers::ReceiverStream;
use tonic::metadata::MetadataValue;
use tonic::{Request, Response, Status, Streaming};

type FlightDataSender = Sender<Result<FlightData, Status>>;
type FlightDataReceiver = Receiver<Result<FlightData, Status>>;
use tracing::warn;

/// Service implementing the Apache Arrow Flight Protocol
#[derive(Clone)]
Expand All @@ -67,7 +65,7 @@ impl Default for BallistaFlightService {
}

type BoxedFlightStream<T> =
Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + Sync + 'static>>;
Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;

#[tonic::async_trait]
impl FlightService for BallistaFlightService {
Expand Down Expand Up @@ -101,19 +99,25 @@ impl FlightService for BallistaFlightService {
let reader =
FileReader::try_new(file, None).map_err(|e| from_arrow_err(&e))?;

let (tx, rx): (FlightDataSender, FlightDataReceiver) = channel(2);

let file_path = path.to_owned();
// Arrow IPC reader does not implement Sync + Send so we need to use a channel
// to communicate
task::spawn(async move {
if let Err(e) = stream_flight_data(file_path, reader, tx).await {
warn!("Error streaming results: {:?}", e);
let (tx, rx) = channel(2);
let schema = reader.schema();
task::spawn_blocking(move || {
if let Err(e) = read_partition(reader, tx) {
warn!(error = %e, "error streaming shuffle partition");
}
});

let write_options: IpcWriteOptions = IpcWriteOptions::default()
.try_with_compression(Some(CompressionType::LZ4_FRAME))
.map_err(|e| from_arrow_err(&e))?;
let flight_data_stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.with_options(write_options)
.build(ReceiverStream::new(rx))
.map_err(|err| Status::from_error(Box::new(err)));

Ok(Response::new(
Box::pin(ReceiverStream::new(rx)) as Self::DoGetStream
Box::pin(flight_data_stream) as Self::DoGetStream
))
}
}
Expand Down Expand Up @@ -148,7 +152,7 @@ impl FlightService for BallistaFlightService {
let output = futures::stream::iter(vec![result]);
let str = format!("Bearer {token}");
let mut resp: Response<
Pin<Box<dyn Stream<Item = Result<_, Status>> + Sync + Send>>,
Pin<Box<dyn Stream<Item = Result<_, Status>> + Send + 'static>>,
> = Response::new(Box::pin(output));
let md = MetadataValue::try_from(str)
.map_err(|_| Status::invalid_argument("authorization not parsable"))?;
Expand Down Expand Up @@ -202,67 +206,34 @@ impl FlightService for BallistaFlightService {
}
}

/// Convert a single RecordBatch into an iterator of FlightData (containing
/// dictionaries and batches)
fn create_flight_iter(
batch: &RecordBatch,
options: &IpcWriteOptions,
) -> Box<dyn Iterator<Item = Result<FlightData, Status>>> {
let data_gen = IpcDataGenerator::default();
let mut dictionary_tracker = DictionaryTracker::new(false);
let res = data_gen.encoded_batch(batch, &mut dictionary_tracker, options);
match res {
Ok((dicts, batch)) => {
let flights = dicts
.into_iter()
.chain(std::iter::once(batch))
.map(|x| x.into());
Box::new(flights.map(Ok))
}
Err(e) => Box::new(std::iter::once(Err(from_arrow_err(&e)))),
}
}

async fn stream_flight_data<T>(
file_path: String,
fn read_partition<T>(
reader: FileReader<T>,
tx: FlightDataSender,
) -> Result<(), Status>
tx: Sender<Result<RecordBatch, FlightError>>,
) -> Result<(), FlightError>
where
T: Read + Seek,
{
let options = arrow::ipc::writer::IpcWriteOptions::default();
let schema_flight_data = SchemaAsIpc::new(reader.schema().as_ref(), &options).into();
send_response(&tx, Ok(schema_flight_data)).await?;
if tx.is_closed() {
return Err(FlightError::Tonic(Status::internal(
"Can't send a batch, channel is closed",
)));
}

let mut row_count = 0;
for batch in reader {
if let Ok(x) = &batch {
row_count += x.num_rows();
}
let batch_flight_data: Vec<_> = batch
.map(|b| create_flight_iter(&b, &options).collect())
.map_err(|e| from_arrow_err(&e))?;
for batch in batch_flight_data.into_iter() {
send_response(&tx, batch).await?;
}
tx.blocking_send(batch.map_err(|err| err.into()))
.map_err(|err| {
if let SendError(Err(err)) = err {
err
} else {
FlightError::Tonic(Status::internal(
"Can't send a batch, something went wrong",
))
}
})?
}
debug!(
"FetchPartition streamed {} rows for file {}",
row_count, file_path
);
Ok(())
}

async fn send_response(
tx: &FlightDataSender,
data: Result<FlightData, Status>,
) -> Result<(), Status> {
tx.send(data)
.await
.map_err(|e| Status::internal(format!("{e:?}")))
}

fn from_arrow_err(e: &ArrowError) -> Status {
Status::internal(format!("ArrowError: {e:?}"))
}
Expand Down

0 comments on commit 1345646

Please sign in to comment.