Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored Sep 12, 2024
2 parents 452df25 + 194b5d4 commit 9751d0b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 17 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/build_docker_image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ jobs:
- name: Build and push Docker image
id: push
uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671
if:
with:
context: .
file: ./Dockerfile
Expand All @@ -55,8 +56,9 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
- name: Generate artifact attestation
uses: actions/attest-build-provenance@v1
if: ${{ github.event_name != 'pull_request' }}
with:
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: ${{ github.event_name != 'pull_request' }}
push-to-registry: true

18 changes: 9 additions & 9 deletions nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,15 +778,15 @@ def _compute_rays_for_vr180(

return vr180_origins, directions_stack

for cam in cam_types:
if CameraType.PERSPECTIVE.value in cam_types:
for cam_type in cam_types:
if CameraType.PERSPECTIVE.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.PERSPECTIVE.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)
directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0], mask).float()
directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1], mask).float()
directions_stack[..., 2][mask] = -1.0

elif CameraType.FISHEYE.value in cam_types:
elif CameraType.FISHEYE.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.FISHEYE.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)

Expand All @@ -803,7 +803,7 @@ def _compute_rays_for_vr180(
).float()
directions_stack[..., 2][mask] = -torch.masked_select(torch.cos(theta), mask).float()

elif CameraType.EQUIRECTANGULAR.value in cam_types:
elif CameraType.EQUIRECTANGULAR.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.EQUIRECTANGULAR.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)

Expand All @@ -816,22 +816,22 @@ def _compute_rays_for_vr180(
directions_stack[..., 1][mask] = torch.masked_select(torch.cos(phi), mask).float()
directions_stack[..., 2][mask] = torch.masked_select(-torch.cos(theta) * torch.sin(phi), mask).float()

elif CameraType.OMNIDIRECTIONALSTEREO_L.value in cam_types:
elif CameraType.OMNIDIRECTIONALSTEREO_L.value == cam_type:
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("left")
# assign final camera origins
c2w[..., :3, 3] = ods_origins_circle

elif CameraType.OMNIDIRECTIONALSTEREO_R.value in cam_types:
elif CameraType.OMNIDIRECTIONALSTEREO_R.value == cam_type:
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("right")
# assign final camera origins
c2w[..., :3, 3] = ods_origins_circle

elif CameraType.VR180_L.value in cam_types:
elif CameraType.VR180_L.value == cam_type:
vr180_origins, directions_stack = _compute_rays_for_vr180("left")
# assign final camera origins
c2w[..., :3, 3] = vr180_origins

elif CameraType.VR180_R.value in cam_types:
elif CameraType.VR180_R.value == cam_type:
vr180_origins, directions_stack = _compute_rays_for_vr180("right")
# assign final camera origins
c2w[..., :3, 3] = vr180_origins
Expand Down Expand Up @@ -880,7 +880,7 @@ def _compute_rays_for_vr180(
directions_stack[coord_mask] = camera_utils.fisheye624_unproject(masked_coords, camera_params)

else:
raise ValueError(f"Camera type {cam} not supported.")
raise ValueError(f"Camera type {cam_type} not supported.")

assert directions_stack.shape == (3,) + num_rays_shape + (3,)

Expand Down
8 changes: 6 additions & 2 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class TrainerConfig(ExperimentConfig):
"""Optionally log gradients during training"""
gradient_accumulation_steps: Dict[str, int] = field(default_factory=lambda: {})
"""Number of steps to accumulate gradients over. Contains a mapping of {param_group:num}"""
start_paused: bool = False
"""Whether to start the training in a paused state."""


class Trainer:
Expand Down Expand Up @@ -121,7 +123,9 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int =
self.device += f":{local_rank}"
self.mixed_precision: bool = self.config.mixed_precision
self.use_grad_scaler: bool = self.mixed_precision or self.config.use_grad_scaler
self.training_state: Literal["training", "paused", "completed"] = "training"
self.training_state: Literal["training", "paused", "completed"] = (
"paused" if self.config.start_paused else "training"
)
self.gradient_accumulation_steps: DefaultDict = defaultdict(lambda: 1)
self.gradient_accumulation_steps.update(self.config.gradient_accumulation_steps)

Expand Down Expand Up @@ -361,7 +365,7 @@ def _init_viewer_state(self) -> None:
assert self.viewer_state and self.pipeline.datamanager.train_dataset
self.viewer_state.init_scene(
train_dataset=self.pipeline.datamanager.train_dataset,
train_state="training",
train_state=self.training_state,
eval_dataset=self.pipeline.datamanager.eval_dataset,
)

Expand Down
12 changes: 9 additions & 3 deletions nerfstudio/viewer/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def __init__(
self.output_type_changed = True
self.output_split_type_changed = True
self.step = 0
self.train_btn_state: Literal["training", "paused", "completed"] = "training"
self._prev_train_state: Literal["training", "paused", "completed"] = "training"
self.train_btn_state: Literal["training", "paused", "completed"] = (
"training" if self.trainer is None else self.trainer.training_state
)
self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state
self.last_move_time = 0
# track the camera index that last being clicked
self.current_camera_idx = 0
Expand Down Expand Up @@ -174,7 +176,11 @@ def __init__(
)
self.resume_train.on_click(lambda _: self.toggle_pause_button())
self.resume_train.on_click(lambda han: self._toggle_training_state(han))
self.resume_train.visible = False
if self.train_btn_state == "training":
self.resume_train.visible = False
else:
self.pause_train.visible = False

# Add buttons to toggle training image visibility
self.hide_images = self.viser_server.gui.add_button(
label="Hide Train Cams", disabled=False, icon=viser.Icon.EYE_OFF, color=None
Expand Down
6 changes: 4 additions & 2 deletions nerfstudio/viewer_legacy/server/viewer_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ def __init__(
self.output_type_changed = True
self.output_split_type_changed = True
self.step = 0
self.train_btn_state: Literal["training", "paused", "completed"] = "training"
self._prev_train_state: Literal["training", "paused", "completed"] = "training"
self.train_btn_state: Literal["training", "paused", "completed"] = (
"training" if self.trainer is None else self.trainer.training_state
)
self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state

self.camera_message = None

Expand Down

0 comments on commit 9751d0b

Please sign in to comment.