Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Render nearest training view #2384

Merged
merged 17 commits into from
Nov 13, 2023
Merged
93 changes: 93 additions & 0 deletions nerfstudio/scripts/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,18 @@
from torch import Tensor
from typing_extensions import Annotated

from scipy.spatial.transform import Rotation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for consistency might be nice to use the viser transform library


from nerfstudio.cameras.camera_paths import (
get_interpolated_camera_path,
get_path_from_json,
get_spiral_path,
)

from nerfstudio.cameras.rays import RayBundle
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager
from nerfstudio.data.datasets.base_dataset import InputDataset, DataparserOutputs
from nerfstudio.data.scene_box import SceneBox
from nerfstudio.model_components import renderers
from nerfstudio.pipelines.base_pipeline import Pipeline
Expand All @@ -75,6 +80,8 @@ def _render_trajectory_video(
image_format: Literal["jpeg", "png"] = "jpeg",
jpeg_quality: int = 100,
colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions(),
render_nearest_camera=False,
check_occlusions: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add to docstring

) -> None:
"""Helper function to create a video of the spiral trajectory.

Expand Down Expand Up @@ -121,6 +128,20 @@ def _render_trajectory_video(
with ExitStack() as stack:
writer = None

train_cameras = Cameras(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be inside the if render_nearest_camera statement below

camera_to_worlds=torch.zeros((0, 4, 4)),
fx=torch.zeros((0,)),
fy=torch.zeros((0,)),
cx=torch.zeros((0,)),
cy=torch.zeros((0,)),
)
train_dataset = InputDataset(DataparserOutputs([], train_cameras))

if render_nearest_camera:
assert pipeline.datamanager.train_dataset is not None
train_dataset = pipeline.datamanager.train_dataset
train_cameras = train_dataset.cameras.to(pipeline.device)

with progress:
for camera_idx in progress.track(range(cameras.size), description=""):
aabb_box = None
Expand All @@ -130,6 +151,49 @@ def _render_trajectory_video(
aabb_box = SceneBox(torch.stack([bounding_box_min, bounding_box_max]).to(pipeline.device))
camera_ray_bundle = cameras.generate_rays(camera_indices=camera_idx, aabb_box=aabb_box)

max_dist, max_idx = -1, -1
true_max_dist, true_max_idx = -1, -1

if render_nearest_camera:
cam_pos = cameras[camera_idx].camera_to_worlds[:, 3].cpu()
cam_rot = Rotation.from_matrix(cameras[camera_idx].camera_to_worlds[:3, :3].cpu())
cam_quat = cam_rot.as_quat()

for i in range(len(train_cameras)):
train_cam_pos = train_cameras[i].camera_to_worlds[:, 3].cpu()
# Make sure the line of sight from rendered cam to training cam is not blocked by any object
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? I think it would be nice if the result is deterministic, regardless of how the NeRF is. Have you found it to be needed in some cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a flag check-occlusions so you can toggle whether or not you want the check

bundle = RayBundle(
origins=cam_pos.view(1, 3),
directions=((cam_pos - train_cam_pos) / (cam_pos - train_cam_pos).norm()).view(1, 3),
pixel_area=torch.tensor(1).view(1, 1),
nears=torch.tensor(0.05).view(1, 1),
fars=torch.tensor(100).view(1, 1),
camera_indices=torch.tensor(0).view(1, 1),
metadata={},
).to(pipeline.device)
outputs = pipeline.model.get_outputs(bundle)

r = Rotation.from_matrix(train_cameras[i].camera_to_worlds[:3, :3].cpu())
q = r.as_quat()
# calculate distance between two quaternions
rot_dist = 1 - np.dot(q, cam_quat) ** 2
pos_dist = torch.norm(train_cam_pos - cam_pos)
dist = 0.3 * rot_dist + 0.7 * pos_dist

if true_max_dist == -1 or dist < true_max_dist:
true_max_dist = dist
true_max_idx = i

if outputs["depth"][0] < torch.norm(cam_pos - train_cam_pos).item():
continue

if check_occlusions and (max_dist == -1 or dist < max_dist):
max_dist = dist
max_idx = i

if max_idx == -1:
max_idx = true_max_idx

if crop_data is not None:
with renderers.background_color_override_context(
crop_data.background_color.to(pipeline.device)
Expand Down Expand Up @@ -158,6 +222,23 @@ def _render_trajectory_video(
.numpy()
)
render_image.append(output_image)

# Add closest training image to the right of the rendered image
if render_nearest_camera:
img = train_dataset.get_image(max_idx)
resized_image = torch.nn.functional.interpolate(
img.permute(2, 0, 1)[None], size=(int(cameras.image_height[0]), int(cameras.image_width[0]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't work properly if the aspect ratio of the render camera is different than the train camera, it should resize the height to render height, and automatically calculate width based on the aspect ratio

)[0].permute(1, 2, 0)
resized_image = (
colormaps.apply_colormap(
image=resized_image,
colormap_options=colormap_options,
)
.cpu()
.numpy()
)
render_image.append(resized_image)

render_image = np.concatenate(render_image, axis=1)
if output_format == "images":
if image_format == "png":
Expand Down Expand Up @@ -315,6 +396,10 @@ class BaseRender:
"""Specifies number of rays per chunk during eval. If None, use the value in the config file."""
colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions()
"""Colormap options."""
render_nearest_camera: bool = False
"""Whether to render the nearest training camera to the rendered camera."""
check_occlusions: bool = False
"""Whether to check occlusions."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is kinda vague: "if true, checks line-of-sight occlusions when computing camera distance and rejects cameras not visible to each other"



@dataclass
Expand Down Expand Up @@ -372,6 +457,8 @@ def main(self) -> None:
image_format=self.image_format,
jpeg_quality=self.jpeg_quality,
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
)

if camera_path.camera_type[0] == CameraType.OMNIDIRECTIONALSTEREO_L.value:
Expand All @@ -396,6 +483,8 @@ def main(self) -> None:
image_format=self.image_format,
jpeg_quality=self.jpeg_quality,
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
)

# stack the left and right eye renders for final output
Expand Down Expand Up @@ -471,6 +560,8 @@ def main(self) -> None:
output_format=self.output_format,
image_format=self.image_format,
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
)


Expand Down Expand Up @@ -514,6 +605,8 @@ def main(self) -> None:
output_format=self.output_format,
image_format=self.image_format,
colormap_options=self.colormap_options,
render_nearest_camera=self.render_nearest_camera,
check_occlusions=self.check_occlusions,
)


Expand Down