From 65b3f2c138a70718b51cc117f7a6f1ac4c0ba041 Mon Sep 17 00:00:00 2001 From: Gwo Tzu-Hsing Date: Wed, 24 Jul 2024 22:32:15 +0800 Subject: [PATCH] fix: getting record from sstable does not handle nullable logic --- Cargo.toml | 1 + src/compaction/mod.rs | 6 +- src/lib.rs | 107 +++++++++++------------------- src/ondisk/sstable.rs | 60 ++++++++--------- src/option.rs | 132 +++++++++++++++++++++++++++++++++++++ src/record/internal.rs | 27 +++++++- src/record/mod.rs | 4 +- src/serdes/mod.rs | 6 +- src/serdes/option.rs | 2 +- src/stream/level.rs | 74 ++++++++++----------- src/stream/mod.rs | 6 +- src/stream/record_batch.rs | 2 +- src/timestamp/mod.rs | 2 +- src/transaction.rs | 44 +++++++------ src/version/set.rs | 2 +- 15 files changed, 306 insertions(+), 169 deletions(-) create mode 100644 src/option.rs diff --git a/Cargo.toml b/Cargo.toml index 46f29fa3..f539f2e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ tokio = ["dep:tokio"] [dependencies] arrow = "52" async-lock = "3" +crc32fast = "1" crossbeam-skiplist = "0.1" flume = { version = "0.11", features = ["async"] } futures-core = "0.3" diff --git a/src/compaction/mod.rs b/src/compaction/mod.rs index 8b31c915..bb36a1a7 100644 --- a/src/compaction/mod.rs +++ b/src/compaction/mod.rs @@ -330,7 +330,7 @@ where max = Some(key.value.to_key()); written_size += key.size(); - builder.push(key, Some(entry.value())); + builder.push(key, entry.value()); if written_size >= option.max_sst_file_size { Self::build_table( @@ -523,7 +523,7 @@ pub(crate) mod tests { }); let scope = Compactor::::minor_compaction( - &DbOption::new(temp_dir.path()), + &DbOption::from(temp_dir.path()), VecDeque::from(vec![batch_2, batch_1]), ) .await @@ -537,7 +537,7 @@ pub(crate) mod tests { async fn major_compaction() { let temp_dir = TempDir::new().unwrap(); - let mut option = DbOption::new(temp_dir.path()); + let mut option = DbOption::from(temp_dir.path()); option.major_threshold_with_sst_size = 2; let option = Arc::new(option); diff --git a/src/lib.rs b/src/lib.rs index b8fb68e8..a511260e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,21 +1,20 @@ #![allow(dead_code)] -pub(crate) mod arrows; +mod arrows; mod compaction; pub mod executor; pub mod fs; mod inmem; mod ondisk; +pub mod option; mod record; mod scope; pub mod serdes; mod stream; mod timestamp; mod transaction; -pub(crate) mod version; +mod version; -use std::{ - collections::VecDeque, io, marker::PhantomData, mem, ops::Bound, path::PathBuf, sync::Arc, -}; +use std::{collections::VecDeque, io, marker::PhantomData, mem, ops::Bound, sync::Arc}; use async_lock::{RwLock, RwLockReadGuard}; use fs::FileProvider; @@ -26,7 +25,6 @@ use lockable::LockableHashMap; use parquet::{ arrow::{arrow_to_parquet_schema, ProjectionMask}, errors::ParquetError, - file::properties::WriterProperties, }; use record::Record; use thiserror::Error; @@ -34,32 +32,13 @@ use timestamp::Timestamp; use tracing::error; use transaction::Transaction; +pub use crate::option::*; use crate::{ executor::Executor, - fs::{FileId, FileType}, stream::{merge::MergeStream, Entry, ScanStream}, version::{cleaner::Cleaner, set::VersionSet, Version, VersionError}, }; -type LockMap = Arc>; - -pub enum Projection { - All, - Parts(Vec), -} - -#[derive(Debug)] -pub struct DbOption { - pub path: PathBuf, - pub max_mem_table_size: usize, - pub immutable_chunk_num: usize, - pub major_threshold_with_sst_size: usize, - pub level_sst_magnification: usize, - pub max_sst_file_size: usize, - pub clean_channel_buffer: usize, - pub write_parquet_option: Option, -} - pub struct DB where R: Record, @@ -71,46 +50,6 @@ where _p: PhantomData, } -impl DbOption { - pub fn new(path: impl Into + Send) -> Self { - DbOption { - path: path.into(), - max_mem_table_size: 8 * 1024 * 1024, - immutable_chunk_num: 3, - major_threshold_with_sst_size: 10, - level_sst_magnification: 10, - max_sst_file_size: 24 * 1024 * 1024, - clean_channel_buffer: 10, - write_parquet_option: None, - } - } - - pub(crate) fn table_path(&self, gen: &FileId) -> PathBuf { - self.path.join(format!("{}.{}", gen, FileType::Parquet)) - } - - pub(crate) fn wal_path(&self, gen: &FileId) -> PathBuf { - self.path.join(format!("{}.{}", gen, FileType::Wal)) - } - - pub(crate) fn version_path(&self) -> PathBuf { - self.path.join(format!("version.{}", FileType::Log)) - } - - pub(crate) fn is_threshold_exceeded_major( - &self, - version: &Version, - level: usize, - ) -> bool - where - R: Record, - E: FileProvider, - { - Version::::tables_len(version, level) - >= (self.major_threshold_with_sst_size * self.level_sst_magnification.pow(level as u32)) - } -} - impl DB where R: Record + Send, @@ -350,6 +289,13 @@ where Parquet(#[from] ParquetError), } +type LockMap = Arc>; + +pub enum Projection { + All, + Parts(Vec), +} + #[cfg(test)] pub(crate) mod tests { use std::{collections::VecDeque, sync::Arc}; @@ -359,10 +305,12 @@ pub(crate) mod tests { datatypes::{DataType, Field, Schema, UInt32Type}, }; use async_lock::RwLock; + use futures_util::io; use once_cell::sync::Lazy; use parquet::arrow::ProjectionMask; use tracing::error; + use crate::serdes::{Decode, Encode}; use crate::{ executor::{tokio::TokioExecutor, Executor}, inmem::{ @@ -381,6 +329,17 @@ pub(crate) mod tests { pub vobool: Option, } + impl Decode for Test { + type Error = io::Error; + + async fn decode(reader: &mut R) -> Result + where + R: futures_io::AsyncRead + Unpin, + { + todo!() + } + } + impl Record for Test { type Columns = TestImmutableArrays; @@ -426,6 +385,21 @@ pub(crate) mod tests { pub vbool: Option, } + impl<'r> Encode for TestRef<'r> { + type Error = io::Error; + + async fn encode(&self, writer: &mut W) -> Result<(), Self::Error> + where + W: io::AsyncWrite + Unpin + Send + Sync, + { + todo!() + } + + fn size(&self) -> usize { + todo!() + } + } + impl<'r> RecordRef<'r> for TestRef<'r> { type Record = Test; @@ -466,7 +440,6 @@ pub(crate) mod tests { if !vbool_array.is_null(offset) { vbool = Some(vbool_array.value(offset)); } - column_i += 1; } let record = TestRef { diff --git a/src/ondisk/sstable.rs b/src/ondisk/sstable.rs index 9d6c4145..dacf8de9 100644 --- a/src/ondisk/sstable.rs +++ b/src/ondisk/sstable.rs @@ -158,7 +158,7 @@ pub(crate) mod tests { async fn write_sstable() { let temp_dir = tempfile::tempdir().unwrap(); let record_batch = get_test_record_batch::( - Arc::new(DbOption::new(temp_dir.path())), + Arc::new(DbOption::from(temp_dir.path())), TokioExecutor::new(), ) .await; @@ -185,7 +185,7 @@ pub(crate) mod tests { async fn projection_query() { let temp_dir = tempfile::tempdir().unwrap(); let record_batch = get_test_record_batch::( - Arc::new(DbOption::new(temp_dir.path())), + Arc::new(DbOption::from(temp_dir.path())), TokioExecutor::new(), ) .await; @@ -212,9 +212,9 @@ pub(crate) mod tests { .await .unwrap() .unwrap(); - assert_eq!(test_ref_1.get().vstring, "hello"); - assert_eq!(test_ref_1.get().vu32, Some(12)); - assert_eq!(test_ref_1.get().vbool, None); + assert_eq!(test_ref_1.get().unwrap().vstring, "hello"); + assert_eq!(test_ref_1.get().unwrap().vu32, Some(12)); + assert_eq!(test_ref_1.get().unwrap().vbool, None); } { let test_ref_2 = open_sstable::(&table_path) @@ -229,9 +229,9 @@ pub(crate) mod tests { .await .unwrap() .unwrap(); - assert_eq!(test_ref_2.get().vstring, "hello"); - assert_eq!(test_ref_2.get().vu32, None); - assert_eq!(test_ref_2.get().vbool, Some(true)); + assert_eq!(test_ref_2.get().unwrap().vstring, "hello"); + assert_eq!(test_ref_2.get().unwrap().vu32, None); + assert_eq!(test_ref_2.get().unwrap().vbool, Some(true)); } { let test_ref_3 = open_sstable::(&table_path) @@ -246,9 +246,9 @@ pub(crate) mod tests { .await .unwrap() .unwrap(); - assert_eq!(test_ref_3.get().vstring, "hello"); - assert_eq!(test_ref_3.get().vu32, None); - assert_eq!(test_ref_3.get().vbool, None); + assert_eq!(test_ref_3.get().unwrap().vstring, "hello"); + assert_eq!(test_ref_3.get().unwrap().vu32, None); + assert_eq!(test_ref_3.get().unwrap().vbool, None); } } @@ -256,7 +256,7 @@ pub(crate) mod tests { async fn projection_scan() { let temp_dir = tempfile::tempdir().unwrap(); let record_batch = get_test_record_batch::( - Arc::new(DbOption::new(temp_dir.path())), + Arc::new(DbOption::from(temp_dir.path())), TokioExecutor::new(), ) .await; @@ -284,14 +284,14 @@ pub(crate) mod tests { .unwrap(); let entry_0 = test_ref_1.next().await.unwrap().unwrap(); - assert_eq!(entry_0.get().vstring, "hello"); - assert_eq!(entry_0.get().vu32, Some(12)); - assert_eq!(entry_0.get().vbool, None); + assert_eq!(entry_0.get().unwrap().vstring, "hello"); + assert_eq!(entry_0.get().unwrap().vu32, Some(12)); + assert_eq!(entry_0.get().unwrap().vbool, None); let entry_1 = test_ref_1.next().await.unwrap().unwrap(); - assert_eq!(entry_1.get().vstring, "world"); - assert_eq!(entry_1.get().vu32, Some(12)); - assert_eq!(entry_1.get().vbool, None); + assert_eq!(entry_1.get().unwrap().vstring, "world"); + assert_eq!(entry_1.get().unwrap().vu32, Some(12)); + assert_eq!(entry_1.get().unwrap().vbool, None); } { let mut test_ref_2 = open_sstable::(&table_path) @@ -309,14 +309,14 @@ pub(crate) mod tests { .unwrap(); let entry_0 = test_ref_2.next().await.unwrap().unwrap(); - assert_eq!(entry_0.get().vstring, "hello"); - assert_eq!(entry_0.get().vu32, None); - assert_eq!(entry_0.get().vbool, Some(true)); + assert_eq!(entry_0.get().unwrap().vstring, "hello"); + assert_eq!(entry_0.get().unwrap().vu32, None); + assert_eq!(entry_0.get().unwrap().vbool, Some(true)); let entry_1 = test_ref_2.next().await.unwrap().unwrap(); - assert_eq!(entry_1.get().vstring, "world"); - assert_eq!(entry_1.get().vu32, None); - assert_eq!(entry_1.get().vbool, None); + assert_eq!(entry_1.get().unwrap().vstring, "world"); + assert_eq!(entry_1.get().unwrap().vu32, None); + assert_eq!(entry_1.get().unwrap().vbool, None); } { let mut test_ref_3 = open_sstable::(&table_path) @@ -334,14 +334,14 @@ pub(crate) mod tests { .unwrap(); let entry_0 = test_ref_3.next().await.unwrap().unwrap(); - assert_eq!(entry_0.get().vstring, "hello"); - assert_eq!(entry_0.get().vu32, None); - assert_eq!(entry_0.get().vbool, None); + assert_eq!(entry_0.get().unwrap().vstring, "hello"); + assert_eq!(entry_0.get().unwrap().vu32, None); + assert_eq!(entry_0.get().unwrap().vbool, None); let entry_1 = test_ref_3.next().await.unwrap().unwrap(); - assert_eq!(entry_1.get().vstring, "world"); - assert_eq!(entry_1.get().vu32, None); - assert_eq!(entry_1.get().vbool, None); + assert_eq!(entry_1.get().unwrap().vstring, "world"); + assert_eq!(entry_1.get().unwrap().vu32, None); + assert_eq!(entry_1.get().unwrap().vbool, None); } } } diff --git a/src/option.rs b/src/option.rs new file mode 100644 index 00000000..5823b40d --- /dev/null +++ b/src/option.rs @@ -0,0 +1,132 @@ +use std::path::PathBuf; + +use parquet::file::properties::WriterProperties; + +use crate::{ + fs::{FileId, FileProvider, FileType}, + record::Record, + version::Version, +}; + +#[derive(Debug)] +pub struct DbOption { + pub(crate) path: PathBuf, + pub(crate) max_mem_table_size: usize, + pub(crate) immutable_chunk_num: usize, + pub(crate) major_threshold_with_sst_size: usize, + pub(crate) level_sst_magnification: usize, + pub(crate) max_sst_file_size: usize, + pub(crate) clean_channel_buffer: usize, + pub(crate) write_parquet_option: Option, + + pub(crate) use_wal: bool, +} + +impl

From

for DbOption +where + P: Into, +{ + fn from(path: P) -> Self { + DbOption { + path: path.into(), + max_mem_table_size: 8 * 1024 * 1024, + immutable_chunk_num: 3, + major_threshold_with_sst_size: 10, + level_sst_magnification: 10, + max_sst_file_size: 24 * 1024 * 1024, + clean_channel_buffer: 10, + write_parquet_option: None, + + use_wal: true, + } + } +} + +impl DbOption { + pub fn path(self, path: impl Into) -> Self { + DbOption { + path: path.into(), + ..self + } + } + + pub fn max_mem_table_size(self, max_mem_table_size: usize) -> Self { + DbOption { + max_mem_table_size, + ..self + } + } + + pub fn immutable_chunk_num(self, immutable_chunk_num: usize) -> Self { + DbOption { + immutable_chunk_num, + ..self + } + } + + pub fn major_threshold_with_sst_size(self, major_threshold_with_sst_size: usize) -> Self { + DbOption { + major_threshold_with_sst_size, + ..self + } + } + + pub fn level_sst_magnification(self, level_sst_magnification: usize) -> Self { + DbOption { + level_sst_magnification, + ..self + } + } + + pub fn max_sst_file_size(self, max_sst_file_size: usize) -> Self { + DbOption { + max_sst_file_size, + ..self + } + } + + pub fn clean_channel_buffer(self, clean_channel_buffer: usize) -> Self { + DbOption { + clean_channel_buffer, + ..self + } + } + + pub fn write_parquet_option(self, write_parquet_option: WriterProperties) -> Self { + DbOption { + write_parquet_option: Some(write_parquet_option), + ..self + } + } + + pub fn use_wal(self, use_wal: bool) -> Self { + DbOption { use_wal, ..self } + } +} + +impl DbOption { + pub(crate) fn table_path(&self, gen: &FileId) -> PathBuf { + self.path.join(format!("{}.{}", gen, FileType::Parquet)) + } + + pub(crate) fn wal_path(&self, gen: &FileId) -> PathBuf { + self.path.join(format!("{}.{}", gen, FileType::Wal)) + } + + pub(crate) fn version_path(&self) -> PathBuf { + self.path.join(format!("version.{}", FileType::Log)) + } + + pub(crate) fn is_threshold_exceeded_major( + &self, + version: &Version, + level: usize, + ) -> bool + where + R: Record, + E: FileProvider, + { + Version::::tables_len(version, level) + >= (self.major_threshold_with_sst_size * self.level_sst_magnification.pow(level as u32)) + } +} diff --git a/src/record/internal.rs b/src/record/internal.rs index 709ff2e2..a34fd860 100644 --- a/src/record/internal.rs +++ b/src/record/internal.rs @@ -1,6 +1,7 @@ use std::{marker::PhantomData, mem::transmute}; use super::{Key, Record, RecordRef}; +use crate::serdes::Encode; use crate::timestamp::{Timestamp, Timestamped}; #[derive(Debug)] @@ -37,7 +38,29 @@ where unsafe { transmute(Timestamped::new(self.record.key(), self.ts)) } } - pub(crate) fn get(&self) -> R { - self.record + pub(crate) fn get(&self) -> Option { + if self.null { + return None; + } + + Some(self.record) } } + +// impl<'r, R> Encode for InternalRecordRef<'r, R> +// where +// R: RecordRef<'r>, +// { +// type Error; + +// async fn encode(&self, writer: &mut W) -> Result<(), Self::Error> +// where +// W: futures_io::AsyncWrite + Unpin + Send + Sync, +// { +// todo!() +// } + +// fn size(&self) -> usize { +// todo!() +// } +// } diff --git a/src/record/mod.rs b/src/record/mod.rs index 53c326d9..14909854 100644 --- a/src/record/mod.rs +++ b/src/record/mod.rs @@ -15,7 +15,9 @@ use crate::{ serdes::{Decode, Encode}, }; -pub trait Key: 'static + Encode + Decode + Ord + Clone + Send + Hash + std::fmt::Debug { +pub trait Key: + 'static + Encode + Decode + Ord + Clone + Send + Sync + Hash + std::fmt::Debug +{ type Ref<'r>: KeyRef<'r, Key = Self> + Copy + Debug where Self: 'r; diff --git a/src/serdes/mod.rs b/src/serdes/mod.rs index beb9e6b0..ba0942b1 100644 --- a/src/serdes/mod.rs +++ b/src/serdes/mod.rs @@ -8,17 +8,17 @@ use std::{future::Future, io}; use futures_io::{AsyncRead, AsyncWrite}; -pub trait Encode: Send + Sync { +pub trait Encode { type Error: From + std::error::Error + Send + Sync + 'static; - fn encode(&self, writer: &mut W) -> impl Future> + Send + fn encode(&self, writer: &mut W) -> impl Future> where W: AsyncWrite + Unpin + Send + Sync; fn size(&self) -> usize; } -impl Encode for &T { +impl Encode for &T { type Error = T::Error; async fn encode(&self, writer: &mut W) -> Result<(), Self::Error> diff --git a/src/serdes/option.rs b/src/serdes/option.rs index 397e4397..ff3475ee 100644 --- a/src/serdes/option.rs +++ b/src/serdes/option.rs @@ -32,7 +32,7 @@ where impl Encode for Option where - V: Encode + Sync, + V: Encode, { type Error = EncodeError; diff --git a/src/stream/level.rs b/src/stream/level.rs index d0ae2ba3..531a07d5 100644 --- a/src/stream/level.rs +++ b/src/stream/level.rs @@ -163,7 +163,7 @@ mod tests { #[tokio::test] async fn projection_scan() { let temp_dir = TempDir::new().unwrap(); - let option = Arc::new(DbOption::new(temp_dir.path())); + let option = Arc::new(DbOption::from(temp_dir.path())); let (_, version) = build_version(&option).await; @@ -184,23 +184,23 @@ mod tests { .unwrap(); let entry_0 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_0.get().vu32.is_some()); - assert!(entry_0.get().vbool.is_none()); + assert!(entry_0.get().unwrap().vu32.is_some()); + assert!(entry_0.get().unwrap().vbool.is_none()); let entry_1 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_1.get().vu32.is_some()); - assert!(entry_1.get().vbool.is_none()); + assert!(entry_1.get().unwrap().vu32.is_some()); + assert!(entry_1.get().unwrap().vbool.is_none()); let entry_2 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_2.get().vu32.is_some()); - assert!(entry_2.get().vbool.is_none()); + assert!(entry_2.get().unwrap().vu32.is_some()); + assert!(entry_2.get().unwrap().vbool.is_none()); let entry_3 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_3.get().vu32.is_some()); - assert!(entry_3.get().vbool.is_none()); + assert!(entry_3.get().unwrap().vu32.is_some()); + assert!(entry_3.get().unwrap().vbool.is_none()); let entry_4 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_4.get().vu32.is_some()); - assert!(entry_4.get().vbool.is_none()); + assert!(entry_4.get().unwrap().vu32.is_some()); + assert!(entry_4.get().unwrap().vbool.is_none()); let entry_5 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_5.get().vu32.is_some()); - assert!(entry_5.get().vbool.is_none()); + assert!(entry_5.get().unwrap().vu32.is_some()); + assert!(entry_5.get().unwrap().vbool.is_none()); } { let mut level_stream_1 = LevelStream::new( @@ -219,23 +219,23 @@ mod tests { .unwrap(); let entry_0 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_0.get().vu32.is_none()); - assert!(entry_0.get().vbool.is_some()); + assert!(entry_0.get().unwrap().vu32.is_none()); + assert!(entry_0.get().unwrap().vbool.is_some()); let entry_1 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_1.get().vu32.is_none()); - assert!(entry_1.get().vbool.is_some()); + assert!(entry_1.get().unwrap().vu32.is_none()); + assert!(entry_1.get().unwrap().vbool.is_some()); let entry_2 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_2.get().vu32.is_none()); - assert!(entry_2.get().vbool.is_some()); + assert!(entry_2.get().unwrap().vu32.is_none()); + assert!(entry_2.get().unwrap().vbool.is_some()); let entry_3 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_3.get().vu32.is_none()); - assert!(entry_3.get().vbool.is_some()); + assert!(entry_3.get().unwrap().vu32.is_none()); + assert!(entry_3.get().unwrap().vbool.is_some()); let entry_4 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_4.get().vu32.is_none()); - assert!(entry_4.get().vbool.is_some()); + assert!(entry_4.get().unwrap().vu32.is_none()); + assert!(entry_4.get().unwrap().vbool.is_some()); let entry_5 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_5.get().vu32.is_none()); - assert!(entry_5.get().vbool.is_some()); + assert!(entry_5.get().unwrap().vu32.is_none()); + assert!(entry_5.get().unwrap().vbool.is_some()); } { let mut level_stream_1 = LevelStream::new( @@ -254,23 +254,23 @@ mod tests { .unwrap(); let entry_0 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_0.get().vu32.is_none()); - assert!(entry_0.get().vbool.is_none()); + assert!(entry_0.get().unwrap().vu32.is_none()); + assert!(entry_0.get().unwrap().vbool.is_none()); let entry_1 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_1.get().vu32.is_none()); - assert!(entry_1.get().vbool.is_none()); + assert!(entry_1.get().unwrap().vu32.is_none()); + assert!(entry_1.get().unwrap().vbool.is_none()); let entry_2 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_2.get().vu32.is_none()); - assert!(entry_2.get().vbool.is_none()); + assert!(entry_2.get().unwrap().vu32.is_none()); + assert!(entry_2.get().unwrap().vbool.is_none()); let entry_3 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_3.get().vu32.is_none()); - assert!(entry_3.get().vbool.is_none()); + assert!(entry_3.get().unwrap().vu32.is_none()); + assert!(entry_3.get().unwrap().vbool.is_none()); let entry_4 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_4.get().vu32.is_none()); - assert!(entry_4.get().vbool.is_none()); + assert!(entry_4.get().unwrap().vu32.is_none()); + assert!(entry_4.get().unwrap().vbool.is_none()); let entry_5 = level_stream_1.next().await.unwrap().unwrap(); - assert!(entry_5.get().vu32.is_none()); - assert!(entry_5.get().vbool.is_none()); + assert!(entry_5.get().unwrap().vu32.is_none()); + assert!(entry_5.get().unwrap().vbool.is_none()); } } } diff --git a/src/stream/mod.rs b/src/stream/mod.rs index ed2b50a4..45e06746 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -55,10 +55,10 @@ where } } - pub(crate) fn value(&self) -> R::Ref<'_> { + pub(crate) fn value(&self) -> Option> { match self { - Entry::Transaction((_, value)) => value.as_ref().map(R::as_record_ref).unwrap(), - Entry::Mutable(entry) => entry.value().as_ref().map(R::as_record_ref).unwrap(), + Entry::Transaction((_, value)) => value.as_ref().map(R::as_record_ref), + Entry::Mutable(entry) => entry.value().as_ref().map(R::as_record_ref), Entry::SsTable(entry) => entry.get(), Entry::Immutable(entry) => entry.get(), Entry::Level(entry) => entry.get(), diff --git a/src/stream/record_batch.rs b/src/stream/record_batch.rs index fe0a176c..e9ba6741 100644 --- a/src/stream/record_batch.rs +++ b/src/stream/record_batch.rs @@ -42,7 +42,7 @@ where *self.record_ref.value().value() } - pub fn get(&self) -> R::Ref<'_> { + pub fn get(&self) -> Option> { // Safety: shorter lifetime of the key must be safe unsafe { transmute(self.record_ref.get()) } } diff --git a/src/timestamp/mod.rs b/src/timestamp/mod.rs index a4cfe039..dbb1981c 100644 --- a/src/timestamp/mod.rs +++ b/src/timestamp/mod.rs @@ -37,7 +37,7 @@ impl Timestamp { impl Encode for Timestamp { type Error = io::Error; - fn encode(&self, writer: &mut W) -> impl Future> + Send + fn encode(&self, writer: &mut W) -> impl Future> where W: AsyncWrite + Unpin + Send + Sync, { diff --git a/src/transaction.rs b/src/transaction.rs index 83114d05..1f82a58f 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -81,7 +81,13 @@ where .share .get(&self.version, key, self.ts, projection) .await? - .map(TransactionEntry::Stream), + .and_then(|entry| { + if entry.value().is_none() { + None + } else { + TransactionEntry::Stream(entry).into() + } + }), }) } @@ -153,12 +159,12 @@ impl<'entry, R> TransactionEntry<'entry, R> where R: Record, { - pub fn get(&self) -> R::Ref<'_> { + pub fn get(&self) -> Option> { match self { TransactionEntry::Stream(entry) => entry.value(), TransactionEntry::Local(value) => { // Safety: shorter lifetime must be safe - unsafe { transmute::, R::Ref<'_>>(*value) } + Some(unsafe { transmute::, R::Ref<'_>>(*value) }) } } } @@ -197,7 +203,7 @@ mod tests { let temp_dir = TempDir::new().unwrap(); let db = DB::::new( - Arc::new(DbOption::new(temp_dir.path())), + Arc::new(DbOption::from(temp_dir.path())), TokioExecutor::new(), ) .await @@ -233,7 +239,7 @@ mod tests { let temp_dir = TempDir::new().unwrap(); let db = DB::::new( - Arc::new(DbOption::new(temp_dir.path())), + Arc::new(DbOption::from(temp_dir.path())), TokioExecutor::new(), ) .await @@ -268,7 +274,7 @@ mod tests { let temp_dir = TempDir::new().unwrap(); let db = DB::::new( - Arc::new(DbOption::new(temp_dir.path())), + Arc::new(DbOption::from(temp_dir.path())), TokioExecutor::new(), ) .await @@ -284,9 +290,9 @@ mod tests { let key = 0.to_string(); let entry = txn1.get(&key, Projection::All).await.unwrap().unwrap(); - assert_eq!(entry.get().vstring, 0.to_string()); - assert_eq!(entry.get().vu32, Some(0)); - assert_eq!(entry.get().vbool, Some(true)); + assert_eq!(entry.get().unwrap().vstring, 0.to_string()); + assert_eq!(entry.get().unwrap().vu32, Some(0)); + assert_eq!(entry.get().unwrap().vbool, Some(true)); drop(entry); txn1.commit().await.unwrap(); @@ -295,7 +301,7 @@ mod tests { #[tokio::test] async fn transaction_scan() { let temp_dir = TempDir::new().unwrap(); - let option = Arc::new(DbOption::new(temp_dir.path())); + let option = Arc::new(DbOption::from(temp_dir.path())); let (_, version) = build_version(&option).await; let schema = build_schema().await; @@ -320,31 +326,31 @@ mod tests { let entry_0 = stream.next().await.unwrap().unwrap(); assert_eq!(entry_0.key().value, "1"); - assert!(entry_0.value().vbool.is_none()); + assert!(entry_0.value().unwrap().vbool.is_none()); let entry_1 = stream.next().await.unwrap().unwrap(); assert_eq!(entry_1.key().value, "2"); - assert!(entry_1.value().vbool.is_none()); + assert!(entry_1.value().unwrap().vbool.is_none()); let entry_2 = stream.next().await.unwrap().unwrap(); assert_eq!(entry_2.key().value, "3"); - assert!(entry_2.value().vbool.is_none()); + assert!(entry_2.value().unwrap().vbool.is_none()); let entry_3 = stream.next().await.unwrap().unwrap(); assert_eq!(entry_3.key().value, "4"); - assert!(entry_3.value().vbool.is_none()); + assert!(entry_3.value().unwrap().vbool.is_none()); let entry_4 = stream.next().await.unwrap().unwrap(); assert_eq!(entry_4.key().value, "5"); - assert!(entry_4.value().vbool.is_none()); + assert!(entry_4.value().unwrap().vbool.is_none()); let entry_5 = stream.next().await.unwrap().unwrap(); assert_eq!(entry_5.key().value, "6"); - assert!(entry_5.value().vbool.is_none()); + assert!(entry_5.value().unwrap().vbool.is_none()); let entry_6 = stream.next().await.unwrap().unwrap(); assert_eq!(entry_6.key().value, "7"); - assert!(entry_6.value().vbool.is_none()); + assert!(entry_6.value().unwrap().vbool.is_none()); let entry_7 = stream.next().await.unwrap().unwrap(); assert_eq!(entry_7.key().value, "8"); - assert!(entry_7.value().vbool.is_none()); + assert!(entry_7.value().unwrap().vbool.is_none()); let entry_8 = stream.next().await.unwrap().unwrap(); assert_eq!(entry_8.key().value, "9"); - assert!(entry_8.value().vbool.is_none()); + assert!(entry_8.value().unwrap().vbool.is_none()); let entry_9 = stream.next().await.unwrap().unwrap(); assert_eq!(entry_9.key().value, "alice"); let entry_10 = stream.next().await.unwrap().unwrap(); diff --git a/src/version/set.rs b/src/version/set.rs index 8a208c53..e8ecc6da 100644 --- a/src/version/set.rs +++ b/src/version/set.rs @@ -206,7 +206,7 @@ pub(crate) mod tests { async fn timestamp_persistence() { let temp_dir = TempDir::new().unwrap(); let (sender, _) = bounded(1); - let option = Arc::new(DbOption::new(temp_dir.path())); + let option = Arc::new(DbOption::from(temp_dir.path())); let version_set: VersionSet = VersionSet::new(sender.clone(), option.clone())