Skip to content

Commit

Permalink
add function to create ProjectionMask from column names
Browse files Browse the repository at this point in the history
  • Loading branch information
etseidl committed Dec 12, 2024
1 parent 4acf4d3 commit 1abe85f
Showing 1 changed file with 177 additions and 1 deletion.
178 changes: 177 additions & 1 deletion parquet/src/arrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,14 @@ pub mod async_writer;
mod record_reader;
experimental!(mod schema);

use std::sync::Arc;

pub use self::arrow_writer::ArrowWriter;
#[cfg(feature = "async")]
pub use self::async_reader::ParquetRecordBatchStreamBuilder;
#[cfg(feature = "async")]
pub use self::async_writer::AsyncArrowWriter;
use crate::schema::types::SchemaDescriptor;
use crate::schema::types::{SchemaDescriptor, Type};
use arrow_schema::{FieldRef, Schema};

pub use self::schema::{
Expand Down Expand Up @@ -206,6 +208,71 @@ impl ProjectionMask {
Self { mask: Some(mask) }
}

// Given a starting point in the schema, do a DFS for that node adding leaf paths to `paths`.
fn find_leaves(root: &Arc<Type>, parent: Option<&String>, paths: &mut Vec<String>) {
let path = parent
.map(|p| [p, root.name()].join("."))
.unwrap_or(root.name().to_string());
if root.is_group() {
for child in root.get_fields() {
Self::find_leaves(child, Some(&path), paths);
}
} else {
// Reached a leaf, add to paths
paths.push(path);
}
}

/// Create a [`ProjectionMask`] which selects only the named columns
///
/// All leaf columns that fall below a given name will be selected. For example, given
/// the schema
/// ```ignore
/// message schema {
/// OPTIONAL group a (MAP) {
/// REPEATED group key_value {
/// REQUIRED BYTE_ARRAY key (UTF8); // leaf index 0
/// OPTIONAL group value (MAP) {
/// REPEATED group key_value {
/// REQUIRED INT32 key; // leaf index 1
/// REQUIRED BOOLEAN value; // leaf index 2
/// }
/// }
/// }
/// }
/// REQUIRED INT32 b; // leaf index 3
/// REQUIRED DOUBLE c; // leaf index 4
/// }
/// ```
/// `["a.key_value.value", "c"]` would return leaf columns 1, 2, and 4. `["a"]` would return
/// columns 0, 1, and 2.
///
/// Note: repeated or out of order indices will not impact the final mask.
///
/// i.e. `["b", "c"]` will construct the same mask as `["c", "b", "c"]`.
pub fn columns<'a>(
schema: &SchemaDescriptor,
names: impl IntoIterator<Item = &'a str>,
) -> Self {
// first make vector of paths for leaf columns
let mut paths: Vec<String> = vec![];
for root in schema.root_schema().get_fields() {
Self::find_leaves(root, None, &mut paths);
}
assert_eq!(paths.len(), schema.num_columns());

let mut mask = vec![false; schema.num_columns()];
for name in names {
for idx in 0..schema.num_columns() {
if paths[idx].starts_with(name) {
mask[idx] = true;
}
}
}

Self { mask: Some(mask) }
}

/// Returns true if the leaf column `leaf_idx` is included by the mask
pub fn leaf_included(&self, leaf_idx: usize) -> bool {
self.mask.as_ref().map(|m| m[leaf_idx]).unwrap_or(true)
Expand Down Expand Up @@ -242,10 +309,14 @@ mod test {
use crate::arrow::ArrowWriter;
use crate::file::metadata::{ParquetMetaData, ParquetMetaDataReader, ParquetMetaDataWriter};
use crate::file::properties::{EnabledStatistics, WriterProperties};
use crate::schema::parser::parse_message_type;
use crate::schema::types::SchemaDescriptor;
use arrow_array::{ArrayRef, Int32Array, RecordBatch};
use bytes::Bytes;
use std::sync::Arc;

use super::ProjectionMask;

#[test]
// Reproducer for https://github.com/apache/arrow-rs/issues/6464
fn test_metadata_read_write_partial_offset() {
Expand Down Expand Up @@ -371,4 +442,109 @@ mod test {
.unwrap();
Bytes::from(buf)
}

#[test]
fn test_mask_from_column_names() {
let message_type = "
message test_schema {
OPTIONAL group a (MAP) {
REPEATED group key_value {
REQUIRED BYTE_ARRAY key (UTF8);
OPTIONAL group value (MAP) {
REPEATED group key_value {
REQUIRED INT32 key;
REQUIRED BOOLEAN value;
}
}
}
}
REQUIRED INT32 b;
REQUIRED DOUBLE c;
}
";
let parquet_group_type = parse_message_type(message_type).unwrap();
let schema = SchemaDescriptor::new(Arc::new(parquet_group_type));

let mask = ProjectionMask::columns(&schema, ["foo", "bar"]);
assert_eq!(mask.mask.unwrap(), vec![false; 5]);

let mask = ProjectionMask::columns(&schema, []);
assert_eq!(mask.mask.unwrap(), vec![false; 5]);

let mask = ProjectionMask::columns(&schema, ["a", "c"]);
assert_eq!(mask.mask.unwrap(), [true, true, true, false, true]);

let mask = ProjectionMask::columns(&schema, ["a.key_value.key", "c"]);
assert_eq!(mask.mask.unwrap(), [true, false, false, false, true]);

let mask = ProjectionMask::columns(&schema, ["a.key_value.value", "b"]);
assert_eq!(mask.mask.unwrap(), [false, true, true, true, false]);

let message_type = "
message test_schema {
OPTIONAL group a (LIST) {
REPEATED group list {
OPTIONAL group element (LIST) {
REPEATED group list {
OPTIONAL group element (LIST) {
REPEATED group list {
OPTIONAL BYTE_ARRAY element (UTF8);
}
}
}
}
}
}
REQUIRED INT32 b;
}
";
let parquet_group_type = parse_message_type(message_type).unwrap();
let schema = SchemaDescriptor::new(Arc::new(parquet_group_type));

let mask = ProjectionMask::columns(&schema, ["a", "b"]);
assert_eq!(mask.mask.unwrap(), [true, true]);

let mask = ProjectionMask::columns(&schema, ["a.list.element", "b"]);
assert_eq!(mask.mask.unwrap(), [true, true]);

let mask =
ProjectionMask::columns(&schema, ["a.list.element.list.element.list.element", "b"]);
assert_eq!(mask.mask.unwrap(), [true, true]);

let mask = ProjectionMask::columns(&schema, ["b"]);
assert_eq!(mask.mask.unwrap(), [false, true]);

let message_type = "
message test_schema {
OPTIONAL INT32 a;
OPTIONAL INT32 b;
OPTIONAL INT32 c;
OPTIONAL INT32 d;
OPTIONAL INT32 e;
}
";
let parquet_group_type = parse_message_type(message_type).unwrap();
let schema = SchemaDescriptor::new(Arc::new(parquet_group_type));

let mask = ProjectionMask::columns(&schema, ["a", "b"]);
assert_eq!(mask.mask.unwrap(), [true, true, false, false, false]);

let mask = ProjectionMask::columns(&schema, ["d", "b", "d"]);
assert_eq!(mask.mask.unwrap(), [false, true, false, true, false]);

let message_type = "
message test_schema {
OPTIONAL INT32 a;
OPTIONAL INT32 b;
OPTIONAL INT32 a;
OPTIONAL INT32 d;
OPTIONAL INT32 e;
}
";
let parquet_group_type = parse_message_type(message_type).unwrap();
let schema = SchemaDescriptor::new(Arc::new(parquet_group_type));

let mask = ProjectionMask::columns(&schema, ["a", "e"]);
assert_eq!(mask.mask.unwrap(), [true, false, true, false, true]);
}
}

0 comments on commit 1abe85f

Please sign in to comment.