Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored May 29, 2024
2 parents 6672de7 + 5491df9 commit 95b847e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 13 deletions.
21 changes: 15 additions & 6 deletions nerfstudio/cameras/camera_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,24 @@ def apply_to_camera(self, camera: Cameras) -> torch.Tensor:
if self.config.mode == "off":
return camera.camera_to_worlds

assert camera.metadata is not None, "Must provide id of camera in its metadata"
if "cam_idx" not in camera.metadata:
# Evalutaion cams?
if camera.metadata is None or "cam_idx" not in camera.metadata:
# Viser cameras
return camera.camera_to_worlds

camera_idx = camera.metadata["cam_idx"]
adj = self(torch.tensor([camera_idx], dtype=torch.long, device=camera.device)) # type: ignore
adj = torch.cat([adj, torch.Tensor([0, 0, 0, 1])[None, None].to(adj)], dim=1)
return torch.bmm(camera.camera_to_worlds, adj)
adj = self(torch.tensor([camera_idx], dtype=torch.long)).to(camera.device) # type: ignore

return torch.cat(
[
# Apply rotation to directions in world coordinates, without touching the origin.
# Equivalent to: directions -> correction[:3,:3] @ directions
torch.bmm(adj[..., :3, :3], camera.camera_to_worlds[..., :3, :3]),
# Apply translation in world coordinate, independently of rotation.
# Equivalent to: origins -> origins + correction[:3,3]
camera.camera_to_worlds[..., :3, 3:] + adj[..., :3, 3:],
],
dim=-1,
)

def get_loss_dict(self, loss_dict: dict) -> None:
"""Add regularization"""
Expand Down
23 changes: 20 additions & 3 deletions nerfstudio/exporter/exporter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
from torch import Tensor

from nerfstudio.cameras.camera_optimizers import CameraOptimizer
from nerfstudio.cameras.cameras import Cameras
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.data.datasets.base_dataset import InputDataset
Expand Down Expand Up @@ -283,11 +284,14 @@ def render_trajectory(
return images, depths


def collect_camera_poses_for_dataset(dataset: Optional[InputDataset]) -> List[Dict[str, Any]]:
def collect_camera_poses_for_dataset(
dataset: Optional[InputDataset], camera_optimizer: Optional[CameraOptimizer] = None
) -> List[Dict[str, Any]]:
"""Collects rescaled, translated and optimised camera poses for a dataset.
Args:
dataset: Dataset to collect camera poses for.
camera_optimizer: Camera optimizer that has been used for adjusting the poses
Returns:
List of dicts containing camera poses.
Expand All @@ -304,7 +308,15 @@ def collect_camera_poses_for_dataset(dataset: Optional[InputDataset]) -> List[Di
# new cameras are in cameras, whereas image paths are stored in a private member of the dataset
for idx in range(len(cameras)):
image_filename = image_filenames[idx]
transform = cameras.camera_to_worlds[idx].tolist()
if camera_optimizer is None:
transform = cameras.camera_to_worlds[idx].tolist()
else:
# print('exporting optimized camera pose for camera %d' % idx)
camera = cameras[idx : idx + 1]
assert camera.metadata is not None
camera.metadata["cam_idx"] = idx
transform = camera_optimizer.apply_to_camera(camera).tolist()[0]

frames.append(
{
"file_path": str(image_filename),
Expand All @@ -331,7 +343,12 @@ def collect_camera_poses(pipeline: VanillaPipeline) -> Tuple[List[Dict[str, Any]
eval_dataset = pipeline.datamanager.eval_dataset
assert isinstance(eval_dataset, InputDataset)

train_frames = collect_camera_poses_for_dataset(train_dataset)
camera_optimizer = None
if hasattr(pipeline.model, "camera_optimizer"):
camera_optimizer = pipeline.model.camera_optimizer

train_frames = collect_camera_poses_for_dataset(train_dataset, camera_optimizer)
# Note: returning original poses, even if --eval-mode=all
eval_frames = collect_camera_poses_for_dataset(eval_dataset)

return train_frames, eval_frames
6 changes: 2 additions & 4 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,10 +672,10 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
return {}
assert camera.shape[0] == 1, "Only one camera at a time"

optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)[0, ...]

# get the background color
if self.training:
optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)[0, ...]

if self.config.background_color == "random":
background = torch.rand(3, device=self.device)
elif self.config.background_color == "white":
Expand All @@ -685,8 +685,6 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
else:
background = self.background_color.to(self.device)
else:
optimized_camera_to_world = camera.camera_to_worlds[0, ...]

if renderers.BACKGROUND_COLOR_OVERRIDE is not None:
background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device)
else:
Expand Down

0 comments on commit 95b847e

Please sign in to comment.