Skip to content

Commit

Permalink
Fix nerfstudio-project#3064 and some minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianbo Ye committed May 1, 2024
1 parent 759fda8 commit 7a08c24
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
5 changes: 4 additions & 1 deletion nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,12 @@ def _load_images(
def undistort_idx(idx: int) -> Dict[str, torch.Tensor]:
data = dataset.get_data(idx, image_type=self.config.cache_images_type)
camera = dataset.cameras[idx].reshape(())
K = camera.get_intrinsics_matrices().numpy()
assert data["image"].shape[1] == camera.width.item() and data["image"].shape[0] == camera.height.item(), \
f'The size of image ({data["image"].shape[1]}, {data["image"].shape[0]}) loaded ' \
f'does not match the camera parameters ({camera.width.item(), camera.height.item()})'
if camera.distortion_params is None:
return data
K = camera.get_intrinsics_matrices().numpy()
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()

Expand Down
24 changes: 10 additions & 14 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,13 @@ def _downscale_if_required(self, image):
return resize_image(image, d)
return image

@staticmethod
def get_empty_outputs(camera, background):
rgb = background.repeat(int(camera.height.item()), int(camera.width.item()), 1)
depth = background.new_ones(*rgb.shape[:2], 1) * 10
accumulation = background.new_zeros(*rgb.shape[:2], 1)
return {"rgb": rgb, "depth": depth, "accumulation": accumulation, "background": background}

def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
"""Takes in a Ray Bundle and returns a dictionary of outputs.
Expand Down Expand Up @@ -687,10 +694,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
if self.crop_box is not None and not self.training:
crop_ids = self.crop_box.within(self.means).squeeze()
if crop_ids.sum() == 0:
rgb = background.repeat(int(camera.height.item()), int(camera.width.item()), 1)
depth = background.new_ones(*rgb.shape[:2], 1) * 10
accumulation = background.new_zeros(*rgb.shape[:2], 1)
return {"rgb": rgb, "depth": depth, "accumulation": accumulation, "background": background}
return self.get_empty_outputs(camera, background)
else:
crop_ids = None
camera_downscale = self._get_downscale_factor()
Expand Down Expand Up @@ -750,11 +754,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
camera.rescale_output_resolution(camera_downscale)

if (self.radii).sum() == 0:
rgb = background.repeat(H, W, 1)
depth = background.new_ones(*rgb.shape[:2], 1) * 10
accumulation = background.new_zeros(*rgb.shape[:2], 1)

return {"rgb": rgb, "depth": depth, "accumulation": accumulation, "background": background}
return self.get_empty_outputs(camera, background)

if self.config.sh_degree > 0:
viewdirs = means_crop.detach() - optimized_camera_to_world.detach()[:3, 3] # (N, 3)
Expand Down Expand Up @@ -925,11 +925,7 @@ def get_image_metrics_and_images(
A dictionary of metrics.
"""
gt_rgb = self.composite_with_background(self.get_gt_img(batch["image"]), outputs["background"])
d = self._get_downscale_factor()
if d > 1:
predicted_rgb = resize_image(outputs["rgb"], d)
else:
predicted_rgb = outputs["rgb"]
predicted_rgb = outputs["rgb"]

combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1)

Expand Down
4 changes: 3 additions & 1 deletion nerfstudio/viewer/render_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def _render_img(self, camera_state: CameraState):

desired_depth_pixels = {"low_move": 128, "low_static": 128, "high": 512}[self.state] ** 2
current_depth_pixels = outputs["depth"].shape[0] * outputs["depth"].shape[1]
scale = min(desired_depth_pixels / current_depth_pixels, 1.0)

# from the panel of ns-viewer, it is possible for user to enter zero resolution
scale = min(desired_depth_pixels / max(1, current_depth_pixels), 1.0)

outputs["gl_z_buf_depth"] = F.interpolate(
outputs["depth"].squeeze(dim=-1)[None, None, ...],
Expand Down

0 comments on commit 7a08c24

Please sign in to comment.