Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into hash_join_mem_mgmt
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Mar 8, 2023
2 parents 2201ec3 + deeaa56 commit e16e20d
Show file tree
Hide file tree
Showing 36 changed files with 2,065 additions and 233 deletions.
65 changes: 22 additions & 43 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,40 +65,6 @@ impl GeometricMean {
pub fn new() -> Self {
GeometricMean { n: 0, prod: 1.0 }
}

// this function receives one entry per argument of this accumulator.
// DataFusion calls this function on every row, and expects this function to update the accumulator's state.
fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
// this is a one-argument UDAF, and thus we use `0`.
let value = &values[0];
match value {
// here we map `ScalarValue` to our internal state. `Float64` indicates that this function
// only accepts Float64 as its argument (DataFusion does try to coerce arguments to this type)
//
// Note that `.map` here ensures that we ignore Nulls.
ScalarValue::Float64(e) => e.map(|value| {
self.prod *= value;
self.n += 1;
}),
_ => unreachable!(""),
};
Ok(())
}

// this function receives states from other accumulators (Vec<ScalarValue>)
// and updates the accumulator.
fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
let prod = &states[0];
let n = &states[1];
match (prod, n) {
(ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) => {
self.prod *= prod;
self.n += n;
}
_ => unreachable!(""),
};
Ok(())
}
}

// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions
Expand Down Expand Up @@ -128,28 +94,41 @@ impl Accumulator for GeometricMean {
if values.is_empty() {
return Ok(());
}
(0..values[0].len()).try_for_each(|index| {
let v = values
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
self.update(&v)
let arr = &values[0];
(0..arr.len()).try_for_each(|index| {
let v = ScalarValue::try_from_array(arr, index)?;

if let ScalarValue::Float64(Some(value)) = v {
self.prod *= value;
self.n += 1;
} else {
unreachable!("")
}
Ok(())
})
}

// Optimization hint: this trait also supports `update_batch` and `merge_batch`,
// that can be used to perform these operations on arrays instead of single values.
// By default, these methods call `update` and `merge` row by row
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
(0..states[0].len()).try_for_each(|index| {
let arr = &states[0];
(0..arr.len()).try_for_each(|index| {
let v = states
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
self.merge(&v)
if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) =
(&v[0], &v[1])
{
self.prod *= prod;
self.n += n;
} else {
unreachable!("")
}
Ok(())
})
}

Expand Down
52 changes: 41 additions & 11 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ impl ScalarValue {
Self::List(scalars, Box::new(Field::new("item", child_type, true)))
}

// Create a zero value in the given type.
/// Create a zero value in the given type.
pub fn new_zero(datatype: &DataType) -> Result<ScalarValue> {
assert!(datatype.is_primitive());
Ok(match datatype {
Expand All @@ -1042,6 +1042,24 @@ impl ScalarValue {
})
}

/// Create a negative one value in the given type.
pub fn new_negative_one(datatype: &DataType) -> Result<ScalarValue> {
assert!(datatype.is_primitive());
Ok(match datatype {
DataType::Int8 | DataType::UInt8 => ScalarValue::Int8(Some(-1)),
DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)),
DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)),
DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)),
DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a negative one scalar from data_type \"{datatype:?}\""
)));
}
})
}

/// Getter for the `DataType` of the value
pub fn get_datatype(&self) -> DataType {
match self {
Expand Down Expand Up @@ -1296,7 +1314,7 @@ impl ScalarValue {
}

macro_rules! build_array_primitive_tz {
($ARRAY_TY:ident, $SCALAR_TY:ident) => {{
($ARRAY_TY:ident, $SCALAR_TY:ident, $TZ:expr) => {{
{
let array = scalars.map(|sv| {
if let ScalarValue::$SCALAR_TY(v, _) = sv {
Expand All @@ -1310,7 +1328,7 @@ impl ScalarValue {
}
})
.collect::<Result<$ARRAY_TY>>()?;
Arc::new(array)
Arc::new(array.with_timezone_opt($TZ.clone()))
}
}};
}
Expand Down Expand Up @@ -1444,17 +1462,29 @@ impl ScalarValue {
DataType::Time64(TimeUnit::Nanosecond) => {
build_array_primitive!(Time64NanosecondArray, Time64Nanosecond)
}
DataType::Timestamp(TimeUnit::Second, _) => {
build_array_primitive_tz!(TimestampSecondArray, TimestampSecond)
DataType::Timestamp(TimeUnit::Second, tz) => {
build_array_primitive_tz!(TimestampSecondArray, TimestampSecond, tz)
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
build_array_primitive_tz!(TimestampMillisecondArray, TimestampMillisecond)
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
build_array_primitive_tz!(
TimestampMillisecondArray,
TimestampMillisecond,
tz
)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
build_array_primitive_tz!(TimestampMicrosecondArray, TimestampMicrosecond)
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
build_array_primitive_tz!(
TimestampMicrosecondArray,
TimestampMicrosecond,
tz
)
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
build_array_primitive_tz!(TimestampNanosecondArray, TimestampNanosecond)
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
build_array_primitive_tz!(
TimestampNanosecondArray,
TimestampNanosecond,
tz
)
}
DataType::Interval(IntervalUnit::DayTime) => {
build_array_primitive!(IntervalDayTimeArray, IntervalDayTime)
Expand Down
65 changes: 41 additions & 24 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, Int64Array, StringArray};
use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
use arrow::datatypes::{DataType, Field};
use async_trait::async_trait;
Expand Down Expand Up @@ -318,7 +318,7 @@ impl DataFrame {
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/tpch-csv/customer.csv", CsvReadOptions::new()).await?;
/// let df = ctx.read_csv("tests/tpch-csv/customer.csv", CsvReadOptions::new()).await?;
/// df.describe().await.unwrap();
///
/// # Ok(())
Expand All @@ -329,10 +329,10 @@ impl DataFrame {
let supported_describe_functions =
vec!["count", "null_count", "mean", "std", "min", "max", "median"];

let fields_iter = self.schema().fields().iter();
let original_schema_fields = self.schema().fields().iter();

//define describe column
let mut describe_schemas = fields_iter
let mut describe_schemas = original_schema_fields
.clone()
.map(|field| {
if field.data_type().is_numeric() {
Expand All @@ -344,24 +344,38 @@ impl DataFrame {
.collect::<Vec<_>>();
describe_schemas.insert(0, Field::new("describe", DataType::Utf8, false));

//count aggregation
let cnt = self.clone().aggregate(
vec![],
original_schema_fields
.clone()
.map(|f| count(col(f.name())))
.collect::<Vec<_>>(),
)?;
// The optimization of AggregateStatistics will rewrite the physical plan
// for the count function and ignore alias functions,
// as shown in https://github.com/apache/arrow-datafusion/issues/5444.
// This logic should be removed when #5444 is fixed.
let cnt = cnt.clone().select(
cnt.schema()
.fields()
.iter()
.zip(original_schema_fields.clone())
.map(|(count_field, orgin_field)| {
col(count_field.name()).alias(orgin_field.name())
})
.collect::<Vec<_>>(),
)?;
//should be removed when #5444 is fixed
//collect recordBatch
let describe_record_batch = vec![
// count aggregation
self.clone()
.aggregate(
vec![],
fields_iter
.clone()
.map(|f| count(col(f.name())).alias(f.name()))
.collect::<Vec<_>>(),
)?
.collect()
.await?,
cnt.collect().await?,
// null_count aggregation
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.map(|f| count(is_null(col(f.name()))).alias(f.name()))
.collect::<Vec<_>>(),
Expand All @@ -372,7 +386,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| avg(col(f.name())).alias(f.name()))
Expand All @@ -384,7 +398,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| stddev(col(f.name())).alias(f.name()))
Expand All @@ -396,7 +410,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary | DataType::Boolean)
Expand All @@ -410,7 +424,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| {
!matches!(f.data_type(), DataType::Binary | DataType::Boolean)
Expand All @@ -424,7 +438,7 @@ impl DataFrame {
self.clone()
.aggregate(
vec![],
fields_iter
original_schema_fields
.clone()
.filter(|f| f.data_type().is_numeric())
.map(|f| median(col(f.name())).alias(f.name()))
Expand All @@ -435,7 +449,7 @@ impl DataFrame {
];

let mut array_ref_vec: Vec<ArrayRef> = vec![];
for field in fields_iter {
for field in original_schema_fields {
let mut array_datas = vec![];
for record_batch in describe_record_batch.iter() {
let column = record_batch.get(0).unwrap().column_by_name(field.name());
Expand Down Expand Up @@ -928,7 +942,8 @@ impl DataFrame {
/// Write a `DataFrame` to a CSV file.
pub async fn write_csv(self, path: &str) -> Result<()> {
let plan = self.session_state.create_physical_plan(&self.plan).await?;
plan_to_csv(&self.session_state, plan, path).await
let task_ctx = Arc::new(self.task_ctx());
plan_to_csv(task_ctx, plan, path).await
}

/// Write a `DataFrame` to a Parquet file.
Expand All @@ -938,13 +953,15 @@ impl DataFrame {
writer_properties: Option<WriterProperties>,
) -> Result<()> {
let plan = self.session_state.create_physical_plan(&self.plan).await?;
plan_to_parquet(&self.session_state, plan, path, writer_properties).await
let task_ctx = Arc::new(self.task_ctx());
plan_to_parquet(task_ctx, plan, path, writer_properties).await
}

/// Executes a query and writes the results to a partitioned JSON file.
pub async fn write_json(self, path: impl AsRef<str>) -> Result<()> {
let plan = self.session_state.create_physical_plan(&self.plan).await?;
plan_to_json(&self.session_state, plan, path).await
let task_ctx = Arc::new(self.task_ctx());
plan_to_json(task_ctx, plan, path).await
}

/// Add an additional column to the DataFrame.
Expand Down
Loading

0 comments on commit e16e20d

Please sign in to comment.