From 57bbc0bcc49abd13378c5347b09ffa6b2935c279 Mon Sep 17 00:00:00 2001 From: jb-ye Date: Thu, 22 Aug 2024 02:03:55 +0000 Subject: [PATCH] use gsplat strategy interface to simplify splatfacto --- nerfstudio/configs/method_configs.py | 1 - nerfstudio/models/splatfacto.py | 311 ++++----------------------- pyproject.toml | 2 +- 3 files changed, 43 insertions(+), 271 deletions(-) diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index e77ab130c46..c5f471c8b15 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -663,7 +663,6 @@ ), model=SplatfactoModelConfig( cull_alpha_thresh=0.005, - continue_cull_post_densification=False, densify_grad_thresh=0.0006, ), ), diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index a6881d42449..63caf360c9f 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -25,7 +25,7 @@ import numpy as np import torch -from gsplat.cuda_legacy._torch_impl import quat_to_rotmat +from gsplat.strategy import DefaultStrategy try: from gsplat.rendering import rasterization @@ -135,8 +135,6 @@ class SplatfactoModelConfig(ModelConfig): """threshold of opacity for culling gaussians. One can set it to a lower value (e.g. 0.005) for higher quality.""" cull_scale_thresh: float = 0.5 """threshold of scale for culling huge gaussians""" - continue_cull_post_densification: bool = True - """If True, continue to cull gaussians post refinement""" reset_alpha_every: int = 30 """Every this many refinement steps, reset the alpha""" densify_grad_thresh: float = 0.0008 @@ -216,8 +214,6 @@ def populate_modules(self): means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color) else: means = torch.nn.Parameter((torch.rand((self.config.num_random, 3)) - 0.5) * self.config.random_scale) - self.xys_grad_norm = None - self.max_2Dsize = None distances, _ = self.k_nearest_sklearn(means.data, 3) distances = torch.from_numpy(distances) # find the average of the three nearest neighbors for each point and use that as the scale @@ -286,6 +282,26 @@ def populate_modules(self): grid_W=self.config.grid_shape[2], ) + # Strategy for GS densification + self.strategy = DefaultStrategy( + prune_opa=self.config.cull_alpha_thresh, + grow_grad2d=self.config.densify_grad_thresh, + grow_scale3d=self.config.densify_size_thresh, + grow_scale2d=self.config.split_screen_size, + prune_scale3d=self.config.cull_scale_thresh, + prune_scale2d=self.config.cull_screen_size, + refine_scale2d_stop_iter=self.config.stop_screen_size_at, + refine_start_iter=self.config.warmup_length, + refine_stop_iter=self.config.stop_split_at, + reset_every=self.config.reset_alpha_every * self.config.refine_every, + refine_every=self.config.refine_every, + pause_refine_after_reset=self.num_train_data, + absgrad=True, + revised_opacity=False, + verbose=True, + ) + self.strategy_state = self.strategy.initialize_state(scene_scale=1.0) + @property def colors(self): if self.config.sh_degree > 0: @@ -364,87 +380,6 @@ def k_nearest_sklearn(self, x: torch.Tensor, k: int): # Exclude the point itself from the result and return return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32) - def remove_from_optim(self, optimizer, deleted_mask, new_params): - """removes the deleted_mask from the optimizer provided""" - assert len(new_params) == 1 - # assert isinstance(optimizer, torch.optim.Adam), "Only works with Adam" - - param = optimizer.param_groups[0]["params"][0] - param_state = optimizer.state[param] - del optimizer.state[param] - - # Modify the state directly without deleting and reassigning. - if "exp_avg" in param_state: - param_state["exp_avg"] = param_state["exp_avg"][~deleted_mask] - param_state["exp_avg_sq"] = param_state["exp_avg_sq"][~deleted_mask] - - # Update the parameter in the optimizer's param group. - del optimizer.param_groups[0]["params"][0] - del optimizer.param_groups[0]["params"] - optimizer.param_groups[0]["params"] = new_params - optimizer.state[new_params[0]] = param_state - - def remove_from_all_optim(self, optimizers, deleted_mask): - param_groups = self.get_gaussian_param_groups() - for group, param in param_groups.items(): - self.remove_from_optim(optimizers.optimizers[group], deleted_mask, param) - torch.cuda.empty_cache() - - def dup_in_optim(self, optimizer, dup_mask, new_params, n=2): - """adds the parameters to the optimizer""" - param = optimizer.param_groups[0]["params"][0] - param_state = optimizer.state[param] - if "exp_avg" in param_state: - repeat_dims = (n,) + tuple(1 for _ in range(param_state["exp_avg"].dim() - 1)) - param_state["exp_avg"] = torch.cat( - [ - param_state["exp_avg"], - torch.zeros_like(param_state["exp_avg"][dup_mask.squeeze()]).repeat(*repeat_dims), - ], - dim=0, - ) - param_state["exp_avg_sq"] = torch.cat( - [ - param_state["exp_avg_sq"], - torch.zeros_like(param_state["exp_avg_sq"][dup_mask.squeeze()]).repeat(*repeat_dims), - ], - dim=0, - ) - del optimizer.state[param] - optimizer.state[new_params[0]] = param_state - optimizer.param_groups[0]["params"] = new_params - del param - - def dup_in_all_optim(self, optimizers, dup_mask, n): - param_groups = self.get_gaussian_param_groups() - for group, param in param_groups.items(): - self.dup_in_optim(optimizers.optimizers[group], dup_mask, param, n) - - def after_train(self, step: int): - assert step == self.step - # to save some training time, we no longer need to update those stats post refinement - if self.step >= self.config.stop_split_at: - return - with torch.no_grad(): - # keep track of a moving average of grad norms - visible_mask = (self.radii > 0).flatten() - grads = self.xys.absgrad[0][visible_mask].norm(dim=-1) # type: ignore - # print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}") - if self.xys_grad_norm is None: - self.xys_grad_norm = torch.zeros(self.num_points, device=self.device, dtype=torch.float32) - self.vis_counts = torch.ones(self.num_points, device=self.device, dtype=torch.float32) - assert self.vis_counts is not None - self.vis_counts[visible_mask] += 1 - self.xys_grad_norm[visible_mask] += grads - # update the max screen size, as a ratio of number of pixels - if self.max_2Dsize is None: - self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32) - newradii = self.radii.detach()[visible_mask] - self.max_2Dsize[visible_mask] = torch.maximum( - self.max_2Dsize[visible_mask], - newradii / float(max(self.last_size[0], self.last_size[1])), - ) - def set_crop(self, crop_box: Optional[OrientedBox]): self.crop_box = crop_box @@ -452,199 +387,39 @@ def set_background(self, background_color: torch.Tensor): assert background_color.shape == (3,) self.background_color = background_color - def refinement_after(self, optimizers: Optimizers, step): + def step_post_backward(self, step): assert step == self.step - if self.step <= self.config.warmup_length: - return - with torch.no_grad(): - # Offset all the opacity reset logic by refine_every so that we don't - # save checkpoints right when the opacity is reset (saves every 2k) - # then cull - # only split/cull if we've seen every image since opacity reset - reset_interval = self.config.reset_alpha_every * self.config.refine_every - do_densification = ( - self.step < self.config.stop_split_at - and self.step % reset_interval > self.num_train_data + self.config.refine_every - ) - if do_densification: - # then we densify - assert self.xys_grad_norm is not None and self.vis_counts is not None and self.max_2Dsize is not None - avg_grad_norm = (self.xys_grad_norm / self.vis_counts) * 0.5 * max(self.last_size[0], self.last_size[1]) - high_grads = (avg_grad_norm > self.config.densify_grad_thresh).squeeze() - splits = (self.scales.exp().max(dim=-1).values > self.config.densify_size_thresh).squeeze() - splits &= high_grads - if self.step < self.config.stop_screen_size_at: - splits |= (self.max_2Dsize > self.config.split_screen_size).squeeze() - nsamps = self.config.n_split_samples - split_params = self.split_gaussians(splits, nsamps) - - dups = (self.scales.exp().max(dim=-1).values <= self.config.densify_size_thresh).squeeze() - dups &= high_grads - dup_params = self.dup_gaussians(dups) - for name, param in self.gauss_params.items(): - self.gauss_params[name] = torch.nn.Parameter( - torch.cat([param.detach(), split_params[name], dup_params[name]], dim=0) - ) - # append zeros to the max_2Dsize tensor - self.max_2Dsize = torch.cat( - [ - self.max_2Dsize, - torch.zeros_like(split_params["scales"][:, 0]), - torch.zeros_like(dup_params["scales"][:, 0]), - ], - dim=0, - ) - - split_idcs = torch.where(splits)[0] - self.dup_in_all_optim(optimizers, split_idcs, nsamps) - - dup_idcs = torch.where(dups)[0] - self.dup_in_all_optim(optimizers, dup_idcs, 1) - - # After a guassian is split into two new gaussians, the original one should also be pruned. - splits_mask = torch.cat( - ( - splits, - torch.zeros( - nsamps * splits.sum() + dups.sum(), - device=self.device, - dtype=torch.bool, - ), - ) - ) - - deleted_mask = self.cull_gaussians(splits_mask) - elif self.step >= self.config.stop_split_at and self.config.continue_cull_post_densification: - deleted_mask = self.cull_gaussians() - else: - # if we donot allow culling post refinement, no more gaussians will be pruned. - deleted_mask = None - - if deleted_mask is not None: - self.remove_from_all_optim(optimizers, deleted_mask) - - if self.step < self.config.stop_split_at and self.step % reset_interval == self.config.refine_every: - # Reset value is set to be twice of the cull_alpha_thresh - reset_value = self.config.cull_alpha_thresh * 2.0 - self.opacities.data = torch.clamp( - self.opacities.data, - max=torch.logit(torch.tensor(reset_value, device=self.device)).item(), - ) - # reset the exp of optimizer - optim = optimizers.optimizers["opacities"] - param = optim.param_groups[0]["params"][0] - param_state = optim.state[param] - param_state["exp_avg"] = torch.zeros_like(param_state["exp_avg"]) - param_state["exp_avg_sq"] = torch.zeros_like(param_state["exp_avg_sq"]) - - self.xys_grad_norm = None - self.vis_counts = None - self.max_2Dsize = None - - def cull_gaussians(self, extra_cull_mask: Optional[torch.Tensor] = None): - """ - This function deletes gaussians with under a certain opacity threshold - extra_cull_mask: a mask indicates extra gaussians to cull besides existing culling criterion - """ - n_bef = self.num_points - # cull transparent ones - culls = (torch.sigmoid(self.opacities) < self.config.cull_alpha_thresh).squeeze() - below_alpha_count = torch.sum(culls).item() - toobigs_count = 0 - if extra_cull_mask is not None: - culls = culls | extra_cull_mask - if self.step > self.config.refine_every * self.config.reset_alpha_every: - # cull huge ones - toobigs = (torch.exp(self.scales).max(dim=-1).values > self.config.cull_scale_thresh).squeeze() - if self.step < self.config.stop_screen_size_at: - # cull big screen space - if self.max_2Dsize is not None: - toobigs = toobigs | (self.max_2Dsize > self.config.cull_screen_size).squeeze() - culls = culls | toobigs - toobigs_count = torch.sum(toobigs).item() - for name, param in self.gauss_params.items(): - self.gauss_params[name] = torch.nn.Parameter(param[~culls]) - - CONSOLE.log( - f"Culled {n_bef - self.num_points} gaussians " - f"({below_alpha_count} below alpha thresh, {toobigs_count} too bigs, {self.num_points} remaining)" + self.strategy.step_post_backward( + params=self.gauss_params, + optimizers=self.optimizers, + state=self.strategy_state, + step=self.step, + info=self.info, + packed=False, ) - return culls - - def split_gaussians(self, split_mask, samps): - """ - This function splits gaussians that are too large - """ - n_splits = split_mask.sum().item() - CONSOLE.log(f"Splitting {split_mask.sum().item()/self.num_points} gaussians: {n_splits}/{self.num_points}") - centered_samples = torch.randn((samps * n_splits, 3), device=self.device) # Nx3 of axis-aligned scales - scaled_samples = ( - torch.exp(self.scales[split_mask].repeat(samps, 1)) * centered_samples - ) # how these scales are rotated - quats = self.quats[split_mask] / self.quats[split_mask].norm(dim=-1, keepdim=True) # normalize them first - rots = quat_to_rotmat(quats.repeat(samps, 1)) # how these scales are rotated - rotated_samples = torch.bmm(rots, scaled_samples[..., None]).squeeze() - new_means = rotated_samples + self.means[split_mask].repeat(samps, 1) - # step 2, sample new colors - new_features_dc = self.features_dc[split_mask].repeat(samps, 1) - new_features_rest = self.features_rest[split_mask].repeat(samps, 1, 1) - # step 3, sample new opacities - new_opacities = self.opacities[split_mask].repeat(samps, 1) - # step 4, sample new scales - size_fac = 1.6 - new_scales = torch.log(torch.exp(self.scales[split_mask]) / size_fac).repeat(samps, 1) - self.scales[split_mask] = torch.log(torch.exp(self.scales[split_mask]) / size_fac) - # step 5, sample new quats - new_quats = self.quats[split_mask].repeat(samps, 1) - out = { - "means": new_means, - "features_dc": new_features_dc, - "features_rest": new_features_rest, - "opacities": new_opacities, - "scales": new_scales, - "quats": new_quats, - } - for name, param in self.gauss_params.items(): - if name not in out: - out[name] = param[split_mask].repeat(samps, 1) - return out - - def dup_gaussians(self, dup_mask): - """ - This function duplicates gaussians that are too small - """ - n_dups = dup_mask.sum().item() - CONSOLE.log(f"Duplicating {dup_mask.sum().item()/self.num_points} gaussians: {n_dups}/{self.num_points}") - new_dups = {} - for name, param in self.gauss_params.items(): - new_dups[name] = param[dup_mask] - return new_dups - def get_training_callbacks( self, training_callback_attributes: TrainingCallbackAttributes ) -> List[TrainingCallback]: cbs = [] - cbs.append(TrainingCallback([TrainingCallbackLocation.BEFORE_TRAIN_ITERATION], self.step_cb)) - # The order of these matters cbs.append( TrainingCallback( - [TrainingCallbackLocation.AFTER_TRAIN_ITERATION], - self.after_train, + [TrainingCallbackLocation.BEFORE_TRAIN_ITERATION], + self.step_cb, + args=[training_callback_attributes.optimizers], ) ) cbs.append( TrainingCallback( [TrainingCallbackLocation.AFTER_TRAIN_ITERATION], - self.refinement_after, - update_every_num_iters=self.config.refine_every, - args=[training_callback_attributes.optimizers], + self.step_post_backward, ) ) return cbs - def step_cb(self, step): + def step_cb(self, optimizers: Optimizers, step): self.step = step + self.optimizers = optimizers.optimizers def get_gaussian_param_groups(self) -> Dict[str, List[Parameter]]: # Here we explicitly use the means, scales as parameters so that the user can override this function and @@ -766,7 +541,6 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1) - BLOCK_WIDTH = 16 # this controls the tile size of rasterization, 16 is a good default camera_scale_fac = self._get_downscale_factor() camera.rescale_output_resolution(1 / camera_scale_fac) viewmat = get_viewmat(optimized_camera_to_world) @@ -790,9 +564,9 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: colors_crop = torch.sigmoid(colors_crop).squeeze(1) # [N, 1, 3] -> [N, 3] sh_degree_to_use = None - render, alpha, info = rasterization( + render, alpha, self.info = rasterization( means=means_crop, - quats=quats_crop / quats_crop.norm(dim=-1, keepdim=True), + quats=quats_crop, # rasterization does normalization internally scales=torch.exp(scales_crop), opacities=torch.sigmoid(opacities_crop).squeeze(-1), colors=colors_crop, @@ -800,22 +574,21 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: Ks=K, # [1, 3, 3] width=W, height=H, - tile_size=BLOCK_WIDTH, packed=False, near_plane=0.01, far_plane=1e10, render_mode=render_mode, sh_degree=sh_degree_to_use, sparse_grad=False, - absgrad=True, + absgrad=self.strategy.absgrad, rasterize_mode=self.config.rasterize_mode, # set some threshold to disregrad small gaussians for faster rendering. # radius_clip=3.0, ) - if self.training and info["means2d"].requires_grad: - info["means2d"].retain_grad() - self.xys = info["means2d"] # [1, N, 2] - self.radii = info["radii"][0] # [N] + if self.training: + self.strategy.step_pre_backward( + self.gauss_params, self.optimizers, self.strategy_state, self.step, self.info + ) alpha = alpha[:, ...] background = self._get_background_color() diff --git a/pyproject.toml b/pyproject.toml index 753c127629c..ae3c1d4c8c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ dependencies = [ "xatlas", "trimesh>=3.20.2", "timm==0.6.7", - "gsplat==1.0.0", + "gsplat==1.2.0", "pytorch-msssim", "pathos", "packaging",