Skip to content

Commit

Permalink
fix(storage): don't drop mmap when writer is dropped
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-starkware committed Oct 18, 2023
1 parent ca20b6c commit 0a886a0
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 55 deletions.
50 changes: 36 additions & 14 deletions crates/papyrus_storage/src/mmap_file/mmap_file_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,9 @@ fn write_read() {

let len = writer.insert(offset, &data);
let res_writer = writer.get(LocationInFile { offset, len }).unwrap();
assert_eq!(res_writer, data);
assert_eq!(res_writer.unwrap(), data);

let another_reader = reader;
let res: Vec<u8> = reader.get(LocationInFile { offset, len }).unwrap();
assert_eq!(res, data);

let res: Vec<u8> = another_reader.get(LocationInFile { offset, len }).unwrap();
let res: Vec<u8> = reader.get(LocationInFile { offset, len }).unwrap().unwrap();
assert_eq!(res, data);

dir.close().unwrap();
Expand All @@ -66,12 +62,13 @@ fn concurrent_reads() {
let mut handles = vec![];

for _ in 0..num_threads {
let reader = reader.clone();
let handle = std::thread::spawn(move || reader.get(location_in_file).unwrap());
handles.push(handle);
}

for handle in handles {
let res: Vec<u8> = handle.join().unwrap();
let res: Vec<u8> = handle.join().unwrap().unwrap();
assert_eq!(res, data);
}

Expand Down Expand Up @@ -99,13 +96,11 @@ fn concurrent_reads_single_write() {
let mut handles = Vec::with_capacity(n);

for _ in 0..n {
let reader = reader.clone();
let reader_barrier = barrier.clone();
let first_data = first_data.clone();
handles.push(std::thread::spawn(move || {
assert_eq!(
<FileReader as Reader<Vec<u8>>>::get(&reader, first_location).unwrap(),
first_data
);
assert_eq!(reader.get(first_location).unwrap().unwrap(), first_data);
reader_barrier.wait();
// readers wait for the writer to write the value.
reader_barrier.wait();
Expand All @@ -120,7 +115,7 @@ fn concurrent_reads_single_write() {
barrier.wait();

for handle in handles {
let res: Vec<u8> = handle.join().unwrap();
let res: Vec<u8> = handle.join().unwrap().unwrap();
assert_eq!(res, second_data);
}
}
Expand Down Expand Up @@ -196,15 +191,19 @@ async fn write_read_different_locations() {
let barrier = Arc::new(Barrier::new(n_readers_per_phase + 1));
let lock = Arc::new(RwLock::new(0));

async fn reader_task(reader: FileReader, lock: Arc<RwLock<usize>>, barrier: Arc<Barrier>) {
async fn reader_task(
reader: FileReader<Vec<u8>>,
lock: Arc<RwLock<usize>>,
barrier: Arc<Barrier>,
) {
barrier.wait().await;
let round: usize;
{
round = *lock.read().await;
}
let read_offset = 3 * rand::thread_rng().gen_range(0..round + 1);
let read_location = LocationInFile { offset: read_offset, len: LEN };
let read_value: Vec<u8> = reader.get(read_location).unwrap();
let read_value: Vec<u8> = reader.get(read_location).unwrap().unwrap();
let first_expected_value: u8 = (read_offset / 3 * 2).try_into().unwrap();
let expected_value = vec![first_expected_value, first_expected_value + 1];
assert_eq!(read_value, expected_value);
Expand All @@ -213,6 +212,7 @@ async fn write_read_different_locations() {
let mut handles = Vec::new();
for round in 0..ROUNDS {
for _ in 0..n_readers_per_phase {
let reader = reader.clone();
handles.push(tokio::spawn(reader_task(reader, lock.clone(), barrier.clone())));
}

Expand All @@ -230,3 +230,25 @@ async fn write_read_different_locations() {
handle.await.unwrap();
}
}

#[test]
fn reader_when_writer_is_out_of_scope() {
let dir = tempdir().unwrap();
let (mut writer, reader) = open_file(
get_test_config(),
dir.path().to_path_buf().join("test_reader_when_writer_is_out_of_scope"),
)
.unwrap();
let data: Vec<u8> = vec![1, 2, 3];
let offset = 0;

let len = writer.insert(offset, &data);
let res: Vec<u8> = reader.get(LocationInFile { offset, len }).unwrap().unwrap();
assert_eq!(res, data);

drop(writer);
let res: Vec<u8> = reader.get(LocationInFile { offset, len }).unwrap().unwrap();
assert_eq!(res, data);

dir.close().unwrap();
}
89 changes: 48 additions & 41 deletions crates/papyrus_storage/src/mmap_file/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::fs::{File, OpenOptions};
use std::marker::PhantomData;
use std::path::PathBuf;
use std::result;
use std::sync::{Arc, Mutex};

use memmap2::{MmapMut, MmapOptions};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -60,8 +61,9 @@ fn validate_config(config: &MmapFileConfig) -> result::Result<(), ValidationErro
#[derive(Debug, Error)]
pub enum MMapFileError {
#[error(transparent)]
/// IO error.
IO(#[from] std::io::Error),
#[error(transparent)]
TryFromInt(#[from] std::num::TryFromIntError),
}

/// A trait for writing to a memory mapped file.
Expand All @@ -73,7 +75,7 @@ pub trait Writer<V: StorageSerde> {
/// A trait for reading from a memory mapped file.
pub trait Reader<V: StorageSerde> {
/// Returns an object from the file.
fn get(&self, location: LocationInFile) -> Option<V>;
fn get(&self, location: LocationInFile) -> MmapFileResult<Option<V>>;
}

/// Represents a location in the file.
Expand All @@ -95,22 +97,25 @@ impl LocationInFile {

/// A wrapper around `MMapFile` that provides a write interface.
pub struct FileWriter<V: StorageSerde> {
mmap_file: MMapFile<V>,
memory_ptr: *const u8,
mmap_file: Arc<Mutex<MMapFile<V>>>,
}
impl<V: StorageSerde> FileWriter<V> {
/// Flushes the mmap to the file.
#[allow(dead_code)]
pub(crate) fn flush(&self) {
self.mmap_file.flush();
let mmap_file = self.mmap_file.lock().expect("Lock should not be poisoned");
mmap_file.flush();
}

fn grow_file_if_needed(&mut self, offset: usize) {
if self.mmap_file.size < offset + self.mmap_file.config.max_object_size {
let mut mmap_file = self.mmap_file.lock().expect("Lock should not be poisoned");
if mmap_file.size < offset + mmap_file.config.max_object_size {
debug!(
"Attempting to grow file. File size: {}, offset: {}, max_object_size: {}",
self.mmap_file.size, offset, self.mmap_file.config.max_object_size
mmap_file.size, offset, mmap_file.config.max_object_size
);
self.mmap_file.grow();
mmap_file.grow();
}
}
}
Expand All @@ -120,51 +125,63 @@ impl<V: StorageSerde + Debug> Writer<V> for FileWriter<V> {
fn insert(&mut self, offset: usize, val: &V) -> usize {
debug!("Inserting object at offset: {}", offset);
trace!("Inserting object: {:?}", val);
let mut mmap_slice = &mut self.mmap_file.mmap[offset..];
// TODO(dan): change serialize_into to return serialization size.
let _ = val.serialize_into(&mut mmap_slice);
let len = val.serialize().expect("Should be able to serialize").len();
self.mmap_file
.mmap
.flush_async_range(offset, len)
.expect("Failed to asynchronously flush the mmap after inserting");
{
let mut mmap_file = self.mmap_file.lock().expect("Lock should not be poisoned");
let mut mmap_slice = &mut mmap_file.mmap[offset..];
let _ = val.serialize_into(&mut mmap_slice);
mmap_file
.mmap
.flush_async_range(offset, len)
.expect("Failed to asynchronously flush the mmap after inserting");
}
self.grow_file_if_needed(offset + len);
len
}
}

impl<V: StorageSerde> Reader<V> for FileWriter<V> {
/// Returns an object from the file.
fn get(&self, location: LocationInFile) -> Option<V> {
self.mmap_file.get(location)
fn get(&self, location: LocationInFile) -> MmapFileResult<Option<V>> {
debug!("Reading object at location: {:?}", location);
let mut bytes = unsafe {
std::slice::from_raw_parts(
self.memory_ptr.offset(location.offset.try_into()?),
location.len,
)
};
trace!("Deserializing object: {:?}", bytes);
Ok(V::deserialize(&mut bytes))
}
}

/// A wrapper around `MMapFile` that provides a read interface.
#[derive(Clone, Copy, Debug)]
pub struct FileReader {
shared_data: *const u8,
#[derive(Clone, Debug)]
pub struct FileReader<V: StorageSerde> {
memory_ptr: *const u8,
_mmap_file: Arc<Mutex<MMapFile<V>>>,
}
unsafe impl Send for FileReader {}
unsafe impl Sync for FileReader {}
unsafe impl<V: StorageSerde> Send for FileReader<V> {}
unsafe impl<V: StorageSerde> Sync for FileReader<V> {}

impl<V: StorageSerde> Reader<V> for FileReader {
impl<V: StorageSerde> Reader<V> for FileReader<V> {
/// Returns an object from the file.
fn get(&self, location: LocationInFile) -> Option<V> {
fn get(&self, location: LocationInFile) -> MmapFileResult<Option<V>> {
debug!("Reading object at location: {:?}", location);
let mut bytes = unsafe {
std::slice::from_raw_parts(
self.shared_data
.offset(location.offset.try_into().expect("offset should fit in usize")),
self.memory_ptr.offset(location.offset.try_into()?),
location.len,
)
};
trace!("Deserializing object: {:?}", bytes);
V::deserialize(&mut bytes)
Ok(V::deserialize(&mut bytes))
}
}

/// Represents a memory mapped append only file.
#[derive(Debug)]
pub struct MMapFile<V: StorageSerde> {
config: MmapFileConfig,
file: File,
Expand All @@ -174,19 +191,7 @@ pub struct MMapFile<V: StorageSerde> {
}

impl<V: StorageSerde> MMapFile<V> {
/// Returns an object from the file.
fn get(&self, location: LocationInFile) -> Option<V> {
debug!("Reading object at location: {:?}", location);
let bytes: std::borrow::Cow<'_, [u8]> = self.get_raw(location);
trace!("Deserializing object: {:?}", bytes.as_ref());
V::deserialize(&mut bytes.as_ref())
}

/// Returns a COW pointer to a slice of the file.
fn get_raw(&self, location: LocationInFile) -> std::borrow::Cow<'_, [u8]> {
std::borrow::Cow::from(&self.mmap[location.offset..(location.offset + location.len)])
}

/// Grows the file by the growth step.
fn grow(&mut self) {
self.flush();
let new_size = self.size + self.config.growth_step;
Expand All @@ -207,22 +212,24 @@ impl<V: StorageSerde> MMapFile<V> {
pub(crate) fn open_file<V: StorageSerde>(
config: MmapFileConfig,
path: PathBuf,
) -> MmapFileResult<(FileWriter<V>, FileReader)> {
) -> MmapFileResult<(FileWriter<V>, FileReader<V>)> {
debug!("Opening file");
// TODO: move validation to caller.
config.validate().expect("Invalid config");
let file = OpenOptions::new().read(true).write(true).create(true).open(path)?;
let size = file.metadata()?.len();
let mmap = unsafe { MmapOptions::new().len(config.max_size).map_mut(&file)? };
let mmap_ptr = mmap.as_ptr();
let mmap_file = MMapFile {
config,
file,
mmap,
size: size.try_into().expect("size should fit in usize"),
_value_type: PhantomData {},
};
let reader = FileReader { shared_data: mmap_file.mmap.as_ptr() };
let mut writer = FileWriter { mmap_file };
let shared_mmap_file = Arc::new(Mutex::new(mmap_file));
let reader = FileReader { memory_ptr: mmap_ptr, _mmap_file: shared_mmap_file.clone() };
let mut writer = FileWriter { memory_ptr: mmap_ptr, mmap_file: shared_mmap_file };
writer.grow_file_if_needed(0);
Ok((writer, reader))
}

0 comments on commit 0a886a0

Please sign in to comment.