Skip to content

Commit

Permalink
Further IO improvements (#2477)
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedorowicz1 authored Oct 18, 2024
1 parent 1e5114c commit 04c297b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 59 deletions.
33 changes: 19 additions & 14 deletions applications/physics/cosmology/cosmoflow/cosmoflow_dataset.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
import numpy as np
from glob import glob
from lbann.util.data import Sample, SampleDims, Dataset, DistConvDataset
from lbann.util.data import Sample, SampleDims, DistConvDataset
import h5py as h5
import os


class CosmoFlowDataset(DistConvDataset):
def __init__(self, data_dir, input_width, num_secrets):
self.data_dir = data_dir
self.input_width = input_width
self.num_secrets = num_secrets
self.samples = glob(os.path.join(data_dir, '*.hdf5'))
self.samples = glob(os.path.join(data_dir, "*.hdf5"))
self.samples.sort()

def __len__(self):
return len(self.samples)

def __getitem__(self, index) -> Sample:
data = h5.File(self.samples[index], 'r')
data = h5.File(self.samples[index], "r")
slice_width = self.input_width // self.num_io_partitions
slice_ind = self.rank % self.num_io_partitions
full = data['full'][:,
slice_ind*slice_width:(slice_ind+1)*slice_width,
:self.input_width,
:self.input_width].astype(np.float32)
par = data['unitPar'][:].astype(np.float32)
return Sample(sample=np.ascontiguousarray(full), response=par)

full = data["full"][
:,
slice_ind * slice_width : (slice_ind + 1) * slice_width,
: self.input_width,
: self.input_width,
]
par = data["unitPar"][:].astype(np.float32)
return Sample(sample=full, response=par)

def get_sample_dims(self):
return SampleDims(sample=[4, self.input_width, self.input_width, self.input_width], response=self.num_secrets)
return SampleDims(
sample=[4, self.input_width, self.input_width, self.input_width],
response=self.num_secrets,
)
115 changes: 70 additions & 45 deletions python/lbann/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, List, Optional, Union
from numpy.typing import ArrayLike
import concurrent.futures as cf
from multiprocessing import resource_tracker


class Sample:
Expand Down Expand Up @@ -183,11 +184,14 @@ def __init__(
self.dataset = dataset
self.num_procs = num_procs
self.prefetch_factor = prefetch_factor
self.dtype = dtype
self.dtype = np.dtype(dtype)
self.sample_dims = dataset.get_sample_dims()
self.num_io_partitions = 1
self.loaded_samples = []
self.thread_pool = cf.ThreadPoolExecutor(max_workers=num_procs)
self.shms = {}
self.returned_shms = []
self.batch = None

if isinstance(self.dataset, DistConvDataset):
self.num_io_partitions = self.dataset.num_io_partitions
Expand All @@ -198,6 +202,19 @@ def __init__(
initargs=(self.dataset,),
)

self.shm_size = 0
if hasattr(self.sample_dims, "sample"):
self.sample_size = (
np.prod(self.sample_dims.sample) // self.num_io_partitions
)
self.shm_size += self.sample_size
if hasattr(self.sample_dims, "label"):
self.label_size = np.prod(self.sample_dims.sample)
self.shm_size += self.label_size
if hasattr(self.sample_dims, "response"):
self.response_size = self.sample_dims.response
self.shm_size += self.response_size

@staticmethod
def init_worker(dataset):
"""
Expand Down Expand Up @@ -225,8 +242,12 @@ def terminate(self) -> None:
"""
self.pool.terminate()

for shm in self.shms.values():
shm.close()
shm.unlink()

@staticmethod
def load_sample(ind) -> Sample:
def load_sample(ind, shm_name, shm_size, dtype) -> Sample:
"""
Loads the sample from the dataset at the specified index.
This function must be called from a worker process.
Expand All @@ -240,19 +261,10 @@ def load_sample(ind) -> Sample:
"""
samp = g_dataset[ind]

shm_size = 0
dtype = None
if hasattr(samp, "sample"):
dtype = samp.sample.dtype
shm_size += samp.sample.size
if hasattr(samp, "label"):
dtype = samp.label.dtype
shm_size += samp.label.size
if hasattr(samp, "response"):
dtype = samp.response.dtype
shm_size += samp.response.size

shm = SharedMemory(create=True, size=shm_size * dtype.itemsize)
shm = SharedMemory(name=shm_name)
resource_tracker.unregister(
shm._name, "shared_memory"
) # Prevent the resource tracker from interfering during process pool shutdown
shm_arr = np.ndarray(shm_size, dtype=dtype, buffer=shm.buf)

offset = 0
Expand All @@ -270,14 +282,23 @@ def load_sample(ind) -> Sample:
offset = new_offset

shm.close()
return shm.name, shm_size
return shm.name

def load_next_sample_async(self, ind: int):
"""
Submit the next sample index to be loaded to the worker pool.
"""
if not self.returned_shms:
shm = SharedMemory(create=True, size=self.shm_size * self.dtype.itemsize)
shm_name = shm.name
self.shms[shm_name] = shm
else:
shm_name = self.returned_shms.pop()

self.loaded_samples.append(
self.pool.apply_async(DataReader.load_sample, (ind,))
self.pool.apply_async(
DataReader.load_sample, (ind, shm_name, self.shm_size, self.dtype)
)
)

def queue_samples(self, inds: List[int]) -> None:
Expand All @@ -301,44 +322,45 @@ def get_batch(self, batch_size: int) -> Dict[str, Union[np.ndarray, int]]:
:rtype: Dict[str, Union[np.ndarray, int]]
"""

batch = {}
if hasattr(self.sample_dims, "sample"):
sample_size = np.prod(self.sample_dims.sample) // self.num_io_partitions
batch["sample"] = np.empty([batch_size, sample_size], dtype=self.dtype)
batch["sample_ptr"] = batch["sample"].ctypes.data
if hasattr(self.sample_dims, "label"):
label_size = np.prod(self.sample_dims.sample)
batch["label"] = np.empty([batch_size, label_size], dtype=self.dtype)
batch["label_ptr"] = batch["label"].ctypes.data
if hasattr(self.sample_dims, "response"):
response_size = self.sample_dims.response
batch["response"] = np.empty([batch_size, response_size], dtype=self.dtype)
batch["response_ptr"] = batch["response"].ctypes.data
if self.batch is None:
batch = {}
if hasattr(self.sample_dims, "sample"):
batch["sample"] = np.empty(
[batch_size, self.sample_size], dtype=self.dtype
)
batch["sample_ptr"] = batch["sample"].ctypes.data
if hasattr(self.sample_dims, "label"):
batch["label"] = np.empty(
[batch_size, self.label_size], dtype=self.dtype
)
batch["label_ptr"] = batch["label"].ctypes.data
if hasattr(self.sample_dims, "response"):
batch["response"] = np.empty(
[batch_size, self.response_size], dtype=self.dtype
)
batch["response_ptr"] = batch["response"].ctypes.data
self.batch = batch

def copy_to_array(i, sample):
shm_name, shm_size = sample.get()

shm = SharedMemory(name=shm_name)
shm_arr = np.ndarray(shm_size, dtype=self.dtype, buffer=shm.buf)
shm_name = sample.get()
shm = self.shms[shm_name]
shm_arr = np.ndarray(self.shm_size, dtype=self.dtype, buffer=shm.buf)

offset = 0
if hasattr(self.sample_dims, "sample"):
new_offset = offset + sample_size
batch["sample"][i, :] = shm_arr[offset:new_offset]
new_offset = offset + self.sample_size
self.batch["sample"][i, :] = shm_arr[offset:new_offset]
offset = new_offset
if hasattr(self.sample_dims, "label"):
new_offset = offset + label_size
batch["label"][i, :] = shm_arr[offset:new_offset]
new_offset = offset + self.label_size
self.batch["label"][i, :] = shm_arr[offset:new_offset]
offset = new_offset
if hasattr(self.sample_dims, "response"):
new_offset = offset + response_size
batch["response"][i, :] = shm_arr[offset:new_offset]
new_offset = offset + self.response_size
self.batch["response"][i, :] = shm_arr[offset:new_offset]
offset = new_offset

del shm_arr

shm.close()
shm.unlink()
self.returned_shms.append(shm_name)

futures = []
for i in range(batch_size):
Expand All @@ -347,8 +369,11 @@ def copy_to_array(i, sample):
)

cf.wait(futures)
# Check for any exceptions
for f in futures:
f.result()

return batch
return self.batch


def construct_python_dataset_reader(
Expand Down

0 comments on commit 04c297b

Please sign in to comment.