Skip to content

Commit

Permalink
Merge commit '1c85d6abd46fd37fc5f6af8e558f19049c73de4c' into chunchun…
Browse files Browse the repository at this point in the history
…/update-df-june-week-1
  • Loading branch information
jeffreyssmith2nd committed Jun 13, 2024
2 parents 28282d1 + 1c85d6a commit d87d5db
Show file tree
Hide file tree
Showing 33 changed files with 2,329 additions and 1,056 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ parking_lot = "0.12"
parquet = { version = "51.0.0", default-features = false, features = ["arrow", "async", "object_store"] }
rand = "0.8"
regex = "1.8"
rstest = "0.20.0"
rstest = "0.21.0"
serde_json = "1"
sqlparser = { version = "0.45.0", features = ["visitor"] }
tempfile = "3"
Expand Down
39 changes: 35 additions & 4 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
//! [`ScalarValue`]: stores single values
mod struct_builder;

use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::{HashSet, VecDeque};
use std::convert::Infallible;
use std::fmt;
use std::hash::Hash;
use std::hash::Hasher;
use std::iter::repeat;
use std::str::FromStr;
use std::sync::Arc;
Expand Down Expand Up @@ -55,6 +55,7 @@ use arrow::{
use arrow_buffer::Buffer;
use arrow_schema::{UnionFields, UnionMode};

use half::f16;
pub use struct_builder::ScalarStructBuilder;

/// A dynamically typed, nullable single value.
Expand Down Expand Up @@ -192,6 +193,8 @@ pub enum ScalarValue {
Null,
/// true or false value
Boolean(Option<bool>),
/// 16bit float
Float16(Option<f16>),
/// 32bit float
Float32(Option<f32>),
/// 64bit float
Expand Down Expand Up @@ -285,6 +288,12 @@ pub enum ScalarValue {
Dictionary(Box<DataType>, Box<ScalarValue>),
}

impl Hash for Fl<f16> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.to_bits().hash(state);
}
}

// manual implementation of `PartialEq`
impl PartialEq for ScalarValue {
fn eq(&self, other: &Self) -> bool {
Expand All @@ -307,7 +316,12 @@ impl PartialEq for ScalarValue {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
},
(Float16(v1), Float16(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
},
(Float32(_), _) => false,
(Float16(_), _) => false,
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
Expand Down Expand Up @@ -425,7 +439,12 @@ impl PartialOrd for ScalarValue {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
},
(Float16(v1), Float16(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
},
(Float32(_), _) => None,
(Float16(_), _) => None,
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
Expand Down Expand Up @@ -637,6 +656,7 @@ impl std::hash::Hash for ScalarValue {
s.hash(state)
}
Boolean(v) => v.hash(state),
Float16(v) => v.map(Fl).hash(state),
Float32(v) => v.map(Fl).hash(state),
Float64(v) => v.map(Fl).hash(state),
Int8(v) => v.hash(state),
Expand Down Expand Up @@ -1082,6 +1102,7 @@ impl ScalarValue {
ScalarValue::TimestampNanosecond(_, tz_opt) => {
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone())
}
ScalarValue::Float16(_) => DataType::Float16,
ScalarValue::Float32(_) => DataType::Float32,
ScalarValue::Float64(_) => DataType::Float64,
ScalarValue::Utf8(_) => DataType::Utf8,
Expand Down Expand Up @@ -1276,6 +1297,7 @@ impl ScalarValue {
match self {
ScalarValue::Boolean(v) => v.is_none(),
ScalarValue::Null => true,
ScalarValue::Float16(v) => v.is_none(),
ScalarValue::Float32(v) => v.is_none(),
ScalarValue::Float64(v) => v.is_none(),
ScalarValue::Decimal128(v, _, _) => v.is_none(),
Expand Down Expand Up @@ -1522,6 +1544,7 @@ impl ScalarValue {
}
DataType::Null => ScalarValue::iter_to_null_array(scalars)?,
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
DataType::Float16 => build_array_primitive!(Float16Array, Float16),
DataType::Float32 => build_array_primitive!(Float32Array, Float32),
DataType::Float64 => build_array_primitive!(Float64Array, Float64),
DataType::Int8 => build_array_primitive!(Int8Array, Int8),
Expand Down Expand Up @@ -1682,8 +1705,7 @@ impl ScalarValue {
// not supported if the TimeUnit is not valid (Time32 can
// only be used with Second and Millisecond, Time64 only
// with Microsecond and Nanosecond)
DataType::Float16
| DataType::Time32(TimeUnit::Microsecond)
DataType::Time32(TimeUnit::Microsecond)
| DataType::Time32(TimeUnit::Nanosecond)
| DataType::Time64(TimeUnit::Second)
| DataType::Time64(TimeUnit::Millisecond)
Expand All @@ -1700,7 +1722,6 @@ impl ScalarValue {
);
}
};

Ok(array)
}

Expand Down Expand Up @@ -1921,6 +1942,9 @@ impl ScalarValue {
ScalarValue::Float32(e) => {
build_array_from_option!(Float32, Float32Array, e, size)
}
ScalarValue::Float16(e) => {
build_array_from_option!(Float16, Float16Array, e, size)
}
ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array, e, size),
ScalarValue::Int16(e) => build_array_from_option!(Int16, Int16Array, e, size),
ScalarValue::Int32(e) => build_array_from_option!(Int32, Int32Array, e, size),
Expand Down Expand Up @@ -2595,6 +2619,9 @@ impl ScalarValue {
ScalarValue::Boolean(val) => {
eq_array_primitive!(array, index, BooleanArray, val)?
}
ScalarValue::Float16(val) => {
eq_array_primitive!(array, index, Float16Array, val)?
}
ScalarValue::Float32(val) => {
eq_array_primitive!(array, index, Float32Array, val)?
}
Expand Down Expand Up @@ -2738,6 +2765,7 @@ impl ScalarValue {
+ match self {
ScalarValue::Null
| ScalarValue::Boolean(_)
| ScalarValue::Float16(_)
| ScalarValue::Float32(_)
| ScalarValue::Float64(_)
| ScalarValue::Decimal128(_, _, _)
Expand Down Expand Up @@ -3022,6 +3050,7 @@ impl TryFrom<&DataType> for ScalarValue {
fn try_from(data_type: &DataType) -> Result<Self> {
Ok(match data_type {
DataType::Boolean => ScalarValue::Boolean(None),
DataType::Float16 => ScalarValue::Float16(None),
DataType::Float64 => ScalarValue::Float64(None),
DataType::Float32 => ScalarValue::Float32(None),
DataType::Int8 => ScalarValue::Int8(None),
Expand Down Expand Up @@ -3147,6 +3176,7 @@ impl fmt::Display for ScalarValue {
write!(f, "{v:?},{p:?},{s:?}")?;
}
ScalarValue::Boolean(e) => format_option!(f, e)?,
ScalarValue::Float16(e) => format_option!(f, e)?,
ScalarValue::Float32(e) => format_option!(f, e)?,
ScalarValue::Float64(e) => format_option!(f, e)?,
ScalarValue::Int8(e) => format_option!(f, e)?,
Expand Down Expand Up @@ -3260,6 +3290,7 @@ impl fmt::Debug for ScalarValue {
ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"),
ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"),
ScalarValue::Boolean(_) => write!(f, "Boolean({self})"),
ScalarValue::Float16(_) => write!(f, "Float16({self})"),
ScalarValue::Float32(_) => write!(f, "Float32({self})"),
ScalarValue::Float64(_) => write!(f, "Float64({self})"),
ScalarValue::Int8(_) => write!(f, "Int8({self})"),
Expand Down
134 changes: 134 additions & 0 deletions datafusion/common/src/utils/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! This module provides a function to estimate the memory size of a HashTable prior to alloaction
use crate::{DataFusionError, Result};

/// Estimates the memory size required for a hash table prior to allocation.
///
/// # Parameters
/// - `num_elements`: The number of elements expected in the hash table.
/// - `fixed_size`: A fixed overhead size associated with the collection
/// (e.g., HashSet or HashTable).
/// - `T`: The type of elements stored in the hash table.
///
/// # Details
/// This function calculates the estimated memory size by considering:
/// - An overestimation of buckets to keep approximately 1/8 of them empty.
/// - The total memory size is computed as:
/// - The size of each entry (`T`) multiplied by the estimated number of
/// buckets.
/// - One byte overhead for each bucket.
/// - The fixed size overhead of the collection.
/// - If the estimation overflows, we return a [`DataFusionError`]
///
/// # Examples
/// ---
///
/// ## From within a struct
///
/// ```rust
/// # use datafusion_common::utils::memory::estimate_memory_size;
/// # use datafusion_common::Result;
///
/// struct MyStruct<T> {
/// values: Vec<T>,
/// other_data: usize,
/// }
///
/// impl<T> MyStruct<T> {
/// fn size(&self) -> Result<usize> {
/// let num_elements = self.values.len();
/// let fixed_size = std::mem::size_of_val(self) +
/// std::mem::size_of_val(&self.values);
///
/// estimate_memory_size::<T>(num_elements, fixed_size)
/// }
/// }
/// ```
/// ---
/// ## With a simple collection
///
/// ```rust
/// # use datafusion_common::utils::memory::estimate_memory_size;
/// # use std::collections::HashMap;
///
/// let num_rows = 100;
/// let fixed_size = std::mem::size_of::<HashMap<u64, u64>>();
/// let estimated_hashtable_size =
/// estimate_memory_size::<(u64, u64)>(num_rows,fixed_size)
/// .expect("Size estimation failed");
/// ```
pub fn estimate_memory_size<T>(num_elements: usize, fixed_size: usize) -> Result<usize> {
// For the majority of cases hashbrown overestimates the bucket quantity
// to keep ~1/8 of them empty. We take this factor into account by
// multiplying the number of elements with a fixed ratio of 8/7 (~1.14).
// This formula leads to overallocation for small tables (< 8 elements)
// but should be fine overall.
num_elements
.checked_mul(8)
.and_then(|overestimate| {
let estimated_buckets = (overestimate / 7).next_power_of_two();
// + size of entry * number of buckets
// + 1 byte for each bucket
// + fixed size of collection (HashSet/HashTable)
std::mem::size_of::<T>()
.checked_mul(estimated_buckets)?
.checked_add(estimated_buckets)?
.checked_add(fixed_size)
})
.ok_or_else(|| {
DataFusionError::Execution(
"usize overflow while estimating the number of buckets".to_string(),
)
})
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;

use super::estimate_memory_size;

#[test]
fn test_estimate_memory() {
// size (bytes): 48
let fixed_size = std::mem::size_of::<HashSet<u32>>();

// estimated buckets: 16 = (8 * 8 / 7).next_power_of_two()
let num_elements = 8;
// size (bytes): 128 = 16 * 4 + 16 + 48
let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
assert_eq!(estimated, 128);

// estimated buckets: 64 = (40 * 8 / 7).next_power_of_two()
let num_elements = 40;
// size (bytes): 368 = 64 * 4 + 64 + 48
let estimated = estimate_memory_size::<u32>(num_elements, fixed_size).unwrap();
assert_eq!(estimated, 368);
}

#[test]
fn test_estimate_memory_overflow() {
let num_elements = usize::MAX;
let fixed_size = std::mem::size_of::<HashSet<u32>>();
let estimated = estimate_memory_size::<u32>(num_elements, fixed_size);

assert!(estimated.is_err());
}
}
1 change: 1 addition & 0 deletions datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! This module provides the bisect function, which implements binary search.
pub mod memory;
pub mod proxy;

use crate::error::{_internal_datafusion_err, _internal_err};
Expand Down
Loading

0 comments on commit d87d5db

Please sign in to comment.