diff --git a/parquet/src/arrow/mod.rs b/parquet/src/arrow/mod.rs index 2d09cd19203..a9e762dd3bf 100644 --- a/parquet/src/arrow/mod.rs +++ b/parquet/src/arrow/mod.rs @@ -108,12 +108,14 @@ pub mod async_writer; mod record_reader; experimental!(mod schema); +use std::sync::Arc; + pub use self::arrow_writer::ArrowWriter; #[cfg(feature = "async")] pub use self::async_reader::ParquetRecordBatchStreamBuilder; #[cfg(feature = "async")] pub use self::async_writer::AsyncArrowWriter; -use crate::schema::types::SchemaDescriptor; +use crate::schema::types::{SchemaDescriptor, Type}; use arrow_schema::{FieldRef, Schema}; pub use self::schema::{ @@ -206,6 +208,71 @@ impl ProjectionMask { Self { mask: Some(mask) } } + // Given a starting point in the schema, do a DFS for that node adding leaf paths to `paths`. + fn find_leaves(root: &Arc, parent: Option<&String>, paths: &mut Vec) { + let path = parent + .map(|p| [p, root.name()].join(".")) + .unwrap_or(root.name().to_string()); + if root.is_group() { + for child in root.get_fields() { + Self::find_leaves(child, Some(&path), paths); + } + } else { + // Reached a leaf, add to paths + paths.push(path); + } + } + + /// Create a [`ProjectionMask`] which selects only the named columns + /// + /// All leaf columns that fall below a given name will be selected. For example, given + /// the schema + /// ```ignore + /// message schema { + /// OPTIONAL group a (MAP) { + /// REPEATED group key_value { + /// REQUIRED BYTE_ARRAY key (UTF8); // leaf index 0 + /// OPTIONAL group value (MAP) { + /// REPEATED group key_value { + /// REQUIRED INT32 key; // leaf index 1 + /// REQUIRED BOOLEAN value; // leaf index 2 + /// } + /// } + /// } + /// } + /// REQUIRED INT32 b; // leaf index 3 + /// REQUIRED DOUBLE c; // leaf index 4 + /// } + /// ``` + /// `["a.key_value.value", "c"]` would return leaf columns 1, 2, and 4. `["a"]` would return + /// columns 0, 1, and 2. + /// + /// Note: repeated or out of order indices will not impact the final mask. + /// + /// i.e. `["b", "c"]` will construct the same mask as `["c", "b", "c"]`. + pub fn columns<'a>( + schema: &SchemaDescriptor, + names: impl IntoIterator, + ) -> Self { + // first make vector of paths for leaf columns + let mut paths: Vec = vec![]; + for root in schema.root_schema().get_fields() { + Self::find_leaves(root, None, &mut paths); + } + assert_eq!(paths.len(), schema.num_columns()); + + let mut mask = vec![false; schema.num_columns()]; + for name in names { + for idx in 0..schema.num_columns() { + if paths[idx].starts_with(name) { + mask[idx] = true; + } + } + } + + Self { mask: Some(mask) } + } + /// Returns true if the leaf column `leaf_idx` is included by the mask pub fn leaf_included(&self, leaf_idx: usize) -> bool { self.mask.as_ref().map(|m| m[leaf_idx]).unwrap_or(true) @@ -242,10 +309,14 @@ mod test { use crate::arrow::ArrowWriter; use crate::file::metadata::{ParquetMetaData, ParquetMetaDataReader, ParquetMetaDataWriter}; use crate::file::properties::{EnabledStatistics, WriterProperties}; + use crate::schema::parser::parse_message_type; + use crate::schema::types::SchemaDescriptor; use arrow_array::{ArrayRef, Int32Array, RecordBatch}; use bytes::Bytes; use std::sync::Arc; + use super::ProjectionMask; + #[test] // Reproducer for https://github.com/apache/arrow-rs/issues/6464 fn test_metadata_read_write_partial_offset() { @@ -371,4 +442,109 @@ mod test { .unwrap(); Bytes::from(buf) } + + #[test] + fn test_mask_from_column_names() { + let message_type = " + message test_schema { + OPTIONAL group a (MAP) { + REPEATED group key_value { + REQUIRED BYTE_ARRAY key (UTF8); + OPTIONAL group value (MAP) { + REPEATED group key_value { + REQUIRED INT32 key; + REQUIRED BOOLEAN value; + } + } + } + } + REQUIRED INT32 b; + REQUIRED DOUBLE c; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["foo", "bar"]); + assert_eq!(mask.mask.unwrap(), vec![false; 5]); + + let mask = ProjectionMask::columns(&schema, []); + assert_eq!(mask.mask.unwrap(), vec![false; 5]); + + let mask = ProjectionMask::columns(&schema, ["a", "c"]); + assert_eq!(mask.mask.unwrap(), [true, true, true, false, true]); + + let mask = ProjectionMask::columns(&schema, ["a.key_value.key", "c"]); + assert_eq!(mask.mask.unwrap(), [true, false, false, false, true]); + + let mask = ProjectionMask::columns(&schema, ["a.key_value.value", "b"]); + assert_eq!(mask.mask.unwrap(), [false, true, true, true, false]); + + let message_type = " + message test_schema { + OPTIONAL group a (LIST) { + REPEATED group list { + OPTIONAL group element (LIST) { + REPEATED group list { + OPTIONAL group element (LIST) { + REPEATED group list { + OPTIONAL BYTE_ARRAY element (UTF8); + } + } + } + } + } + } + REQUIRED INT32 b; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["a", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true]); + + let mask = ProjectionMask::columns(&schema, ["a.list.element", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true]); + + let mask = + ProjectionMask::columns(&schema, ["a.list.element.list.element.list.element", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true]); + + let mask = ProjectionMask::columns(&schema, ["b"]); + assert_eq!(mask.mask.unwrap(), [false, true]); + + let message_type = " + message test_schema { + OPTIONAL INT32 a; + OPTIONAL INT32 b; + OPTIONAL INT32 c; + OPTIONAL INT32 d; + OPTIONAL INT32 e; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["a", "b"]); + assert_eq!(mask.mask.unwrap(), [true, true, false, false, false]); + + let mask = ProjectionMask::columns(&schema, ["d", "b", "d"]); + assert_eq!(mask.mask.unwrap(), [false, true, false, true, false]); + + let message_type = " + message test_schema { + OPTIONAL INT32 a; + OPTIONAL INT32 b; + OPTIONAL INT32 a; + OPTIONAL INT32 d; + OPTIONAL INT32 e; + } + "; + let parquet_group_type = parse_message_type(message_type).unwrap(); + let schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + + let mask = ProjectionMask::columns(&schema, ["a", "e"]); + assert_eq!(mask.mask.unwrap(), [true, false, true, false, true]); + } }