Skip to content

Commit

Permalink
remove stuff from __post_init__, tune num workers more, add random of…
Browse files Browse the repository at this point in the history
…fset in raybatchstream
  • Loading branch information
kerrj committed Jan 16, 2025
1 parent 8d42154 commit 01f40b4
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 50 deletions.
4 changes: 0 additions & 4 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=4096,
eval_num_rays_per_batch=4096,
use_parallel_dataloader=True,
),
model=NerfactoModelConfig(
eval_num_rays_per_chunk=1 << 15,
Expand Down Expand Up @@ -135,7 +134,6 @@
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=8192,
eval_num_rays_per_batch=4096,
use_parallel_dataloader=True,
),
model=NerfactoModelConfig(
eval_num_rays_per_chunk=1 << 15,
Expand Down Expand Up @@ -181,7 +179,6 @@
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=16384,
eval_num_rays_per_batch=4096,
use_parallel_dataloader=True,
),
model=NerfactoModelConfig(
eval_num_rays_per_chunk=1 << 15,
Expand Down Expand Up @@ -231,7 +228,6 @@
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=4096,
eval_num_rays_per_batch=4096,
use_parallel_dataloader=True,
),
model=DepthNerfactoModelConfig(
eval_num_rays_per_chunk=1 << 15,
Expand Down
34 changes: 7 additions & 27 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,13 @@ class VanillaDataManagerConfig(DataManagerConfig):
"""The image type returned from manager, caching images in uint8 saves memory"""
train_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per training iteration."""
train_num_images_to_sample_from: int = -1
train_num_images_to_sample_from: int = 50
"""Number of images to sample during training iteration."""
train_num_times_to_repeat_images: int = -1
train_num_times_to_repeat_images: int = 10
"""When not training on all images, number of iterations before picking new
images. If -1, never pick new images."""
images. If -1, never pick new images.
Note: decreasing train_num_images_to_sample_from and increasing train_num_times_to_repeat_images alleviates CPU bottleneck.
"""
eval_num_rays_per_batch: int = 1024
"""Number of rays per batch to use per eval iteration."""
eval_num_images_to_sample_from: int = -1
Expand All @@ -310,15 +312,13 @@ class VanillaDataManagerConfig(DataManagerConfig):
along with relevant information about camera intrinsics"""
patch_size: int = 1
"""Size of patch to sample from. If > 1, patch-based sampling will be used."""
use_parallel_dataloader: bool = False
"""Allows parallelization of the dataloading process with multiple workers prefetching RayBundles."""
load_from_disk: bool = False
"""If True, conserves RAM memory by loading images from disk.
If False, caches all the images as tensors to RAM and loads from RAM."""
dataloader_num_workers: int = 0
dataloader_num_workers: int = 4
"""The number of workers performing the dataloading from either disk/RAM, which
includes collating, pixel sampling, unprojecting, ray generation etc."""
prefetch_factor: int | None = None
prefetch_factor: int = 10
"""The limit number of batches a worker will start loading once an iterator is created.
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
cache_compressed_images: bool = False
Expand All @@ -340,26 +340,6 @@ def __post_init__(self):
)
warnings.warn("above message coming from", FutureWarning, stacklevel=3)

"""
These heuristics allow the CPU dataloading bottleneck to equal the GPU bottleneck when training, but can be adjusted
Note: decreasing train_num_images_to_sample_from and increasing train_num_times_to_repeat_images alleviates CPU bottleneck.
"""
if self.load_from_disk:
self.train_num_images_to_sample_from = (
50 if self.train_num_images_to_sample_from == -1 else self.train_num_images_to_sample_from
)
self.train_num_times_to_repeat_images = (
10 if self.train_num_times_to_repeat_images == -1 else self.train_num_times_to_repeat_images
)
self.prefetch_factor = self.train_num_times_to_repeat_images if self.use_parallel_dataloader else None

if self.use_parallel_dataloader:
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
assert torch.multiprocessing.get_start_method() == "spawn", 'start method must be "spawn"'
self.dataloader_num_workers = 4 if self.dataloader_num_workers == 0 else self.dataloader_num_workers


TDataset = TypeVar("TDataset", bound=InputDataset, default=InputDataset)

Expand Down
24 changes: 7 additions & 17 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,30 +81,15 @@ class FullImageDatamanagerConfig(DataManagerConfig):
fps_reset_every: int = 100
"""The number of iterations before one resets fps sampler repeatly, which is essentially drawing fps_reset_every
samples from the pool of all training cameras without replacement before a new round of sampling starts."""
dataloader_num_workers: int = 0
dataloader_num_workers: int = 4
"""The number of workers performing the dataloading from either disk/RAM, which
includes collating, pixel sampling, unprojecting, ray generation etc."""
prefetch_factor: int = 0
prefetch_factor: int = 4
"""The limit number of batches a worker will start loading once an iterator is created.
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
cache_compressed_images: bool = False
"""If True, cache raw image files as byte strings to RAM."""

def __post_init__(self):
if self.cache_images == "disk":
# If a user would like to load from disk, we pre-emptively set the number of
# workers and prefetch factor to parallelize the dataloading process.
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
assert torch.multiprocessing.get_start_method() == "spawn", 'start method must be "spawn"'
if self.prefetch_factor == 0:
CONSOLE.log('cache_images set to "disk" with no prefetch factor, defaulting to 4')
self.prefetch_factor = 4
if self.dataloader_num_workers == 0:
CONSOLE.log('cache_images set to "disk" with 0 dataloader workers, defaulting to 4')
self.dataloader_num_workers = 4


class FullImageDatamanager(DataManager, Generic[TDataset]):
"""
Expand All @@ -126,6 +111,11 @@ def __init__(
local_rank: int = 0,
**kwargs,
):
if config.cache_images == "disk":
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
assert torch.multiprocessing.get_start_method() == "spawn", 'start method must be "spawn"'
self.config = config
self.device = device
self.world_size = world_size
Expand Down
4 changes: 4 additions & 0 deletions nerfstudio/data/datamanagers/parallel_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def __init__(
local_rank: int = 0,
**kwargs,
):
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
assert torch.multiprocessing.get_start_method() == "spawn", 'start method must be "spawn"'
self.config = config
self.device = device
self.world_size = world_size
Expand Down
8 changes: 6 additions & 2 deletions nerfstudio/data/utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,6 @@ def _get_batch_list(self, indices=None):
for idx in indices:
res = executor.submit(self.input_dataset.__getitem__, idx)
results.append(res)
# results = tqdm(results) # this is temporary and will be removed in the final push
for res in results:
batch_list.append(res.result())

Expand Down Expand Up @@ -545,11 +544,16 @@ def __iter__(self):
self.ray_generator = RayGenerator(self.input_dataset.cameras)

i = 0
true_random = random.Random(worker_info.id) if worker_info is not None else r
# We offset the value of repeat so that they're not all running out of images at once
repeat_offset_max = 10 if worker_info is not None else 1
repeat_offset = true_random.randint(0, repeat_offset_max)
while True:
if not self.load_from_disk:
collated_batch = self._cached_collated_batch
elif i % self.num_times_to_repeat_images == 0:
elif i % (self.num_times_to_repeat_images + repeat_offset) == 0:
r.shuffle(worker_indices)
repeat_offset = true_random.randint(0, repeat_offset_max)
if self.num_images_to_sample_from == -1:
# if -1, the worker gets all available indices in its partition
image_indices = worker_indices
Expand Down

0 comments on commit 01f40b4

Please sign in to comment.