Skip to content

Commit

Permalink
add specialize code to concat lists to be able to use the concat dict…
Browse files Browse the repository at this point in the history
…ionary logic
  • Loading branch information
rluvaton committed Dec 21, 2024
1 parent 87c2865 commit 5ca7c07
Showing 1 changed file with 20 additions and 35 deletions.
55 changes: 20 additions & 35 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(
Ok(Arc::new(array))
}

fn concat_list_of_dictionaries<OffsetSize: OffsetSizeTrait, K: ArrowDictionaryKeyType>(
fn concat_lists<OffsetSize: OffsetSizeTrait>(
arrays: &[&dyn Array],
) -> Result<ArrayRef, ArrowError> {
let mut output_len = 0;
Expand All @@ -145,11 +145,6 @@ fn concat_list_of_dictionaries<OffsetSize: OffsetSizeTrait, K: ArrowDictionaryKe
})
.collect::<Vec<_>>();

let dictionaries: Vec<_> = lists
.iter()
.map(|x| x.values().as_ref())
.collect();

let lists_nulls = list_has_nulls.then(|| {
let mut nulls = BooleanBufferBuilder::new(output_len);
for l in &lists {
Expand All @@ -161,7 +156,12 @@ fn concat_list_of_dictionaries<OffsetSize: OffsetSizeTrait, K: ArrowDictionaryKe
NullBuffer::new(nulls.finish())
});

let concat_dictionaries = concat_dictionaries::<K>(dictionaries.as_slice())?;
let values = lists
.iter()
.map(|x| x.values().as_ref())
.collect::<Vec<_>>();

let concatenated_values = concat(values.as_slice())?;

// Merge value offsets from the lists
let value_offset_buffer =
Expand All @@ -175,7 +175,7 @@ fn concat_list_of_dictionaries<OffsetSize: OffsetSizeTrait, K: ArrowDictionaryKe
// `GenericListArray` must only have 1 buffer
.buffers(vec![value_offset_buffer])
// `GenericListArray` must only have 1 child_data
.child_data(vec![concat_dictionaries.to_data()]);
.child_data(vec![concatenated_values.to_data()]);

// TODO - maybe use build_unchecked?
let array_data = builder.build()?;
Expand All @@ -190,12 +190,6 @@ macro_rules! dict_helper {
};
}

macro_rules! list_dict_helper {
($t:ty, $o: ty, $arrays:expr) => {
return Ok(Arc::new(concat_list_of_dictionaries::<$o, $t>($arrays)?) as _)
};
}

fn get_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capacities {
match data_type {
DataType::Utf8 => binary_capacity::<Utf8Type>(arrays),
Expand Down Expand Up @@ -224,29 +218,20 @@ pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
"It is not possible to concatenate arrays of different data types.".to_string(),
));
}
if let DataType::Dictionary(k, _) = d {
downcast_integer! {
k.as_ref() => (dict_helper, arrays),
_ => unreachable!("illegal dictionary key type {k}")
};
} else {
if let DataType::List(field) = d {
if let DataType::Dictionary(k, _) = field.data_type() {
downcast_integer! {
k.as_ref() => (list_dict_helper, i32, arrays),
_ => unreachable!("illegal dictionary key type {k}")
};
}
} else if let DataType::LargeList(field) = d {
if let DataType::Dictionary(k, _) = field.data_type() {
downcast_integer! {
k.as_ref() => (list_dict_helper, i64, arrays),
_ => unreachable!("illegal dictionary key type {k}")
};

match d {
DataType::Dictionary(k, _) => {
downcast_integer! {
k.as_ref() => (dict_helper, arrays),
_ => unreachable!("illegal dictionary key type {k}")
}
}
let capacity = get_capacity(arrays, d);
concat_fallback(arrays, capacity)
DataType::List(_) => concat_lists::<i32>(arrays),
DataType::LargeList(_) => concat_lists::<i64>(arrays),
_ => {
let capacity = get_capacity(arrays, d);
concat_fallback(arrays, capacity)
}
}
}

Expand Down

0 comments on commit 5ca7c07

Please sign in to comment.