diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs index 516310dc81ae6..17e3d904cfcbc 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs @@ -17,6 +17,7 @@ use crate::datasource::listing::FileRange; use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use arrow::datatypes::LargeUtf8Type; use arrow::{array::ArrayRef, datatypes::Schema}; use arrow_array::BooleanArray; use datafusion_common::{Column, Result, ScalarValue}; @@ -228,6 +229,85 @@ struct BloomFilterStatistics { column_sbbf: HashMap, } +impl BloomFilterStatistics { + /// Helper function for checking if [`Sbbf`] filter contains [`ScalarValue`]. + /// + /// In case the type of scalar is not supported, returns `true`, assuming that the + /// value may be present. + fn check_scalar(sbbf: &Sbbf, value: &ScalarValue, parquet_type: &Type) -> bool { + match value { + ScalarValue::Utf8(Some(v)) + | ScalarValue::Utf8View(Some(v)) + | ScalarValue::LargeUtf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Binary(Some(v)) | ScalarValue::BinaryView(Some(v)) => { + sbbf.check(v) + } + ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), + ScalarValue::Boolean(Some(v)) => sbbf.check(v), + ScalarValue::Float64(Some(v)) => sbbf.check(v), + ScalarValue::Float32(Some(v)) => sbbf.check(v), + ScalarValue::Int64(Some(v)) => sbbf.check(v), + ScalarValue::Int32(Some(v)) => sbbf.check(v), + ScalarValue::UInt64(Some(v)) => sbbf.check(v), + ScalarValue::UInt32(Some(v)) => sbbf.check(v), + ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { + Type::INT32 => { + //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 + // All physical type are little-endian + if *p > 9 { + //DECIMAL can be used to annotate the following types: + // + // int32: for 1 <= precision <= 9 + // int64: for 1 <= precision <= 18 + return true; + } + let b = (*v as i32).to_le_bytes(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Int32 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::INT64 => { + if *p > 18 { + return true; + } + let b = (*v as i64).to_le_bytes(); + let decimal = Decimal::Int64 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::FIXED_LEN_BYTE_ARRAY => { + // keep with from_bytes_to_i128 + let b = v.to_be_bytes().to_vec(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Bytes { + value: b.into(), + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + _ => true, + }, + // Bloom filter pruning is performed only for Utf8 dictionary types since + // pruning predicate is not created for Dictionary(Numeric/Binary) types + ScalarValue::Dictionary(_, inner) => match inner.as_ref() { + ScalarValue::Utf8(_) | ScalarValue::LargeUtf8(_) => { + BloomFilterStatistics::check_scalar(sbbf, inner, parquet_type) + } + _ => true, + }, + _ => true, + } + } +} + impl PruningStatistics for BloomFilterStatistics { fn min_values(&self, _column: &Column) -> Option { None @@ -268,70 +348,7 @@ impl PruningStatistics for BloomFilterStatistics { let known_not_present = values .iter() - .map(|value| { - match value { - ScalarValue::Utf8(Some(v)) | ScalarValue::Utf8View(Some(v)) => { - sbbf.check(&v.as_str()) - } - ScalarValue::Binary(Some(v)) | ScalarValue::BinaryView(Some(v)) => { - sbbf.check(v) - } - ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), - ScalarValue::Boolean(Some(v)) => sbbf.check(v), - ScalarValue::Float64(Some(v)) => sbbf.check(v), - ScalarValue::Float32(Some(v)) => sbbf.check(v), - ScalarValue::Int64(Some(v)) => sbbf.check(v), - ScalarValue::Int32(Some(v)) => sbbf.check(v), - ScalarValue::UInt64(Some(v)) => sbbf.check(v), - ScalarValue::UInt32(Some(v)) => sbbf.check(v), - ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { - Type::INT32 => { - //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 - // All physical type are little-endian - if *p > 9 { - //DECIMAL can be used to annotate the following types: - // - // int32: for 1 <= precision <= 9 - // int64: for 1 <= precision <= 18 - return true; - } - let b = (*v as i32).to_le_bytes(); - // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 - let decimal = Decimal::Int32 { - value: b, - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - Type::INT64 => { - if *p > 18 { - return true; - } - let b = (*v as i64).to_le_bytes(); - let decimal = Decimal::Int64 { - value: b, - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - Type::FIXED_LEN_BYTE_ARRAY => { - // keep with from_bytes_to_i128 - let b = v.to_be_bytes().to_vec(); - // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 - let decimal = Decimal::Bytes { - value: b.into(), - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - _ => true, - }, - _ => true, - } - }) + .map(|value| BloomFilterStatistics::check_scalar(sbbf, value, parquet_type)) // The row group doesn't contain any of the values if // all the checks are false .all(|v| !v); diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 46be2433116a4..9ecad98b2973a 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -17,14 +17,14 @@ //! Parquet integration tests use crate::parquet::utils::MetricsFinder; -use arrow::array::Decimal128Array; use arrow::{ array::{ make_array, Array, ArrayRef, BinaryArray, Date32Array, Date64Array, - FixedSizeBinaryArray, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, LargeBinaryArray, LargeStringArray, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Decimal128Array, DictionaryArray, FixedSizeBinaryArray, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, + StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, @@ -64,7 +64,7 @@ fn init() { // ---------------------- /// What data to use -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] enum Scenario { Timestamps, Dates, @@ -84,6 +84,7 @@ enum Scenario { WithNullValues, WithNullValuesPageLevel, UTF8, + Dictionary, } enum Unit { @@ -740,6 +741,32 @@ fn make_utf8_batch(value: Vec>) -> RecordBatch { .unwrap() } +fn make_dictionary_batch(strings: Vec<&str>) -> RecordBatch { + let keys = Int32Array::from_iter(0..strings.len() as i32); + + let utf8_values = StringArray::from(strings.clone()); + let utf8_dict = DictionaryArray::new(keys.clone(), Arc::new(utf8_values)); + + let large_utf8 = LargeStringArray::from(strings.clone()); + let large_utf8_dict = DictionaryArray::new(keys.clone(), Arc::new(large_utf8)); + + let binary = + BinaryArray::from_iter_values(strings.iter().cloned().map(|v| v.as_bytes())); + let binary_dict = DictionaryArray::new(keys.clone(), Arc::new(binary)); + + let large_binary = + LargeBinaryArray::from_iter_values(strings.iter().cloned().map(|v| v.as_bytes())); + let large_binary_dict = DictionaryArray::new(keys.clone(), Arc::new(large_binary)); + + RecordBatch::try_from_iter(vec![ + ("utf8", Arc::new(utf8_dict) as _), + ("large_utf8", Arc::new(large_utf8_dict) as _), + ("binary", Arc::new(binary_dict) as _), + ("large_binary", Arc::new(large_binary_dict) as _), + ]) + .unwrap() +} + fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Timestamps => { @@ -961,6 +988,13 @@ fn create_data_batch(scenario: Scenario) -> Vec { ]), ] } + + Scenario::Dictionary => { + vec![ + make_dictionary_batch(vec!["a", "b", "c", "d", "e"]), + make_dictionary_batch(vec!["f", "g", "h", "i", "j"]), + ] + } } } diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 536ac5414a9a8..6e53ce0df3058 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -1323,3 +1323,54 @@ async fn test_row_group_with_null_values() { .test_row_group_prune() .await; } + +#[tokio::test] +async fn test_bloom_filter_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE utf8 = 'h'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE utf8 = 'ab'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE large_utf8 = 'b'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE large_utf8 = 'cd'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +}