diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs index 516310dc81ae6..4d8f00816f53a 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs @@ -228,6 +228,80 @@ struct BloomFilterStatistics { column_sbbf: HashMap, } +impl BloomFilterStatistics { + /// Helper function for checking if [`ScalarValue`] contained in [`Sbbf`] filter. + /// + /// In case the type of scalar is not supported, returns `true`, assuming that the + /// value may be present. + fn scalar_contained(sbbf: &Sbbf, value: &ScalarValue, parquet_type: &Type) -> bool { + match value { + ScalarValue::Utf8(Some(v)) | ScalarValue::Utf8View(Some(v)) => { + sbbf.check(&v.as_str()) + } + ScalarValue::Binary(Some(v)) | ScalarValue::BinaryView(Some(v)) => { + sbbf.check(v) + } + ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), + ScalarValue::Boolean(Some(v)) => sbbf.check(v), + ScalarValue::Float64(Some(v)) => sbbf.check(v), + ScalarValue::Float32(Some(v)) => sbbf.check(v), + ScalarValue::Int64(Some(v)) => sbbf.check(v), + ScalarValue::Int32(Some(v)) => sbbf.check(v), + ScalarValue::UInt64(Some(v)) => sbbf.check(v), + ScalarValue::UInt32(Some(v)) => sbbf.check(v), + ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { + Type::INT32 => { + //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 + // All physical type are little-endian + if *p > 9 { + //DECIMAL can be used to annotate the following types: + // + // int32: for 1 <= precision <= 9 + // int64: for 1 <= precision <= 18 + return true; + } + let b = (*v as i32).to_le_bytes(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Int32 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::INT64 => { + if *p > 18 { + return true; + } + let b = (*v as i64).to_le_bytes(); + let decimal = Decimal::Int64 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::FIXED_LEN_BYTE_ARRAY => { + // keep with from_bytes_to_i128 + let b = v.to_be_bytes().to_vec(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Bytes { + value: b.into(), + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + _ => true, + }, + ScalarValue::Dictionary(_, inner) => { + BloomFilterStatistics::scalar_contained(sbbf, inner, parquet_type) + } + _ => true, + } + } +} + impl PruningStatistics for BloomFilterStatistics { fn min_values(&self, _column: &Column) -> Option { None @@ -269,68 +343,7 @@ impl PruningStatistics for BloomFilterStatistics { let known_not_present = values .iter() .map(|value| { - match value { - ScalarValue::Utf8(Some(v)) | ScalarValue::Utf8View(Some(v)) => { - sbbf.check(&v.as_str()) - } - ScalarValue::Binary(Some(v)) | ScalarValue::BinaryView(Some(v)) => { - sbbf.check(v) - } - ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), - ScalarValue::Boolean(Some(v)) => sbbf.check(v), - ScalarValue::Float64(Some(v)) => sbbf.check(v), - ScalarValue::Float32(Some(v)) => sbbf.check(v), - ScalarValue::Int64(Some(v)) => sbbf.check(v), - ScalarValue::Int32(Some(v)) => sbbf.check(v), - ScalarValue::UInt64(Some(v)) => sbbf.check(v), - ScalarValue::UInt32(Some(v)) => sbbf.check(v), - ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { - Type::INT32 => { - //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 - // All physical type are little-endian - if *p > 9 { - //DECIMAL can be used to annotate the following types: - // - // int32: for 1 <= precision <= 9 - // int64: for 1 <= precision <= 18 - return true; - } - let b = (*v as i32).to_le_bytes(); - // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 - let decimal = Decimal::Int32 { - value: b, - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - Type::INT64 => { - if *p > 18 { - return true; - } - let b = (*v as i64).to_le_bytes(); - let decimal = Decimal::Int64 { - value: b, - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - Type::FIXED_LEN_BYTE_ARRAY => { - // keep with from_bytes_to_i128 - let b = v.to_be_bytes().to_vec(); - // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 - let decimal = Decimal::Bytes { - value: b.into(), - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - _ => true, - }, - _ => true, - } + BloomFilterStatistics::scalar_contained(sbbf, value, parquet_type) }) // The row group doesn't contain any of the values if // all the checks are false diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 46be2433116a4..f49d2ad44535d 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -17,16 +17,16 @@ //! Parquet integration tests use crate::parquet::utils::MetricsFinder; -use arrow::array::Decimal128Array; use arrow::{ array::{ make_array, Array, ArrayRef, BinaryArray, Date32Array, Date64Array, - FixedSizeBinaryArray, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, LargeBinaryArray, LargeStringArray, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Decimal128Array, DictionaryArray, FixedSizeBinaryArray, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, + StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }, - datatypes::{DataType, Field, Schema}, + datatypes::{DataType, Field, Int32Type, Schema}, record_batch::RecordBatch, util::pretty::pretty_format_batches, }; @@ -64,7 +64,7 @@ fn init() { // ---------------------- /// What data to use -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] enum Scenario { Timestamps, Dates, @@ -84,6 +84,7 @@ enum Scenario { WithNullValues, WithNullValuesPageLevel, UTF8, + Dictionary(Box), } enum Unit { @@ -961,6 +962,25 @@ fn create_data_batch(scenario: Scenario) -> Vec { ]), ] } + + Scenario::Dictionary(inner) => { + let mut batches = vec![]; + for source_batch in create_data_batch(*inner) { + let mut columns = vec![]; + for (idx, column) in source_batch.columns().iter().enumerate() { + let keys = Int32Array::from_iter((0..column.len()).map(|i| i as i32)); + let dict_array = Arc::new( + DictionaryArray::::try_new(keys, column.clone()) + .unwrap(), + ) as ArrayRef; + columns + .push((source_batch.schema_ref().field(idx).name(), dict_array)); + } + batches.push(RecordBatch::try_from_iter(columns).unwrap()); + } + + batches + } } } diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 536ac5414a9a8..8635da6de22e2 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -1323,3 +1323,18 @@ async fn test_row_group_with_null_values() { .test_row_group_prune() .await; } + +#[tokio::test] +async fn prune_by_bloom_filter_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary(Box::new(Scenario::UTF8))) + .with_query("SELECT * FROM t WHERE utf8 = 'ab'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +}