diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index c408456df147..96715e1164b2 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -613,6 +613,9 @@ impl std::fmt::Debug for StreamState { /// An asynchronous [`Stream`](https://docs.rs/futures/latest/futures/stream/trait.Stream.html) of [`RecordBatch`] /// for a parquet file that can be constructed using [`ParquetRecordBatchStreamBuilder`]. +/// +/// `ParquetRecordBatchStream` also provides [`ParquetRecordBatchStream::next_row_group`] for fetching row groups, +/// allowing users to decode record batches separately from I/O. pub struct ParquetRecordBatchStream { metadata: Arc, @@ -654,6 +657,70 @@ impl ParquetRecordBatchStream { } } +impl ParquetRecordBatchStream +where + T: AsyncFileReader + Unpin + Send + 'static, +{ + /// Fetches the next row group from the stream. + /// + /// Users can continue to call this function to get row groups and decode them concurrently. + /// + /// ## Notes + /// + /// ParquetRecordBatchStream should be used either as a `Stream` or with `next_row_group`; they should not be used simultaneously. + /// + /// ## Returns + /// + /// - `Ok(None)` if the stream has ended. + /// - `Err(error)` if the stream has errored. All subsequent calls will return `Ok(None)`. + /// - `Ok(Some(reader))` which holds all the data for the row group. + pub async fn next_row_group(&mut self) -> Result> { + loop { + match &mut self.state { + StreamState::Decoding(_) | StreamState::Reading(_) => { + return Err(ParquetError::General( + "Cannot combine the use of next_row_group with the Stream API".to_string(), + )) + } + StreamState::Init => { + let row_group_idx = match self.row_groups.pop_front() { + Some(idx) => idx, + None => return Ok(None), + }; + + let row_count = self.metadata.row_group(row_group_idx).num_rows() as usize; + + let selection = self.selection.as_mut().map(|s| s.split_off(row_count)); + + let reader_factory = self.reader.take().expect("lost reader"); + + let (reader_factory, maybe_reader) = reader_factory + .read_row_group( + row_group_idx, + selection, + self.projection.clone(), + self.batch_size, + ) + .await + .map_err(|err| { + self.state = StreamState::Error; + err + })?; + self.reader = Some(reader_factory); + + if let Some(reader) = maybe_reader { + return Ok(Some(reader)); + } else { + // All rows skipped, read next row group + continue; + } + } + StreamState::Error => return Ok(None), // Ends the stream as error happens. + } + } + } +} + impl Stream for ParquetRecordBatchStream where T: AsyncFileReader + Unpin + Send + 'static, @@ -1020,6 +1087,71 @@ mod tests { ); } + #[tokio::test] + async fn test_async_reader_with_next_row_group() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/alltypes_plain.parquet"); + let data = Bytes::from(std::fs::read(path).unwrap()); + + let metadata = ParquetMetaDataReader::new() + .parse_and_finish(&data) + .unwrap(); + let metadata = Arc::new(metadata); + + assert_eq!(metadata.num_row_groups(), 1); + + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + + let requests = async_reader.requests.clone(); + let builder = ParquetRecordBatchStreamBuilder::new(async_reader) + .await + .unwrap(); + + let mask = ProjectionMask::leaves(builder.parquet_schema(), vec![1, 2]); + let mut stream = builder + .with_projection(mask.clone()) + .with_batch_size(1024) + .build() + .unwrap(); + + let mut readers = vec![]; + while let Some(reader) = stream.next_row_group().await.unwrap() { + readers.push(reader); + } + + let async_batches: Vec<_> = readers + .into_iter() + .flat_map(|r| r.map(|v| v.unwrap()).collect::>()) + .collect(); + + let sync_batches = ParquetRecordBatchReaderBuilder::try_new(data) + .unwrap() + .with_projection(mask) + .with_batch_size(104) + .build() + .unwrap() + .collect::>>() + .unwrap(); + + assert_eq!(async_batches, sync_batches); + + let requests = requests.lock().unwrap(); + let (offset_1, length_1) = metadata.row_group(0).column(1).byte_range(); + let (offset_2, length_2) = metadata.row_group(0).column(2).byte_range(); + + assert_eq!( + &requests[..], + &[ + offset_1 as usize..(offset_1 + length_1) as usize, + offset_2 as usize..(offset_2 + length_2) as usize + ] + ); + } + #[tokio::test] async fn test_async_reader_with_index() { let testdata = arrow::util::test_util::parquet_test_data();