From 5ca7c07d9f3caf1a863daaf1c3f322f8b97dd84a Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sat, 21 Dec 2024 18:28:17 +0200 Subject: [PATCH] add specialize code to concat lists to be able to use the concat dictionary logic --- arrow-select/src/concat.rs | 55 ++++++++++++++------------------------ 1 file changed, 20 insertions(+), 35 deletions(-) diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 9dbae1f646c..91ce16e062a 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -130,7 +130,7 @@ fn concat_dictionaries( Ok(Arc::new(array)) } -fn concat_list_of_dictionaries( +fn concat_lists( arrays: &[&dyn Array], ) -> Result { let mut output_len = 0; @@ -145,11 +145,6 @@ fn concat_list_of_dictionaries>(); - let dictionaries: Vec<_> = lists - .iter() - .map(|x| x.values().as_ref()) - .collect(); - let lists_nulls = list_has_nulls.then(|| { let mut nulls = BooleanBufferBuilder::new(output_len); for l in &lists { @@ -161,7 +156,12 @@ fn concat_list_of_dictionaries(dictionaries.as_slice())?; + let values = lists + .iter() + .map(|x| x.values().as_ref()) + .collect::>(); + + let concatenated_values = concat(values.as_slice())?; // Merge value offsets from the lists let value_offset_buffer = @@ -175,7 +175,7 @@ fn concat_list_of_dictionaries { - return Ok(Arc::new(concat_list_of_dictionaries::<$o, $t>($arrays)?) as _) - }; -} - fn get_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capacities { match data_type { DataType::Utf8 => binary_capacity::(arrays), @@ -224,29 +218,20 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { "It is not possible to concatenate arrays of different data types.".to_string(), )); } - if let DataType::Dictionary(k, _) = d { - downcast_integer! { - k.as_ref() => (dict_helper, arrays), - _ => unreachable!("illegal dictionary key type {k}") - }; - } else { - if let DataType::List(field) = d { - if let DataType::Dictionary(k, _) = field.data_type() { - downcast_integer! { - k.as_ref() => (list_dict_helper, i32, arrays), - _ => unreachable!("illegal dictionary key type {k}") - }; - } - } else if let DataType::LargeList(field) = d { - if let DataType::Dictionary(k, _) = field.data_type() { - downcast_integer! { - k.as_ref() => (list_dict_helper, i64, arrays), - _ => unreachable!("illegal dictionary key type {k}") - }; + + match d { + DataType::Dictionary(k, _) => { + downcast_integer! { + k.as_ref() => (dict_helper, arrays), + _ => unreachable!("illegal dictionary key type {k}") } } - let capacity = get_capacity(arrays, d); - concat_fallback(arrays, capacity) + DataType::List(_) => concat_lists::(arrays), + DataType::LargeList(_) => concat_lists::(arrays), + _ => { + let capacity = get_capacity(arrays, d); + concat_fallback(arrays, capacity) + } } }