Skip to content

Commit

Permalink
Implement date_part for durations (#6246)
Browse files Browse the repository at this point in the history
Signed-off-by: Nick Cameron <[email protected]>
  • Loading branch information
nrc authored Aug 19, 2024
1 parent 27789d7 commit d7ad4fe
Showing 1 changed file with 286 additions and 1 deletion.
287 changes: 286 additions & 1 deletion arrow-arith/src/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use arrow_array::timezone::Tz;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::ArrowNativeType;
use arrow_schema::{ArrowError, DataType, IntervalUnit};
use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit};

/// Valid parts to extract from date/time/timestamp arrays.
///
Expand Down Expand Up @@ -113,6 +113,7 @@ where
/// - Time32/Time64
/// - Timestamp
/// - Interval
/// - Duration
///
/// Returns an [`Int32Array`] unless input was a dictionary type, in which case returns
/// the dictionary but with this function applied onto its values.
Expand Down Expand Up @@ -154,6 +155,26 @@ pub fn date_part(array: &dyn Array, part: DatePart) -> Result<ArrayRef, ArrowErr
let array = Arc::new(array) as ArrayRef;
Ok(array)
}
DataType::Duration(TimeUnit::Second) => {
let array = as_primitive_array::<DurationSecondType>(array).date_part(part)?;
let array = Arc::new(array) as ArrayRef;
Ok(array)
}
DataType::Duration(TimeUnit::Millisecond) => {
let array = as_primitive_array::<DurationMillisecondType>(array).date_part(part)?;
let array = Arc::new(array) as ArrayRef;
Ok(array)
}
DataType::Duration(TimeUnit::Microsecond) => {
let array = as_primitive_array::<DurationMicrosecondType>(array).date_part(part)?;
let array = Arc::new(array) as ArrayRef;
Ok(array)
}
DataType::Duration(TimeUnit::Nanosecond) => {
let array = as_primitive_array::<DurationNanosecondType>(array).date_part(part)?;
let array = Arc::new(array) as ArrayRef;
Ok(array)
}
DataType::Dictionary(_, _) => {
let array = array.as_any_dictionary();
let values = date_part(array.values(), part)?;
Expand Down Expand Up @@ -482,6 +503,126 @@ impl ExtractDatePartExt for PrimitiveArray<IntervalMonthDayNanoType> {
}
}

impl ExtractDatePartExt for PrimitiveArray<DurationSecondType> {
fn date_part(&self, part: DatePart) -> Result<Int32Array, ArrowError> {
match part {
DatePart::Week => Ok(self.unary_opt(|d| (d / (60 * 60 * 24 * 7)).try_into().ok())),
DatePart::Day => Ok(self.unary_opt(|d| (d / (60 * 60 * 24)).try_into().ok())),
DatePart::Hour => Ok(self.unary_opt(|d| (d / (60 * 60)).try_into().ok())),
DatePart::Minute => Ok(self.unary_opt(|d| (d / 60).try_into().ok())),
DatePart::Second => Ok(self.unary_opt(|d| d.try_into().ok())),
DatePart::Millisecond => {
Ok(self.unary_opt(|d| d.checked_mul(1_000).and_then(|d| d.try_into().ok())))
}
DatePart::Microsecond => {
Ok(self.unary_opt(|d| d.checked_mul(1_000_000).and_then(|d| d.try_into().ok())))
}
DatePart::Nanosecond => Ok(
self.unary_opt(|d| d.checked_mul(1_000_000_000).and_then(|d| d.try_into().ok()))
),

DatePart::Year
| DatePart::Quarter
| DatePart::Month
| DatePart::DayOfWeekSunday0
| DatePart::DayOfWeekMonday0
| DatePart::DayOfYear => {
return_compute_error_with!(format!("{part} does not support"), self.data_type())
}
}
}
}

impl ExtractDatePartExt for PrimitiveArray<DurationMillisecondType> {
fn date_part(&self, part: DatePart) -> Result<Int32Array, ArrowError> {
match part {
DatePart::Week => {
Ok(self.unary_opt(|d| (d / (1_000 * 60 * 60 * 24 * 7)).try_into().ok()))
}
DatePart::Day => Ok(self.unary_opt(|d| (d / (1_000 * 60 * 60 * 24)).try_into().ok())),
DatePart::Hour => Ok(self.unary_opt(|d| (d / (1_000 * 60 * 60)).try_into().ok())),
DatePart::Minute => Ok(self.unary_opt(|d| (d / (1_000 * 60)).try_into().ok())),
DatePart::Second => Ok(self.unary_opt(|d| (d / 1_000).try_into().ok())),
DatePart::Millisecond => Ok(self.unary_opt(|d| d.try_into().ok())),
DatePart::Microsecond => {
Ok(self.unary_opt(|d| d.checked_mul(1_000).and_then(|d| d.try_into().ok())))
}
DatePart::Nanosecond => {
Ok(self.unary_opt(|d| d.checked_mul(1_000_000).and_then(|d| d.try_into().ok())))
}

DatePart::Year
| DatePart::Quarter
| DatePart::Month
| DatePart::DayOfWeekSunday0
| DatePart::DayOfWeekMonday0
| DatePart::DayOfYear => {
return_compute_error_with!(format!("{part} does not support"), self.data_type())
}
}
}
}

impl ExtractDatePartExt for PrimitiveArray<DurationMicrosecondType> {
fn date_part(&self, part: DatePart) -> Result<Int32Array, ArrowError> {
match part {
DatePart::Week => {
Ok(self.unary_opt(|d| (d / (1_000_000 * 60 * 60 * 24 * 7)).try_into().ok()))
}
DatePart::Day => {
Ok(self.unary_opt(|d| (d / (1_000_000 * 60 * 60 * 24)).try_into().ok()))
}
DatePart::Hour => Ok(self.unary_opt(|d| (d / (1_000_000 * 60 * 60)).try_into().ok())),
DatePart::Minute => Ok(self.unary_opt(|d| (d / (1_000_000 * 60)).try_into().ok())),
DatePart::Second => Ok(self.unary_opt(|d| (d / 1_000_000).try_into().ok())),
DatePart::Millisecond => Ok(self.unary_opt(|d| (d / 1_000).try_into().ok())),
DatePart::Microsecond => Ok(self.unary_opt(|d| d.try_into().ok())),
DatePart::Nanosecond => {
Ok(self.unary_opt(|d| d.checked_mul(1_000).and_then(|d| d.try_into().ok())))
}

DatePart::Year
| DatePart::Quarter
| DatePart::Month
| DatePart::DayOfWeekSunday0
| DatePart::DayOfWeekMonday0
| DatePart::DayOfYear => {
return_compute_error_with!(format!("{part} does not support"), self.data_type())
}
}
}
}

impl ExtractDatePartExt for PrimitiveArray<DurationNanosecondType> {
fn date_part(&self, part: DatePart) -> Result<Int32Array, ArrowError> {
match part {
DatePart::Week => {
Ok(self.unary_opt(|d| (d / (1_000_000_000 * 60 * 60 * 24 * 7)).try_into().ok()))
}
DatePart::Day => {
Ok(self.unary_opt(|d| (d / (1_000_000_000 * 60 * 60 * 24)).try_into().ok()))
}
DatePart::Hour => {
Ok(self.unary_opt(|d| (d / (1_000_000_000 * 60 * 60)).try_into().ok()))
}
DatePart::Minute => Ok(self.unary_opt(|d| (d / (1_000_000_000 * 60)).try_into().ok())),
DatePart::Second => Ok(self.unary_opt(|d| (d / 1_000_000_000).try_into().ok())),
DatePart::Millisecond => Ok(self.unary_opt(|d| (d / 1_000_000).try_into().ok())),
DatePart::Microsecond => Ok(self.unary_opt(|d| (d / 1_000).try_into().ok())),
DatePart::Nanosecond => Ok(self.unary_opt(|d| d.try_into().ok())),

DatePart::Year
| DatePart::Quarter
| DatePart::Month
| DatePart::DayOfWeekSunday0
| DatePart::DayOfWeekMonday0
| DatePart::DayOfYear => {
return_compute_error_with!(format!("{part} does not support"), self.data_type())
}
}
}
}

macro_rules! return_compute_error_with {
($msg:expr, $param:expr) => {
return { Err(ArrowError::ComputeError(format!("{}: {:?}", $msg, $param))) }
Expand Down Expand Up @@ -1796,4 +1937,148 @@ mod tests {
IntervalMonthDayNano::ZERO,
]));
}

#[test]
fn test_duration_second() {
let input: DurationSecondArray = vec![0, 42, 60 * 60 * 24 + 1].into();

let actual = date_part(&input, DatePart::Second).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(42, actual.value(1));
assert_eq!(60 * 60 * 24 + 1, actual.value(2));

let actual = date_part(&input, DatePart::Millisecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(42_000, actual.value(1));
assert_eq!((60 * 60 * 24 + 1) * 1_000, actual.value(2));

let actual = date_part(&input, DatePart::Microsecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(42_000_000, actual.value(1));
assert_eq!(0, actual.value(2));

let actual = date_part(&input, DatePart::Nanosecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(0, actual.value(1));
assert_eq!(0, actual.value(2));
}

#[test]
fn test_duration_millisecond() {
let input: DurationMillisecondArray = vec![0, 42, 60 * 60 * 24 + 1].into();

let actual = date_part(&input, DatePart::Second).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(0, actual.value(1));
assert_eq!((60 * 60 * 24 + 1) / 1_000, actual.value(2));

let actual = date_part(&input, DatePart::Millisecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(42, actual.value(1));
assert_eq!(60 * 60 * 24 + 1, actual.value(2));

let actual = date_part(&input, DatePart::Microsecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(42_000, actual.value(1));
assert_eq!((60 * 60 * 24 + 1) * 1_000, actual.value(2));

let actual = date_part(&input, DatePart::Nanosecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(42_000_000, actual.value(1));
assert_eq!(0, actual.value(2));
}

#[test]
fn test_duration_microsecond() {
let input: DurationMicrosecondArray = vec![0, 42, 60 * 60 * 24 + 1].into();

let actual = date_part(&input, DatePart::Second).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(0, actual.value(1));
assert_eq!(0, actual.value(2));

let actual = date_part(&input, DatePart::Millisecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(0, actual.value(1));
assert_eq!((60 * 60 * 24 + 1) / 1_000, actual.value(2));

let actual = date_part(&input, DatePart::Microsecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(42, actual.value(1));
assert_eq!(60 * 60 * 24 + 1, actual.value(2));

let actual = date_part(&input, DatePart::Nanosecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(42_000, actual.value(1));
assert_eq!((60 * 60 * 24 + 1) * 1_000, actual.value(2));
}

#[test]
fn test_duration_nanosecond() {
let input: DurationNanosecondArray = vec![0, 42, 60 * 60 * 24 + 1].into();

let actual = date_part(&input, DatePart::Second).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(0, actual.value(1));
assert_eq!(0, actual.value(2));

let actual = date_part(&input, DatePart::Millisecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(0, actual.value(1));
assert_eq!(0, actual.value(2));

let actual = date_part(&input, DatePart::Microsecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(0, actual.value(1));
assert_eq!((60 * 60 * 24 + 1) / 1_000, actual.value(2));

let actual = date_part(&input, DatePart::Nanosecond).unwrap();
let actual = actual.as_primitive::<Int32Type>();
assert_eq!(0, actual.value(0));
assert_eq!(42, actual.value(1));
assert_eq!(60 * 60 * 24 + 1, actual.value(2));
}

#[test]
fn test_duration_invalid_parts() {
fn ensure_returns_error(array: &dyn Array) {
let invalid_parts = [
DatePart::Year,
DatePart::Quarter,
DatePart::Month,
DatePart::DayOfWeekSunday0,
DatePart::DayOfWeekMonday0,
DatePart::DayOfYear,
];

for part in invalid_parts {
let err = date_part(array, part).unwrap_err();
let expected = format!(
"Compute error: {part} does not support: {}",
array.data_type()
);
assert_eq!(expected, err.to_string());
}
}

ensure_returns_error(&DurationSecondArray::from(vec![0]));
ensure_returns_error(&DurationMillisecondArray::from(vec![0]));
ensure_returns_error(&DurationMicrosecondArray::from(vec![0]));
ensure_returns_error(&DurationNanosecondArray::from(vec![0]));
}
}

0 comments on commit d7ad4fe

Please sign in to comment.