diff --git a/arrow-array/src/array/fixed_size_list_array.rs b/arrow-array/src/array/fixed_size_list_array.rs index 72855cef1f04..a972eb32fe80 100644 --- a/arrow-array/src/array/fixed_size_list_array.rs +++ b/arrow-array/src/array/fixed_size_list_array.rs @@ -159,8 +159,7 @@ impl FixedSizeListArray { if let Some(n) = nulls.as_ref() { if n.len() != len { return Err(ArrowError::InvalidArgumentError(format!( - "Incorrect length of null buffer for FixedSizeListArray, expected {} got {}", - len, + "Incorrect length of null buffer for FixedSizeListArray, expected {len} got {}", n.len(), ))); } @@ -521,7 +520,9 @@ mod tests { } #[test] - #[should_panic(expected = "assertion failed: (offset + length) <= self.len()")] + #[should_panic( + expected = "Attempting to slice an array with offset 0, len 9 when self.len is 8" + )] // Different error messages, so skip for now // https://github.com/apache/arrow-rs/issues/1545 #[cfg(not(feature = "force_validate"))] diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index 81c8cdcea4d3..e1d532ff4003 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -254,14 +254,19 @@ impl RunArray { } impl From for RunArray { - // The method assumes the caller already validated the data using `ArrayData::validate_data()` fn from(data: ArrayData) -> Self { - match data.data_type() { - DataType::RunEndEncoded(_, _) => {} - _ => { - panic!("Invalid data type for RunArray. The data type should be DataType::RunEndEncoded"); - } - } + Self::from(&data) + } +} + +impl From<&ArrayData> for RunArray { + // The method assumes the caller already validated the data using `ArrayData::validate_data()` + fn from(data: &ArrayData) -> Self { + let DataType::RunEndEncoded(_, _) = data.data_type() else { + panic!( + "Invalid data type for RunArray. The data type should be DataType::RunEndEncoded" + ); + }; // Safety // ArrayData is valid diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index 8d1a46583fca..816ea1ad8371 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -261,11 +261,11 @@ impl Buffer { } /// Returns a slice of this buffer starting at a certain bit offset. - /// If the offset is byte-aligned the returned buffer is a shallow clone, + /// If the offset and length are byte-aligned the returned buffer is a shallow clone, /// otherwise a new buffer is allocated and filled with a copy of the bits in the range. pub fn bit_slice(&self, offset: usize, len: usize) -> Self { - if offset % 8 == 0 { - return self.slice(offset / 8); + if offset % 8 == 0 && len % 8 == 0 { + return self.slice_with_length(offset / 8, len / 8); } bitwise_unary_op_helper(self, offset, len, |a| a) diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index 8af2a91cf159..e3a85d8f1564 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -423,72 +423,98 @@ impl ArrayData { size } - /// Returns the total number of the bytes of memory occupied by - /// the buffers by this slice of [`ArrayData`] (See also diagram on [`ArrayData`]). + /// Returns the total number of the bytes of memory occupied by the buffers by this slice of + /// [`ArrayData`] (See also diagram on [`ArrayData`]). /// - /// This is approximately the number of bytes if a new - /// [`ArrayData`] was formed by creating new [`Buffer`]s with - /// exactly the data needed. + /// This is approximately the number of bytes if a new [`ArrayData`] was formed by creating new + /// [`Buffer`]s with exactly the data needed. /// - /// For example, a [`DataType::Int64`] with `100` elements, - /// [`Self::get_slice_memory_size`] would return `100 * 8 = 800`. If - /// the [`ArrayData`] was then [`Self::slice`]ed to refer to its - /// first `20` elements, then [`Self::get_slice_memory_size`] on the - /// sliced [`ArrayData`] would return `20 * 8 = 160`. - pub fn get_slice_memory_size(&self) -> Result { - let mut result: usize = 0; + /// For example, a [`DataType::Int64`] with `100` elements, [`Self::get_slice_memory_size`] + /// would return `100 * 8 = 800`. If the [`ArrayData`] was then [`Self::slice`]d to refer to + /// its first `20` elements, then [`Self::get_slice_memory_size`] on the sliced [`ArrayData`] + /// would return `20 * 8 = 160`. + /// + /// The `alignment` parameter is used to add padding to each buffer being counted, to ensure + /// the size for each one is aligned to `alignment` bytes (if it is `Some`). This function + /// assumes that `alignment` is a power of 2. + pub fn get_slice_memory_size_with_alignment( + &self, + alignment: Option, + ) -> Result { + // Note: This accounts for data used by the Dictionary DataType that isn't actually encoded + // as a part of `write_array_data` in arrow-ipc - specifically, the `values` part of + // each Dictionary are encoded in the `child_data` of the `ArrayData` it produces, but (for + // some reason that I don't fully understand) it doesn't encode those values. hmm. let layout = layout(&self.data_type); - for spec in layout.buffers.iter() { - match spec { + // Just pulled from arrow-ipc + #[inline] + fn pad_to_alignment(alignment: u8, len: usize) -> usize { + let a = usize::from(alignment.saturating_sub(1)); + ((len + a) & !a) - len + } + + let mut result = layout.buffers.iter().map(|spec| { + let size = match spec { BufferSpec::FixedWidth { byte_width, .. } => { - let buffer_size = self.len.checked_mul(*byte_width).ok_or_else(|| { + let num_elems = match self.data_type { + // On these offsets-plus-values datatypes, their offset buffer (which is + // FixedWidth and thus the one we're looking at right now in this + // FixedWidth arm) contains self.len + 1 elements due to the way the + // offsets are encoded as overlapping pairs of (start, (end+start), + // (end+start), etc). + DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => self.len + 1, + _ => self.len + }; + + num_elems.checked_mul(*byte_width).ok_or_else(|| { ArrowError::ComputeError( "Integer overflow computing buffer size".to_string(), ) - })?; - result += buffer_size; - } - BufferSpec::VariableWidth => { - let buffer_len: usize; - match self.data_type { - DataType::Utf8 | DataType::Binary => { - let offsets = self.typed_offsets::()?; - buffer_len = (offsets[self.len] - offsets[0] ) as usize; - } - DataType::LargeUtf8 | DataType::LargeBinary => { - let offsets = self.typed_offsets::()?; - buffer_len = (offsets[self.len] - offsets[0]) as usize; - } - _ => { - return Err(ArrowError::NotYetImplemented(format!( - "Invalid data type for VariableWidth buffer. Expected Utf8, LargeUtf8, Binary or LargeBinary. Got {}", - self.data_type - ))) - } - }; - result += buffer_len; + }) } - BufferSpec::BitMap => { - let buffer_size = bit_util::ceil(self.len, 8); - result += buffer_size; - } - BufferSpec::AlwaysNull => { - // Nothing to do + BufferSpec::VariableWidth => match &self.data_type { + // UTF8 and Binary have two buffers - one for the offsets, one for the values. + // When calculating size, the offset buffer's size is calculated by the + // FixedWidth buffer arm above, so we just need to count the offsets here. + DataType::Utf8 | DataType::Binary => { + self.typed_offsets::() + .map(|off| (off[self.len] - off[0]) as usize) + } + DataType::LargeUtf8 | DataType::LargeBinary => { + self.typed_offsets::() + .map(|off| (off[self.len] - off[0]) as usize) + } + dt => Err(ArrowError::NotYetImplemented(format!( + "Invalid data type for VariableWidth buffer. Expected Utf8, LargeUtf8, Binary or LargeBinary. Got {dt}", + ))), } - } - } + BufferSpec::BitMap => Ok(bit_util::ceil(self.len, 8)), + // Nothing to do when AlwaysNull + BufferSpec::AlwaysNull => Ok(0) + }?; + + Ok(size + alignment.map_or(0, |a| pad_to_alignment(a, size))) + }).sum::>()?; if self.nulls().is_some() { - result += bit_util::ceil(self.len, 8); + let null_len = bit_util::ceil(self.len, 8); + result += null_len + alignment.map_or(0, |a| pad_to_alignment(a, null_len)); } for child in &self.child_data { - result += child.get_slice_memory_size()?; + result += child.get_slice_memory_size_with_alignment(alignment)?; } + Ok(result) } + /// Equivalent to calling [`Self::get_slice_memory_size_with_alignment()`] with `None` for the + /// alignment + pub fn get_slice_memory_size(&self) -> Result { + self.get_slice_memory_size_with_alignment(None) + } + /// Returns the total number of bytes of memory occupied /// physically by this [`ArrayData`] and all its [`Buffer`]s and /// children. (See also diagram on [`ArrayData`]). @@ -523,15 +549,16 @@ impl ArrayData { /// /// Panics if `offset + length > self.len()`. pub fn slice(&self, offset: usize, length: usize) -> ArrayData { - assert!((offset + length) <= self.len()); + if (offset + length) > self.len() { + panic!("Attempting to slice an array with offset {offset}, len {length} when self.len is {}", self.len); + } if let DataType::Struct(_) = self.data_type() { // Slice into children - let new_offset = self.offset + offset; - let new_data = ArrayData { + ArrayData { data_type: self.data_type().clone(), len: length, - offset: new_offset, + offset: self.offset + offset, buffers: self.buffers.clone(), // Slice child data, to propagate offsets down to them child_data: self @@ -540,9 +567,7 @@ impl ArrayData { .map(|data| data.slice(offset, length)) .collect(), nulls: self.nulls.as_ref().map(|x| x.slice(offset, length)), - }; - - new_data + } } else { let mut new_data = self.clone(); @@ -888,7 +913,7 @@ impl ArrayData { ))); } - Ok(&buffer.typed_data::()[self.offset..self.offset + len]) + Ok(&buffer.typed_data::()[self.offset..][..len]) } /// Does a cheap sanity check that the `self.len` values in `buffer` are valid @@ -2186,11 +2211,11 @@ mod tests { ) .unwrap(); let string_data_slice = string_data.slice(1, 2); + + let data_len = string_data.get_slice_memory_size().unwrap(); + let slice_len = string_data_slice.get_slice_memory_size().unwrap(); //4 bytes of offset and 2 bytes of data reduced by slicing. - assert_eq!( - string_data.get_slice_memory_size().unwrap() - 6, - string_data_slice.get_slice_memory_size().unwrap() - ); + assert_eq!(data_len - 6, slice_len); } #[test] diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index acfbd9b53030..53db10f9ce45 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -327,6 +327,10 @@ impl FlightDataEncoder { /// Encodes batch into one or more `FlightData` messages in self.queue fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> { + if batch.num_rows() == 0 { + return Ok(()); + } + let schema = match &self.schema { Some(schema) => schema.clone(), // encode the schema if this is the first time we have seen it @@ -338,12 +342,12 @@ impl FlightDataEncoder { DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?, }; - for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) { - let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?; + let (flight_dictionaries, flight_batches) = self + .encoder + .encode_batch(&batch, self.max_flight_data_size)?; - self.queue_messages(flight_dictionaries); - self.queue_message(flight_batch); - } + self.queue_messages(flight_dictionaries); + self.queue_messages(flight_batches); Ok(()) } @@ -563,38 +567,6 @@ fn prepare_schema_for_flight( Schema::new(fields).with_metadata(schema.metadata().clone()) } -/// Split [`RecordBatch`] so it hopefully fits into a gRPC response. -/// -/// Data is zero-copy sliced into batches. -/// -/// Note: this method does not take into account already sliced -/// arrays: -fn split_batch_for_grpc_response( - batch: RecordBatch, - max_flight_data_size: usize, -) -> Vec { - let size = batch - .columns() - .iter() - .map(|col| col.get_buffer_memory_size()) - .sum::(); - - let n_batches = - (size / max_flight_data_size + usize::from(size % max_flight_data_size != 0)).max(1); - let rows_per_batch = (batch.num_rows() / n_batches).max(1); - let mut out = Vec::with_capacity(n_batches + 1); - - let mut offset = 0; - while offset < batch.num_rows() { - let length = (rows_per_batch).min(batch.num_rows() - offset); - out.push(batch.slice(offset, length)); - - offset += length; - } - - out -} - /// The data needed to encode a stream of flight data, holding on to /// shared Dictionaries. /// @@ -626,16 +598,23 @@ impl FlightIpcEncoder { } /// Convert a `RecordBatch` to a Vec of `FlightData` representing - /// dictionaries and a `FlightData` representing the batch - fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec, FlightData)> { - let (encoded_dictionaries, encoded_batch) = - self.data_gen - .encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?; + /// dictionaries and a Vec of `FlightData`s representing the batch + fn encode_batch( + &mut self, + batch: &RecordBatch, + max_flight_data_size: usize, + ) -> Result<(Vec, Vec)> { + let (encoded_dictionaries, encoded_batches) = self.data_gen.encoded_batch_with_size( + batch, + &mut self.dictionary_tracker, + &self.options, + max_flight_data_size, + )?; let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); - let flight_batch = encoded_batch.into(); + let flight_batches = encoded_batches.into_iter().map(Into::into).collect(); - Ok((flight_dictionaries, flight_batch)) + Ok((flight_dictionaries, flight_batches)) } } @@ -684,7 +663,10 @@ fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result fn test_encode_flight_data() { @@ -711,12 +696,13 @@ mod tests { .expect("cannot create record batch"); let schema = batch.schema_ref(); - let (_, baseline_flight_batch) = make_flight_data(&batch, &options); + let (_, baseline_flight_batch) = utils::flight_data_from_arrow_batch(&batch, &options); let big_batch = batch.slice(0, batch.num_rows() - 1); let optimized_big_batch = hydrate_dictionaries(&big_batch, Arc::clone(schema)).expect("failed to optimize"); - let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options); + let (_, optimized_big_flight_batch) = + utils::flight_data_from_arrow_batch(&optimized_big_batch, &options); assert_eq!( baseline_flight_batch.data_body.len(), @@ -726,7 +712,8 @@ mod tests { let small_batch = batch.slice(0, 1); let optimized_small_batch = hydrate_dictionaries(&small_batch, Arc::clone(schema)).expect("failed to optimize"); - let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options); + let (_, optimized_small_flight_batch) = + utils::flight_data_from_arrow_batch(&optimized_small_batch, &options); assert!( baseline_flight_batch.data_body.len() > optimized_small_flight_batch.data_body.len() @@ -999,12 +986,16 @@ mod tests { ))], ); - struct_builder.field_builder::>>>(0).unwrap().append_value(vec![Some("a"), None, Some("b")]); + struct_builder.field_builder::>>>(0) + .unwrap() + .append_value(vec![Some("a"), None, Some("b")]); struct_builder.append(true); let arr1 = struct_builder.finish(); - struct_builder.field_builder::>>>(0).unwrap().append_value(vec![Some("c"), None, Some("d")]); + struct_builder.field_builder::>>>(0) + .unwrap() + .append_value(vec![Some("c"), None, Some("d")]); struct_builder.append(true); let arr2 = struct_builder.finish(); @@ -1212,6 +1203,11 @@ mod tests { .into_iter() .collect::(); + let mut field_types = union_fields.iter().map(|(_, field)| field.data_type()); + let dict_list_ty = field_types.next().unwrap(); + let struct_ty = field_types.next().unwrap(); + let string_ty = field_types.next().unwrap(); + let struct_fields = vec![Field::new_list( "dict_list", Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), @@ -1230,9 +1226,9 @@ mod tests { type_id_buffer, None, vec![ - Arc::new(arr1) as Arc, - new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), - new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + Arc::new(arr1), + new_null_array(struct_ty, 1), + new_null_array(string_ty, 1), ], ) .unwrap(); @@ -1248,9 +1244,9 @@ mod tests { type_id_buffer, None, vec![ - new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), + new_null_array(dict_list_ty, 1), Arc::new(arr2), - new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1), + new_null_array(string_ty, 1), ], ) .unwrap(); @@ -1261,8 +1257,8 @@ mod tests { type_id_buffer, None, vec![ - new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1), - new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1), + new_null_array(dict_list_ty, 1), + new_null_array(struct_ty, 1), Arc::new(StringArray::from(vec!["e"])), ], ) @@ -1485,34 +1481,40 @@ mod tests { hydrate_dictionaries(&batch, batch.schema()).expect("failed to optimize"); } - pub fn make_flight_data( - batch: &RecordBatch, - options: &IpcWriteOptions, - ) -> (Vec, FlightData) { - let data_gen = IpcDataGenerator::default(); - let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); - - let (encoded_dictionaries, encoded_batch) = data_gen - .encoded_batch(batch, &mut dictionary_tracker, options) - .expect("DictionaryTracker configured above to not error on replacement"); - - let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); - let flight_batch = encoded_batch.into(); - - (flight_dictionaries, flight_batch) - } + #[tokio::test] + async fn test_split_batch_for_grpc_response() { + async fn get_decoded(schema: SchemaRef, encoded: Vec) -> Vec { + FlightDataDecoder::new(futures::stream::iter( + std::iter::once(SchemaAsIpc::new(&schema, &IpcWriteOptions::default()).into()) + .chain(encoded.into_iter().map(FlightData::from)) + .map(Ok), + )) + .collect::>>() + .await + .into_iter() + .map(|r| r.unwrap()) + .filter_map(|data| match data.payload { + DecodedPayload::RecordBatch(rb) => Some(rb), + _ => None, + }) + .collect() + } - #[test] - fn test_split_batch_for_grpc_response() { let max_flight_data_size = 1024; + let write_opts = IpcWriteOptions::default(); + let mut dict_tracker = DictionaryTracker::new(false); + let gen = IpcDataGenerator {}; // no split let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) .expect("cannot create record batch"); - let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size); + let split = gen + .encoded_batch_with_size(&batch, &mut dict_tracker, &write_opts, max_flight_data_size) + .unwrap() + .1; assert_eq!(split.len(), 1); - assert_eq!(batch, split[0]); + assert_eq!(batch, get_decoded(batch.schema(), split).await[0]); // split once let n_rows = max_flight_data_size + 1; @@ -1520,58 +1522,21 @@ mod tests { let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::>()); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) .expect("cannot create record batch"); - let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size); - assert_eq!(split.len(), 3); + let split = gen + .encoded_batch_with_size(&batch, &mut dict_tracker, &write_opts, max_flight_data_size) + .unwrap() + .1; + assert_eq!(split.len(), 2); + let batches = get_decoded(batch.schema(), split).await; assert_eq!( - split.iter().map(|batch| batch.num_rows()).sum::(), + batches.iter().map(RecordBatch::num_rows).sum::(), n_rows ); - let a = pretty_format_batches(&split).unwrap().to_string(); + let a = pretty_format_batches(&batches).unwrap().to_string(); let b = pretty_format_batches(&[batch]).unwrap().to_string(); assert_eq!(a, b); } - #[test] - fn test_split_batch_for_grpc_response_sizes() { - // 2000 8 byte entries into 2k pieces: 8 chunks of 250 rows - verify_split(2000, 2 * 1024, vec![250, 250, 250, 250, 250, 250, 250, 250]); - - // 2000 8 byte entries into 4k pieces: 4 chunks of 500 rows - verify_split(2000, 4 * 1024, vec![500, 500, 500, 500]); - - // 2023 8 byte entries into 3k pieces does not divide evenly - verify_split(2023, 3 * 1024, vec![337, 337, 337, 337, 337, 337, 1]); - - // 10 8 byte entries into 1 byte pieces means each rows gets its own - verify_split(10, 1, vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1]); - - // 10 8 byte entries into 1k byte pieces means one piece - verify_split(10, 1024, vec![10]); - } - - /// Creates a UInt64Array of 8 byte integers with input_rows rows - /// `max_flight_data_size_bytes` pieces and verifies the row counts in - /// those pieces - fn verify_split( - num_input_rows: u64, - max_flight_data_size_bytes: usize, - expected_sizes: Vec, - ) { - let array: UInt64Array = (0..num_input_rows).collect(); - - let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)]) - .expect("cannot create record batch"); - - let input_rows = batch.num_rows(); - - let split = split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes); - let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect(); - let output_rows: usize = sizes.iter().sum(); - - assert_eq!(sizes, expected_sizes, "mismatch for {batch:?}"); - assert_eq!(input_rows, output_rows, "mismatch for {batch:?}"); - } - // test sending record batches // test sending record batches with multiple different dictionaries @@ -1582,7 +1547,7 @@ mod tests { let s2 = StringArray::from_iter_values(std::iter::repeat("6bytes").take(1024)); let i2 = Int64Array::from_iter_values(0..1024); - let batch = RecordBatch::try_from_iter(vec![ + let batch = RecordBatch::try_from_iter([ ("s1", Arc::new(s1) as _), ("i1", Arc::new(i1) as _), ("s2", Arc::new(s2) as _), @@ -1590,18 +1555,16 @@ mod tests { ]) .unwrap(); - verify_encoded_split(batch, 112).await; + verify_encoded_split_no_overage(batch).await; } #[tokio::test] async fn flight_data_size_uneven_variable_lengths() { // each row has a longer string than the last with increasing lengths 0 --> 1024 let array = StringArray::from_iter_values((0..1024).map(|i| "*".repeat(i))); - let batch = RecordBatch::try_from_iter(vec![("data", Arc::new(array) as _)]).unwrap(); + let batch = RecordBatch::try_from_iter([("data", Arc::new(array) as _)]).unwrap(); - // overage is much higher than ideal - // https://github.com/apache/arrow-rs/issues/3478 - verify_encoded_split(batch, 4304).await; + verify_encoded_split_no_overage(batch).await; } #[tokio::test] @@ -1634,10 +1597,7 @@ mod tests { ]) .unwrap(); - // 5k over limit (which is 2x larger than limit of 5k) - // overage is much higher than ideal - // https://github.com/apache/arrow-rs/issues/3478 - verify_encoded_split(batch, 5800).await; + verify_encoded_split_no_overage(batch).await; } #[tokio::test] @@ -1653,7 +1613,7 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); - verify_encoded_split(batch, 160).await; + verify_encoded_split_no_overage(batch).await; } #[tokio::test] @@ -1665,9 +1625,7 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); - // overage is much higher than ideal - // https://github.com/apache/arrow-rs/issues/3478 - verify_encoded_split(batch, 3328).await; + verify_encoded_split_no_overage(batch).await; } #[tokio::test] @@ -1679,9 +1637,7 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); - // overage is much higher than ideal - // https://github.com/apache/arrow-rs/issues/3478 - verify_encoded_split(batch, 5280).await; + verify_encoded_split_no_overage(batch).await; } #[tokio::test] @@ -1704,9 +1660,7 @@ mod tests { ]) .unwrap(); - // overage is much higher than ideal - // https://github.com/apache/arrow-rs/issues/3478 - verify_encoded_split(batch, 4128).await; + verify_encoded_split_no_overage(batch).await; } /// Return size, in memory of flight data @@ -1726,59 +1680,49 @@ mod tests { /// Coverage for /// - /// Encodes the specified batch using several values of - /// `max_flight_data_size` between 1K to 5K and ensures that the - /// resulting size of the flight data stays within the limit - /// + `allowed_overage` - /// - /// `allowed_overage` is how far off the actual data encoding is - /// from the target limit that was set. It is an improvement when - /// the allowed_overage decreses. - /// - /// Note this overhead will likely always be greater than zero to - /// account for encoding overhead such as IPC headers and padding. + /// Encodes the specified batch using several values of `max_flight_data_size` between 1K to 5K + /// and ensures that the resulting size of the flight data stays within the limit, except for + /// in cases where only 1 row is sent - if only 1 row is sent, then we know that there was no + /// way to keep the data within the limit (since the minimum possible amount of data was sent), + /// so we allow it to go over. /// - /// - async fn verify_encoded_split(batch: RecordBatch, allowed_overage: usize) { + async fn verify_encoded_split_no_overage(batch: RecordBatch) { let num_rows = batch.num_rows(); - // Track the overall required maximum overage - let mut max_overage_seen = 0; - for max_flight_data_size in [1024, 2021, 5000] { println!("Encoding {num_rows} with a maximum size of {max_flight_data_size}"); - let mut stream = FlightDataEncoderBuilder::new() + let stream = FlightDataEncoderBuilder::new() .with_max_flight_data_size(max_flight_data_size) // use 8-byte alignment - default alignment is 64 which produces bigger ipc data .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()) .build(futures::stream::iter([Ok(batch.clone())])); + let mut stream = FlightDataDecoder::new(stream); + let mut i = 0; while let Some(data) = stream.next().await.transpose().unwrap() { - let actual_data_size = flight_data_size(&data); + let actual_data_size = flight_data_size(&data.inner); let actual_overage = actual_data_size.saturating_sub(max_flight_data_size); - assert!( - actual_overage <= allowed_overage, - "encoded data[{i}]: actual size {actual_data_size}, \ - actual_overage: {actual_overage} \ - allowed_overage: {allowed_overage}" - ); + let is_1_row = + matches!(data.payload, DecodedPayload::RecordBatch(rb) if rb.num_rows() == 1); + + // If only 1 row was sent over via this recordBatch, there was no way to avoid + // going over the limit. There's currently no mechanism for splitting a single row + // of results over multiple messages, so we allow going over the limit if it's the + // bare minimum over (1 row) + if !is_1_row { + assert_eq!( + actual_overage, + 0, + "encoded data[{i}]: actual size {actual_data_size}, actual_overage: {actual_overage}" + ); + } i += 1; - - max_overage_seen = max_overage_seen.max(actual_overage) } } - - // ensure that the specified overage is exactly the maxmium so - // that when the splitting logic improves, the tests must be - // updated to reflect the better logic - assert_eq!( - allowed_overage, max_overage_seen, - "Specified overage was too high" - ); } } diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index f6129ddfe248..92d0dcfb809b 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -38,18 +38,34 @@ pub fn flight_data_from_arrow_batch( batch: &RecordBatch, options: &IpcWriteOptions, ) -> (Vec, FlightData) { + let (flight_dictionaries, mut flight_batches) = _flight_data_from_arrow_batch(batch, options); + + assert_eq!( + flight_batches.len(), + 1, + "encoded_batch with a max size of usize::MAX should not be able to return more or less than 1 batch" + ); + let flight_batch = flight_batches.pop().unwrap(); + + (flight_dictionaries, flight_batch) +} + +fn _flight_data_from_arrow_batch( + batch: &RecordBatch, + options: &IpcWriteOptions, +) -> (Vec, Vec) { let data_gen = writer::IpcDataGenerator::default(); let mut dictionary_tracker = writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); - let (encoded_dictionaries, encoded_batch) = data_gen - .encoded_batch(batch, &mut dictionary_tracker, options) + let (encoded_dictionaries, encoded_batches) = data_gen + .encoded_batch_with_size(batch, &mut dictionary_tracker, options, usize::MAX) .expect("DictionaryTracker configured above to not error on replacement"); let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); - let flight_batch = encoded_batch.into(); + let flight_batches = encoded_batches.into_iter().map(Into::into).collect(); - (flight_dictionaries, flight_batch) + (flight_dictionaries, flight_batches) } /// Convert a slice of wire protocol `FlightData`s into a vector of `RecordBatch`es @@ -150,15 +166,15 @@ pub fn batches_to_flight_data( let mut flight_data = vec![]; let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = + let mut dict_tracker = writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); for batch in batches.iter() { let (encoded_dictionaries, encoded_batch) = - data_gen.encoded_batch(batch, &mut dictionary_tracker, &options)?; + data_gen.encoded_batch_with_size(batch, &mut dict_tracker, &options, usize::MAX)?; dictionaries.extend(encoded_dictionaries.into_iter().map(Into::into)); - flight_data.push(encoded_batch.into()); + flight_data.extend(encoded_batch.into_iter().map(Into::into)); } let mut stream = Vec::with_capacity(1 + dictionaries.len() + flight_data.len()); diff --git a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index c8289ff446a0..7882e6539c46 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -125,16 +125,9 @@ async fn send_batch( batch: &RecordBatch, options: &writer::IpcWriteOptions, ) -> Result { - let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = writer::DictionaryTracker::new_with_preserve_dict_id(false, true); - - let (encoded_dictionaries, encoded_batch) = data_gen - .encoded_batch(batch, &mut dictionary_tracker, options) - .expect("DictionaryTracker configured above to not error on replacement"); - - let dictionary_flight_data: Vec = - encoded_dictionaries.into_iter().map(Into::into).collect(); - let mut batch_flight_data: FlightData = encoded_batch.into(); + #[allow(deprecated)] + let (dictionary_flight_data, mut batch_flight_data) = + arrow_flight::utils::flight_data_from_arrow_batch(batch, options); upload_tx .send_all(&mut stream::iter(dictionary_flight_data).map(Ok)) diff --git a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs index 0f404b2ae289..524339ecc389 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs @@ -27,7 +27,7 @@ use arrow::{ buffer::Buffer, datatypes::Schema, datatypes::SchemaRef, - ipc::{self, reader, writer}, + ipc::{self, reader}, record_batch::RecordBatch, }; use arrow_flight::{ @@ -127,22 +127,16 @@ impl FlightService for FlightServiceImpl { .iter() .enumerate() .flat_map(|(counter, batch)| { - let data_gen = writer::IpcDataGenerator::default(); - let mut dictionary_tracker = - writer::DictionaryTracker::new_with_preserve_dict_id(false, true); - - let (encoded_dictionaries, encoded_batch) = data_gen - .encoded_batch(batch, &mut dictionary_tracker, &options) - .expect("DictionaryTracker configured above to not error on replacement"); - - let dictionary_flight_data = encoded_dictionaries.into_iter().map(Into::into); - let mut batch_flight_data: FlightData = encoded_batch.into(); + #[allow(deprecated)] + let (dictionary_flight_data, mut batch_flight_data) = + arrow_flight::utils::flight_data_from_arrow_batch(batch, &options); // Only the record batch's FlightData gets app_metadata let metadata = counter.to_string().into(); batch_flight_data.app_metadata = metadata; dictionary_flight_data + .into_iter() .chain(std::iter::once(batch_flight_data)) .map(Ok) }); diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 0820e3590827..d9d157b72606 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -1058,7 +1058,7 @@ impl FileReader { /// Try to create a new file reader. /// /// There is no internal buffering. If buffered reads are needed you likely want to use - /// [`FileReader::try_new_buffered`] instead. + /// [`FileReader::try_new_buffered`] instead. /// /// # Errors /// @@ -1785,7 +1785,7 @@ mod tests { // can be compared as such. assert_eq!(input_batch.column(1), output_batch.column(1)); - let run_array_1_unsliced = unslice_run_array(run_array_1_sliced.into_data()).unwrap(); + let run_array_1_unsliced = unslice_run_array(&run_array_1_sliced.into_data()).unwrap(); assert_eq!(run_array_1_unsliced, output_batch.column(0).into_data()); } diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index b5c4dd95ed9f..181baae3a5f4 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -20,12 +20,15 @@ //! The `FileWriter` and `StreamWriter` have similar interfaces, //! however the `FileWriter` expects a reader that supports `Seek`ing -use std::cmp::min; -use std::collections::HashMap; -use std::io::{BufWriter, Write}; -use std::sync::Arc; +use std::{ + borrow::Borrow, + cmp::min, + collections::HashMap, + io::{BufWriter, Write}, + sync::Arc, +}; -use flatbuffers::FlatBufferBuilder; +use flatbuffers::{FlatBufferBuilder, UnionWIPOffset, WIPOffset}; use arrow_array::builder::BufferBuilder; use arrow_array::cast::*; @@ -36,9 +39,11 @@ use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; use arrow_data::{layout, ArrayData, ArrayDataBuilder, BufferSpec}; use arrow_schema::*; -use crate::compression::CompressionCodec; -use crate::convert::IpcSchemaEncoder; -use crate::CONTINUATION_MARKER; +use crate::{ + compression::CompressionCodec, convert::IpcSchemaEncoder, BodyCompressionBuilder, + BodyCompressionMethod, DictionaryBatchBuilder, MessageBuilder, MessageHeader, RecordBatchArgs, + CONTINUATION_MARKER, +}; /// IPC write options used to control the behaviour of the [`IpcDataGenerator`] #[derive(Debug, Clone)] @@ -157,7 +162,7 @@ impl IpcWriteOptions { impl Default for IpcWriteOptions { fn default() -> Self { Self { - alignment: 64, + alignment: DEFAULT_ALIGNMENT, write_legacy_ipc_format: false, metadata_version: crate::MetadataVersion::V5, batch_compression_type: None, @@ -416,7 +421,7 @@ impl IpcDataGenerator { dict_id_seq, )?; - // It's importnat to only take the dict_id at this point, because the dict ID + // It's important to only take the dict_id at this point, because the dict ID // sequence is assigned depth-first, so we need to first encode children and have // them take their assigned dict IDs before we take the dict ID for this field. let dict_id = dict_id_seq @@ -448,15 +453,45 @@ impl IpcDataGenerator { Ok(()) } + /// Calls [`Self::encoded_batch_with_size`] with no limit, returning the first (and only) + /// [`EncodedData`] that is produced. This method should be used over + /// [`Self::encoded_batch_with_size`] if the consumer has no concerns about encoded message + /// size limits + pub fn encoded_batch( + &self, + batch: &RecordBatch, + dictionary_tracker: &mut DictionaryTracker, + write_options: &IpcWriteOptions, + ) -> Result<(Vec, EncodedData), ArrowError> { + let (encoded_dictionaries, mut encoded_messages) = + self.encoded_batch_with_size(batch, dictionary_tracker, write_options, usize::MAX)?; + + assert_eq!( + encoded_messages.len(), + 1, + "encoded_batch with max size of usize::MAX should not be able to return more or less than 1 batch" + ); + + Ok((encoded_dictionaries, encoded_messages.pop().unwrap())) + } + /// Encodes a batch to a number of [EncodedData] items (dictionary batches + the record batch). /// The [DictionaryTracker] keeps track of dictionaries with new `dict_id`s (so they are only sent once) /// Make sure the [DictionaryTracker] is initialized at the start of the stream. - pub fn encoded_batch( + /// The `max_encoded_data_size` is used to control how much space each encoded [`RecordBatch`] is + /// allowed to take up. + /// + /// Each [`EncodedData`] in the second element of the returned tuple will be smaller than + /// `max_encoded_data_size` bytes, if possible at all. However, this API has no support for + /// splitting rows into multiple [`EncodedData`]s, so if a row is larger, by itself, than + /// `max_encoded_data_size`, it will be encoded to a message which is larger than the limit. + pub fn encoded_batch_with_size( &self, batch: &RecordBatch, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, - ) -> Result<(Vec, EncodedData), ArrowError> { + max_encoded_data_size: usize, + ) -> Result<(Vec, Vec), ArrowError> { let schema = batch.schema(); let mut encoded_dictionaries = Vec::with_capacity(schema.flattened_fields().len()); @@ -474,100 +509,11 @@ impl IpcDataGenerator { )?; } - let encoded_message = self.record_batch_to_bytes(batch, write_options)?; + let encoded_message = + chunked_encoded_batch_bytes(batch, write_options, max_encoded_data_size)?; Ok((encoded_dictionaries, encoded_message)) } - /// Write a `RecordBatch` into two sets of bytes, one for the header (crate::Message) and the - /// other for the batch's data - fn record_batch_to_bytes( - &self, - batch: &RecordBatch, - write_options: &IpcWriteOptions, - ) -> Result { - let mut fbb = FlatBufferBuilder::new(); - - let mut nodes: Vec = vec![]; - let mut buffers: Vec = vec![]; - let mut arrow_data: Vec = vec![]; - let mut offset = 0; - - // get the type of compression - let batch_compression_type = write_options.batch_compression_type; - - let compression = batch_compression_type.map(|batch_compression_type| { - let mut c = crate::BodyCompressionBuilder::new(&mut fbb); - c.add_method(crate::BodyCompressionMethod::BUFFER); - c.add_codec(batch_compression_type); - c.finish() - }); - - let compression_codec: Option = - batch_compression_type.map(TryInto::try_into).transpose()?; - - let mut variadic_buffer_counts = vec![]; - - for array in batch.columns() { - let array_data = array.to_data(); - offset = write_array_data( - &array_data, - &mut buffers, - &mut arrow_data, - &mut nodes, - offset, - array.len(), - array.null_count(), - compression_codec, - write_options, - )?; - - append_variadic_buffer_counts(&mut variadic_buffer_counts, &array_data); - } - // pad the tail of body data - let len = arrow_data.len(); - let pad_len = pad_to_alignment(write_options.alignment, len); - arrow_data.extend_from_slice(&PADDING[..pad_len]); - - // write data - let buffers = fbb.create_vector(&buffers); - let nodes = fbb.create_vector(&nodes); - let variadic_buffer = if variadic_buffer_counts.is_empty() { - None - } else { - Some(fbb.create_vector(&variadic_buffer_counts)) - }; - - let root = { - let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb); - batch_builder.add_length(batch.num_rows() as i64); - batch_builder.add_nodes(nodes); - batch_builder.add_buffers(buffers); - if let Some(c) = compression { - batch_builder.add_compression(c); - } - - if let Some(v) = variadic_buffer { - batch_builder.add_variadicBufferCounts(v); - } - let b = batch_builder.finish(); - b.as_union_value() - }; - // create an crate::Message - let mut message = crate::MessageBuilder::new(&mut fbb); - message.add_version(write_options.metadata_version); - message.add_header_type(crate::MessageHeader::RecordBatch); - message.add_bodyLength(arrow_data.len() as i64); - message.add_header(root); - let root = message.finish(); - fbb.finish(root, None); - let finished_data = fbb.finished_data(); - - Ok(EncodedData { - ipc_message: finished_data.to_vec(), - arrow_data, - }) - } - /// Write dictionary values into two sets of bytes, one for the header (crate::Message) and the /// other for the data fn dictionary_batch_to_bytes( @@ -576,92 +522,28 @@ impl IpcDataGenerator { array_data: &ArrayData, write_options: &IpcWriteOptions, ) -> Result { - let mut fbb = FlatBufferBuilder::new(); - - let mut nodes: Vec = vec![]; - let mut buffers: Vec = vec![]; - let mut arrow_data: Vec = vec![]; - - // get the type of compression - let batch_compression_type = write_options.batch_compression_type; - - let compression = batch_compression_type.map(|batch_compression_type| { - let mut c = crate::BodyCompressionBuilder::new(&mut fbb); - c.add_method(crate::BodyCompressionMethod::BUFFER); - c.add_codec(batch_compression_type); - c.finish() - }); - - let compression_codec: Option = batch_compression_type - .map(|batch_compression_type| batch_compression_type.try_into()) - .transpose()?; - - write_array_data( - array_data, - &mut buffers, - &mut arrow_data, - &mut nodes, - 0, + let mut encoded_datas = encode_array_datas( + &[array_data.clone()], array_data.len(), - array_data.null_count(), - compression_codec, + |fbb, offset| { + let mut builder = DictionaryBatchBuilder::new(&mut fbb.fbb); + builder.add_id(dict_id); + builder.add_data(offset); + builder.finish().as_union_value() + }, + MessageHeader::DictionaryBatch, + // ASK: No maximum message size here? + usize::MAX, write_options, )?; - let mut variadic_buffer_counts = vec![]; - append_variadic_buffer_counts(&mut variadic_buffer_counts, array_data); - - // pad the tail of body data - let len = arrow_data.len(); - let pad_len = pad_to_alignment(write_options.alignment, len); - arrow_data.extend_from_slice(&PADDING[..pad_len]); - - // write data - let buffers = fbb.create_vector(&buffers); - let nodes = fbb.create_vector(&nodes); - let variadic_buffer = if variadic_buffer_counts.is_empty() { - None - } else { - Some(fbb.create_vector(&variadic_buffer_counts)) - }; - - let root = { - let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb); - batch_builder.add_length(array_data.len() as i64); - batch_builder.add_nodes(nodes); - batch_builder.add_buffers(buffers); - if let Some(c) = compression { - batch_builder.add_compression(c); - } - if let Some(v) = variadic_buffer { - batch_builder.add_variadicBufferCounts(v); - } - batch_builder.finish() - }; - - let root = { - let mut batch_builder = crate::DictionaryBatchBuilder::new(&mut fbb); - batch_builder.add_id(dict_id); - batch_builder.add_data(root); - batch_builder.finish().as_union_value() - }; - - let root = { - let mut message_builder = crate::MessageBuilder::new(&mut fbb); - message_builder.add_version(write_options.metadata_version); - message_builder.add_header_type(crate::MessageHeader::DictionaryBatch); - message_builder.add_bodyLength(arrow_data.len() as i64); - message_builder.add_header(root); - message_builder.finish() - }; - - fbb.finish(root, None); - let finished_data = fbb.finished_data(); + assert_eq!( + encoded_datas.len(), + 1, + "encode_array_datas with a max size of usize::MAX should not be able to return more or less than 1 batch" + ); - Ok(EncodedData { - ipc_message: finished_data.to_vec(), - arrow_data, - }) + Ok(encoded_datas.pop().unwrap()) } } @@ -684,7 +566,7 @@ fn append_variadic_buffer_counts(counts: &mut Vec, array: &ArrayData) { } } -pub(crate) fn unslice_run_array(arr: ArrayData) -> Result { +pub(crate) fn unslice_run_array(arr: &ArrayData) -> Result { match arr.data_type() { DataType::RunEndEncoded(k, _) => match k.data_type() { DataType::Int16 => { @@ -971,10 +853,11 @@ impl FileWriter { )); } - let (encoded_dictionaries, encoded_message) = self.data_gen.encoded_batch( + let (encoded_dictionaries, encoded_messages) = self.data_gen.encoded_batch_with_size( batch, &mut self.dictionary_tracker, &self.write_options, + usize::MAX, )?; for encoded_dictionary in encoded_dictionaries { @@ -986,15 +869,22 @@ impl FileWriter { self.block_offsets += meta + data; } - let (meta, data) = write_message(&mut self.writer, encoded_message, &self.write_options)?; - // add a record block for the footer - let block = crate::Block::new( - self.block_offsets as i64, - meta as i32, // TODO: is this still applicable? - data as i64, - ); - self.record_blocks.push(block); - self.block_offsets += meta + data; + // theoretically, since the maximum size for encoding is usize::MAX, there should never be + // more than 1 encoded message. However, since there's no need to assert that (i.e. if + // someone changes usize::MAX above to be a lower message, that's fine), we just assume + // there can be many messages + for encoded_message in encoded_messages { + let (meta, data) = + write_message(&mut self.writer, encoded_message, &self.write_options)?; + // add a record block for the footer + let block = crate::Block::new( + self.block_offsets as i64, + meta as i32, // TODO: is this still applicable? + data as i64, + ); + self.record_blocks.push(block); + self.block_offsets += meta + data; + } Ok(()) } @@ -1168,16 +1058,23 @@ impl StreamWriter { )); } - let (encoded_dictionaries, encoded_message) = self + let (encoded_dictionaries, encoded_messages) = self .data_gen - .encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options) + .encoded_batch_with_size( + batch, + &mut self.dictionary_tracker, + &self.write_options, + usize::MAX, + ) .expect("StreamWriter is configured to not error on dictionary replacement"); for encoded_dictionary in encoded_dictionaries { write_message(&mut self.writer, encoded_dictionary, &self.write_options)?; } - write_message(&mut self.writer, encoded_message, &self.write_options)?; + for message in encoded_messages { + write_message(&mut self.writer, message, &self.write_options)?; + } Ok(()) } @@ -1272,6 +1169,7 @@ impl RecordBatchWriter for StreamWriter { } } +#[derive(Debug)] /// Stores the encoded data, which is an crate::Message, and optional Arrow data pub struct EncodedData { /// An encoded crate::Message @@ -1279,6 +1177,7 @@ pub struct EncodedData { /// Arrow buffers to be written, should be an empty vec for schema messages pub arrow_data: Vec, } + /// Write a message's IPC data and buffers, returning metadata and buffer data lengths written pub fn write_message( mut writer: W, @@ -1416,38 +1315,67 @@ fn get_buffer_element_width(spec: &BufferSpec) -> usize { /// Common functionality for re-encoding offsets. Returns the new offsets as well as /// original start offset and length for use in slicing child data. -fn reencode_offsets( - offsets: &Buffer, - data: &ArrayData, -) -> (Buffer, usize, usize) { - let offsets_slice: &[O] = offsets.typed_data::(); - let offset_slice = &offsets_slice[data.offset()..data.offset() + data.len() + 1]; - - let start_offset = offset_slice.first().unwrap(); - let end_offset = offset_slice.last().unwrap(); - - let offsets = match start_offset.as_usize() { - 0 => offsets.clone(), - _ => offset_slice.iter().map(|x| *x - *start_offset).collect(), +/// +/// # Panics +/// +/// Will panic if you call this on an `ArrayData` that does not have a buffer of offsets as the +/// very first buffer (i.e. expects this to be a valid variable-length array) +fn reencode_offsets(data: &ArrayData) -> (Buffer, usize, usize) { + // first we want to see: what is the offset of this `ArrayData` into the buffer (which is a + // buffer of offsets into the buffer of data) + let orig_offset = data.offset(); + // and also we need to get the buffer of offsets + let offsets_buf = &data.buffers()[0]; + + // then we have to turn it into a typed slice that we can read and manipulate below if + // needed, and slice it according to the size that we need to return + // we need to do `self.len + 1` instead of just `self.len` because the offsets are encoded + // as overlapping pairs - e.g. an array of two items, starting at idx 0, and spanning two + // each, would be encoded as [0, 2, 4]. + let offsets_slice = &offsets_buf.typed_data::()[orig_offset..][..data.len() + 1]; + + // and now we can see what the very first offset and the very last offset is + let start_offset = offsets_slice.first().unwrap(); + let end_offset = offsets_slice.last().unwrap().as_usize(); + + // if the start offset is just 0, i.e. it points to the very beginning of the values of + // this `ArrayData`, then we don't need to shift anything to be a 'correct' offset 'cause + // all the offsets in this buffer are already offset by 0. + // But if it's not 0, then we need to shift them all so that the offsets don't start at + // some weird value. + let (start_offset, offsets) = match start_offset.as_usize() { + 0 => (0, offsets_slice.to_vec().into()), + start => ( + start, + offsets_slice.iter().map(|x| *x - *start_offset).collect(), + ), }; - let start_offset = start_offset.as_usize(); - let end_offset = end_offset.as_usize(); - (offsets, start_offset, end_offset - start_offset) } -/// Returns the values and offsets [`Buffer`] for a ByteArray with offset type `O` +/// Returns the offsets and values [`Buffer`]s for a ByteArray with offset type `O` /// /// In particular, this handles re-encoding the offsets if they don't start at `0`, /// slicing the values buffer as appropriate. This helps reduce the encoded /// size of sliced arrays, as values that have been sliced away are not encoded -fn get_byte_array_buffers(data: &ArrayData) -> (Buffer, Buffer) { +/// +/// # Panics +/// +/// Panics if self.buffers does not contain at least 2 buffers (this code expects that the +/// first will contain the offsets for this variable-length array and the other will contain +/// the values) +pub fn get_byte_array_buffers(data: &ArrayData) -> (Buffer, Buffer) { if data.is_empty() { return (MutableBuffer::new(0).into(), MutableBuffer::new(0).into()); } - let (offsets, original_start_offset, len) = reencode_offsets::(&data.buffers()[0], data); + // get the buffer of offsets, now shifted so they are shifted to be accurate to the slice + // of values that we'll be taking (e.g. if they previously said [0, 3, 5, 7], but we slice + // to only get the last offset, they'll be shifted to be [0, 2], since that would be the + // offset pair for the last value in this shifted slice). + // also, in this example, original_start_offset would be 5 and len would be 2. + let (offsets, original_start_offset, len) = reencode_offsets::(data); let values = data.buffers()[1].slice_with_length(original_start_offset, len); (offsets, values) } @@ -1462,280 +1390,521 @@ fn get_list_array_buffers(data: &ArrayData) -> (Buffer, Arra ); } - let (offsets, original_start_offset, len) = reencode_offsets::(&data.buffers()[0], data); + let (offsets, original_start_offset, len) = reencode_offsets::(data); let child_data = data.child_data()[0].slice(original_start_offset, len); (offsets, child_data) } -/// Write array data to a vector of bytes -#[allow(clippy::too_many_arguments)] -fn write_array_data( - array_data: &ArrayData, - buffers: &mut Vec, - arrow_data: &mut Vec, - nodes: &mut Vec, - offset: i64, - num_rows: usize, - null_count: usize, - compression_codec: Option, +const DEFAULT_ALIGNMENT: u8 = 64; +const PADDING: [u8; DEFAULT_ALIGNMENT as usize] = [0; DEFAULT_ALIGNMENT as usize]; + +/// Calculate an alignment boundary and return the number of bytes needed to pad to the alignment boundary +#[inline] +fn pad_to_alignment(alignment: u8, len: usize) -> usize { + let a = usize::from(alignment - 1); + ((len + a) & !a) - len +} + +fn chunked_encoded_batch_bytes( + batch: &RecordBatch, write_options: &IpcWriteOptions, -) -> Result { - let mut offset = offset; - if !matches!(array_data.data_type(), DataType::Null) { - nodes.push(crate::FieldNode::new(num_rows as i64, null_count as i64)); - } else { - // NullArray's null_count equals to len, but the `null_count` passed in is from ArrayData - // where null_count is always 0. - nodes.push(crate::FieldNode::new(num_rows as i64, num_rows as i64)); - } - if has_validity_bitmap(array_data.data_type(), write_options) { - // write null buffer if exists - let null_buffer = match array_data.nulls() { - None => { - // create a buffer and fill it with valid bits - let num_bytes = bit_util::ceil(num_rows, 8); - let buffer = MutableBuffer::new(num_bytes); - let buffer = buffer.with_bitset(num_bytes, true); - buffer.into() - } - Some(buffer) => buffer.inner().sliced(), + max_encoded_data_size: usize, +) -> Result, ArrowError> { + encode_array_datas( + &batch + .columns() + .iter() + .map(ArrayRef::to_data) + .collect::>(), + batch.num_rows(), + |_, offset| offset.as_union_value(), + MessageHeader::RecordBatch, + max_encoded_data_size, + write_options, + ) +} + +fn get_encoded_arr_batch_size>( + iter: impl IntoIterator, + write_options: &IpcWriteOptions, +) -> Result { + iter.into_iter() + .map(|arr| { + let arr = arr.borrow(); + arr.get_slice_memory_size_with_alignment(Some(write_options.alignment)) + .and_then(|mut size| { + let didnt_count_nulls = arr.nulls().is_none(); + let will_write_nulls = has_validity_bitmap(arr.data_type(), write_options); + + if will_write_nulls && didnt_count_nulls { + let null_len = bit_util::ceil(arr.len(), 8); + size += null_len + pad_to_alignment(write_options.alignment, null_len) + } + + // TODO: This is ugly. We remove the child_data size in RunEndEncoded because + // it was calculated as the size existing in memory but we care about the size + // when it's decoded and then encoded into a flatbuffer. Afaik, this is the + // only data type where the size in memory is not the same size as when encoded + // (since it has a different representation in memory), so it's not horrible, + // but it's definitely not ideal. + if let DataType::RunEndEncoded(_, _) = arr.data_type() { + size -= arr + .child_data() + .iter() + .map(|data| { + data.get_slice_memory_size_with_alignment(Some( + write_options.alignment, + )) + }) + .sum::>()?; + + size += unslice_run_array(arr)? + .child_data() + .iter() + .map(|data| get_encoded_arr_batch_size([data], write_options)) + .sum::>()?; + } + + Ok(size) + }) + }) + .sum() +} + +fn encode_array_datas( + arr_datas: &[ArrayData], + n_rows: usize, + encode_root: impl Fn( + &mut FlatBufferSizeTracker, + WIPOffset, + ) -> WIPOffset, + header_type: MessageHeader, + mut max_msg_size: usize, + write_options: &IpcWriteOptions, +) -> Result, ArrowError> { + let mut fbb = FlatBufferSizeTracker::for_dry_run(arr_datas.len()); + fbb.encode_array_datas( + arr_datas, + n_rows as i64, + &encode_root, + header_type, + write_options, + )?; + + let header_len = fbb.fbb.finished_data().len(); + max_msg_size = max_msg_size.saturating_sub(header_len).max(1); + + let total_size = get_encoded_arr_batch_size(arr_datas.iter(), write_options)?; + + let n_batches = bit_util::ceil(total_size, max_msg_size); + let mut out = Vec::with_capacity(n_batches); + + let mut offset = 0; + while offset < n_rows.max(1) { + let slice_arrays = |len: usize| { + arr_datas.iter().map(move |arr| { + if len >= arr.len() { + arr.clone() + } else { + arr.slice(offset, len) + } + }) }; - offset = write_buffer( - null_buffer.as_slice(), - buffers, - arrow_data, - offset, - compression_codec, - write_options.alignment, - )?; - } + let rows_left = n_rows - offset; + // TODO? maybe this could be more efficient by continually approximating the maximum number + // of rows based on (size / n_rows) of the current ArrayData slice until we've found the + // maximum that can fit? e.g. 'oh, it's 200 bytes and 10 rows, so each row is probably 20 + // bytes - let's do (max_size / 20) rows and see if that fits' + let length = (1..=rows_left) + .find(|len| { + // If we've exhausted the available length of the array datas, then just return - + // we've got it. + if offset + len > n_rows { + return true; + } - let data_type = array_data.data_type(); - if matches!(data_type, DataType::Binary | DataType::Utf8) { - let (offsets, values) = get_byte_array_buffers::(array_data); - for buffer in [offsets, values] { - offset = write_buffer( - buffer.as_slice(), - buffers, - arrow_data, - offset, - compression_codec, - write_options.alignment, + // we can unwrap this here b/c this only errors on malformed buffer-type/data-type + // combinations, and if any of these arrays had that, this function would've + // already short-circuited on an earlier call of this function + get_encoded_arr_batch_size(slice_arrays(*len), write_options).unwrap() + > max_msg_size + }) + // If no rows fit in the given max size, we want to try to get the data across anyways, + // so that just means doing a single row. Calling `max(2)` is how we ensure that - if + // the very first item would go over the max size, giving us a length of 0, we want to + // set this to `2` so that taking away 1 leaves us with one row to encode. + .map(|len| len.max(2) - 1) + // If all rows can comfortably fit in this given size, then just get them all + .unwrap_or(rows_left); + + // We could get into a situtation where we were given all 0-row arrays to be sent over + // flight - we do need to send a flight message to show that there is no data, but we also + // can't have `length` be 0 at this point because it could also be that all rows are too + // large to send with the provided limits and so we just want to try to send one now + // anyways, so the checks in this fn are just how we cover our bases there. + let new_arrs = slice_arrays(length).collect::>(); + + // If we've got more than one row to encode or if we have 0 rows to encode but we haven't + // encoded anything yet, then continue with encoding. We don't need to do encoding, though, + // if we've already encoded some rows and there's no rows left + if length != 0 || offset == 0 { + fbb.reset_for_real_run(); + fbb.encode_array_datas( + &new_arrs, + length as i64, + &encode_root, + header_type, + write_options, )?; + + let finished_data = fbb.fbb.finished_data(); + + out.push(EncodedData { + ipc_message: finished_data.to_vec(), + arrow_data: fbb.arrow_data.clone(), + }); } - } else if matches!(data_type, DataType::BinaryView | DataType::Utf8View) { - // Slicing the views buffer is safe and easy, - // but pruning unneeded data buffers is much more nuanced since it's complicated to prove that no views reference the pruned buffers - // - // Current implementation just serialize the raw arrays as given and not try to optimize anything. - // If users wants to "compact" the arrays prior to sending them over IPC, - // they should consider the gc API suggested in #5513 - for buffer in array_data.buffers() { - offset = write_buffer( - buffer.as_slice(), - buffers, - arrow_data, - offset, - compression_codec, - write_options.alignment, - )?; + + // If length == 0, that means they gave us ArrayData with no rows, so a single iteration is + // always sufficient. + if length == 0 { + break; } - } else if matches!(data_type, DataType::LargeBinary | DataType::LargeUtf8) { - let (offsets, values) = get_byte_array_buffers::(array_data); - for buffer in [offsets, values] { - offset = write_buffer( - buffer.as_slice(), - buffers, - arrow_data, - offset, - compression_codec, - write_options.alignment, + + offset += length; + } + + Ok(out) +} + +/// A struct to help ensure that the size of encoded flight messages never goes over a provided +/// limit (except in ridiculous cases like a limit of 1 byte). The way it does so is by first +/// running through a provided slice of [`ArrayData`], producing an IPC header message for that +/// slice, and then subtracting the size of that generated header from the message limit it has +/// been given. Because IPC header message sizes don't change due to a different amount of rows, +/// this header size will stay consistent throughout the entire time that we have to transmit a +/// chunk of rows, so we can just subtract it from the overall limit and use that to check +/// different slices of `ArrayData` against to know how many to transmit each time. +/// +/// This whole process is done in [`encode_array_datas()`] above +#[derive(Default)] +struct FlatBufferSizeTracker<'fbb> { + // the builder and backing flatbuffer that we use to write the arrow data into. + fbb: FlatBufferBuilder<'fbb>, + // tracks the data in `arrow_data` - `buffers` contains the offsets and length of different + // buffers encoded within the big chunk that is `arrow_data`. + buffers: Vec, + // the raw array data that we need to send across the wire + arrow_data: Vec, + nodes: Vec, + dry_run: bool, +} + +impl<'fbb> FlatBufferSizeTracker<'fbb> { + /// Preferred initializer, as this should always be used with a dry-run before a real run to + /// figure out the size of the IPC header. + #[must_use] + fn for_dry_run(capacity: usize) -> Self { + Self { + dry_run: true, + buffers: Vec::with_capacity(capacity), + nodes: Vec::with_capacity(capacity), + ..Self::default() + } + } + + /// Should be called in-between calls to `encode_array_datas` to ensure we don't accidentally + /// keep & encode old data each time. + fn reset_for_real_run(&mut self) { + self.fbb.reset(); + self.buffers.clear(); + self.arrow_data.clear(); + self.nodes.clear(); + self.dry_run = false; + + // this helps us avoid completely re-allocating the buffers by just creating a new `Self`. + // So everything should be allocated correctly now besides arrow_data. If we're calling + // this after only a dry run, `arrow_data` shouldn't have anything written into it, but + // we call this after every real run loop, so we still need to clear it. + } + + fn encode_array_datas( + &mut self, + arr_datas: &[ArrayData], + n_rows: i64, + encode_root: impl FnOnce( + &mut FlatBufferSizeTracker, + WIPOffset, + ) -> WIPOffset, + header_type: MessageHeader, + write_options: &IpcWriteOptions, + ) -> Result<(), ArrowError> { + let batch_compression_type = write_options.batch_compression_type; + + let compression = batch_compression_type.map(|compression_type| { + let mut builder = BodyCompressionBuilder::new(&mut self.fbb); + builder.add_method(BodyCompressionMethod::BUFFER); + builder.add_codec(compression_type); + builder.finish() + }); + + let mut variadic_buffer_counts = Vec::::default(); + let mut offset = 0; + + for array in arr_datas { + self.write_array_data( + array, + &mut offset, + array.len(), + array.null_count(), + write_options, )?; + + append_variadic_buffer_counts(&mut variadic_buffer_counts, array); } - } else if DataType::is_numeric(data_type) - || DataType::is_temporal(data_type) - || matches!( - array_data.data_type(), - DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _) - ) - { - // Truncate values - assert_eq!(array_data.buffers().len(), 1); - - let buffer = &array_data.buffers()[0]; - let layout = layout(data_type); - let spec = &layout.buffers[0]; - - let byte_width = get_buffer_element_width(spec); - let min_length = array_data.len() * byte_width; - let buffer_slice = if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) { - let byte_offset = array_data.offset() * byte_width; - let buffer_length = min(min_length, buffer.len() - byte_offset); - &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)] - } else { - buffer.as_slice() - }; - offset = write_buffer( - buffer_slice, - buffers, - arrow_data, - offset, - compression_codec, - write_options.alignment, - )?; - } else if matches!(data_type, DataType::Boolean) { - // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes). - // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around. - assert_eq!(array_data.buffers().len(), 1); - - let buffer = &array_data.buffers()[0]; - let buffer = buffer.bit_slice(array_data.offset(), array_data.len()); - offset = write_buffer( - &buffer, - buffers, - arrow_data, - offset, - compression_codec, - write_options.alignment, - )?; - } else if matches!( - data_type, - DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _) - ) { - assert_eq!(array_data.buffers().len(), 1); - assert_eq!(array_data.child_data().len(), 1); - - // Truncate offsets and the child data to avoid writing unnecessary data - let (offsets, sliced_child_data) = match data_type { - DataType::List(_) => get_list_array_buffers::(array_data), - DataType::Map(_, _) => get_list_array_buffers::(array_data), - DataType::LargeList(_) => get_list_array_buffers::(array_data), - _ => unreachable!(), - }; - offset = write_buffer( - offsets.as_slice(), - buffers, - arrow_data, - offset, - compression_codec, - write_options.alignment, - )?; - offset = write_array_data( - &sliced_child_data, - buffers, - arrow_data, - nodes, - offset, - sliced_child_data.len(), - sliced_child_data.null_count(), - compression_codec, - write_options, - )?; - return Ok(offset); - } else { - for buffer in array_data.buffers() { - offset = write_buffer( - buffer, - buffers, - arrow_data, + + // pad the tail of the body data + let pad_len = pad_to_alignment(write_options.alignment, self.arrow_data.len()); + self.arrow_data.extend_from_slice(&PADDING[..pad_len]); + + let buffers = self.fbb.create_vector(&self.buffers); + let nodes = self.fbb.create_vector(&self.nodes); + let variadic_buffer = (!variadic_buffer_counts.is_empty()) + .then(|| self.fbb.create_vector(&variadic_buffer_counts)); + + let root = crate::RecordBatch::create( + &mut self.fbb, + &RecordBatchArgs { + length: n_rows, + nodes: Some(nodes), + buffers: Some(buffers), + compression, + variadicBufferCounts: variadic_buffer, + }, + ); + + let root = encode_root(self, root); + + let arrow_len = self.arrow_data.len() as i64; + + let mut msg_bldr = MessageBuilder::new(&mut self.fbb); + msg_bldr.add_version(write_options.metadata_version); + msg_bldr.add_header_type(header_type); + msg_bldr.add_header(root); + msg_bldr.add_bodyLength(arrow_len); + let msg = msg_bldr.finish(); + + self.fbb.finish(msg, None); + Ok(()) + } + + fn write_array_data( + &mut self, + array_data: &ArrayData, + offset: &mut i64, + num_rows: usize, + null_count: usize, + write_options: &IpcWriteOptions, + ) -> Result<(), ArrowError> { + let compression_codec: Option = write_options + .batch_compression_type + .map(TryInto::try_into) + .transpose()?; + + // NullArray's null_count equals to len, but the `null_count` passed in is from ArrayData + // where null_count is always 0. + self.nodes.push(crate::FieldNode::new( + num_rows as i64, + match array_data.data_type() { + DataType::Null => num_rows, + _ => null_count, + } as i64, + )); + + if has_validity_bitmap(array_data.data_type(), write_options) { + // write null buffer if exists + let null_buffer = match array_data.nulls() { + None => { + let num_bytes = bit_util::ceil(num_rows, 8); + // create a buffer and fill it with valid bits + MutableBuffer::new(num_bytes) + .with_bitset(num_bytes, true) + .into() + } + Some(buffer) => buffer.inner().sliced(), + }; + + self.write_buffer( + &null_buffer, offset, compression_codec, write_options.alignment, )?; } - } - match array_data.data_type() { - DataType::Dictionary(_, _) => {} - DataType::RunEndEncoded(_, _) => { - // unslice the run encoded array. - let arr = unslice_run_array(array_data.clone())?; - // recursively write out nested structures - for data_ref in arr.child_data() { - // write the nested data (e.g list data) - offset = write_array_data( - data_ref, - buffers, - arrow_data, - nodes, + let mut write_byte_array_byffers = |(offsets, values): (Buffer, Buffer)| { + self.write_buffer(&offsets, offset, compression_codec, write_options.alignment)?; + self.write_buffer(&values, offset, compression_codec, write_options.alignment) + }; + + match array_data.data_type() { + DataType::Binary | DataType::Utf8 => { + write_byte_array_byffers(get_byte_array_buffers::(array_data))? + } + DataType::LargeBinary | DataType::LargeUtf8 => { + write_byte_array_byffers(get_byte_array_buffers::(array_data))? + } + dt if DataType::is_numeric(dt) + || DataType::is_temporal(dt) + || matches!( + dt, + DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _) + ) => + { + // Truncate values + let [buffer] = array_data.buffers() else { + panic!("Temporal, Numeric, FixedSizeBinary, and Dictionary data types must contain only one buffer"); + }; + + let layout = layout(dt); + let spec = &layout.buffers[0]; + + let byte_width = get_buffer_element_width(spec); + let min_length = array_data.len() * byte_width; + let mut buffer_slice = buffer.as_slice(); + + if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) { + let byte_offset = array_data.offset() * byte_width; + let buffer_length = min(min_length, buffer.len() - byte_offset); + buffer_slice = &buffer_slice[byte_offset..(byte_offset + buffer_length)]; + } + + self.write_buffer( + buffer_slice, offset, - data_ref.len(), - data_ref.null_count(), compression_codec, - write_options, + write_options.alignment, )?; } + DataType::Boolean => { + // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes). + // The array data may not start at the physical boundary of the underlying buffer, + // so we need to shift bits around. + let [single_buf] = array_data.buffers() else { + panic!("ArrayData of type Boolean should only contain 1 buffer"); + }; + + let buffer = &single_buf.bit_slice(array_data.offset(), array_data.len()); + self.write_buffer(buffer, offset, compression_codec, write_options.alignment)?; + } + dt @ (DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _)) => { + assert_eq!(array_data.buffers().len(), 1); + assert_eq!(array_data.child_data().len(), 1); + + // Truncate offsets and the child data to avoid writing unnecessary data + let (offsets, sliced_child_data) = match dt { + DataType::List(_) | DataType::Map(_, _) => { + get_list_array_buffers::(array_data) + } + DataType::LargeList(_) => get_list_array_buffers::(array_data), + _ => unreachable!(), + }; + self.write_buffer(&offsets, offset, compression_codec, write_options.alignment)?; + return self.write_array_data( + &sliced_child_data, + offset, + sliced_child_data.len(), + sliced_child_data.null_count(), + write_options, + ); + } + _ => { + // This accommodates for even the `View` types (e.g. BinaryView and Utf8View): + // Slicing the views buffer is safe and easy, + // but pruning unneeded data buffers is much more nuanced since it's complicated + // to prove that no views reference the pruned buffers + // + // Current implementation just serialize the raw arrays as given and not try to optimize anything. + // If users wants to "compact" the arrays prior to sending them over IPC, + // they should consider the gc API suggested in #5513 + for buffer in array_data.buffers() { + self.write_buffer(buffer, offset, compression_codec, write_options.alignment)?; + } + } } - _ => { - // recursively write out nested structures - for data_ref in array_data.child_data() { - // write the nested data (e.g list data) - offset = write_array_data( + + let mut write_arr = |arr: &ArrayData| { + arr.child_data().iter().try_for_each(|data_ref| { + self.write_array_data( data_ref, - buffers, - arrow_data, - nodes, offset, data_ref.len(), data_ref.null_count(), - compression_codec, write_options, - )?; - } + ) + }) + }; + + match array_data.data_type() { + DataType::Dictionary(_, _) => Ok(()), + // unslice the run encoded array. + DataType::RunEndEncoded(_, _) => write_arr(&unslice_run_array(array_data)?), + // recursively write out nested structures + _ => write_arr(array_data), } } - Ok(offset) -} -/// Write a buffer into `arrow_data`, a vector of bytes, and adds its -/// [`crate::Buffer`] to `buffers`. Returns the new offset in `arrow_data` -/// -/// -/// From -/// Each constituent buffer is first compressed with the indicated -/// compressor, and then written with the uncompressed length in the first 8 -/// bytes as a 64-bit little-endian signed integer followed by the compressed -/// buffer bytes (and then padding as required by the protocol). The -/// uncompressed length may be set to -1 to indicate that the data that -/// follows is not compressed, which can be useful for cases where -/// compression does not yield appreciable savings. -fn write_buffer( - buffer: &[u8], // input - buffers: &mut Vec, // output buffer descriptors - arrow_data: &mut Vec, // output stream - offset: i64, // current output stream offset - compression_codec: Option, - alignment: u8, -) -> Result { - let len: i64 = match compression_codec { - Some(compressor) => compressor.compress_to_vec(buffer, arrow_data)?, - None => { - arrow_data.extend_from_slice(buffer); - buffer.len() - } - } - .try_into() - .map_err(|e| { - ArrowError::InvalidArgumentError(format!("Could not convert compressed size to i64: {e}")) - })?; - - // make new index entry - buffers.push(crate::Buffer::new(offset, len)); - // padding and make offset aligned - let pad_len = pad_to_alignment(alignment, len as usize); - arrow_data.extend_from_slice(&PADDING[..pad_len]); - - Ok(offset + len + (pad_len as i64)) -} + /// Write a buffer into `arrow_data`, a vector of bytes, and adds its + /// [`crate::Buffer`] to `buffers`. Modifies the offset passed in to respect the new value. + /// + /// From + /// Each constituent buffer is first compressed with the indicated + /// compressor, and then written with the uncompressed length in the first 8 + /// bytes as a 64-bit little-endian signed integer followed by the compressed + /// buffer bytes (and then padding as required by the protocol). The + /// uncompressed length may be set to -1 to indicate that the data that + /// follows is not compressed, which can be useful for cases where + /// compression does not yield appreciable savings. + fn write_buffer( + &mut self, + // input + buffer: &[u8], + // current output stream offset + offset: &mut i64, + compression_codec: Option, + alignment: u8, + ) -> Result<(), ArrowError> { + let len: i64 = if self.dry_run { + // Flatbuffers will essentially optimize this away if we say the len is 0 for all of + // these, so to make sure the header size is the same in the dry run and in the real + // thing, we need to set this to a non-zero value + 1 + } else { + match compression_codec { + Some(compressor) => compressor.compress_to_vec(buffer, &mut self.arrow_data)?, + None => { + self.arrow_data.extend_from_slice(buffer); + buffer.len() + } + } + .try_into() + .map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Could not convert compressed size to i64: {e}" + )) + })? + }; -const PADDING: [u8; 64] = [0; 64]; + // make new index entry + self.buffers.push(crate::Buffer::new(*offset, len)); + // padding and make offset aligned + let pad_len = pad_to_alignment(alignment, len as usize); + self.arrow_data.extend_from_slice(&PADDING[..pad_len]); -/// Calculate an alignment boundary and return the number of bytes needed to pad to the alignment boundary -#[inline] -fn pad_to_alignment(alignment: u8, len: usize) -> usize { - let a = usize::from(alignment - 1); - ((len + a) & !a) - len + *offset += len + (pad_len as i64); + Ok(()) + } } #[cfg(test)] @@ -2780,4 +2949,73 @@ mod tests { assert_eq!(stream_bytes_written_on_flush, expected_stream_flushed_bytes); assert_eq!(file_bytes_written_on_flush, expected_file_flushed_bytes); } + + #[test] + fn encoded_arr_data_same_size_as_compute_api() { + fn encode_test(arr: T) { + println!("Checking arr {arr:?}"); + + let arr_data = arr.to_data(); + + let write_options = IpcWriteOptions::default() + .try_with_compression(None) + .unwrap(); + + let compute_size = get_encoded_arr_batch_size([&arr_data], &write_options).unwrap(); + let num_rows = arr_data.len(); + + let encoded = encode_array_datas( + &[arr_data], + num_rows, + |_, root| root.as_union_value(), + MessageHeader::RecordBatch, + usize::MAX, + &write_options, + ) + .unwrap() + .pop() + .unwrap(); + + assert_eq!(compute_size, encoded.arrow_data.len()); + } + + let str_arr = [Some("fooo"), Some("ba"), Some("bazrrrrrrrrr"), Some("quz")] + .into_iter() + .collect::(); + let int_arr = [None, Some(2), Some(1), Some(3)] + .into_iter() + .collect::(); + encode_test(str_arr.clone()); + encode_test(int_arr.clone()); + + // For some reason, DictionaryArrays don't encode their `values` the flight messages. I + // don't know why that is, but that will cause this test to fail. + // encode_test(DictionaryArray::new(int_arr, Arc::new(str_arr))); + + let time_arr = [Some(0), Some(14000), Some(-1), Some(-1)] + .into_iter() + .collect::(); + encode_test(time_arr); + + let list_field: FieldRef = Arc::new(Field::new("a", DataType::Int32, true)); + let all_null_list = FixedSizeListArray::new_null(Arc::clone(&list_field), 3, 8); + encode_test(all_null_list); + + let list = FixedSizeListArray::new(list_field, 2, make_array(int_arr.to_data()), None); + encode_test(list); + + let vals: Vec> = vec![Some(1), None, Some(2), Some(3), Some(4), None, Some(5)]; + let repeats: Vec = vec![3, 4, 1, 2]; + let mut input_array: Vec> = Vec::with_capacity(80); + for ix in 0_usize..32 { + let repeat: usize = repeats[ix % repeats.len()]; + let val: Option = vals[ix % vals.len()]; + input_array.resize(input_array.len() + repeat, val); + } + let mut builder = + PrimitiveRunBuilder::::with_capacity(input_array.len()); + builder.extend(input_array); + let run_array = builder.finish(); + encode_test(run_array); + } }