From cbe176533b227f31f467b9c9b512d65a8406ce10 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Wed, 18 Dec 2024 05:57:53 -0800 Subject: [PATCH] Enable string-based column projections from Parquet files (#6871) * add function to create ProjectionMask from column names * add some more tests --- parquet/src/arrow/arrow_reader/mod.rs | 68 ++++++++++ parquet/src/arrow/mod.rs | 178 +++++++++++++++++++++++++- parquet/src/arrow/schema/mod.rs | 11 ++ 3 files changed, 256 insertions(+), 1 deletion(-) diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index 378884a1c430..6eba04c86f91 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -989,6 +989,21 @@ mod tests { assert_eq!(original_schema.fields()[1], reader.schema().fields()[0]); } + #[test] + fn test_arrow_reader_single_column_by_name() { + let file = get_test_file("parquet/generated_simple_numerics/blogs.parquet"); + + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + let original_schema = Arc::clone(builder.schema()); + + let mask = ProjectionMask::columns(builder.parquet_schema(), ["blog_id"]); + let reader = builder.with_projection(mask).build().unwrap(); + + // Verify that the schema was correctly parsed + assert_eq!(1, reader.schema().fields().len()); + assert_eq!(original_schema.fields()[1], reader.schema().fields()[0]); + } + #[test] fn test_null_column_reader_test() { let mut file = tempfile::tempfile().unwrap(); @@ -2563,6 +2578,59 @@ mod tests { } } + #[test] + // same as test_read_structs but constructs projection mask via column names + fn test_read_structs_by_name() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/nested_structs.rust.parquet"); + let file = File::open(&path).unwrap(); + let record_batch_reader = ParquetRecordBatchReader::try_new(file, 60).unwrap(); + + for batch in record_batch_reader { + batch.unwrap(); + } + + let file = File::open(&path).unwrap(); + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + + let mask = ProjectionMask::columns( + builder.parquet_schema(), + ["roll_num.count", "PC_CUR.mean", "PC_CUR.sum"], + ); + let projected_reader = builder + .with_projection(mask) + .with_batch_size(60) + .build() + .unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new( + "roll_num", + ArrowDataType::Struct(Fields::from(vec![Field::new( + "count", + ArrowDataType::UInt64, + false, + )])), + false, + ), + Field::new( + "PC_CUR", + ArrowDataType::Struct(Fields::from(vec![ + Field::new("mean", ArrowDataType::Int64, false), + Field::new("sum", ArrowDataType::Int64, false), + ])), + false, + ), + ]); + + assert_eq!(&expected_schema, projected_reader.schema().as_ref()); + + for batch in projected_reader { + let batch = batch.unwrap(); + assert_eq!(batch.schema().as_ref(), &expected_schema); + } + } + #[test] fn test_read_maps() { let testdata = arrow::util::test_util::parquet_test_data(); diff --git a/parquet/src/arrow/mod.rs b/parquet/src/arrow/mod.rs index d77436bc1ff7..6777e00fb05c 100644 --- a/parquet/src/arrow/mod.rs +++ b/parquet/src/arrow/mod.rs @@ -108,12 +108,14 @@ pub mod async_writer; mod record_reader; experimental!(mod schema); +use std::sync::Arc; + pub use self::arrow_writer::ArrowWriter; #[cfg(feature = "async")] pub use self::async_reader::ParquetRecordBatchStreamBuilder; #[cfg(feature = "async")] pub use self::async_writer::AsyncArrowWriter; -use crate::schema::types::SchemaDescriptor; +use crate::schema::types::{SchemaDescriptor, Type}; use arrow_schema::{FieldRef, Schema}; // continue to export deprecated methods until they are removed @@ -210,6 +212,71 @@ impl ProjectionMask { Self { mask: Some(mask) } } + // Given a starting point in the schema, do a DFS for that node adding leaf paths to `paths`. + fn find_leaves(root: &Arc, parent: Option<&String>, paths: &mut Vec) { + let path = parent + .map(|p| [p, root.name()].join(".")) + .unwrap_or(root.name().to_string()); + if root.is_group() { + for child in root.get_fields() { + Self::find_leaves(child, Some(&path), paths); + } + } else { + // Reached a leaf, add to paths + paths.push(path); + } + } + + /// Create a [`ProjectionMask`] which selects only the named columns + /// + /// All leaf columns that fall below a given name will be selected. For example, given + /// the schema + /// ```ignore + /// message schema { + /// OPTIONAL group a (MAP) { + /// REPEATED group key_value { + /// REQUIRED BYTE_ARRAY key (UTF8); // leaf index 0 + /// OPTIONAL group value (MAP) { + /// REPEATED group key_value { + /// REQUIRED INT32 key; // leaf index 1 + /// REQUIRED BOOLEAN value; // leaf index 2 + /// } + /// } + /// } + /// } + /// REQUIRED INT32 b; // leaf index 3 + /// REQUIRED DOUBLE c; // leaf index 4 + /// } + /// ``` + /// `["a.key_value.value", "c"]` would return leaf columns 1, 2, and 4. `["a"]` would return + /// columns 0, 1, and 2. + /// + /// Note: repeated or out of order indices will not impact the final mask. + /// + /// i.e. `["b", "c"]` will construct the same mask as `["c", "b", "c"]`. + pub fn columns<'a>( + schema: &SchemaDescriptor, + names: impl IntoIterator, + ) -> Self { + // first make vector of paths for leaf columns + let mut paths: Vec = vec![]; + for root in schema.root_schema().get_fields() { + Self::find_leaves(root, None, &mut paths); + } + assert_eq!(paths.len(), schema.num_columns()); + + let mut mask = vec![false; schema.num_columns()]; + for name in names { + for idx in 0..schema.num_columns() { + if paths[idx].starts_with(name) { + mask[idx] = true; + } + } + } + + Self { mask: Some(mask) } + } + /// Returns true if the leaf column `leaf_idx` is included by the mask pub fn leaf_included(&self, leaf_idx: usize) -> bool { self.mask.as_ref().map(|m| m[leaf_idx]).unwrap_or(true) @@ -246,10 +313,14 @@ mod test { use crate::arrow::ArrowWriter; use crate::file::metadata::{ParquetMetaData, ParquetMetaDataReader, ParquetMetaDataWriter}; use crate::file::properties::{EnabledStatistics, WriterProperties}; + use crate::schema::parser::parse_message_type; + use crate::schema::types::SchemaDescriptor; use arrow_array::{ArrayRef, Int32Array, RecordBatch}; use bytes::Bytes; use std::sync::Arc; + use super::ProjectionMask; + #[test] // Reproducer for https://github.com/apache/arrow-rs/issues/6464 fn test_metadata_read_write_partial_offset() { @@ -375,4 +446,109 @@ mod test { .unwrap(); Bytes::from(buf) } + + #[test] + fn test_mask_from_column_names() { + let message_type = " + message test_schema { + OPTIONAL group a (MAP) { + REPEATED group key_value { + REQUIRED BYTE_ARRAY key (UTF8); + OPTIONAL group value (MAP) { + REPEATED group key_value { + REQUIRED INT32 key; + REQUIRED BOOLEAN value; + } + } + } + } + REQUIRED INT32 b; + REQUIRED DOUBLE c; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["foo", "bar"]); + assert_eq!(mask.mask.unwrap(), vec![false; 5]); + + let mask = ProjectionMask::columns(&schema, []); + assert_eq!(mask.mask.unwrap(), vec![false; 5]); + + let mask = ProjectionMask::columns(&schema, ["a", "c"]); + assert_eq!(mask.mask.unwrap(), [true, true, true, false, true]); + + let mask = ProjectionMask::columns(&schema, ["a.key_value.key", "c"]); + assert_eq!(mask.mask.unwrap(), [true, false, false, false, true]); + + let mask = ProjectionMask::columns(&schema, ["a.key_value.value", "b"]); + assert_eq!(mask.mask.unwrap(), [false, true, true, true, false]); + + let message_type = " + message test_schema { + OPTIONAL group a (LIST) { + REPEATED group list { + OPTIONAL group element (LIST) { + REPEATED group list { + OPTIONAL group element (LIST) { + REPEATED group list { + OPTIONAL BYTE_ARRAY element (UTF8); + } + } + } + } + } + } + REQUIRED INT32 b; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["a", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true]); + + let mask = ProjectionMask::columns(&schema, ["a.list.element", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true]); + + let mask = + ProjectionMask::columns(&schema, ["a.list.element.list.element.list.element", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true]); + + let mask = ProjectionMask::columns(&schema, ["b"]); + assert_eq!(mask.mask.unwrap(), [false, true]); + + let message_type = " + message test_schema { + OPTIONAL INT32 a; + OPTIONAL INT32 b; + OPTIONAL INT32 c; + OPTIONAL INT32 d; + OPTIONAL INT32 e; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["a", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true, false, false, false]); + + let mask = ProjectionMask::columns(&schema, ["d", "b", "d"]); + assert_eq!(mask.mask.unwrap(), [false, true, false, true, false]); + + let message_type = " + message test_schema { + OPTIONAL INT32 a; + OPTIONAL INT32 b; + OPTIONAL INT32 a; + OPTIONAL INT32 d; + OPTIONAL INT32 e; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["a", "e"]); + assert_eq!(mask.mask.unwrap(), [true, false, true, false, true]); + } } diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 5d3d7b2a6541..212dec525833 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -1399,6 +1399,17 @@ mod tests { for i in 0..arrow_fields.len() { assert_eq!(&arrow_fields[i], converted_fields[i].as_ref()); } + + let mask = + ProjectionMask::columns(&parquet_schema, ["group2.leaf4", "group1.leaf1", "leaf5"]); + let converted_arrow_schema = + parquet_to_arrow_schema_by_columns(&parquet_schema, mask, None).unwrap(); + let converted_fields = converted_arrow_schema.fields(); + + assert_eq!(arrow_fields.len(), converted_fields.len()); + for i in 0..arrow_fields.len() { + assert_eq!(&arrow_fields[i], converted_fields[i].as_ref()); + } } #[test]