Skip to content

Commit

Permalink
fix CI pyright issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Jing committed Dec 7, 2023
1 parent 4719567 commit 0ee50a8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def train(self) -> None:
# (https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management)
# for more details about GPU memory management.
writer.put_scalar(
name="GPU Memory (MB)", scalar=torch.cuda.max_memory_allocated() / (1024 ** 2), step=step
name="GPU Memory (MB)", scalar=torch.cuda.max_memory_allocated() / (1024**2), step=step
)

# Do not perform evaluation if there are no validation images
Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _distributed_worker(
dist_url: str,
config: TrainerConfig,
timeout: timedelta = DEFAULT_TIMEOUT,
device_type: Literal["cpu", "cuda", "mps"] = "cuda",
device_type: Literal["cpu", "cuda", "mps", "xpu"] = "cuda",
) -> Any:
"""Spawned distributed worker that handles the initialization of process group and handles the
training process on multiple processes.
Expand Down Expand Up @@ -165,7 +165,7 @@ def launch(
dist_url: str = "auto",
config: Optional[TrainerConfig] = None,
timeout: timedelta = DEFAULT_TIMEOUT,
device_type: Literal["cpu", "cuda", "mps"] = "cuda",
device_type: Literal["cpu", "cuda", "mps", "xpu"] = "cuda",
) -> None:
"""Function that spawns multiple processes to call on main_func
Expand Down

0 comments on commit 0ee50a8

Please sign in to comment.