diff --git a/src/mito2/src/memtable/merge_tree/data.rs b/src/mito2/src/memtable/merge_tree/data.rs index 2e903519e618..e4ed65f8f601 100644 --- a/src/mito2/src/memtable/merge_tree/data.rs +++ b/src/mito2/src/memtable/merge_tree/data.rs @@ -23,7 +23,7 @@ use datatypes::arrow; use datatypes::arrow::array::{RecordBatch, UInt16Array, UInt32Array}; use datatypes::arrow::datatypes::{Field, Schema, SchemaRef}; use datatypes::data_type::DataType; -use datatypes::prelude::{ConcreteDataType, MutableVector, ScalarVectorBuilder, VectorRef}; +use datatypes::prelude::{ConcreteDataType, MutableVector, ScalarVectorBuilder, Vector, VectorRef}; use datatypes::schema::ColumnSchema; use datatypes::types::TimestampType; use datatypes::vectors::{ @@ -136,11 +136,12 @@ impl DataBuffer { } } - /// Freezes `DataBuffer` to bytes. Use `pk_weights` to convert pk_id to pk sort order. + /// Freezes `DataBuffer` to bytes. Use `pk_weights` to sort rows and replace pk_index to pk_weights. /// `freeze` clears the buffers of builders. - pub fn freeze(&mut self, _pk_weights: &[u16]) -> Result { - // we need distinguish between `freeze` in `ShardWriter` And `Shard`. - todo!() + pub fn freeze(&mut self, pk_weights: &[u16]) -> Result { + let encoder = DataPartEncoder::new(&self.metadata, pk_weights, None); + let encoded = encoder.write(self)?; + Ok(DataPart::Parquet(encoded)) } /// Reads batches from data buffer without resetting builder's buffers. @@ -152,6 +153,7 @@ impl DataBuffer { pk_weights, true, true, + true, )?; DataBufferIter::new(batch) } @@ -190,12 +192,16 @@ impl LazyMutableVectorBuilder { } /// Converts `DataBuffer` to record batches, with rows sorted according to pk_weights. +/// `keep_data`: whether to keep the original data inside `DataBuffer`. +/// `dedup`: whether to true to remove the duplicated rows inside `DataBuffer`. +/// `replace_pk_index`: whether to replace the pk_index values with corresponding pk weight. fn data_buffer_to_record_batches( schema: SchemaRef, buffer: &mut DataBuffer, pk_weights: &[u16], keep_data: bool, dedup: bool, + replace_pk_index: bool, ) -> Result { let num_rows = buffer.ts_builder.len(); @@ -217,17 +223,27 @@ fn data_buffer_to_record_batches( let mut rows = build_rows_to_sort(pk_weights, &pk_index_v, &ts_v, &sequence_v); + let pk_array = if replace_pk_index { + // replace pk index values with pk weights. + Arc::new(UInt16Array::from_iter_values( + rows.iter().map(|(_, key)| key.pk_weight), + )) as Arc<_> + } else { + pk_index_v.to_arrow_array() + }; + // sort and dedup rows.sort_unstable_by(|l, r| l.1.cmp(&r.1)); if dedup { rows.dedup_by(|l, r| l.1.pk_weight == r.1.pk_weight && l.1.timestamp == r.1.timestamp); } - let indices_to_take = UInt32Array::from_iter_values(rows.into_iter().map(|v| v.0 as u32)); + + let indices_to_take = UInt32Array::from_iter_values(rows.iter().map(|(idx, _)| *idx as u32)); let mut columns = Vec::with_capacity(4 + buffer.field_builders.len()); columns.push( - arrow::compute::take(&pk_index_v.as_arrow(), &indices_to_take, None) + arrow::compute::take(&pk_array, &indices_to_take, None) .context(error::ComputeArrowSnafu)?, ); @@ -500,6 +516,7 @@ impl<'a> DataPartEncoder<'a> { self.pk_weights, false, true, + true, )?; writer.write(&rb).context(error::EncodeMemtableSnafu)?; let _file_meta = writer.close().context(error::EncodeMemtableSnafu)?; @@ -563,7 +580,8 @@ mod tests { assert_eq!(5, buffer.num_rows()); let schema = memtable_schema_to_encoded_schema(&meta); let batch = - data_buffer_to_record_batches(schema, &mut buffer, &[3, 1], keep_data, true).unwrap(); + data_buffer_to_record_batches(schema, &mut buffer, &[3, 1], keep_data, true, true) + .unwrap(); assert_eq!( vec![1, 2, 1, 2], @@ -579,7 +597,7 @@ mod tests { ); assert_eq!( - vec![1, 1, 0, 0], + vec![1, 1, 3, 3], batch .column_by_name(PK_INDEX_COLUMN_NAME) .unwrap() @@ -627,7 +645,7 @@ mod tests { assert_eq!(4, buffer.num_rows()); let schema = memtable_schema_to_encoded_schema(&meta); let batch = - data_buffer_to_record_batches(schema, &mut buffer, &[0, 1], true, true).unwrap(); + data_buffer_to_record_batches(schema, &mut buffer, &[0, 1], true, true, true).unwrap(); assert_eq!(3, batch.num_rows()); assert_eq!( @@ -681,10 +699,10 @@ mod tests { assert_eq!(5, buffer.num_rows()); let schema = memtable_schema_to_encoded_schema(&meta); let batch = - data_buffer_to_record_batches(schema, &mut buffer, &[3, 1], true, false).unwrap(); + data_buffer_to_record_batches(schema, &mut buffer, &[3, 1], true, false, true).unwrap(); assert_eq!( - vec![1, 1, 0, 0, 0], + vec![1, 1, 3, 3, 3], batch .column_by_name(PK_INDEX_COLUMN_NAME) .unwrap()