Skip to content

Commit

Permalink
Introduce FPSampler for train cameras
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored and Jianbo Ye committed May 28, 2024
1 parent 12f2e68 commit 6672de7
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
48 changes: 45 additions & 3 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import Dict, ForwardRef, Generic, List, Literal, Optional, Tuple, Type, Union, cast, get_args, get_origin

import cv2
import fpsample
import numpy as np
import torch
from rich.progress import track
Expand Down Expand Up @@ -68,6 +69,15 @@ class FullImageDatamanagerConfig(DataManagerConfig):
"""The image type returned from manager, caching images in uint8 saves memory"""
max_thread_workers: Optional[int] = None
"""The maximum number of threads to use for caching images. If None, uses all available threads."""
train_cameras_sampling_strategy: Literal["random", "fps"] = "random"
"""Specifies which sampling strategy is used to generate train cameras, 'random' means sampling
uniformly random without replacement, 'fps' means farthest point sampling which is helpful to reduce the artifacts
due to oversampling subsets of cameras that are very closer to each other."""
train_cameras_sampling_seed: int = 42
"""Random seed for sampling train cameras. Fixing seed may help reduce variance of trained models across
different runs."""
fps_reset_every: int = 100
"""The number of iterations before one resets fps sampler repeatly"""


class FullImageDatamanager(DataManager, Generic[TDataset]):
Expand Down Expand Up @@ -123,12 +133,44 @@ def __init__(
self.exclude_batch_keys_from_device.remove("image")

# Some logic to make sure we sample every camera in equal amounts
self.train_unseen_cameras = [i for i in range(len(self.train_dataset))]
self.train_unseen_cameras = self.sample_train_cameras()
self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))]
assert len(self.train_unseen_cameras) > 0, "No data found in dataset"

super().__init__()

def sample_train_cameras(self):
"""Return a list of camera indices sampled using the strategy specified by
self.config.train_cameras_sampling_strategy"""
num_train_cameras = len(self.train_dataset)
if self.config.train_cameras_sampling_strategy == "random":
if not hasattr(self, "random_generator"):
self.random_generator = random.Random(self.config.train_cameras_sampling_seed)
indices = list(range(num_train_cameras))
self.random_generator.shuffle(indices)
return indices
elif self.config.train_cameras_sampling_strategy == "fps":
if not hasattr(self, "train_unsampled_epoch_count"):
np.random.seed(self.config.train_cameras_sampling_seed) # fix random seed of fpsample
self.train_unsampled_epoch_count = np.zeros(num_train_cameras)
camera_origins = self.train_dataset.cameras.camera_to_worlds[..., 3].numpy()
# We concatenate camera origins with weighted train_unsampled_epoch_count because we want to
# increase the chance to sample camera that hasn't been sampled in consecutive epochs previously.
# We assume the camera origins are also rescaled, so the weight 0.1 is relative to the scale of scene
data = np.concatenate(
(camera_origins, 0.1 * np.expand_dims(self.train_unsampled_epoch_count, axis=-1)), axis=-1
)
n = self.config.fps_reset_every
if num_train_cameras < n:
n = num_train_cameras
kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(data, n, h=3)

self.train_unsampled_epoch_count += 1
self.train_unsampled_epoch_count[kdline_fps_samples_idx] = 0
return kdline_fps_samples_idx.tolist()
else:
raise ValueError(f"Unknown train camera sampling strategy: {self.config.train_cameras_sampling_strategy}")

@cached_property
def cached_train(self) -> List[Dict[str, torch.Tensor]]:
"""Get the training images. Will load and undistort the images the
Expand Down Expand Up @@ -288,10 +330,10 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]:
"""Returns the next training batch
Returns a Camera instead of raybundle"""
image_idx = self.train_unseen_cameras.pop(random.randint(0, len(self.train_unseen_cameras) - 1))
image_idx = self.train_unseen_cameras.pop(0)
# Make sure to re-populate the unseen cameras list if we have exhausted it
if len(self.train_unseen_cameras) == 0:
self.train_unseen_cameras = [i for i in range(len(self.train_dataset))]
self.train_unseen_cameras = self.sample_train_cameras()

data = deepcopy(self.cached_train[image_idx])
data["image"] = data["image"].to(self.device)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ dependencies = [
"gsplat>=0.1.11",
"pytorch-msssim",
"pathos",
"packaging"
"packaging",
"fpsample"
]

[project.urls]
Expand Down

0 comments on commit 6672de7

Please sign in to comment.