diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index 5e7a52e015..d9d815ed71 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -48,11 +48,14 @@ from torch import Tensor from typing_extensions import Annotated +import viser.transforms as tf + from nerfstudio.cameras.camera_paths import ( get_interpolated_camera_path, get_path_from_json, get_spiral_path, ) + from nerfstudio.cameras.cameras import Cameras, CameraType, RayBundle from nerfstudio.data.datamanagers.base_datamanager import ( VanillaDataManager, @@ -84,6 +87,8 @@ def _render_trajectory_video( depth_near_plane: Optional[float] = None, depth_far_plane: Optional[float] = None, colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions(), + render_nearest_camera=False, + check_occlusions: bool = False, ) -> None: """Helper function to create a video of the spiral trajectory. @@ -99,6 +104,8 @@ def _render_trajectory_video( depth_near_plane: Closest depth to consider when using the colormap for depth. If None, use min value. depth_far_plane: Furthest depth to consider when using the colormap for depth. If None, use max value. colormap_options: Options for colormap. + render_nearest_camera: Whether to render the nearest training camera to the rendered camera. + check_occlusions: If true, checks line-of-sight occlusions when computing camera distance and rejects cameras not visible to each other """ CONSOLE.print("[bold green]Creating trajectory " + output_format) cameras.rescale_output_resolution(rendered_resolution_scaling_factor) @@ -132,6 +139,14 @@ def _render_trajectory_video( with ExitStack() as stack: writer = None + if render_nearest_camera: + assert pipeline.datamanager.train_dataset is not None + train_dataset = pipeline.datamanager.train_dataset + train_cameras = train_dataset.cameras.to(pipeline.device) + else: + train_dataset = None + train_cameras = None + with progress: for camera_idx in progress.track(range(cameras.size), description=""): obb_box = None @@ -139,6 +154,50 @@ def _render_trajectory_video( obb_box = crop_data.obb camera_ray_bundle = cameras.generate_rays(camera_indices=camera_idx, obb_box=obb_box) + max_dist, max_idx = -1, -1 + true_max_dist, true_max_idx = -1, -1 + + if render_nearest_camera: + assert pipeline.datamanager.train_dataset is not None + assert train_dataset is not None + assert train_cameras is not None + cam_pos = cameras[camera_idx].camera_to_worlds[:, 3].cpu() + cam_quat = tf.SO3.from_matrix(cameras[camera_idx].camera_to_worlds[:3, :3].numpy(force=True)).wxyz + + for i in range(len(train_cameras)): + train_cam_pos = train_cameras[i].camera_to_worlds[:, 3].cpu() + # Make sure the line of sight from rendered cam to training cam is not blocked by any object + bundle = RayBundle( + origins=cam_pos.view(1, 3), + directions=((cam_pos - train_cam_pos) / (cam_pos - train_cam_pos).norm()).view(1, 3), + pixel_area=torch.tensor(1).view(1, 1), + nears=torch.tensor(0.05).view(1, 1), + fars=torch.tensor(100).view(1, 1), + camera_indices=torch.tensor(0).view(1, 1), + metadata={}, + ).to(pipeline.device) + outputs = pipeline.model.get_outputs(bundle) + + q = tf.SO3.from_matrix(train_cameras[i].camera_to_worlds[:3, :3].numpy(force=True)).wxyz + # calculate distance between two quaternions + rot_dist = 1 - np.dot(q, cam_quat) ** 2 + pos_dist = torch.norm(train_cam_pos - cam_pos) + dist = 0.3 * rot_dist + 0.7 * pos_dist + + if true_max_dist == -1 or dist < true_max_dist: + true_max_dist = dist + true_max_idx = i + + if outputs["depth"][0] < torch.norm(cam_pos - train_cam_pos).item(): + continue + + if check_occlusions and (max_dist == -1 or dist < max_dist): + max_dist = dist + max_idx = i + + if max_idx == -1: + max_idx = true_max_idx + if crop_data is not None: with renderers.background_color_override_context( crop_data.background_color.to(pipeline.device) @@ -181,6 +240,28 @@ def _render_trajectory_video( .numpy() ) render_image.append(output_image) + + # Add closest training image to the right of the rendered image + if render_nearest_camera: + assert train_dataset is not None + assert train_cameras is not None + img = train_dataset.get_image(max_idx) + height = cameras.image_height[0] + # maintain the resolution of the img to calculate the width from the height + width = int(img.shape[1] * (height / img.shape[0])) + resized_image = torch.nn.functional.interpolate( + img.permute(2, 0, 1)[None], size=(int(height), int(width)) + )[0].permute(1, 2, 0) + resized_image = ( + colormaps.apply_colormap( + image=resized_image, + colormap_options=colormap_options, + ) + .cpu() + .numpy() + ) + render_image.append(resized_image) + render_image = np.concatenate(render_image, axis=1) if output_format == "images": if image_format == "png": @@ -354,6 +435,10 @@ class BaseRender: """Furthest depth to consider when using the colormap for depth. If None, use max value.""" colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions() """Colormap options.""" + render_nearest_camera: bool = False + """Whether to render the nearest training camera to the rendered camera.""" + check_occlusions: bool = False + """If true, checks line-of-sight occlusions when computing camera distance and rejects cameras not visible to each other""" @dataclass @@ -418,6 +503,8 @@ def main(self) -> None: depth_near_plane=self.depth_near_plane, depth_far_plane=self.depth_far_plane, colormap_options=self.colormap_options, + render_nearest_camera=self.render_nearest_camera, + check_occlusions=self.check_occlusions, ) if ( @@ -451,6 +538,8 @@ def main(self) -> None: depth_near_plane=self.depth_near_plane, depth_far_plane=self.depth_far_plane, colormap_options=self.colormap_options, + render_nearest_camera=self.render_nearest_camera, + check_occlusions=self.check_occlusions, ) self.output_path = Path(str(left_eye_path.parent)[:-5] + ".mp4") @@ -549,6 +638,8 @@ def main(self) -> None: depth_near_plane=self.depth_near_plane, depth_far_plane=self.depth_far_plane, colormap_options=self.colormap_options, + render_nearest_camera=self.render_nearest_camera, + check_occlusions=self.check_occlusions, ) @@ -592,6 +683,8 @@ def main(self) -> None: depth_near_plane=self.depth_near_plane, depth_far_plane=self.depth_far_plane, colormap_options=self.colormap_options, + render_nearest_camera=self.render_nearest_camera, + check_occlusions=self.check_occlusions, )