Skip to content

Commit

Permalink
Use AbsGrad to get better results with less gaussians (nerfstudio-pro…
Browse files Browse the repository at this point in the history
…ject#3113)

* Use AbsGrad to get better results with less gaussians

* some adjustment such that splatfacto-big don't go OOM

* use new gsplat version

---------

Co-authored-by: Jianbo Ye <[email protected]>
Co-authored-by: Justin Kerr <[email protected]>
  • Loading branch information
3 people authored Apr 29, 2024
1 parent bc9328c commit b190874
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
1 change: 1 addition & 0 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@
model=SplatfactoModelConfig(
cull_alpha_thresh=0.005,
continue_cull_post_densification=False,
densify_grad_thresh=0.0006,
),
),
optimizers={
Expand Down
42 changes: 21 additions & 21 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,21 @@ def SH2RGB(sh):
return sh * C0 + 0.5


def resize_image(image: torch.Tensor, d: int):
"""
Downscale images using the same 'area' method in opencv
:param image shape [H, W, C]
:param d downscale factor (must be 2, 4, 8, etc.)
return downscaled image in shape [H//d, W//d, C]
"""
import torch.nn.functional as tf

weight = (1.0 / (d * d)) * torch.ones((1, 1, d, d), dtype=torch.float32, device=image.device)
return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0)


@dataclass
class SplatfactoModelConfig(ModelConfig):
"""Splatfacto Model Config, nerfstudio's implementation of Gaussian Splatting"""
Expand All @@ -103,7 +118,7 @@ class SplatfactoModelConfig(ModelConfig):
"""If True, continue to cull gaussians post refinement"""
reset_alpha_every: int = 30
"""Every this many refinement steps, reset the alpha"""
densify_grad_thresh: float = 0.0002
densify_grad_thresh: float = 0.0008
"""threshold of positional gradient norm for densifying gaussians"""
densify_size_thresh: float = 0.01
"""below this size, gaussians are *duplicated*, otherwise split"""
Expand Down Expand Up @@ -379,8 +394,8 @@ def after_train(self, step: int):
with torch.no_grad():
# keep track of a moving average of grad norms
visible_mask = (self.radii > 0).flatten()
assert self.xys.grad is not None
grads = self.xys.grad.detach().norm(dim=-1)
assert self.xys.absgrad is not None # type: ignore
grads = self.xys.absgrad.detach().norm(dim=-1) # type: ignore
# print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}")
if self.xys_grad_norm is None:
self.xys_grad_norm = grads
Expand Down Expand Up @@ -631,12 +646,7 @@ def _get_downscale_factor(self):
def _downscale_if_required(self, image):
d = self._get_downscale_factor()
if d > 1:
newsize = [image.shape[0] // d, image.shape[1] // d]

# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

return TF.resize(image.permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
return resize_image(image, d)
return image

def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
Expand Down Expand Up @@ -746,23 +756,17 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:

return {"rgb": rgb, "depth": depth, "accumulation": accumulation, "background": background}

# Important to allow xys grads to populate properly
if self.training:
self.xys.retain_grad()

if self.config.sh_degree > 0:
viewdirs = means_crop.detach() - optimized_camera_to_world.detach()[:3, 3] # (N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
rgbs = spherical_harmonics(n, viewdirs, colors_crop)
rgbs = spherical_harmonics(n, viewdirs, colors_crop) # input unnormalized viewdirs
rgbs = torch.clamp(rgbs + 0.5, min=0.0) # type: ignore
else:
rgbs = torch.sigmoid(colors_crop[:, 0, :])

assert (num_tiles_hit > 0).any() # type: ignore

# apply the compensation of screen space blurring to gaussians
opacities = None
if self.config.rasterize_mode == "antialiased":
opacities = torch.sigmoid(opacities_crop) * comp[:, None]
elif self.config.rasterize_mode == "classic":
Expand Down Expand Up @@ -923,11 +927,7 @@ def get_image_metrics_and_images(
gt_rgb = self.composite_with_background(self.get_gt_img(batch["image"]), outputs["background"])
d = self._get_downscale_factor()
if d > 1:
# torchvision can be slow to import, so we do it lazily.
import torchvision.transforms.functional as TF

newsize = [batch["image"].shape[0] // d, batch["image"].shape[1] // d]
predicted_rgb = TF.resize(outputs["rgb"].permute(2, 0, 1), newsize, antialias=None).permute(1, 2, 0)
predicted_rgb = resize_image(outputs["rgb"], d)
else:
predicted_rgb = outputs["rgb"]

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ dependencies = [
"xatlas",
"trimesh>=3.20.2",
"timm==0.6.7",
"gsplat>=0.1.9",
"gsplat>=0.1.11",
"pytorch-msssim",
"pathos",
"packaging"
Expand Down

0 comments on commit b190874

Please sign in to comment.