From 01f40b420d624a38bdfca28dee5d35abdef279c4 Mon Sep 17 00:00:00 2001 From: Justin Kerr Date: Thu, 16 Jan 2025 13:30:00 -0800 Subject: [PATCH] remove stuff from __post_init__, tune num workers more, add random offset in raybatchstream --- nerfstudio/configs/method_configs.py | 4 --- .../data/datamanagers/base_datamanager.py | 34 ++++--------------- .../datamanagers/full_images_datamanager.py | 24 ++++--------- .../data/datamanagers/parallel_datamanager.py | 4 +++ nerfstudio/data/utils/dataloaders.py | 8 +++-- 5 files changed, 24 insertions(+), 50 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index eb040f0d07..52afc1c21a 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -97,7 +97,6 @@ dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=4096, eval_num_rays_per_batch=4096, - use_parallel_dataloader=True, ), model=NerfactoModelConfig( eval_num_rays_per_chunk=1 << 15, @@ -135,7 +134,6 @@ dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=8192, eval_num_rays_per_batch=4096, - use_parallel_dataloader=True, ), model=NerfactoModelConfig( eval_num_rays_per_chunk=1 << 15, @@ -181,7 +179,6 @@ dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=16384, eval_num_rays_per_batch=4096, - use_parallel_dataloader=True, ), model=NerfactoModelConfig( eval_num_rays_per_chunk=1 << 15, @@ -231,7 +228,6 @@ dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=4096, eval_num_rays_per_batch=4096, - use_parallel_dataloader=True, ), model=DepthNerfactoModelConfig( eval_num_rays_per_chunk=1 << 15, diff --git a/nerfstudio/data/datamanagers/base_datamanager.py b/nerfstudio/data/datamanagers/base_datamanager.py index a50c142147..7b6fe80c2d 100644 --- a/nerfstudio/data/datamanagers/base_datamanager.py +++ b/nerfstudio/data/datamanagers/base_datamanager.py @@ -289,11 +289,13 @@ class VanillaDataManagerConfig(DataManagerConfig): """The image type returned from manager, caching images in uint8 saves memory""" train_num_rays_per_batch: int = 1024 """Number of rays per batch to use per training iteration.""" - train_num_images_to_sample_from: int = -1 + train_num_images_to_sample_from: int = 50 """Number of images to sample during training iteration.""" - train_num_times_to_repeat_images: int = -1 + train_num_times_to_repeat_images: int = 10 """When not training on all images, number of iterations before picking new - images. If -1, never pick new images.""" + images. If -1, never pick new images. + Note: decreasing train_num_images_to_sample_from and increasing train_num_times_to_repeat_images alleviates CPU bottleneck. + """ eval_num_rays_per_batch: int = 1024 """Number of rays per batch to use per eval iteration.""" eval_num_images_to_sample_from: int = -1 @@ -310,15 +312,13 @@ class VanillaDataManagerConfig(DataManagerConfig): along with relevant information about camera intrinsics""" patch_size: int = 1 """Size of patch to sample from. If > 1, patch-based sampling will be used.""" - use_parallel_dataloader: bool = False - """Allows parallelization of the dataloading process with multiple workers prefetching RayBundles.""" load_from_disk: bool = False """If True, conserves RAM memory by loading images from disk. If False, caches all the images as tensors to RAM and loads from RAM.""" - dataloader_num_workers: int = 0 + dataloader_num_workers: int = 4 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - prefetch_factor: int | None = None + prefetch_factor: int = 10 """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" cache_compressed_images: bool = False @@ -340,26 +340,6 @@ def __post_init__(self): ) warnings.warn("above message coming from", FutureWarning, stacklevel=3) - """ - These heuristics allow the CPU dataloading bottleneck to equal the GPU bottleneck when training, but can be adjusted - Note: decreasing train_num_images_to_sample_from and increasing train_num_times_to_repeat_images alleviates CPU bottleneck. - """ - if self.load_from_disk: - self.train_num_images_to_sample_from = ( - 50 if self.train_num_images_to_sample_from == -1 else self.train_num_images_to_sample_from - ) - self.train_num_times_to_repeat_images = ( - 10 if self.train_num_times_to_repeat_images == -1 else self.train_num_times_to_repeat_images - ) - self.prefetch_factor = self.train_num_times_to_repeat_images if self.use_parallel_dataloader else None - - if self.use_parallel_dataloader: - try: - torch.multiprocessing.set_start_method("spawn") - except RuntimeError: - assert torch.multiprocessing.get_start_method() == "spawn", 'start method must be "spawn"' - self.dataloader_num_workers = 4 if self.dataloader_num_workers == 0 else self.dataloader_num_workers - TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset) diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index d126c03ec4..2b2436ba80 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -81,30 +81,15 @@ class FullImageDatamanagerConfig(DataManagerConfig): fps_reset_every: int = 100 """The number of iterations before one resets fps sampler repeatly, which is essentially drawing fps_reset_every samples from the pool of all training cameras without replacement before a new round of sampling starts.""" - dataloader_num_workers: int = 0 + dataloader_num_workers: int = 4 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - prefetch_factor: int = 0 + prefetch_factor: int = 4 """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" cache_compressed_images: bool = False """If True, cache raw image files as byte strings to RAM.""" - def __post_init__(self): - if self.cache_images == "disk": - # If a user would like to load from disk, we pre-emptively set the number of - # workers and prefetch factor to parallelize the dataloading process. - try: - torch.multiprocessing.set_start_method("spawn") - except RuntimeError: - assert torch.multiprocessing.get_start_method() == "spawn", 'start method must be "spawn"' - if self.prefetch_factor == 0: - CONSOLE.log('cache_images set to "disk" with no prefetch factor, defaulting to 4') - self.prefetch_factor = 4 - if self.dataloader_num_workers == 0: - CONSOLE.log('cache_images set to "disk" with 0 dataloader workers, defaulting to 4') - self.dataloader_num_workers = 4 - class FullImageDatamanager(DataManager, Generic[TDataset]): """ @@ -126,6 +111,11 @@ def __init__( local_rank: int = 0, **kwargs, ): + if config.cache_images == "disk": + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + assert torch.multiprocessing.get_start_method() == "spawn", 'start method must be "spawn"' self.config = config self.device = device self.world_size = world_size diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index abd0e50fc9..d1d2da788d 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -66,6 +66,10 @@ def __init__( local_rank: int = 0, **kwargs, ): + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + assert torch.multiprocessing.get_start_method() == "spawn", 'start method must be "spawn"' self.config = config self.device = device self.world_size = world_size diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 16ecc121f0..48e32d1633 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -498,7 +498,6 @@ def _get_batch_list(self, indices=None): for idx in indices: res = executor.submit(self.input_dataset.__getitem__, idx) results.append(res) - # results = tqdm(results) # this is temporary and will be removed in the final push for res in results: batch_list.append(res.result()) @@ -545,11 +544,16 @@ def __iter__(self): self.ray_generator = RayGenerator(self.input_dataset.cameras) i = 0 + true_random = random.Random(worker_info.id) if worker_info is not None else r + # We offset the value of repeat so that they're not all running out of images at once + repeat_offset_max = 10 if worker_info is not None else 1 + repeat_offset = true_random.randint(0, repeat_offset_max) while True: if not self.load_from_disk: collated_batch = self._cached_collated_batch - elif i % self.num_times_to_repeat_images == 0: + elif i % (self.num_times_to_repeat_images + repeat_offset) == 0: r.shuffle(worker_indices) + repeat_offset = true_random.randint(0, repeat_offset_max) if self.num_images_to_sample_from == -1: # if -1, the worker gets all available indices in its partition image_indices = worker_indices