Skip to content

Commit

Permalink
Enable string-based column projections from Parquet files (apache#6871)
Browse files Browse the repository at this point in the history
* add function to create ProjectionMask from column names

* add some more tests
  • Loading branch information
etseidl authored and CurtHagenlocher committed Dec 28, 2024
1 parent 63f5d5e commit b8cc13e
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 1 deletion.
68 changes: 68 additions & 0 deletions parquet/src/arrow/arrow_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,21 @@ mod tests {
assert_eq!(original_schema.fields()[1], reader.schema().fields()[0]);
}

#[test]
fn test_arrow_reader_single_column_by_name() {
let file = get_test_file("parquet/generated_simple_numerics/blogs.parquet");

let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap();
let original_schema = Arc::clone(builder.schema());

let mask = ProjectionMask::columns(builder.parquet_schema(), ["blog_id"]);
let reader = builder.with_projection(mask).build().unwrap();

// Verify that the schema was correctly parsed
assert_eq!(1, reader.schema().fields().len());
assert_eq!(original_schema.fields()[1], reader.schema().fields()[0]);
}

#[test]
fn test_null_column_reader_test() {
let mut file = tempfile::tempfile().unwrap();
Expand Down Expand Up @@ -2563,6 +2578,59 @@ mod tests {
}
}

#[test]
// same as test_read_structs but constructs projection mask via column names
fn test_read_structs_by_name() {
let testdata = arrow::util::test_util::parquet_test_data();
let path = format!("{testdata}/nested_structs.rust.parquet");
let file = File::open(&path).unwrap();
let record_batch_reader = ParquetRecordBatchReader::try_new(file, 60).unwrap();

for batch in record_batch_reader {
batch.unwrap();
}

let file = File::open(&path).unwrap();
let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap();

let mask = ProjectionMask::columns(
builder.parquet_schema(),
["roll_num.count", "PC_CUR.mean", "PC_CUR.sum"],
);
let projected_reader = builder
.with_projection(mask)
.with_batch_size(60)
.build()
.unwrap();

let expected_schema = Schema::new(vec![
Field::new(
"roll_num",
ArrowDataType::Struct(Fields::from(vec![Field::new(
"count",
ArrowDataType::UInt64,
false,
)])),
false,
),
Field::new(
"PC_CUR",
ArrowDataType::Struct(Fields::from(vec![
Field::new("mean", ArrowDataType::Int64, false),
Field::new("sum", ArrowDataType::Int64, false),
])),
false,
),
]);

assert_eq!(&expected_schema, projected_reader.schema().as_ref());

for batch in projected_reader {
let batch = batch.unwrap();
assert_eq!(batch.schema().as_ref(), &expected_schema);
}
}

#[test]
fn test_read_maps() {
let testdata = arrow::util::test_util::parquet_test_data();
Expand Down
178 changes: 177 additions & 1 deletion parquet/src/arrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,14 @@ pub mod async_writer;
mod record_reader;
experimental!(mod schema);

use std::sync::Arc;

pub use self::arrow_writer::ArrowWriter;
#[cfg(feature = "async")]
pub use self::async_reader::ParquetRecordBatchStreamBuilder;
#[cfg(feature = "async")]
pub use self::async_writer::AsyncArrowWriter;
use crate::schema::types::SchemaDescriptor;
use crate::schema::types::{SchemaDescriptor, Type};
use arrow_schema::{FieldRef, Schema};

// continue to export deprecated methods until they are removed
Expand Down Expand Up @@ -210,6 +212,71 @@ impl ProjectionMask {
Self { mask: Some(mask) }
}

// Given a starting point in the schema, do a DFS for that node adding leaf paths to `paths`.
fn find_leaves(root: &Arc<Type>, parent: Option<&String>, paths: &mut Vec<String>) {
let path = parent
.map(|p| [p, root.name()].join("."))
.unwrap_or(root.name().to_string());
if root.is_group() {
for child in root.get_fields() {
Self::find_leaves(child, Some(&path), paths);
}
} else {
// Reached a leaf, add to paths
paths.push(path);
}
}

/// Create a [`ProjectionMask`] which selects only the named columns
///
/// All leaf columns that fall below a given name will be selected. For example, given
/// the schema
/// ```ignore
/// message schema {
/// OPTIONAL group a (MAP) {
/// REPEATED group key_value {
/// REQUIRED BYTE_ARRAY key (UTF8); // leaf index 0
/// OPTIONAL group value (MAP) {
/// REPEATED group key_value {
/// REQUIRED INT32 key; // leaf index 1
/// REQUIRED BOOLEAN value; // leaf index 2
/// }
/// }
/// }
/// }
/// REQUIRED INT32 b; // leaf index 3
/// REQUIRED DOUBLE c; // leaf index 4
/// }
/// ```
/// `["a.key_value.value", "c"]` would return leaf columns 1, 2, and 4. `["a"]` would return
/// columns 0, 1, and 2.
///
/// Note: repeated or out of order indices will not impact the final mask.
///
/// i.e. `["b", "c"]` will construct the same mask as `["c", "b", "c"]`.
pub fn columns<'a>(
schema: &SchemaDescriptor,
names: impl IntoIterator<Item = &'a str>,
) -> Self {
// first make vector of paths for leaf columns
let mut paths: Vec<String> = vec![];
for root in schema.root_schema().get_fields() {
Self::find_leaves(root, None, &mut paths);
}
assert_eq!(paths.len(), schema.num_columns());

let mut mask = vec![false; schema.num_columns()];
for name in names {
for idx in 0..schema.num_columns() {
if paths[idx].starts_with(name) {
mask[idx] = true;
}
}
}

Self { mask: Some(mask) }
}

/// Returns true if the leaf column `leaf_idx` is included by the mask
pub fn leaf_included(&self, leaf_idx: usize) -> bool {
self.mask.as_ref().map(|m| m[leaf_idx]).unwrap_or(true)
Expand Down Expand Up @@ -246,10 +313,14 @@ mod test {
use crate::arrow::ArrowWriter;
use crate::file::metadata::{ParquetMetaData, ParquetMetaDataReader, ParquetMetaDataWriter};
use crate::file::properties::{EnabledStatistics, WriterProperties};
use crate::schema::parser::parse_message_type;
use crate::schema::types::SchemaDescriptor;
use arrow_array::{ArrayRef, Int32Array, RecordBatch};
use bytes::Bytes;
use std::sync::Arc;

use super::ProjectionMask;

#[test]
// Reproducer for https://github.com/apache/arrow-rs/issues/6464
fn test_metadata_read_write_partial_offset() {
Expand Down Expand Up @@ -375,4 +446,109 @@ mod test {
.unwrap();
Bytes::from(buf)
}

#[test]
fn test_mask_from_column_names() {
let message_type = "
message test_schema {
OPTIONAL group a (MAP) {
REPEATED group key_value {
REQUIRED BYTE_ARRAY key (UTF8);
OPTIONAL group value (MAP) {
REPEATED group key_value {
REQUIRED INT32 key;
REQUIRED BOOLEAN value;
}
}
}
}
REQUIRED INT32 b;
REQUIRED DOUBLE c;
}
";
let parquet_group_type = parse_message_type(message_type).unwrap();
let schema = SchemaDescriptor::new(Arc::new(parquet_group_type));

let mask = ProjectionMask::columns(&schema, ["foo", "bar"]);
assert_eq!(mask.mask.unwrap(), vec![false; 5]);

let mask = ProjectionMask::columns(&schema, []);
assert_eq!(mask.mask.unwrap(), vec![false; 5]);

let mask = ProjectionMask::columns(&schema, ["a", "c"]);
assert_eq!(mask.mask.unwrap(), [true, true, true, false, true]);

let mask = ProjectionMask::columns(&schema, ["a.key_value.key", "c"]);
assert_eq!(mask.mask.unwrap(), [true, false, false, false, true]);

let mask = ProjectionMask::columns(&schema, ["a.key_value.value", "b"]);
assert_eq!(mask.mask.unwrap(), [false, true, true, true, false]);

let message_type = "
message test_schema {
OPTIONAL group a (LIST) {
REPEATED group list {
OPTIONAL group element (LIST) {
REPEATED group list {
OPTIONAL group element (LIST) {
REPEATED group list {
OPTIONAL BYTE_ARRAY element (UTF8);
}
}
}
}
}
}
REQUIRED INT32 b;
}
";
let parquet_group_type = parse_message_type(message_type).unwrap();
let schema = SchemaDescriptor::new(Arc::new(parquet_group_type));

let mask = ProjectionMask::columns(&schema, ["a", "b"]);
assert_eq!(mask.mask.unwrap(), [true, true]);

let mask = ProjectionMask::columns(&schema, ["a.list.element", "b"]);
assert_eq!(mask.mask.unwrap(), [true, true]);

let mask =
ProjectionMask::columns(&schema, ["a.list.element.list.element.list.element", "b"]);
assert_eq!(mask.mask.unwrap(), [true, true]);

let mask = ProjectionMask::columns(&schema, ["b"]);
assert_eq!(mask.mask.unwrap(), [false, true]);

let message_type = "
message test_schema {
OPTIONAL INT32 a;
OPTIONAL INT32 b;
OPTIONAL INT32 c;
OPTIONAL INT32 d;
OPTIONAL INT32 e;
}
";
let parquet_group_type = parse_message_type(message_type).unwrap();
let schema = SchemaDescriptor::new(Arc::new(parquet_group_type));

let mask = ProjectionMask::columns(&schema, ["a", "b"]);
assert_eq!(mask.mask.unwrap(), [true, true, false, false, false]);

let mask = ProjectionMask::columns(&schema, ["d", "b", "d"]);
assert_eq!(mask.mask.unwrap(), [false, true, false, true, false]);

let message_type = "
message test_schema {
OPTIONAL INT32 a;
OPTIONAL INT32 b;
OPTIONAL INT32 a;
OPTIONAL INT32 d;
OPTIONAL INT32 e;
}
";
let parquet_group_type = parse_message_type(message_type).unwrap();
let schema = SchemaDescriptor::new(Arc::new(parquet_group_type));

let mask = ProjectionMask::columns(&schema, ["a", "e"]);
assert_eq!(mask.mask.unwrap(), [true, false, true, false, true]);
}
}
11 changes: 11 additions & 0 deletions parquet/src/arrow/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,17 @@ mod tests {
for i in 0..arrow_fields.len() {
assert_eq!(&arrow_fields[i], converted_fields[i].as_ref());
}

let mask =
ProjectionMask::columns(&parquet_schema, ["group2.leaf4", "group1.leaf1", "leaf5"]);
let converted_arrow_schema =
parquet_to_arrow_schema_by_columns(&parquet_schema, mask, None).unwrap();
let converted_fields = converted_arrow_schema.fields();

assert_eq!(arrow_fields.len(), converted_fields.len());
for i in 0..arrow_fields.len() {
assert_eq!(&arrow_fields[i], converted_fields[i].as_ref());
}
}

#[test]
Expand Down

0 comments on commit b8cc13e

Please sign in to comment.