From 04c297bddf76144385be8ce39cab9394b47401a3 Mon Sep 17 00:00:00 2001 From: Pier Fiedorowicz <117680821+fiedorowicz1@users.noreply.github.com> Date: Thu, 17 Oct 2024 17:09:45 -0700 Subject: [PATCH] Further IO improvements (#2477) --- .../cosmology/cosmoflow/cosmoflow_dataset.py | 33 ++--- python/lbann/util/data.py | 115 +++++++++++------- 2 files changed, 89 insertions(+), 59 deletions(-) diff --git a/applications/physics/cosmology/cosmoflow/cosmoflow_dataset.py b/applications/physics/cosmology/cosmoflow/cosmoflow_dataset.py index 5c909465826..80342ab5007 100644 --- a/applications/physics/cosmology/cosmoflow/cosmoflow_dataset.py +++ b/applications/physics/cosmology/cosmoflow/cosmoflow_dataset.py @@ -1,31 +1,36 @@ import numpy as np from glob import glob -from lbann.util.data import Sample, SampleDims, Dataset, DistConvDataset +from lbann.util.data import Sample, SampleDims, DistConvDataset import h5py as h5 import os - + class CosmoFlowDataset(DistConvDataset): def __init__(self, data_dir, input_width, num_secrets): self.data_dir = data_dir self.input_width = input_width self.num_secrets = num_secrets - self.samples = glob(os.path.join(data_dir, '*.hdf5')) + self.samples = glob(os.path.join(data_dir, "*.hdf5")) self.samples.sort() - + def __len__(self): return len(self.samples) - + def __getitem__(self, index) -> Sample: - data = h5.File(self.samples[index], 'r') + data = h5.File(self.samples[index], "r") slice_width = self.input_width // self.num_io_partitions slice_ind = self.rank % self.num_io_partitions - full = data['full'][:, - slice_ind*slice_width:(slice_ind+1)*slice_width, - :self.input_width, - :self.input_width].astype(np.float32) - par = data['unitPar'][:].astype(np.float32) - return Sample(sample=np.ascontiguousarray(full), response=par) - + full = data["full"][ + :, + slice_ind * slice_width : (slice_ind + 1) * slice_width, + : self.input_width, + : self.input_width, + ] + par = data["unitPar"][:].astype(np.float32) + return Sample(sample=full, response=par) + def get_sample_dims(self): - return SampleDims(sample=[4, self.input_width, self.input_width, self.input_width], response=self.num_secrets) + return SampleDims( + sample=[4, self.input_width, self.input_width, self.input_width], + response=self.num_secrets, + ) diff --git a/python/lbann/util/data.py b/python/lbann/util/data.py index 7c6ef7030d3..9a9b43f607d 100644 --- a/python/lbann/util/data.py +++ b/python/lbann/util/data.py @@ -9,6 +9,7 @@ from typing import Dict, List, Optional, Union from numpy.typing import ArrayLike import concurrent.futures as cf +from multiprocessing import resource_tracker class Sample: @@ -183,11 +184,14 @@ def __init__( self.dataset = dataset self.num_procs = num_procs self.prefetch_factor = prefetch_factor - self.dtype = dtype + self.dtype = np.dtype(dtype) self.sample_dims = dataset.get_sample_dims() self.num_io_partitions = 1 self.loaded_samples = [] self.thread_pool = cf.ThreadPoolExecutor(max_workers=num_procs) + self.shms = {} + self.returned_shms = [] + self.batch = None if isinstance(self.dataset, DistConvDataset): self.num_io_partitions = self.dataset.num_io_partitions @@ -198,6 +202,19 @@ def __init__( initargs=(self.dataset,), ) + self.shm_size = 0 + if hasattr(self.sample_dims, "sample"): + self.sample_size = ( + np.prod(self.sample_dims.sample) // self.num_io_partitions + ) + self.shm_size += self.sample_size + if hasattr(self.sample_dims, "label"): + self.label_size = np.prod(self.sample_dims.sample) + self.shm_size += self.label_size + if hasattr(self.sample_dims, "response"): + self.response_size = self.sample_dims.response + self.shm_size += self.response_size + @staticmethod def init_worker(dataset): """ @@ -225,8 +242,12 @@ def terminate(self) -> None: """ self.pool.terminate() + for shm in self.shms.values(): + shm.close() + shm.unlink() + @staticmethod - def load_sample(ind) -> Sample: + def load_sample(ind, shm_name, shm_size, dtype) -> Sample: """ Loads the sample from the dataset at the specified index. This function must be called from a worker process. @@ -240,19 +261,10 @@ def load_sample(ind) -> Sample: """ samp = g_dataset[ind] - shm_size = 0 - dtype = None - if hasattr(samp, "sample"): - dtype = samp.sample.dtype - shm_size += samp.sample.size - if hasattr(samp, "label"): - dtype = samp.label.dtype - shm_size += samp.label.size - if hasattr(samp, "response"): - dtype = samp.response.dtype - shm_size += samp.response.size - - shm = SharedMemory(create=True, size=shm_size * dtype.itemsize) + shm = SharedMemory(name=shm_name) + resource_tracker.unregister( + shm._name, "shared_memory" + ) # Prevent the resource tracker from interfering during process pool shutdown shm_arr = np.ndarray(shm_size, dtype=dtype, buffer=shm.buf) offset = 0 @@ -270,14 +282,23 @@ def load_sample(ind) -> Sample: offset = new_offset shm.close() - return shm.name, shm_size + return shm.name def load_next_sample_async(self, ind: int): """ Submit the next sample index to be loaded to the worker pool. """ + if not self.returned_shms: + shm = SharedMemory(create=True, size=self.shm_size * self.dtype.itemsize) + shm_name = shm.name + self.shms[shm_name] = shm + else: + shm_name = self.returned_shms.pop() + self.loaded_samples.append( - self.pool.apply_async(DataReader.load_sample, (ind,)) + self.pool.apply_async( + DataReader.load_sample, (ind, shm_name, self.shm_size, self.dtype) + ) ) def queue_samples(self, inds: List[int]) -> None: @@ -301,44 +322,45 @@ def get_batch(self, batch_size: int) -> Dict[str, Union[np.ndarray, int]]: :rtype: Dict[str, Union[np.ndarray, int]] """ - batch = {} - if hasattr(self.sample_dims, "sample"): - sample_size = np.prod(self.sample_dims.sample) // self.num_io_partitions - batch["sample"] = np.empty([batch_size, sample_size], dtype=self.dtype) - batch["sample_ptr"] = batch["sample"].ctypes.data - if hasattr(self.sample_dims, "label"): - label_size = np.prod(self.sample_dims.sample) - batch["label"] = np.empty([batch_size, label_size], dtype=self.dtype) - batch["label_ptr"] = batch["label"].ctypes.data - if hasattr(self.sample_dims, "response"): - response_size = self.sample_dims.response - batch["response"] = np.empty([batch_size, response_size], dtype=self.dtype) - batch["response_ptr"] = batch["response"].ctypes.data + if self.batch is None: + batch = {} + if hasattr(self.sample_dims, "sample"): + batch["sample"] = np.empty( + [batch_size, self.sample_size], dtype=self.dtype + ) + batch["sample_ptr"] = batch["sample"].ctypes.data + if hasattr(self.sample_dims, "label"): + batch["label"] = np.empty( + [batch_size, self.label_size], dtype=self.dtype + ) + batch["label_ptr"] = batch["label"].ctypes.data + if hasattr(self.sample_dims, "response"): + batch["response"] = np.empty( + [batch_size, self.response_size], dtype=self.dtype + ) + batch["response_ptr"] = batch["response"].ctypes.data + self.batch = batch def copy_to_array(i, sample): - shm_name, shm_size = sample.get() - - shm = SharedMemory(name=shm_name) - shm_arr = np.ndarray(shm_size, dtype=self.dtype, buffer=shm.buf) + shm_name = sample.get() + shm = self.shms[shm_name] + shm_arr = np.ndarray(self.shm_size, dtype=self.dtype, buffer=shm.buf) offset = 0 if hasattr(self.sample_dims, "sample"): - new_offset = offset + sample_size - batch["sample"][i, :] = shm_arr[offset:new_offset] + new_offset = offset + self.sample_size + self.batch["sample"][i, :] = shm_arr[offset:new_offset] offset = new_offset if hasattr(self.sample_dims, "label"): - new_offset = offset + label_size - batch["label"][i, :] = shm_arr[offset:new_offset] + new_offset = offset + self.label_size + self.batch["label"][i, :] = shm_arr[offset:new_offset] offset = new_offset if hasattr(self.sample_dims, "response"): - new_offset = offset + response_size - batch["response"][i, :] = shm_arr[offset:new_offset] + new_offset = offset + self.response_size + self.batch["response"][i, :] = shm_arr[offset:new_offset] offset = new_offset - del shm_arr - - shm.close() - shm.unlink() + self.returned_shms.append(shm_name) futures = [] for i in range(batch_size): @@ -347,8 +369,11 @@ def copy_to_array(i, sample): ) cf.wait(futures) + # Check for any exceptions + for f in futures: + f.result() - return batch + return self.batch def construct_python_dataset_reader(