Skip to content

Commit

Permalink
fix the bugs associated with blender datas (nerfstudio-project#2704)
Browse files Browse the repository at this point in the history
* fix the bug when camera.distortion_params is None

* Handle background color override when using blender.

* fix bare except

* format

* Update background color override in Blender dataparser

* Add ability to download EyefulTower dataset

* wip before I copy linning's stuff in

* Generate per-resolution cameras.xml

* Generate transforms.json at download

* Fix a couple of quotes

* Use official EyefulTower splits for train and val

* Disable projectaria-tools on windows

* Fix extra imports

* Add a new nerfacto method tund for EyefulTower

* Split eyefultower download into a separate file

* Fix typo

* Add some fisheye support for eyeful data

* Reformatted imports to not be dumb

* Apparently this file was missed when formatting originally

* Added 1k resolution scenes

* revert method_configs.py to original values

* Also add 1k exrs

* Add option to modify bg color in gaussian splatting

* fix back the config, bg color should work now

* removed camera optimizer for gs to align with main

* Address feedback

* Revert changes to pyproject.toml, to be added in a later PR

* Oops, probably shouldn't have gotten rid of awscli ...

* adding support for bg color, tested and reformatted now

* formatted

* formatted

* changed glob variable name

* Refactor background color variable name

* prevent viser overriding

---------

Co-authored-by: Vasu Agrawal <[email protected]>
Co-authored-by: Ethan Weber <[email protected]>
Co-authored-by: Brent Yi <[email protected]>
Co-authored-by: Justin Kerr <[email protected]>
  • Loading branch information
5 people authored Jan 20, 2024
1 parent ff4002d commit c83ed7d
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 28 deletions.
7 changes: 6 additions & 1 deletion nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def __init__(
self.train_dataset = self.create_train_dataset()
self.eval_dataset = self.create_eval_dataset()
if len(self.train_dataset) > 500 and self.config.cache_images == "gpu":
CONSOLE.print("Train dataset has over 500 images, overriding cach_images to cpu", style="bold yellow")
CONSOLE.print(
"Train dataset has over 500 images, overriding cache_images to cpu",
style="bold yellow",
)
self.config.cache_images = "cpu"
self.cached_train, self.cached_eval = self.cache_images(self.config.cache_images)
self.exclude_batch_keys_from_device = self.train_dataset.exclude_batch_keys_from_device
Expand All @@ -133,6 +136,7 @@ def cache_images(self, cache_images_option):
camera = self.train_dataset.cameras[i].reshape(())
K = camera.get_intrinsics_matrices().numpy()
if camera.distortion_params is None:
cached_train.append(data)
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()
Expand All @@ -158,6 +162,7 @@ def cache_images(self, cache_images_option):
camera = self.eval_dataset.cameras[i].reshape(())
K = camera.get_intrinsics_matrices().numpy()
if camera.distortion_params is None:
cached_eval.append(data)
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()
Expand Down
9 changes: 4 additions & 5 deletions nerfstudio/data/dataparsers/blender_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,12 @@ def __init__(self, config: BlenderDataParserConfig):
self.data: Path = config.data
self.scale_factor: float = config.scale_factor
self.alpha_color = config.alpha_color

def _generate_dataparser_outputs(self, split="train"):
if self.alpha_color is not None:
alpha_color_tensor = get_color(self.alpha_color)
self.alpha_color_tensor = get_color(self.alpha_color)
else:
alpha_color_tensor = None
self.alpha_color_tensor = None

def _generate_dataparser_outputs(self, split="train"):
meta = load_from_json(self.data / f"transforms_{split}.json")
image_filenames = []
poses = []
Expand Down Expand Up @@ -98,7 +97,7 @@ def _generate_dataparser_outputs(self, split="train"):
dataparser_outputs = DataparserOutputs(
image_filenames=image_filenames,
cameras=cameras,
alpha_color=alpha_color_tensor,
alpha_color=self.alpha_color_tensor,
scene_box=scene_box,
dataparser_scale=self.scale_factor,
)
Expand Down
106 changes: 84 additions & 22 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from gsplat.sh import num_sh_bases, spherical_harmonics
from pytorch_msssim import SSIM
from torch.nn import Parameter
from typing_extensions import Literal

from nerfstudio.cameras.cameras import Cameras
from nerfstudio.data.scene_box import OrientedBox
Expand All @@ -40,6 +41,7 @@
# need following import for background color override
from nerfstudio.model_components import renderers
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils.colors import get_color
from nerfstudio.utils.rich_utils import CONSOLE


Expand Down Expand Up @@ -109,6 +111,8 @@ class SplatfactoModelConfig(ModelConfig):
"""period of steps where gaussians are culled and densified"""
resolution_schedule: int = 250
"""training starts at 1/d resolution, every n steps this is doubled"""
background_color: Literal["random", "black", "white"] = "random"
"""Whether to randomize the background color."""
num_downscales: int = 0
"""at the beginning, resolution is 1/2^d, where d is this number"""
cull_alpha_thresh: float = 0.1
Expand All @@ -135,6 +139,10 @@ class SplatfactoModelConfig(ModelConfig):
"""stop culling/splitting at this step WRT screen size of gaussians"""
random_init: bool = False
"""whether to initialize the positions uniformly randomly (not SFM points)"""
num_random: int = 50000
"""Number of gaussians to initialize if random init is used"""
random_scale: float = 10.0
"Size of the cube to initialize random gaussians within"
ssim_lambda: float = 0.2
"""weight of ssim loss"""
stop_split_at: int = 15000
Expand Down Expand Up @@ -171,7 +179,7 @@ def populate_modules(self):
if self.seed_points is not None and not self.config.random_init:
self.means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color)
else:
self.means = torch.nn.Parameter((torch.rand((500000, 3)) - 0.5) * 10)
self.means = torch.nn.Parameter((torch.rand((self.config.num_random, 3)) - 0.5) * self.config.random_scale)
self.xys_grad_norm = None
self.max_2Dsize = None
distances, _ = self.k_nearest_sklearn(self.means.data, 3)
Expand Down Expand Up @@ -213,7 +221,10 @@ def populate_modules(self):
self.step = 0

self.crop_box: Optional[OrientedBox] = None
self.back_color = torch.zeros(3)
if self.config.background_color == "random":
self.background_color = torch.rand(3)
else:
self.background_color = get_color(self.config.background_color)

@property
def colors(self):
Expand Down Expand Up @@ -295,7 +306,10 @@ def dup_in_optim(self, optimizer, dup_mask, new_params, n=2):
param_state = optimizer.state[param]
repeat_dims = (n,) + tuple(1 for _ in range(param_state["exp_avg"].dim() - 1))
param_state["exp_avg"] = torch.cat(
[param_state["exp_avg"], torch.zeros_like(param_state["exp_avg"][dup_mask.squeeze()]).repeat(*repeat_dims)],
[
param_state["exp_avg"],
torch.zeros_like(param_state["exp_avg"][dup_mask.squeeze()]).repeat(*repeat_dims),
],
dim=0,
)
param_state["exp_avg_sq"] = torch.cat(
Expand Down Expand Up @@ -339,15 +353,16 @@ def after_train(self, step: int):
self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32)
newradii = self.radii.detach()[visible_mask]
self.max_2Dsize[visible_mask] = torch.maximum(
self.max_2Dsize[visible_mask], newradii / float(max(self.last_size[0], self.last_size[1]))
self.max_2Dsize[visible_mask],
newradii / float(max(self.last_size[0], self.last_size[1])),
)

def set_crop(self, crop_box: Optional[OrientedBox]):
self.crop_box = crop_box

def set_background(self, back_color: torch.Tensor):
assert back_color.shape == (3,)
self.back_color = back_color
def set_background(self, background_color: torch.Tensor):
assert background_color.shape == (3,)
self.background_color = background_color

def refinement_after(self, optimizers: Optimizers, step):
assert step == self.step
Expand Down Expand Up @@ -394,17 +409,31 @@ def refinement_after(self, optimizers: Optimizers, step):
) = self.dup_gaussians(dups)
self.means = Parameter(torch.cat([self.means.detach(), split_means, dup_means], dim=0))
self.features_dc = Parameter(
torch.cat([self.features_dc.detach(), split_features_dc, dup_features_dc], dim=0)
torch.cat(
[self.features_dc.detach(), split_features_dc, dup_features_dc],
dim=0,
)
)
self.features_rest = Parameter(
torch.cat([self.features_rest.detach(), split_features_rest, dup_features_rest], dim=0)
torch.cat(
[
self.features_rest.detach(),
split_features_rest,
dup_features_rest,
],
dim=0,
)
)
self.opacities = Parameter(torch.cat([self.opacities.detach(), split_opacities, dup_opacities], dim=0))
self.scales = Parameter(torch.cat([self.scales.detach(), split_scales, dup_scales], dim=0))
self.quats = Parameter(torch.cat([self.quats.detach(), split_quats, dup_quats], dim=0))
# append zeros to the max_2Dsize tensor
self.max_2Dsize = torch.cat(
[self.max_2Dsize, torch.zeros_like(split_scales[:, 0]), torch.zeros_like(dup_scales[:, 0])],
[
self.max_2Dsize,
torch.zeros_like(split_scales[:, 0]),
torch.zeros_like(dup_scales[:, 0]),
],
dim=0,
)

Expand All @@ -416,7 +445,14 @@ def refinement_after(self, optimizers: Optimizers, step):

# After a guassian is split into two new gaussians, the original one should also be pruned.
splits_mask = torch.cat(
(splits, torch.zeros(nsamps * splits.sum() + dups.sum(), device=self.device, dtype=torch.bool))
(
splits,
torch.zeros(
nsamps * splits.sum() + dups.sum(),
device=self.device,
dtype=torch.bool,
),
)
)

deleted_mask = self.cull_gaussians(splits_mask)
Expand All @@ -433,7 +469,8 @@ def refinement_after(self, optimizers: Optimizers, step):
# Reset value is set to be twice of the cull_alpha_thresh
reset_value = self.config.cull_alpha_thresh * 2.0
self.opacities.data = torch.clamp(
self.opacities.data, max=torch.logit(torch.tensor(reset_value, device=self.device)).item()
self.opacities.data,
max=torch.logit(torch.tensor(reset_value, device=self.device)).item(),
)
# reset the exp of optimizer
optim = optimizers.optimizers["opacity"]
Expand Down Expand Up @@ -507,7 +544,14 @@ def split_gaussians(self, split_mask, samps):
self.scales[split_mask] = torch.log(torch.exp(self.scales[split_mask]) / size_fac)
# step 5, sample new quats
new_quats = self.quats[split_mask].repeat(samps, 1)
return new_means, new_features_dc, new_features_rest, new_opacities, new_scales, new_quats
return (
new_means,
new_features_dc,
new_features_rest,
new_opacities,
new_scales,
new_quats,
)

def dup_gaussians(self, dup_mask):
"""
Expand All @@ -521,7 +565,14 @@ def dup_gaussians(self, dup_mask):
dup_opacities = self.opacities[dup_mask]
dup_scales = self.scales[dup_mask]
dup_quats = self.quats[dup_mask]
return dup_means, dup_features_dc, dup_features_rest, dup_opacities, dup_scales, dup_quats
return (
dup_means,
dup_features_dc,
dup_features_rest,
dup_opacities,
dup_scales,
dup_quats,
)

@property
def num_points(self):
Expand Down Expand Up @@ -573,7 +624,10 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:

def _get_downscale_factor(self):
if self.training:
return 2 ** max((self.config.num_downscales - self.step // self.config.resolution_schedule), 0)
return 2 ** max(
(self.config.num_downscales - self.step // self.config.resolution_schedule),
0,
)
else:
return 1

Expand All @@ -591,14 +645,23 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
print("Called get_outputs with not a camera")
return {}
assert camera.shape[0] == 1, "Only one camera at a time"

# get the background color
if self.training:
background = torch.rand(3, device=self.device)
if self.config.background_color == "random":
background = torch.rand(3, device=self.device)
elif self.config.background_color == "white":
background = torch.ones(3, device=self.device)
elif self.config.background_color == "black":
background = torch.zeros(3, device=self.device)
else:
background = self.background_color.to(self.device)
else:
# logic for setting the background of the scene
if renderers.BACKGROUND_COLOR_OVERRIDE is not None:
background = renderers.BACKGROUND_COLOR_OVERRIDE
background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device)
else:
background = self.back_color.to(self.device)
background = self.background_color.to(self.device)

if self.crop_box is not None and not self.training:
crop_ids = self.crop_box.within(self.means).squeeze()
if crop_ids.sum() == 0:
Expand Down Expand Up @@ -684,9 +747,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:

# rescale the camera back to original dimensions
camera.rescale_output_resolution(camera_downscale)

assert (num_tiles_hit > 0).any() # type: ignore

rgb = rasterize_gaussians( # type: ignore
self.xys,
depths,
Expand Down Expand Up @@ -777,7 +838,8 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
scale_exp = torch.exp(self.scales)
scale_reg = (
torch.maximum(
scale_exp.amax(dim=-1) / scale_exp.amin(dim=-1), torch.tensor(self.config.max_gauss_ratio)
scale_exp.amax(dim=-1) / scale_exp.amin(dim=-1),
torch.tensor(self.config.max_gauss_ratio),
)
- self.config.max_gauss_ratio
)
Expand Down
1 change: 1 addition & 0 deletions nerfstudio/utils/tensor_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def _broadcast_dict_fields(self, dict_: Dict, batch_shape) -> Dict:
elif isinstance(v, Dict):
new_dict[k] = self._broadcast_dict_fields(v, batch_shape)
else:
# Don't broadcast the remaining fields
new_dict[k] = v
return new_dict

Expand Down

0 comments on commit c83ed7d

Please sign in to comment.