diff --git a/Dockerfile b/Dockerfile index aac97673cc..f1b29ea7aa 100644 --- a/Dockerfile +++ b/Dockerfile @@ -127,18 +127,18 @@ WORKDIR /home/${USERNAME} ENV PATH="${PATH}:/home/${USERNAME}/.local/bin" # Upgrade pip and install packages. -RUN python3.10 -m pip install --no-cache-dir --upgrade pip setuptools pathtools promise pybind11 omegaconf +RUN python3.10 -m pip install --no-cache-dir --upgrade pip setuptools==69.5.1 pathtools promise pybind11 omegaconf # Install pytorch and submodules # echo "${CUDA_VERSION}" | sed 's/.$//' | tr -d '.' -- CUDA_VERSION -> delete last digit -> delete all '.' RUN CUDA_VER=$(echo "${CUDA_VERSION}" | sed 's/.$//' | tr -d '.') && python3.10 -m pip install --no-cache-dir \ - torch==2.0.1+cu${CUDA_VER} \ - torchvision==0.15.2+cu${CUDA_VER} \ + torch==2.1.2+cu${CUDA_VER} \ + torchvision==0.16.2+cu${CUDA_VER} \ --extra-index-url https://download.pytorch.org/whl/cu${CUDA_VER} # Install tiny-cuda-nn (we need to set the target architectures as environment variable first). ENV TCNN_CUDA_ARCHITECTURES=${CUDA_ARCHITECTURES} -RUN python3.10 -m pip install --no-cache-dir git+https://github.com/NVlabs/tiny-cuda-nn.git@v1.6#subdirectory=bindings/torch +RUN python3.10 -m pip install --no-cache-dir git+https://github.com/NVlabs/tiny-cuda-nn.git#subdirectory=bindings/torch # Install pycolmap, required by hloc. RUN git clone --branch v0.4.0 --recursive https://github.com/colmap/pycolmap.git && \ diff --git a/docs/nerfology/methods/index.md b/docs/nerfology/methods/index.md index d08c7a4ab3..a8ecc7f099 100644 --- a/docs/nerfology/methods/index.md +++ b/docs/nerfology/methods/index.md @@ -28,6 +28,7 @@ The following methods are supported in nerfstudio: :maxdepth: 1 Instant-NGP Splatfacto + Splatfacto-W Instruct-NeRF2NeRF Instruct-GS2GS SIGNeRF diff --git a/docs/nerfology/methods/splat.md b/docs/nerfology/methods/splat.md index 35a88a8f7d..de37763749 100644 --- a/docs/nerfology/methods/splat.md +++ b/docs/nerfology/methods/splat.md @@ -41,7 +41,9 @@ We provide a few additional variants: | `depth-splatfacto` | Default Model, Depth Supervision | ~6GB | Fast | | `splatfacto-big` | More Gaussians, Higher Quality | ~12GB | Slower | -A full evaluation of Nerfstudio's implementation of Gaussian Splatting against the original Inria method can be found [here](https://docs.gsplat.studio/tests/eval.html). + +A full evalaution of Nerfstudio's implementation of Gaussian Splatting against the original Inria method can be found [here](https://docs.gsplat.studio/main/tests/eval.html). + #### Quality and Regularization diff --git a/docs/nerfology/methods/splatw.md b/docs/nerfology/methods/splatw.md new file mode 100644 index 0000000000..21c772a39e --- /dev/null +++ b/docs/nerfology/methods/splatw.md @@ -0,0 +1,51 @@ +# Splatfacto in the Wild + +This is the implementation of [Splatfacto in the Wild: A Nerfstudio Implementation of Gaussian Splatting for Unconstrained Photo Collections](https://kevinxu02.github.io/splatfactow). The official code can be found [here](https://github.com/KevinXu02/splatfacto-w). + + + +## Installation +This repository follows the nerfstudio method [template](https://github.com/nerfstudio-project/nerfstudio-method-template/tree/main) + +### 1. Install Nerfstudio dependencies +Please follow the Nerfstudio [installation guide](https://docs.nerf.studio/quickstart/installation.html) to create an environment and install dependencies. + +### 2. Install the repository +Run the following commands: +`pip install git+https://github.com/KevinXu02/splatfacto-w` + +Then, run `ns-install-cli`. + +### 3. Check installation +Run `ns-train splatfacto-w --help`. You should see the help message for the splatfacto-w method. + +## Downloading data +You can download the phototourism dataset from running. +``` +ns-download-data phototourism --capture-name +``` + +## Running Splafacto-w +To train with it, download the train/test tsv file from the bottom of [nerf-w](https://nerf-w.github.io/) and put it under the data folder (or copy them from `./splatfacto-w/dataset_split`). For instance, for Brandenburg Gate the path would be `your-data-folder/brandenburg_gate/brandenburg.tsv`. You should have the following structure in your data folder: +``` +|---brandenburg_gate +| |---dense +| | |---images +| | |---sparse +| | |---stereo +| |---brandenburg.tsv +``` + +Then, run the command: +``` +ns-train splatfacto-w --data [PATH] +``` + +If you want to train datasets without nerf-w's train/test split or your own datasets, we provided a light-weight version of the method for general cases. To train with it, you can run the following command +``` +ns-train splatfacto-w-light --data [PATH] [dataparser] +``` +For phototourism, the `dataparser` should be `colmap` and you need to change the colmap path through the CLI because phototourism dataparser does not load 3D points. \ No newline at end of file diff --git a/docs/quickstart/existing_dataset.md b/docs/quickstart/existing_dataset.md index 81ae0ebcd3..1aea4f2857 100644 --- a/docs/quickstart/existing_dataset.md +++ b/docs/quickstart/existing_dataset.md @@ -17,6 +17,7 @@ ns-download-data blender ns-download-data nerfstudio --capture-name nerfstudio-dataset # Download a few room-scale scenes from the EyefulTower dataset at different resolutions +pip install awscli # Install `awscli` for EyefulTower downloads. ns-download-data eyefultower --capture-name riverview seating_area apartment --resolution-name jpeg_1k jpeg_2k # Download the full D-NeRF dataset of dynamic synthetic scenes @@ -87,3 +88,10 @@ In the tables below, each dataset was placed into a bucket based on the table's [record3d]: https://record3d.app/ [sdfstudio]: https://github.com/autonomousvision/sdfstudio/blob/master/docs/sdfstudio-data.md#Existing-dataset [sitcoms3d]: https://github.com/ethanweber/sitcoms3D/blob/master/METADATA.md + +### Eyeful Tower +Downloading Eyeful Tower scenes requires installing the AWS CLI, an optional dependency. To do so, run: +```bash +conda activate nerfstudio +pip install awscli +``` diff --git a/docs/quickstart/installation.md b/docs/quickstart/installation.md index b7d2580ef9..d3472835f3 100644 --- a/docs/quickstart/installation.md +++ b/docs/quickstart/installation.md @@ -14,6 +14,8 @@ Install [Git](https://git-scm.com/downloads). Install Visual Studio 2022. This must be done before installing CUDA. The necessary components are included in the `Desktop Development with C++` workflow (also called `C++ Build Tools` in the BuildTools edition). +Install Visual Studio Build Tools. If MSVC 143 does not work (usually will fail if your version > 17.10), you may also need to install MSVC 142 for Visual Studio 2019. Ensure your CUDA environment is set up properly. + Nerfstudio requires `python >= 3.8`. We recommend using conda to manage dependencies. Make sure to install [Conda](https://docs.conda.io/en/latest/miniconda.html) before proceeding. ::::: @@ -76,14 +78,55 @@ conda install -c "nvidia/label/cuda-11.7.1" cuda-toolkit ::: :::: -### tiny-cuda-nn +### Install tiny-cuda-nn/gsplat -After pytorch and ninja, install the torch bindings for tiny-cuda-nn: +::::::{tab-set} +:::::{tab-item} Linux +After pytorch and ninja, install the torch bindings for tiny-cuda-nn: ```bash pip install ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch ``` +::::: +:::::{tab-item} Windows + +Activate your Visual C++ environment: +Navigate to the directory where `vcvars64.bat` is located. This path might vary depending on your installation. A common path is: + +``` +C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Auxiliary\Build +``` + +Run the following command: +```bash +./vcvars64.bat +``` + +If the above command does not work, try activating an older version of VC: +```bash +./vcvarsall.bat x64 -vcvars_ver= +``` +Replace `` with the version of your VC++ compiler toolset. The version number should appear in the same folder. + +For example: +```bash +./vcvarsall.bat x64 -vcvars_ver=14.29 +``` + +Install `gsplat` from source: +```bash +pip install git+https://github.com/nerfstudio-project/gsplat.git +``` + +Install the torch bindings for tiny-cuda-nn: +```bash +pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch +``` + +::::: +:::::: + ## Installing nerfstudio **From pip** @@ -137,19 +180,33 @@ curl -fsSL https://pixi.sh/install.sh | bash ### Install Pixi Environmnent After Pixi is installed, you can run ```bash +git clone https://github.com/nerfstudio-project/nerfstudio.git +cd nerfstudio pixi run post-install pixi shell ``` -This will install all enviroment dependancies including colmap, tinycudann and hloc, and the active the conda environment +This will fetch the latest Nerfstudio code, install all enviroment dependencies including colmap, tinycudann and hloc, and then activate the pixi environment (similar to conda). +From now on, each time you want to run Nerfstudio in a new shell, you have to navigate to the nerfstudio folder and run `pixi shell` again. -you could also run +You could also run ```bash pixi run post-install pixi run train-example-nerf ``` -to download an example dataset and run nerfacto straight away +to download an example dataset and run nerfacto straight away. + +Note that this method gets you the very latest upstream Nerfstudio version, if you want to use a specific release, you have to first checkout a specific version or commit in the nerfstudio folder, i.e.: +``` +git checkout tags/v1.1.3 +``` + +Similarly, if you want to update, you want to update the git repo in your nerfstudio folder: +``` +git pull +``` +Remember that if you ran a checkout on a specific tag before, you have to manually specify a new tag or `git checkout main` to see the new changes. ## Use docker image diff --git a/nerfstudio/cameras/cameras.py b/nerfstudio/cameras/cameras.py index 614cd4a0d2..cc04b17876 100644 --- a/nerfstudio/cameras/cameras.py +++ b/nerfstudio/cameras/cameras.py @@ -15,6 +15,7 @@ """ Camera Models """ + import base64 import math from dataclasses import dataclass diff --git a/nerfstudio/cameras/lie_groups.py b/nerfstudio/cameras/lie_groups.py index bba85d4dea..c47ee1119f 100644 --- a/nerfstudio/cameras/lie_groups.py +++ b/nerfstudio/cameras/lie_groups.py @@ -15,6 +15,7 @@ """ Helper for Lie group operations. Currently only used for pose optimization. """ + import torch from jaxtyping import Float from torch import Tensor diff --git a/nerfstudio/cameras/rays.py b/nerfstudio/cameras/rays.py index 33476f92eb..a9c38d8e61 100644 --- a/nerfstudio/cameras/rays.py +++ b/nerfstudio/cameras/rays.py @@ -15,6 +15,7 @@ """ Some ray datastructures. """ + import random from dataclasses import dataclass, field from typing import Callable, Dict, Literal, Optional, Tuple, Union, overload @@ -153,15 +154,13 @@ def get_weights(self, densities: Float[Tensor, "*batch num_samples 1"]) -> Float @staticmethod def get_weights_and_transmittance_from_alphas( alphas: Float[Tensor, "*batch num_samples 1"], weights_only: Literal[True] - ) -> Float[Tensor, "*batch num_samples 1"]: - ... + ) -> Float[Tensor, "*batch num_samples 1"]: ... @overload @staticmethod def get_weights_and_transmittance_from_alphas( alphas: Float[Tensor, "*batch num_samples 1"], weights_only: Literal[False] = False - ) -> Tuple[Float[Tensor, "*batch num_samples 1"], Float[Tensor, "*batch num_samples 1"]]: - ... + ) -> Tuple[Float[Tensor, "*batch num_samples 1"], Float[Tensor, "*batch num_samples 1"]]: ... @staticmethod def get_weights_and_transmittance_from_alphas( diff --git a/nerfstudio/configs/base_config.py b/nerfstudio/configs/base_config.py index f5ffddda90..e7f71d8bab 100644 --- a/nerfstudio/configs/base_config.py +++ b/nerfstudio/configs/base_config.py @@ -14,7 +14,6 @@ """Base Configs""" - from __future__ import annotations from dataclasses import dataclass, field diff --git a/nerfstudio/configs/external_methods.py b/nerfstudio/configs/external_methods.py index 8bad790eea..b25a9f16a6 100644 --- a/nerfstudio/configs/external_methods.py +++ b/nerfstudio/configs/external_methods.py @@ -268,6 +268,21 @@ class ExternalMethod: ) ) +# Splatfacto-W +external_methods.append( + ExternalMethod( + """[bold yellow]Splatfacto-W[/bold yellow] +For more information visit: https://docs.nerf.studio/nerfology/methods/splatw.html + +To enable Splatfacto-W, you must install it first by running: + [grey]pip install git+https://github.com/KevinXu02/splatfacto-w"[/grey]""", + configurations=[ + ("splatfacto-w", "Splatfacto in the wild"), + ], + pip_package="git+https://github.com/KevinXu02/splatfacto-w", + ) +) + @dataclass class ExternalMethodDummyTrainerConfig: diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index 877ef7a9b0..dca5d7412c 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -69,6 +69,7 @@ method_configs: Dict[str, Union[TrainerConfig, ExternalMethodDummyTrainerConfig]] = {} descriptions = { "nerfacto": "Recommended real-time model tuned for real captures. This model will be continually updated.", + "nerfacto-huge": "Larger version of Nerfacto with higher quality.", "depth-nerfacto": "Nerfacto with depth supervision.", "instant-ngp": "Implementation of Instant-NGP. Recommended real-time model for unbounded scenes.", "instant-ngp-bounded": "Implementation of Instant-NGP. Recommended for bounded real and synthetic scenes", @@ -83,6 +84,7 @@ "neus-facto": "Implementation of NeuS-Facto. (slow)", "splatfacto": "Gaussian Splatting model", "depth-splatfacto": "Depth supervised Gaussian Splatting model", + "splatfacto-big": "Larger version of Splatfacto with higher quality.", } method_configs["nerfacto"] = TrainerConfig( @@ -301,8 +303,6 @@ viewer=ViewerConfig(num_rays_per_chunk=1 << 12), vis="viewer", ) -# -# method_configs["mipnerf"] = TrainerConfig( method_name="mipnerf", pipeline=VanillaPipelineConfig( diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 0b723ea8d1..a67d492992 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -30,6 +30,7 @@ from typing import Dict, ForwardRef, Generic, List, Literal, Optional, Tuple, Type, Union, cast, get_args, get_origin import cv2 +import fpsample import numpy as np import torch from rich.progress import track @@ -62,12 +63,22 @@ class FullImageDatamanagerConfig(DataManagerConfig): new images. If -1, never pick new images.""" eval_image_indices: Optional[Tuple[int, ...]] = (0,) """Specifies the image indices to use during eval; if None, uses all.""" - cache_images: Literal["cpu", "gpu"] = "cpu" + cache_images: Literal["cpu", "gpu"] = "gpu" """Whether to cache images in memory. If "cpu", caches on cpu. If "gpu", caches on device.""" cache_images_type: Literal["uint8", "float32"] = "float32" """The image type returned from manager, caching images in uint8 saves memory""" max_thread_workers: Optional[int] = None """The maximum number of threads to use for caching images. If None, uses all available threads.""" + train_cameras_sampling_strategy: Literal["random", "fps"] = "random" + """Specifies which sampling strategy is used to generate train cameras, 'random' means sampling + uniformly random without replacement, 'fps' means farthest point sampling which is helpful to reduce the artifacts + due to oversampling subsets of cameras that are very close to each other.""" + train_cameras_sampling_seed: int = 42 + """Random seed for sampling train cameras. Fixing seed may help reduce variance of trained models across + different runs.""" + fps_reset_every: int = 100 + """The number of iterations before one resets fps sampler repeatly, which is essentially drawing fps_reset_every + samples from the pool of all training cameras without replacement before a new round of sampling starts.""" class FullImageDatamanager(DataManager, Generic[TDataset]): @@ -123,12 +134,48 @@ def __init__( self.exclude_batch_keys_from_device.remove("image") # Some logic to make sure we sample every camera in equal amounts - self.train_unseen_cameras = [i for i in range(len(self.train_dataset))] + self.train_unseen_cameras = self.sample_train_cameras() self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] assert len(self.train_unseen_cameras) > 0, "No data found in dataset" super().__init__() + def sample_train_cameras(self): + """Return a list of camera indices sampled using the strategy specified by + self.config.train_cameras_sampling_strategy""" + num_train_cameras = len(self.train_dataset) + if self.config.train_cameras_sampling_strategy == "random": + if not hasattr(self, "random_generator"): + self.random_generator = random.Random(self.config.train_cameras_sampling_seed) + indices = list(range(num_train_cameras)) + self.random_generator.shuffle(indices) + return indices + elif self.config.train_cameras_sampling_strategy == "fps": + if not hasattr(self, "train_unsampled_epoch_count"): + np.random.seed(self.config.train_cameras_sampling_seed) # fix random seed of fpsample + self.train_unsampled_epoch_count = np.zeros(num_train_cameras) + camera_origins = self.train_dataset.cameras.camera_to_worlds[..., 3].numpy() + # We concatenate camera origins with weighted train_unsampled_epoch_count because we want to + # increase the chance to sample camera that hasn't been sampled in consecutive epochs previously. + # We assume the camera origins are also rescaled, so the weight 0.1 is relative to the scale of scene + data = np.concatenate( + (camera_origins, 0.1 * np.expand_dims(self.train_unsampled_epoch_count, axis=-1)), axis=-1 + ) + n = self.config.fps_reset_every + if num_train_cameras < n: + CONSOLE.log( + f"num_train_cameras={num_train_cameras} is smaller than fps_reset_ever={n}, the behavior of " + "camera sampler will be very similar to sampling random without replacement (default setting)." + ) + n = num_train_cameras + kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(data, n, h=3) + + self.train_unsampled_epoch_count += 1 + self.train_unsampled_epoch_count[kdline_fps_samples_idx] = 0 + return kdline_fps_samples_idx.tolist() + else: + raise ValueError(f"Unknown train camera sampling strategy: {self.config.train_cameras_sampling_strategy}") + @cached_property def cached_train(self) -> List[Dict[str, torch.Tensor]]: """Get the training images. Will load and undistort the images the @@ -200,11 +247,15 @@ def undistort_idx(idx: int) -> Dict[str, torch.Tensor]: cache["image"] = cache["image"].to(self.device) if "mask" in cache: cache["mask"] = cache["mask"].to(self.device) + if "depth" in cache: + cache["depth"] = cache["depth"].to(self.device) + self.train_cameras = self.train_dataset.cameras.to(self.device) elif cache_images_device == "cpu": for cache in undistorted_images: cache["image"] = cache["image"].pin_memory() if "mask" in cache: cache["mask"] = cache["mask"].pin_memory() + self.train_cameras = self.train_dataset.cameras else: assert_never(cache_images_device) @@ -288,16 +339,16 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]: """Returns the next training batch Returns a Camera instead of raybundle""" - image_idx = self.train_unseen_cameras.pop(random.randint(0, len(self.train_unseen_cameras) - 1)) + image_idx = self.train_unseen_cameras.pop(0) # Make sure to re-populate the unseen cameras list if we have exhausted it if len(self.train_unseen_cameras) == 0: - self.train_unseen_cameras = [i for i in range(len(self.train_dataset))] + self.train_unseen_cameras = self.sample_train_cameras() - data = deepcopy(self.cached_train[image_idx]) + data = self.cached_train[image_idx] data["image"] = data["image"].to(self.device) - assert len(self.train_dataset.cameras.shape) == 1, "Assumes single batch dimension" - camera = self.train_dataset.cameras[image_idx : image_idx + 1].to(self.device) + assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension" + camera = self.train_cameras[image_idx : image_idx + 1].to(self.device) if camera.metadata is None: camera.metadata = {} camera.metadata["cam_idx"] = image_idx diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index b28e530f91..0b56c988c0 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -15,6 +15,7 @@ """ Parallel data manager that generates training data in multiple python processes. """ + from __future__ import annotations import concurrent.futures diff --git a/nerfstudio/data/dataparsers/arkitscenes_dataparser.py b/nerfstudio/data/dataparsers/arkitscenes_dataparser.py index 8dbd7d8bb9..484424f3d6 100644 --- a/nerfstudio/data/dataparsers/arkitscenes_dataparser.py +++ b/nerfstudio/data/dataparsers/arkitscenes_dataparser.py @@ -13,6 +13,7 @@ # limitations under the License. """Data parser for ARKitScenes dataset""" + import math from dataclasses import dataclass, field from pathlib import Path diff --git a/nerfstudio/data/dataparsers/blender_dataparser.py b/nerfstudio/data/dataparsers/blender_dataparser.py index 1027d7504d..fa57f4d42d 100644 --- a/nerfstudio/data/dataparsers/blender_dataparser.py +++ b/nerfstudio/data/dataparsers/blender_dataparser.py @@ -13,6 +13,7 @@ # limitations under the License. """Data parser for blender dataset""" + from __future__ import annotations from dataclasses import dataclass, field diff --git a/nerfstudio/data/dataparsers/colmap_dataparser.py b/nerfstudio/data/dataparsers/colmap_dataparser.py index b96d9b830f..a709361807 100644 --- a/nerfstudio/data/dataparsers/colmap_dataparser.py +++ b/nerfstudio/data/dataparsers/colmap_dataparser.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Data parser for nerfstudio datasets. """ +"""Data parser for nerfstudio datasets.""" from __future__ import annotations @@ -478,12 +478,13 @@ def calculate_scaled_size(original_width, original_height, downscale_factor, mod with status(msg="[bold yellow]Downscaling images...", spinner="growVertical"): assert downscale_factor > 1 assert isinstance(downscale_factor, int) - filepath = next(iter(paths)) - img = Image.open(filepath) - w, h = img.size - w_scaled, h_scaled = calculate_scaled_size(w, h, downscale_factor, downscale_rounding_mode) # Using %05d ffmpeg commands appears to be unreliable (skips images). for path in paths: + # Compute image-wise rescaled width/height. + img = Image.open(path) + w, h = img.size + w_scaled, h_scaled = calculate_scaled_size(w, h, downscale_factor, downscale_rounding_mode) + # Downscale images using ffmpeg. nn_flag = "" if not nearest_neighbor else ":flags=neighbor" path_out = get_fname(path) path_out.parent.mkdir(parents=True, exist_ok=True) diff --git a/nerfstudio/data/dataparsers/dnerf_dataparser.py b/nerfstudio/data/dataparsers/dnerf_dataparser.py index 478b1fc38c..7e890edff8 100644 --- a/nerfstudio/data/dataparsers/dnerf_dataparser.py +++ b/nerfstudio/data/dataparsers/dnerf_dataparser.py @@ -13,6 +13,7 @@ # limitations under the License. """Data parser for blender dataset""" + from __future__ import annotations from dataclasses import dataclass, field diff --git a/nerfstudio/data/dataparsers/dycheck_dataparser.py b/nerfstudio/data/dataparsers/dycheck_dataparser.py index 90f5e9e978..c92cad339b 100644 --- a/nerfstudio/data/dataparsers/dycheck_dataparser.py +++ b/nerfstudio/data/dataparsers/dycheck_dataparser.py @@ -13,6 +13,7 @@ # limitations under the License. """Data parser for DyCheck (https://arxiv.org/abs/2210.13445) dataset of `iphone` subset""" + from __future__ import annotations import math diff --git a/nerfstudio/data/dataparsers/nerfosr_dataparser.py b/nerfstudio/data/dataparsers/nerfosr_dataparser.py index db884cb736..a1ad876810 100644 --- a/nerfstudio/data/dataparsers/nerfosr_dataparser.py +++ b/nerfstudio/data/dataparsers/nerfosr_dataparser.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Data parser for NeRF-OSR datasets +"""Data parser for NeRF-OSR datasets - Presented in the paper: https://4dqv.mpi-inf.mpg.de/NeRF-OSR/ +Presented in the paper: https://4dqv.mpi-inf.mpg.de/NeRF-OSR/ """ diff --git a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py index 910e0bc15e..e11902c094 100644 --- a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py +++ b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Data parser for nerfstudio datasets. """ +"""Data parser for nerfstudio datasets.""" from __future__ import annotations diff --git a/nerfstudio/data/dataparsers/nuscenes_dataparser.py b/nerfstudio/data/dataparsers/nuscenes_dataparser.py index f1fecc8889..69ec263153 100644 --- a/nerfstudio/data/dataparsers/nuscenes_dataparser.py +++ b/nerfstudio/data/dataparsers/nuscenes_dataparser.py @@ -13,6 +13,7 @@ # limitations under the License. """Data parser for NuScenes dataset""" + import math import os from dataclasses import dataclass, field diff --git a/nerfstudio/data/dataparsers/phototourism_dataparser.py b/nerfstudio/data/dataparsers/phototourism_dataparser.py index 538bad2f0a..83c35bbb40 100644 --- a/nerfstudio/data/dataparsers/phototourism_dataparser.py +++ b/nerfstudio/data/dataparsers/phototourism_dataparser.py @@ -13,6 +13,7 @@ # limitations under the License. """Phototourism dataset parser. Datasets and documentation here: http://phototour.cs.washington.edu/datasets/""" + from __future__ import annotations import math diff --git a/nerfstudio/data/dataparsers/scannet_dataparser.py b/nerfstudio/data/dataparsers/scannet_dataparser.py index eb0c9ae19b..1a780e6875 100644 --- a/nerfstudio/data/dataparsers/scannet_dataparser.py +++ b/nerfstudio/data/dataparsers/scannet_dataparser.py @@ -13,6 +13,7 @@ # limitations under the License. """Data parser for ScanNet dataset""" + import math from dataclasses import dataclass, field from pathlib import Path @@ -42,11 +43,12 @@ class ScanNetDataParserConfig(DataParserConfig): ├── depth/ ├── intrinsic/ ├── pose/ + |── ply/ """ _target: Type = field(default_factory=lambda: ScanNet) """target class to instantiate""" - data: Path = Path("data/scannet/scene0423_02") + data: Path = Path("./nvsmask3d/data/scene_example") """Path to ScanNet folder with densely extracted scenes.""" scale_factor: float = 1.0 """How much to scale the camera origins by.""" @@ -60,6 +62,12 @@ class ScanNetDataParserConfig(DataParserConfig): """The fraction of images to use for training. The remaining images are for eval.""" depth_unit_scale_factor: float = 1e-3 """Scales the depth values to meters. Default value is 0.001 for a millimeter to meter conversion.""" + load_3D_points: bool = True + """Whether to load the 3D points from the .ply""" + point_cloud_color: bool = True + """read point cloud colors from .ply files or not """ + ply_file_path: Path = data / (data.name + ".ply") + """path to the .ply file containing the 3D points""" @dataclass @@ -158,15 +166,70 @@ def _generate_dataparser_outputs(self, split="train"): camera_type=CameraType.PERSPECTIVE, ) + metadata = { + "depth_filenames": depth_filenames if len(depth_filenames) > 0 else None, + "depth_unit_scale_factor": self.config.depth_unit_scale_factor, + } + + if self.config.load_3D_points: + point_color = self.config.point_cloud_color + ply_file_path = self.config.ply_file_path + point_cloud_data = self._load_3D_points(ply_file_path, transform_matrix, scale_factor, point_color) + if point_cloud_data is not None: + metadata.update(point_cloud_data) + dataparser_outputs = DataparserOutputs( image_filenames=image_filenames, cameras=cameras, scene_box=scene_box, dataparser_scale=scale_factor, dataparser_transform=transform_matrix, - metadata={ - "depth_filenames": depth_filenames if len(depth_filenames) > 0 else None, - "depth_unit_scale_factor": self.config.depth_unit_scale_factor, - }, + metadata=metadata, ) return dataparser_outputs + + def _load_3D_points( + self, ply_file_path: Path, transform_matrix: torch.Tensor, scale_factor: float, points_color: bool + ) -> dict: + """Loads point clouds positions and colors from .ply + + Args: + ply_file_path: Path to .ply file + transform_matrix: Matrix to transform world coordinates + scale_factor: How much to scale the camera origins by. + points_color: Whether to load the point cloud colors or not + + Returns: + A dictionary of points: points3D_xyz and colors: points3D_rgb + or + A dictionary of points: points3D_xyz if points_color is False + """ + import open3d as o3d # Importing open3d is slow, so we only do it if we need it. + + pcd = o3d.io.read_point_cloud(str(ply_file_path)) + + # if no points found don't read in an initial point cloud + if len(pcd.points) == 0: + return {} + + points3D = torch.from_numpy(np.asarray(pcd.points, dtype=np.float32)) + points3D = ( + torch.cat( + ( + points3D, + torch.ones_like(points3D[..., :1]), + ), + -1, + ) + @ transform_matrix.T + ) + points3D *= scale_factor + out = { + "points3D_xyz": points3D, + } + + if points_color: + points3D_rgb = torch.from_numpy((np.asarray(pcd.colors) * 255).astype(np.uint8)) + out["points3D_rgb"] = points3D_rgb + + return out diff --git a/nerfstudio/data/dataparsers/scannetpp_dataparser.py b/nerfstudio/data/dataparsers/scannetpp_dataparser.py index 18afa4f423..3eed922e07 100644 --- a/nerfstudio/data/dataparsers/scannetpp_dataparser.py +++ b/nerfstudio/data/dataparsers/scannetpp_dataparser.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Data parser for ScanNet++ datasets. """ +"""Data parser for ScanNet++ datasets.""" from __future__ import annotations diff --git a/nerfstudio/data/datasets/base_dataset.py b/nerfstudio/data/datasets/base_dataset.py index e16ea33482..449720bd36 100644 --- a/nerfstudio/data/datasets/base_dataset.py +++ b/nerfstudio/data/datasets/base_dataset.py @@ -15,6 +15,7 @@ """ Dataset. """ + from __future__ import annotations from copy import deepcopy diff --git a/nerfstudio/data/pixel_samplers.py b/nerfstudio/data/pixel_samplers.py index ad11ee4094..080b4016ad 100644 --- a/nerfstudio/data/pixel_samplers.py +++ b/nerfstudio/data/pixel_samplers.py @@ -337,7 +337,7 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int, collated_batch = { key: value[c, y, x] for key, value in batch.items() - if key != "image_idx" and key != "image" and key != "mask" and key != "depth_image" and value is not None + if key not in ("image_idx", "image", "mask", "depth_image") and value is not None } collated_batch["image"] = torch.cat(all_images, dim=0) diff --git a/nerfstudio/data/utils/colmap_parsing_utils.py b/nerfstudio/data/utils/colmap_parsing_utils.py index 818798ec89..0eeaf6911b 100644 --- a/nerfstudio/data/utils/colmap_parsing_utils.py +++ b/nerfstudio/data/utils/colmap_parsing_utils.py @@ -235,11 +235,12 @@ def read_images_binary(path_to_model_file): qvec = np.array(binary_image_properties[1:5]) tvec = np.array(binary_image_properties[5:8]) camera_id = binary_image_properties[8] - image_name = "" + image_name = b"" current_char = read_next_bytes(fid, 1, "c")[0] while current_char != b"\x00": # look for the ASCII 0 entry - image_name += current_char.decode("utf-8") + image_name += current_char current_char = read_next_bytes(fid, 1, "c")[0] + image_name = image_name.decode("utf-8") num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0] x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D, format_char_sequence="ddq" * num_points2D) xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))]) diff --git a/nerfstudio/data/utils/data_utils.py b/nerfstudio/data/utils/data_utils.py index c81101c4f3..11ce74d9da 100644 --- a/nerfstudio/data/utils/data_utils.py +++ b/nerfstudio/data/utils/data_utils.py @@ -13,6 +13,7 @@ # limitations under the License. """Utility functions to allow easy re-use of common operations across dataloaders""" + from pathlib import Path from typing import List, Tuple, Union diff --git a/nerfstudio/data/utils/dataparsers_utils.py b/nerfstudio/data/utils/dataparsers_utils.py index b48323f21e..0c79cbde18 100644 --- a/nerfstudio/data/utils/dataparsers_utils.py +++ b/nerfstudio/data/utils/dataparsers_utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Data parser utils for nerfstudio datasets. """ +"""Data parser utils for nerfstudio datasets.""" import math import os diff --git a/nerfstudio/data/utils/nerfstudio_collate.py b/nerfstudio/data/utils/nerfstudio_collate.py index 8c8a633fb8..b5b391f543 100644 --- a/nerfstudio/data/utils/nerfstudio_collate.py +++ b/nerfstudio/data/utils/nerfstudio_collate.py @@ -101,7 +101,7 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N storage = elem.storage()._new_shared(numel, device=elem.device) out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out) - elif elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_": + elif elem_type.__module__ == "numpy" and elem_type.__name__ not in ("str_", "string_"): if elem_type.__name__ in ("ndarray", "memmap"): # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: diff --git a/nerfstudio/engine/callbacks.py b/nerfstudio/engine/callbacks.py index f37cc75011..776fd829b6 100644 --- a/nerfstudio/engine/callbacks.py +++ b/nerfstudio/engine/callbacks.py @@ -15,6 +15,7 @@ """ Callback code used for training iterations """ + from __future__ import annotations from dataclasses import dataclass diff --git a/nerfstudio/engine/optimizers.py b/nerfstudio/engine/optimizers.py index 0a87947863..830ac9a3ad 100644 --- a/nerfstudio/engine/optimizers.py +++ b/nerfstudio/engine/optimizers.py @@ -15,6 +15,7 @@ """ Optimizers class. """ + from __future__ import annotations from dataclasses import dataclass diff --git a/nerfstudio/engine/trainer.py b/nerfstudio/engine/trainer.py index 0634f62e53..871e383757 100644 --- a/nerfstudio/engine/trainer.py +++ b/nerfstudio/engine/trainer.py @@ -15,6 +15,7 @@ """ Code to train model. """ + from __future__ import annotations import dataclasses diff --git a/nerfstudio/exporter/exporter_utils.py b/nerfstudio/exporter/exporter_utils.py index bad46f517c..b87078bc9c 100644 --- a/nerfstudio/exporter/exporter_utils.py +++ b/nerfstudio/exporter/exporter_utils.py @@ -16,7 +16,6 @@ Export utils such as structs, point cloud generation, and rendering code. """ - from __future__ import annotations import sys diff --git a/nerfstudio/exporter/texture_utils.py b/nerfstudio/exporter/texture_utils.py index 120696752e..5859fcb477 100644 --- a/nerfstudio/exporter/texture_utils.py +++ b/nerfstudio/exporter/texture_utils.py @@ -16,7 +16,6 @@ Texture utils. """ - from __future__ import annotations import math diff --git a/nerfstudio/exporter/tsdf_utils.py b/nerfstudio/exporter/tsdf_utils.py index 301591316d..9be7b3d9fe 100644 --- a/nerfstudio/exporter/tsdf_utils.py +++ b/nerfstudio/exporter/tsdf_utils.py @@ -16,7 +16,6 @@ TSDF utils. """ - from __future__ import annotations from dataclasses import dataclass, field diff --git a/nerfstudio/field_components/__init__.py b/nerfstudio/field_components/__init__.py index 5f99dbcc8b..365718aba8 100644 --- a/nerfstudio/field_components/__init__.py +++ b/nerfstudio/field_components/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. """init field modules""" + from .base_field_component import FieldComponent as FieldComponent from .encodings import Encoding as Encoding, ScalingAndOffset as ScalingAndOffset from .mlp import MLP as MLP diff --git a/nerfstudio/field_components/base_field_component.py b/nerfstudio/field_components/base_field_component.py index 8d84b782b5..c961136c8c 100644 --- a/nerfstudio/field_components/base_field_component.py +++ b/nerfstudio/field_components/base_field_component.py @@ -15,6 +15,7 @@ """ The field module baseclass. """ + from abc import abstractmethod from typing import Optional diff --git a/nerfstudio/field_components/embedding.py b/nerfstudio/field_components/embedding.py index cc934b0994..4b568bbcb8 100644 --- a/nerfstudio/field_components/embedding.py +++ b/nerfstudio/field_components/embedding.py @@ -16,7 +16,6 @@ Code for embeddings. """ - import torch from jaxtyping import Shaped from torch import Tensor diff --git a/nerfstudio/field_components/field_heads.py b/nerfstudio/field_components/field_heads.py index 44df9f6021..7c86d7c5fa 100644 --- a/nerfstudio/field_components/field_heads.py +++ b/nerfstudio/field_components/field_heads.py @@ -15,6 +15,7 @@ """ Collection of render heads """ + from enum import Enum from typing import Callable, Optional, Union diff --git a/nerfstudio/field_components/mlp.py b/nerfstudio/field_components/mlp.py index f1df2baa33..c27c7c996d 100644 --- a/nerfstudio/field_components/mlp.py +++ b/nerfstudio/field_components/mlp.py @@ -15,6 +15,7 @@ """ Multi Layer Perceptron """ + from typing import Literal, Optional, Set, Tuple, Union import numpy as np diff --git a/nerfstudio/fields/density_fields.py b/nerfstudio/fields/density_fields.py index d4f6f28f6d..e8eba7b9d2 100644 --- a/nerfstudio/fields/density_fields.py +++ b/nerfstudio/fields/density_fields.py @@ -16,7 +16,6 @@ Proposal network field. """ - from typing import Literal, Optional, Tuple import torch diff --git a/nerfstudio/fields/generfacto_field.py b/nerfstudio/fields/generfacto_field.py index ee31bcb0b6..d2597d9a51 100644 --- a/nerfstudio/fields/generfacto_field.py +++ b/nerfstudio/fields/generfacto_field.py @@ -16,7 +16,6 @@ Field for Generfacto model """ - from typing import Dict, Literal, Optional, Tuple import numpy as np diff --git a/nerfstudio/fields/nerfacto_field.py b/nerfstudio/fields/nerfacto_field.py index 52a4de4b34..5b3af8b55f 100644 --- a/nerfstudio/fields/nerfacto_field.py +++ b/nerfstudio/fields/nerfacto_field.py @@ -16,7 +16,6 @@ Field for compound nerf model, adds scene contraction and image embeddings to instant ngp """ - from typing import Dict, Literal, Optional, Tuple import torch diff --git a/nerfstudio/fields/semantic_nerf_field.py b/nerfstudio/fields/semantic_nerf_field.py index 62b75094f4..d36174472e 100644 --- a/nerfstudio/fields/semantic_nerf_field.py +++ b/nerfstudio/fields/semantic_nerf_field.py @@ -15,6 +15,7 @@ """ Semantic NeRF field implementation. """ + from typing import Dict, Optional, Tuple import torch diff --git a/nerfstudio/fields/tensorf_field.py b/nerfstudio/fields/tensorf_field.py index b3aba4e542..a6c5b6c9e2 100644 --- a/nerfstudio/fields/tensorf_field.py +++ b/nerfstudio/fields/tensorf_field.py @@ -14,7 +14,6 @@ """TensoRF Field""" - from typing import Dict, Optional import torch diff --git a/nerfstudio/fields/vanilla_nerf_field.py b/nerfstudio/fields/vanilla_nerf_field.py index efb3589ebe..02565fbd94 100644 --- a/nerfstudio/fields/vanilla_nerf_field.py +++ b/nerfstudio/fields/vanilla_nerf_field.py @@ -14,7 +14,6 @@ """Classic NeRF field""" - from typing import Dict, Optional, Tuple, Type import torch diff --git a/nerfstudio/model_components/losses.py b/nerfstudio/model_components/losses.py index 8f3ab17dbb..5c66a39e1c 100644 --- a/nerfstudio/model_components/losses.py +++ b/nerfstudio/model_components/losses.py @@ -15,6 +15,7 @@ """ Collection of Losses. """ + from enum import Enum from typing import Dict, Literal, Optional, Tuple, cast diff --git a/nerfstudio/model_components/ray_generators.py b/nerfstudio/model_components/ray_generators.py index fab9e39bba..33ab2c8d1c 100644 --- a/nerfstudio/model_components/ray_generators.py +++ b/nerfstudio/model_components/ray_generators.py @@ -15,6 +15,7 @@ """ Ray generator. """ + from jaxtyping import Int from torch import Tensor, nn diff --git a/nerfstudio/model_components/ray_samplers.py b/nerfstudio/model_components/ray_samplers.py index 7a2052b639..e0d15db924 100644 --- a/nerfstudio/model_components/ray_samplers.py +++ b/nerfstudio/model_components/ray_samplers.py @@ -379,8 +379,7 @@ class DensityFn(Protocol): def __call__( self, positions: Float[Tensor, "*batch 3"], times: Optional[Float[Tensor, "*batch 1"]] = None - ) -> Float[Tensor, "*batch 1"]: - ... + ) -> Float[Tensor, "*batch 1"]: ... class VolumetricSampler(Sampler): diff --git a/nerfstudio/model_components/renderers.py b/nerfstudio/model_components/renderers.py index 1fde0d693c..1f3ee17499 100644 --- a/nerfstudio/model_components/renderers.py +++ b/nerfstudio/model_components/renderers.py @@ -26,6 +26,7 @@ rgb = rgb_renderer(rgb=field_outputs[FieldHeadNames.RGB], weights=weights) """ + import contextlib import math from typing import Generator, Literal, Optional, Tuple, Union diff --git a/nerfstudio/model_components/shaders.py b/nerfstudio/model_components/shaders.py index f6cc227219..015dbfd41e 100644 --- a/nerfstudio/model_components/shaders.py +++ b/nerfstudio/model_components/shaders.py @@ -13,6 +13,7 @@ # limitations under the License. """Shaders for rendering.""" + from typing import Optional from jaxtyping import Float diff --git a/nerfstudio/models/mipnerf.py b/nerfstudio/models/mipnerf.py index e033681f73..503fb4b031 100644 --- a/nerfstudio/models/mipnerf.py +++ b/nerfstudio/models/mipnerf.py @@ -15,6 +15,7 @@ """ Implementation of mip-NeRF. """ + from __future__ import annotations from typing import Dict, List, Tuple diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 9b29eca629..37856e3891 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -14,35 +14,35 @@ # limitations under the License. """ -NeRF implementation that combines many recent advancements. +Gaussian Splatting implementation that combines many recent advancements. """ from __future__ import annotations import math from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Dict, List, Literal, Optional, Tuple, Type, Union import numpy as np import torch -from gsplat._torch_impl import quat_to_rotmat -from gsplat.project_gaussians import project_gaussians -from gsplat.rasterize import rasterize_gaussians -from gsplat.sh import num_sh_bases, spherical_harmonics +from gsplat.cuda_legacy._torch_impl import quat_to_rotmat + +try: + from gsplat.rendering import rasterization +except ImportError: + print("Please install gsplat>=1.0.0") +from gsplat.cuda_legacy._wrapper import num_sh_bases from pytorch_msssim import SSIM from torch.nn import Parameter -from typing_extensions import Literal from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig from nerfstudio.cameras.cameras import Cameras from nerfstudio.data.scene_box import OrientedBox from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation from nerfstudio.engine.optimizers import Optimizers - -# need following import for background color override -from nerfstudio.model_components import renderers from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils.colors import get_color +from nerfstudio.utils.misc import torch_compile from nerfstudio.utils.rich_utils import CONSOLE @@ -96,6 +96,25 @@ def resize_image(image: torch.Tensor, d: int): return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0) +@torch_compile() +def get_viewmat(optimized_camera_to_world): + """ + function that converts c2w to gsplat world2camera matrix, using compile for some speed + """ + R = optimized_camera_to_world[:, :3, :3] # 3 x 3 + T = optimized_camera_to_world[:, :3, 3:4] # 3 x 1 + # flip the z and y axes to align with gsplat conventions + R = R * torch.tensor([[[1, -1, -1]]], device=R.device, dtype=R.dtype) + # analytic matrix inverse to get world2camera matrix + R_inv = R.transpose(1, 2) + T_inv = -torch.bmm(R_inv, T) + viewmat = torch.zeros(R.shape[0], 4, 4, device=R.device, dtype=R.dtype) + viewmat[:, 3, 3] = 1.0 # homogenous + viewmat[:, :3, :3] = R_inv + viewmat[:, :3, 3:4] = T_inv + return viewmat + + @dataclass class SplatfactoModelConfig(ModelConfig): """Splatfacto Model Config, nerfstudio's implementation of Gaussian Splatting""" @@ -395,17 +414,14 @@ def after_train(self, step: int): with torch.no_grad(): # keep track of a moving average of grad norms visible_mask = (self.radii > 0).flatten() - assert self.xys.absgrad is not None # type: ignore - grads = self.xys.absgrad.detach().norm(dim=-1) # type: ignore + grads = self.xys.absgrad[0][visible_mask].norm(dim=-1) # type: ignore # print(f"grad norm min {grads.min().item()} max {grads.max().item()} mean {grads.mean().item()} size {grads.shape}") if self.xys_grad_norm is None: - self.xys_grad_norm = grads - self.vis_counts = torch.ones_like(self.xys_grad_norm) - else: - assert self.vis_counts is not None - self.vis_counts[visible_mask] = self.vis_counts[visible_mask] + 1 - self.xys_grad_norm[visible_mask] = grads[visible_mask] + self.xys_grad_norm[visible_mask] - + self.xys_grad_norm = torch.zeros(self.num_points, device=self.device, dtype=torch.float32) + self.vis_counts = torch.ones(self.num_points, device=self.device, dtype=torch.float32) + assert self.vis_counts is not None + self.vis_counts[visible_mask] += 1 + self.xys_grad_norm[visible_mask] += grads # update the max screen size, as a ratio of number of pixels if self.max_2Dsize is None: self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32) @@ -455,7 +471,6 @@ def refinement_after(self, optimizers: Optimizers, step): self.gauss_params[name] = torch.nn.Parameter( torch.cat([param.detach(), split_params[name], dup_params[name]], dim=0) ) - # append zeros to the max_2Dsize tensor self.max_2Dsize = torch.cat( [ @@ -529,8 +544,8 @@ def cull_gaussians(self, extra_cull_mask: Optional[torch.Tensor] = None): toobigs = (torch.exp(self.scales).max(dim=-1).values > self.config.cull_scale_thresh).squeeze() if self.step < self.config.stop_screen_size_at: # cull big screen space - assert self.max_2Dsize is not None - toobigs = toobigs | (self.max_2Dsize > self.config.cull_screen_size).squeeze() + if self.max_2Dsize is not None: + toobigs = toobigs | (self.max_2Dsize > self.config.cull_screen_size).squeeze() culls = culls | toobigs toobigs_count = torch.sum(toobigs).item() for name, param in self.gauss_params.items(): @@ -657,12 +672,26 @@ def get_empty_outputs(width: int, height: int, background: torch.Tensor) -> Dict accumulation = background.new_zeros(*rgb.shape[:2], 1) return {"rgb": rgb, "depth": depth, "accumulation": accumulation, "background": background} + def _get_background_color(self): + if self.config.background_color == "random": + if self.training: + background = torch.rand(3, device=self.device) + else: + background = self.background_color.to(self.device) + elif self.config.background_color == "white": + background = torch.ones(3, device=self.device) + elif self.config.background_color == "black": + background = torch.zeros(3, device=self.device) + else: + raise ValueError(f"Unknown background color {self.config.background_color}") + return background + def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: - """Takes in a Ray Bundle and returns a dictionary of outputs. + """Takes in a camera and returns a dictionary of outputs. Args: - ray_bundle: Input bundle of rays. This raybundle should have all the - needed information to compute the outputs. + camera: The camera(s) for which output images are rendered. It should have + all the needed information to compute the outputs. Returns: Outputs of model. (ie. rendered colors) @@ -670,52 +699,22 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: if not isinstance(camera, Cameras): print("Called get_outputs with not a camera") return {} - assert camera.shape[0] == 1, "Only one camera at a time" - - optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)[0, ...] - # get the background color if self.training: - if self.config.background_color == "random": - background = torch.rand(3, device=self.device) - elif self.config.background_color == "white": - background = torch.ones(3, device=self.device) - elif self.config.background_color == "black": - background = torch.zeros(3, device=self.device) - else: - background = self.background_color.to(self.device) + assert camera.shape[0] == 1, "Only one camera at a time" + optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera) else: - if renderers.BACKGROUND_COLOR_OVERRIDE is not None: - background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device) - else: - background = self.background_color.to(self.device) + optimized_camera_to_world = camera.camera_to_worlds + # cropping if self.crop_box is not None and not self.training: crop_ids = self.crop_box.within(self.means).squeeze() if crop_ids.sum() == 0: - return self.get_empty_outputs(int(camera.width.item()), int(camera.height.item()), background) + return self.get_empty_outputs( + int(camera.width.item()), int(camera.height.item()), self.background_color + ) else: crop_ids = None - camera_downscale = self._get_downscale_factor() - camera.rescale_output_resolution(1 / camera_downscale) - # shift the camera to center of scene looking at center - R = optimized_camera_to_world[:3, :3] # 3 x 3 - T = optimized_camera_to_world[:3, 3:4] # 3 x 1 - - # flip the z and y axes to align with gsplat conventions - R_edit = torch.diag(torch.tensor([1, -1, -1], device=self.device, dtype=R.dtype)) - R = R @ R_edit - # analytic matrix inverse to get world2camera matrix - R_inv = R.T - T_inv = -R_inv @ T - viewmat = torch.eye(4, device=R.device, dtype=R.dtype) - viewmat[:3, :3] = R_inv - viewmat[:3, 3:4] = T_inv - # calculate the FOV of the camera given fx and fy, width and height - cx = camera.cx.item() - cy = camera.cy.item() - W, H = int(camera.width.item()), int(camera.height.item()) - self.last_size = (H, W) if crop_ids is not None: opacities_crop = self.opacities[crop_ids] @@ -733,80 +732,78 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: quats_crop = self.quats colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1) - BLOCK_WIDTH = 16 # this controls the tile size of rasterization, 16 is a good default - self.xys, depths, self.radii, conics, comp, num_tiles_hit, cov3d = project_gaussians( # type: ignore - means_crop, - torch.exp(scales_crop), - 1, - quats_crop / quats_crop.norm(dim=-1, keepdim=True), - viewmat.squeeze()[:3, :], - camera.fx.item(), - camera.fy.item(), - cx, - cy, - H, - W, - BLOCK_WIDTH, - ) # type: ignore - - # rescale the camera back to original dimensions before returning - camera.rescale_output_resolution(camera_downscale) - - if (self.radii).sum() == 0: - return self.get_empty_outputs(W, H, background) - - if self.config.sh_degree > 0: - viewdirs = means_crop.detach() - optimized_camera_to_world.detach()[:3, 3] # (N, 3) - n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree) - rgbs = spherical_harmonics(n, viewdirs, colors_crop) # input unnormalized viewdirs - rgbs = torch.clamp(rgbs + 0.5, min=0.0) # type: ignore - else: - rgbs = torch.sigmoid(colors_crop[:, 0, :]) - assert (num_tiles_hit > 0).any() # type: ignore + BLOCK_WIDTH = 16 # this controls the tile size of rasterization, 16 is a good default + camera_scale_fac = self._get_downscale_factor() + camera.rescale_output_resolution(1 / camera_scale_fac) + viewmat = get_viewmat(optimized_camera_to_world) + K = camera.get_intrinsics_matrices().cuda() + W, H = int(camera.width.item()), int(camera.height.item()) + self.last_size = (H, W) + camera.rescale_output_resolution(camera_scale_fac) # type: ignore # apply the compensation of screen space blurring to gaussians - if self.config.rasterize_mode == "antialiased": - opacities = torch.sigmoid(opacities_crop) * comp[:, None] - elif self.config.rasterize_mode == "classic": - opacities = torch.sigmoid(opacities_crop) - else: + if self.config.rasterize_mode not in ["antialiased", "classic"]: raise ValueError("Unknown rasterize_mode: %s", self.config.rasterize_mode) - rgb, alpha = rasterize_gaussians( # type: ignore - self.xys, - depths, - self.radii, - conics, - num_tiles_hit, # type: ignore - rgbs, - opacities, - H, - W, - BLOCK_WIDTH, - background=background, - return_alpha=True, - ) # type: ignore - alpha = alpha[..., None] - rgb = torch.clamp(rgb, max=1.0) # type: ignore - depth_im = None if self.config.output_depth_during_training or not self.training: - depth_im = rasterize_gaussians( # type: ignore - self.xys, - depths, - self.radii, - conics, - num_tiles_hit, # type: ignore - depths[:, None].repeat(1, 3), - opacities, - H, - W, - BLOCK_WIDTH, - background=torch.zeros(3, device=self.device), - )[..., 0:1] # type: ignore - depth_im = torch.where(alpha > 0, depth_im / alpha, depth_im.detach().max()) - - return {"rgb": rgb, "depth": depth_im, "accumulation": alpha, "background": background} # type: ignore + render_mode = "RGB+ED" + else: + render_mode = "RGB" + + if self.config.sh_degree > 0: + sh_degree_to_use = min(self.step // self.config.sh_degree_interval, self.config.sh_degree) + else: + colors_crop = torch.sigmoid(colors_crop).squeeze(1) # [N, 1, 3] -> [N, 3] + sh_degree_to_use = None + + render, alpha, info = rasterization( + means=means_crop, + quats=quats_crop / quats_crop.norm(dim=-1, keepdim=True), + scales=torch.exp(scales_crop), + opacities=torch.sigmoid(opacities_crop).squeeze(-1), + colors=colors_crop, + viewmats=viewmat, # [1, 4, 4] + Ks=K, # [1, 3, 3] + width=W, + height=H, + tile_size=BLOCK_WIDTH, + packed=False, + near_plane=0.01, + far_plane=1e10, + render_mode=render_mode, + sh_degree=sh_degree_to_use, + sparse_grad=False, + absgrad=True, + rasterize_mode=self.config.rasterize_mode, + # set some threshold to disregrad small gaussians for faster rendering. + # radius_clip=3.0, + ) + if self.training and info["means2d"].requires_grad: + info["means2d"].retain_grad() + self.xys = info["means2d"] # [1, N, 2] + self.radii = info["radii"][0] # [N] + alpha = alpha[:, ...] + + background = self._get_background_color() + rgb = render[:, ..., :3] + (1 - alpha) * background + rgb = torch.clamp(rgb, 0.0, 1.0) + + if render_mode == "RGB+ED": + depth_im = render[:, ..., 3:4] + depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max()).squeeze(0) + else: + depth_im = None + + if background.shape[0] == 3 and not self.training: + background = background.expand(H, W, 3) + + return { + "rgb": rgb.squeeze(0), # type: ignore + "depth": depth_im, # type: ignore + "accumulation": alpha.squeeze(0), # type: ignore + "background": background, # type: ignore + } # type: ignore def get_gt_img(self, image: torch.Tensor): """Compute groundtruth image with iteration dependent downscale factor for evaluation purpose diff --git a/nerfstudio/pipelines/base_pipeline.py b/nerfstudio/pipelines/base_pipeline.py index 651a2c3008..20392dcb62 100644 --- a/nerfstudio/pipelines/base_pipeline.py +++ b/nerfstudio/pipelines/base_pipeline.py @@ -15,6 +15,7 @@ """ Abstracts for the Pipeline class. """ + from __future__ import annotations import typing @@ -34,9 +35,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from nerfstudio.configs.base_config import InstantiateConfig -from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig, VanillaDataManager -from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager -from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManager +from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils import profiler @@ -345,12 +344,19 @@ def get_eval_image_metrics_and_images(self, step: int): return metrics_dict, images_dict @profiler.time_function - def get_average_eval_image_metrics( - self, step: Optional[int] = None, output_path: Optional[Path] = None, get_std: bool = False + def get_average_image_metrics( + self, + data_loader, + image_prefix: str, + step: Optional[int] = None, + output_path: Optional[Path] = None, + get_std: bool = False, ): - """Iterate over all the images in the eval dataset and get the average. + """Iterate over all the images in the dataset and get the average. Args: + data_loader: the data loader to iterate over + image_prefix: prefix to use for the saved image filenames step: current training step output_path: optional path to save rendered images to get_std: Set True if you want to return std with the mean metric. @@ -360,8 +366,7 @@ def get_average_eval_image_metrics( """ self.eval() metrics_dict_list = [] - assert isinstance(self.datamanager, (VanillaDataManager, ParallelDataManager, FullImageDatamanager)) - num_images = len(self.datamanager.fixed_indices_eval_dataloader) + num_images = len(data_loader) if output_path is not None: output_path.mkdir(exist_ok=True, parents=True) with Progress( @@ -371,9 +376,9 @@ def get_average_eval_image_metrics( MofNCompleteColumn(), transient=True, ) as progress: - task = progress.add_task("[green]Evaluating all eval images...", total=num_images) + task = progress.add_task("[green]Evaluating all images...", total=num_images) idx = 0 - for camera, batch in self.datamanager.fixed_indices_eval_dataloader: + for camera, batch in data_loader: # time this the following line inner_start = time() outputs = self.model.get_outputs_for_camera(camera=camera) @@ -383,7 +388,9 @@ def get_average_eval_image_metrics( if output_path is not None: for key in image_dict.keys(): image = image_dict[key] # [H, W, C] order - vutils.save_image(image.permute(2, 0, 1).cpu(), output_path / f"eval_{key}_{idx:04d}.png") + vutils.save_image( + image.permute(2, 0, 1).cpu(), output_path / f"{image_prefix}_{key}_{idx:04d}.png" + ) assert "num_rays_per_sec" not in metrics_dict metrics_dict["num_rays_per_sec"] = (num_rays / (time() - inner_start)).item() @@ -393,7 +400,7 @@ def get_average_eval_image_metrics( metrics_dict_list.append(metrics_dict) progress.advance(task) idx = idx + 1 - # average the metrics list + metrics_dict = {} for key in metrics_dict_list[0].keys(): if get_std: @@ -406,9 +413,23 @@ def get_average_eval_image_metrics( metrics_dict[key] = float( torch.mean(torch.tensor([metrics_dict[key] for metrics_dict in metrics_dict_list])) ) + self.train() return metrics_dict + @profiler.time_function + def get_average_eval_image_metrics( + self, step: Optional[int] = None, output_path: Optional[Path] = None, get_std: bool = False + ): + """Get the average metrics for evaluation images.""" + assert hasattr( + self.datamanager, "fixed_indices_eval_dataloader" + ), "datamanager must have 'fixed_indices_eval_dataloader' attribute" + image_prefix = "eval" + return self.get_average_image_metrics( + self.datamanager.fixed_indices_eval_dataloader, image_prefix, step, output_path, get_std + ) + def load_pipeline(self, loaded_state: Dict[str, Any], step: int) -> None: """Load the checkpoint from the given path diff --git a/nerfstudio/plugins/types.py b/nerfstudio/plugins/types.py index cfdb54b55c..7ad878f18d 100644 --- a/nerfstudio/plugins/types.py +++ b/nerfstudio/plugins/types.py @@ -15,6 +15,7 @@ """ This package contains specifications used to register plugins. """ + from dataclasses import dataclass from nerfstudio.engine.trainer import TrainerConfig diff --git a/nerfstudio/process_data/colmap_converter_to_nerfstudio_dataset.py b/nerfstudio/process_data/colmap_converter_to_nerfstudio_dataset.py index 4429730368..ec82f0ad7f 100644 --- a/nerfstudio/process_data/colmap_converter_to_nerfstudio_dataset.py +++ b/nerfstudio/process_data/colmap_converter_to_nerfstudio_dataset.py @@ -246,7 +246,7 @@ def _run_colmap(self, mask_path: Optional[Path] = None): def __post_init__(self) -> None: super().__post_init__() install_checks.check_ffmpeg_installed() - install_checks.check_colmap_installed() + install_checks.check_colmap_installed(self.colmap_cmd) if self.crop_bottom < 0.0 or self.crop_bottom > 1: raise RuntimeError("crop_bottom must be set between 0 and 1.") diff --git a/nerfstudio/scripts/docs/build_docs.py b/nerfstudio/scripts/docs/build_docs.py index 3d060670b6..ff7c576775 100644 --- a/nerfstudio/scripts/docs/build_docs.py +++ b/nerfstudio/scripts/docs/build_docs.py @@ -14,6 +14,7 @@ #!/usr/bin/env python """Simple yaml debugger""" + import subprocess import sys diff --git a/nerfstudio/scripts/downloads/download_data.py b/nerfstudio/scripts/downloads/download_data.py index 4d89c1b484..5c976ba167 100644 --- a/nerfstudio/scripts/downloads/download_data.py +++ b/nerfstudio/scripts/downloads/download_data.py @@ -13,6 +13,7 @@ # limitations under the License. """Download datasets and specific captures from the datasets.""" + from __future__ import annotations import json @@ -31,7 +32,6 @@ from typing_extensions import Annotated from nerfstudio.process_data import process_data_utils -from nerfstudio.scripts.downloads.eyeful_tower import EyefulTowerDownload from nerfstudio.scripts.downloads.utils import DatasetDownload from nerfstudio.utils import install_checks from nerfstudio.utils.scripts import run_command @@ -550,10 +550,40 @@ def download(self, save_dir: Path) -> None: Annotated[SDFstudioDemoDownload, tyro.conf.subcommand(name="sdfstudio")], Annotated[NeRFOSRDownload, tyro.conf.subcommand(name="nerfosr")], Annotated[Mill19Download, tyro.conf.subcommand(name="mill19")], - Annotated[EyefulTowerDownload, tyro.conf.subcommand(name="eyefultower")], ] +@dataclass +class NotInstalled(DatasetDownload): + def main(self) -> None: ... + + +# Add eyefultower subcommand if awscli is installed. +try: + import awscli +except ImportError: + awscli = None + +if awscli is not None: + from nerfstudio.scripts.downloads.eyeful_tower import EyefulTowerDownload + + Commands = Union[ + Commands, + Annotated[EyefulTowerDownload, tyro.conf.subcommand(name="eyefultower")], + ] +else: + Commands = Union[ + Commands, + Annotated[ + NotInstalled, + tyro.conf.subcommand( + name="eyefultower", + description="**Not installed.** Downloading EyefulTower data requires `pip install awscli`.", + ), + ], + ] + + def main( dataset: DatasetDownload, ): diff --git a/nerfstudio/scripts/downloads/eyeful_tower.py b/nerfstudio/scripts/downloads/eyeful_tower.py index 3672e2c027..d4b1d7628a 100644 --- a/nerfstudio/scripts/downloads/eyeful_tower.py +++ b/nerfstudio/scripts/downloads/eyeful_tower.py @@ -16,15 +16,21 @@ import collections import copy import json +import sys import xml.etree.ElementTree as ET from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Tuple -import awscli.clidriver import numpy as np import tyro +try: + import awscli.clidriver +except ImportError: + print("awscli is required for EyefulTower download. Please install it with `pip install awscli`.") + sys.exit(1) + from nerfstudio.scripts.downloads.utils import DatasetDownload from nerfstudio.utils.rich_utils import CONSOLE @@ -41,6 +47,8 @@ "seating_area", "table", "workshop", + "raf_emptyroom", + "raf_furnishedroom", ] # Crop radii empirically chosen to try to avoid hitting the rig base or go out of bounds diff --git a/nerfstudio/scripts/eval.py b/nerfstudio/scripts/eval.py index 103467bbd1..64b4403503 100644 --- a/nerfstudio/scripts/eval.py +++ b/nerfstudio/scripts/eval.py @@ -16,6 +16,7 @@ """ eval.py """ + from __future__ import annotations import json diff --git a/nerfstudio/scripts/exporter.py b/nerfstudio/scripts/exporter.py index 7d90e0708d..5ae6037009 100644 --- a/nerfstudio/scripts/exporter.py +++ b/nerfstudio/scripts/exporter.py @@ -16,7 +16,6 @@ Script for exporting NeRF into other formats. """ - from __future__ import annotations import json diff --git a/nerfstudio/scripts/github/run_actions.py b/nerfstudio/scripts/github/run_actions.py index 2fa3aa54c6..289876768f 100644 --- a/nerfstudio/scripts/github/run_actions.py +++ b/nerfstudio/scripts/github/run_actions.py @@ -14,6 +14,7 @@ #!/usr/bin/env python """Simple yaml debugger""" + import subprocess import sys diff --git a/nerfstudio/scripts/process_data.py b/nerfstudio/scripts/process_data.py index b2c2fa13fc..5e1d869f32 100644 --- a/nerfstudio/scripts/process_data.py +++ b/nerfstudio/scripts/process_data.py @@ -152,6 +152,9 @@ def main(self) -> None: zip_ref.extractall(self.output_dir) extracted_folder = zip_ref.namelist()[0].split("/")[0] self.data = self.output_dir / extracted_folder + if not (self.data / "keyframes").exists(): + # new versions of polycam data have a different structure, strip the last dir off + self.data = self.output_dir if (self.data / "keyframes" / "corrected_images").exists() and not self.use_uncorrected_images: polycam_image_dir = self.data / "keyframes" / "corrected_images" @@ -503,8 +506,7 @@ def main(self) -> None: @dataclass class NotInstalled: - def main(self) -> None: - ... + def main(self) -> None: ... Commands = Union[ diff --git a/nerfstudio/scripts/render.py b/nerfstudio/scripts/render.py index c2d6d83ce6..60bcaa9eb8 100644 --- a/nerfstudio/scripts/render.py +++ b/nerfstudio/scripts/render.py @@ -16,6 +16,7 @@ """ render.py """ + from __future__ import annotations import gzip diff --git a/nerfstudio/scripts/viewer/sync_viser_message_defs.py b/nerfstudio/scripts/viewer/sync_viser_message_defs.py index a178aae4cb..3b3e2ea5b6 100644 --- a/nerfstudio/scripts/viewer/sync_viser_message_defs.py +++ b/nerfstudio/scripts/viewer/sync_viser_message_defs.py @@ -13,6 +13,7 @@ # limitations under the License. """Generate viser message definitions for TypeScript, by parsing Python dataclasses.""" + import json import pathlib from datetime import datetime diff --git a/nerfstudio/utils/colormaps.py b/nerfstudio/utils/colormaps.py index 4d5284e77e..0a790e1237 100644 --- a/nerfstudio/utils/colormaps.py +++ b/nerfstudio/utils/colormaps.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Helper functions for visualizing outputs """ +"""Helper functions for visualizing outputs""" from dataclasses import dataclass from typing import Literal, Optional diff --git a/nerfstudio/utils/colors.py b/nerfstudio/utils/colors.py index 66ac8d2435..1208659bc0 100644 --- a/nerfstudio/utils/colors.py +++ b/nerfstudio/utils/colors.py @@ -13,6 +13,7 @@ # limitations under the License. """Common Colors""" + from typing import Union import torch diff --git a/nerfstudio/utils/comms.py b/nerfstudio/utils/comms.py index 03aa445816..0a3880cff3 100644 --- a/nerfstudio/utils/comms.py +++ b/nerfstudio/utils/comms.py @@ -13,6 +13,7 @@ # limitations under the License. """functionality to handle multiprocessing syncing and communicating""" + import torch.distributed as dist LOCAL_PROCESS_GROUP = None diff --git a/nerfstudio/utils/decorators.py b/nerfstudio/utils/decorators.py index cc78be8a50..9e439e7039 100644 --- a/nerfstudio/utils/decorators.py +++ b/nerfstudio/utils/decorators.py @@ -15,6 +15,7 @@ """ Decorator definitions """ + from typing import Callable, List from nerfstudio.utils import comms diff --git a/nerfstudio/utils/eval_utils.py b/nerfstudio/utils/eval_utils.py index 2c4bfcd154..11a8b23416 100644 --- a/nerfstudio/utils/eval_utils.py +++ b/nerfstudio/utils/eval_utils.py @@ -15,6 +15,7 @@ """ Evaluation utils """ + from __future__ import annotations import os diff --git a/nerfstudio/utils/install_checks.py b/nerfstudio/utils/install_checks.py index e9298cbb90..f9768ce24d 100644 --- a/nerfstudio/utils/install_checks.py +++ b/nerfstudio/utils/install_checks.py @@ -15,6 +15,7 @@ """Helpers for checking if programs are installed""" import shutil +import subprocess import sys from nerfstudio.utils.rich_utils import CONSOLE @@ -29,10 +30,10 @@ def check_ffmpeg_installed(): sys.exit(1) -def check_colmap_installed(): +def check_colmap_installed(colmap_cmd: str): """Checks if colmap is installed.""" - colmap_path = shutil.which("colmap") - if colmap_path is None: + out = subprocess.run(f"{colmap_cmd} -h", capture_output=True, shell=True, check=False) + if out.returncode != 0: CONSOLE.print("[bold red]Could not find COLMAP. Please install COLMAP.") print("See https://colmap.github.io/install.html for installation instructions.") sys.exit(1) diff --git a/nerfstudio/utils/math.py b/nerfstudio/utils/math.py index 0ba9e6a51c..d71907bee3 100644 --- a/nerfstudio/utils/math.py +++ b/nerfstudio/utils/math.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Math Helper Functions """ +"""Math Helper Functions""" import itertools import math diff --git a/nerfstudio/utils/misc.py b/nerfstudio/utils/misc.py index ae3c4d5573..f55e1259a3 100644 --- a/nerfstudio/utils/misc.py +++ b/nerfstudio/utils/misc.py @@ -16,7 +16,6 @@ Miscellaneous helper code. """ - import platform import typing import warnings diff --git a/nerfstudio/utils/profiler.py b/nerfstudio/utils/profiler.py index ff0ae5bb93..b3046aa73f 100644 --- a/nerfstudio/utils/profiler.py +++ b/nerfstudio/utils/profiler.py @@ -15,6 +15,7 @@ """ Profiler base class and functionality """ + from __future__ import annotations import functools @@ -41,13 +42,11 @@ @overload -def time_function(name_or_func: CallableT) -> CallableT: - ... +def time_function(name_or_func: CallableT) -> CallableT: ... @overload -def time_function(name_or_func: str) -> ContextManager[Any]: - ... +def time_function(name_or_func: str) -> ContextManager[Any]: ... def time_function(name_or_func: Union[CallableT, str]) -> Union[CallableT, ContextManager[Any]]: diff --git a/nerfstudio/utils/writer.py b/nerfstudio/utils/writer.py index d460e9411b..a986618cad 100644 --- a/nerfstudio/utils/writer.py +++ b/nerfstudio/utils/writer.py @@ -15,6 +15,7 @@ """ Generic Writer class """ + from __future__ import annotations import enum diff --git a/nerfstudio/viewer/control_panel.py b/nerfstudio/viewer/control_panel.py index 6524c53093..73a136af92 100644 --- a/nerfstudio/viewer/control_panel.py +++ b/nerfstudio/viewer/control_panel.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Control panel for the viewer """ +"""Control panel for the viewer""" + from collections import defaultdict from typing import Callable, DefaultDict, List, Tuple, get_args @@ -49,7 +50,7 @@ class ControlPanel: def __init__( self, - viser_server: ViserServer, + server: ViserServer, time_enabled: bool, scale_ratio: float, rerender_cb: Callable[[], None], @@ -59,7 +60,7 @@ def __init__( ): self.viser_scale_ratio = scale_ratio # elements holds a mapping from tag: [elements] - self.viser_server = viser_server + self.server = server self._elements_by_tag: DefaultDict[str, List[ViewerElement]] = defaultdict(lambda: []) self.default_composite_depth = default_composite_depth @@ -151,7 +152,7 @@ def __init__( self._background_color = ViewerRGB( "Background color", (38, 42, 55), cb_hook=lambda _: rerender_cb(), hint="Color of the background" ) - self._crop_handle = self.viser_server.add_transform_controls("Crop", depth_test=False, line_width=4.0) + self._crop_handle = self.server.scene.add_transform_controls("Crop", depth_test=False, line_width=4.0) def update_center(han): self._crop_handle.position = tuple(p * self.viser_scale_ratio for p in han.value) # type: ignore @@ -192,7 +193,7 @@ def _update_crop_handle(han): self.add_element(self._train_speed) self.add_element(self._train_util) - with self.viser_server.add_gui_folder("Render Options"): + with self.server.gui.add_folder("Render Options"): self.add_element(self._max_res) self.add_element(self._output_render) self.add_element(self._colormap) @@ -204,7 +205,7 @@ def _update_crop_handle(han): self.add_element(self._max, additional_tags=("colormap",)) # split options - with self.viser_server.add_gui_folder("Split Screen"): + with self.server.gui.add_folder("Split Screen"): self.add_element(self._split) self.add_element(self._split_percentage, additional_tags=("split",)) @@ -216,7 +217,7 @@ def _update_crop_handle(han): self.add_element(self._split_min, additional_tags=("split_colormap",)) self.add_element(self._split_max, additional_tags=("split_colormap",)) - with self.viser_server.add_gui_folder("Crop Viewport"): + with self.server.gui.add_folder("Crop Viewport"): self.add_element(self._crop_viewport) # Crop options self.add_element(self._background_color, additional_tags=("crop",)) @@ -225,7 +226,7 @@ def _update_crop_handle(han): self.add_element(self._crop_rot, additional_tags=("crop",)) self.add_element(self._time, additional_tags=("time",)) - self._reset_camera = viser_server.add_gui_button( + self._reset_camera = server.gui.add_button( label="Reset Up Direction", icon=viser.Icon.ARROW_BIG_UP_LINES, color="gray", @@ -248,7 +249,7 @@ def _train_speed_cb(self) -> None: self._max_res.value = 1024 def _reset_camera_cb(self, _) -> None: - for client in self.viser_server.get_clients().values(): + for client in self.server.get_clients().values(): client.camera.up_direction = vtf.SO3(client.camera.wxyz) @ np.array([0.0, -1.0, 0.0]) def update_output_options(self, new_options: List[str]): @@ -270,7 +271,7 @@ def add_element(self, e: ViewerElement, additional_tags: Tuple[str, ...] = tuple self._elements_by_tag["all"].append(e) for t in additional_tags: self._elements_by_tag[t].append(e) - e.install(self.viser_server) + e.install(self.server) def update_control_panel(self) -> None: """ diff --git a/nerfstudio/viewer/export_panel.py b/nerfstudio/viewer/export_panel.py index 16201ba299..2bb3969cd5 100644 --- a/nerfstudio/viewer/export_panel.py +++ b/nerfstudio/viewer/export_panel.py @@ -34,17 +34,17 @@ def populate_export_tab( ) -> None: viewing_gsplat = isinstance(viewer_model, SplatfactoModel) if not viewing_gsplat: - crop_output = server.add_gui_checkbox("Use Crop", False) + crop_output = server.gui.add_checkbox("Use Crop", False) @crop_output.on_update def _(_) -> None: control_panel.crop_viewport = crop_output.value - with server.add_gui_folder("Splat"): + with server.gui.add_folder("Splat"): populate_splat_tab(server, control_panel, config_path, viewing_gsplat) - with server.add_gui_folder("Point Cloud"): + with server.gui.add_folder("Point Cloud"): populate_point_cloud_tab(server, control_panel, config_path, viewing_gsplat) - with server.add_gui_folder("Mesh"): + with server.gui.add_folder("Mesh"): populate_mesh_tab(server, control_panel, config_path, viewing_gsplat) @@ -54,8 +54,8 @@ def show_command_modal(client: viser.ClientHandle, what: Literal["mesh", "point In the future, we should only show the modal to the client that pushes the generation button. """ - with client.add_gui_modal(what.title() + " Export") as modal: - client.add_gui_markdown( + with client.gui.add_modal(what.title() + " Export") as modal: + client.gui.add_markdown( "\n".join( [ f"To export a {what}, run the following from the command line:", @@ -66,7 +66,7 @@ def show_command_modal(client: viser.ClientHandle, what: Literal["mesh", "point ] ) ) - close_button = client.add_gui_button("Close") + close_button = client.gui.add_button("Close") @close_button.on_click def _(_) -> None: @@ -80,6 +80,7 @@ def get_crop_string(obb: OrientedBox, crop_viewport: bool): if not crop_viewport: return "" rpy = vtf.SO3.from_matrix(obb.R.numpy(force=True)).as_rpy_radians() + rpy = [rpy.roll, rpy.pitch, rpy.yaw] pos = obb.T.squeeze().tolist() scale = obb.S.squeeze().tolist() rpystring = " ".join([f"{x:.10f}" for x in rpy]) @@ -95,9 +96,9 @@ def populate_point_cloud_tab( viewing_gsplat: bool, ) -> None: if not viewing_gsplat: - server.add_gui_markdown("Render depth, project to an oriented point cloud, and filter ") - num_points = server.add_gui_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) - world_frame = server.add_gui_checkbox( + server.gui.add_markdown("Render depth, project to an oriented point cloud, and filter ") + num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) + world_frame = server.gui.add_checkbox( "Save in world frame", False, hint=( @@ -105,16 +106,16 @@ def populate_point_cloud_tab( "scaled and reoriented coordinate space expected by the NeRF models." ), ) - remove_outliers = server.add_gui_checkbox("Remove outliers", True) - normals = server.add_gui_dropdown( + remove_outliers = server.gui.add_checkbox("Remove outliers", True) + normals = server.gui.add_dropdown( "Normals", # TODO: options here could depend on what's available to the model. ("open3d", "model_output"), initial_value="open3d", hint="Normal map source.", ) - output_dir = server.add_gui_text("Output Directory", initial_value="exports/pcd/") - generate_command = server.add_gui_button("Generate Command", icon=viser.Icon.TERMINAL_2) + output_dir = server.gui.add_text("Output Directory", initial_value="exports/pcd/") + generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @generate_command.on_click def _(event: viser.GuiEvent) -> None: @@ -134,7 +135,7 @@ def _(event: viser.GuiEvent) -> None: show_command_modal(event.client, "point cloud", command) else: - server.add_gui_markdown("Point cloud export is not currently supported with Gaussian Splatting") + server.gui.add_markdown("Point cloud export is not currently supported with Gaussian Splatting") def populate_mesh_tab( @@ -144,23 +145,23 @@ def populate_mesh_tab( viewing_gsplat: bool, ) -> None: if not viewing_gsplat: - server.add_gui_markdown( + server.gui.add_markdown( "Render depth, project to an oriented point cloud, and run Poisson surface reconstruction" ) - normals = server.add_gui_dropdown( + normals = server.gui.add_dropdown( "Normals", ("open3d", "model_output"), initial_value="open3d", hint="Source for normal maps.", ) - num_faces = server.add_gui_number("# Faces", initial_value=50_000, min=1) - texture_resolution = server.add_gui_number("Texture Resolution", min=8, initial_value=2048) - output_directory = server.add_gui_text("Output Directory", initial_value="exports/mesh/") - num_points = server.add_gui_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) - remove_outliers = server.add_gui_checkbox("Remove outliers", True) + num_faces = server.gui.add_number("# Faces", initial_value=50_000, min=1) + texture_resolution = server.gui.add_number("Texture Resolution", min=8, initial_value=2048) + output_directory = server.gui.add_text("Output Directory", initial_value="exports/mesh/") + num_points = server.gui.add_number("# Points", initial_value=1_000_000, min=1, max=None, step=1) + remove_outliers = server.gui.add_checkbox("Remove outliers", True) - generate_command = server.add_gui_button("Generate Command", icon=viser.Icon.TERMINAL_2) + generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @generate_command.on_click def _(event: viser.GuiEvent) -> None: @@ -181,7 +182,7 @@ def _(event: viser.GuiEvent) -> None: show_command_modal(event.client, "mesh", command) else: - server.add_gui_markdown("Mesh export is not currently supported with Gaussian Splatting") + server.gui.add_markdown("Mesh export is not currently supported with Gaussian Splatting") def populate_splat_tab( @@ -191,10 +192,10 @@ def populate_splat_tab( viewing_gsplat: bool, ) -> None: if viewing_gsplat: - server.add_gui_markdown("Generate ply export of Gaussian Splat") + server.gui.add_markdown("Generate ply export of Gaussian Splat") - output_directory = server.add_gui_text("Output Directory", initial_value="exports/splat/") - generate_command = server.add_gui_button("Generate Command", icon=viser.Icon.TERMINAL_2) + output_directory = server.gui.add_text("Output Directory", initial_value="exports/splat/") + generate_command = server.gui.add_button("Generate Command", icon=viser.Icon.TERMINAL_2) @generate_command.on_click def _(event: viser.GuiEvent) -> None: @@ -210,4 +211,4 @@ def _(event: viser.GuiEvent) -> None: show_command_modal(event.client, "splat", command) else: - server.add_gui_markdown("Splat export is only supported with Gaussian Splatting methods") + server.gui.add_markdown("Splat export is only supported with Gaussian Splatting methods") diff --git a/nerfstudio/viewer/render_panel.py b/nerfstudio/viewer/render_panel.py index 4cfe380d9e..10d263f8c2 100644 --- a/nerfstudio/viewer/render_panel.py +++ b/nerfstudio/viewer/render_panel.py @@ -101,7 +101,7 @@ def add_camera(self, keyframe: Keyframe, keyframe_index: Optional[int] = None) - keyframe_index = self._keyframe_counter self._keyframe_counter += 1 - frustum_handle = server.add_camera_frustum( + frustum_handle = server.scene.add_camera_frustum( f"/render_cameras/{keyframe_index}", fov=keyframe.override_fov_rad if keyframe.override_fov_enabled else self.default_fov, aspect=keyframe.aspect, @@ -111,7 +111,7 @@ def add_camera(self, keyframe: Keyframe, keyframe_index: Optional[int] = None) - position=keyframe.position, visible=self._keyframes_visible, ) - self._server.add_icosphere( + self._server.scene.add_icosphere( f"/render_cameras/{keyframe_index}/sphere", radius=0.03, color=(200, 10, 30), @@ -123,13 +123,13 @@ def _(_) -> None: self._camera_edit_panel.remove() self._camera_edit_panel = None - with server.add_3d_gui_container( + with server.scene.add_3d_gui_container( "/camera_edit_panel", position=keyframe.position, ) as camera_edit_panel: self._camera_edit_panel = camera_edit_panel - override_fov = server.add_gui_checkbox("Override FOV", initial_value=keyframe.override_fov_enabled) - override_fov_degrees = server.add_gui_slider( + override_fov = server.gui.add_checkbox("Override FOV", initial_value=keyframe.override_fov_enabled) + override_fov_degrees = server.gui.add_slider( "Override FOV (degrees)", 5.0, 175.0, @@ -138,10 +138,10 @@ def _(_) -> None: disabled=not keyframe.override_fov_enabled, ) if self.time_enabled: - override_time = server.add_gui_checkbox( + override_time = server.gui.add_checkbox( "Override Time", initial_value=keyframe.override_time_enabled ) - override_time_val = server.add_gui_slider( + override_time_val = server.gui.add_slider( "Override Time", 0.0, 1.0, @@ -161,9 +161,9 @@ def _(_) -> None: keyframe.override_time_val = override_time_val.value self.add_camera(keyframe, keyframe_index) - delete_button = server.add_gui_button("Delete", color="red", icon=viser.Icon.TRASH) - go_to_button = server.add_gui_button("Go to") - close_button = server.add_gui_button("Close") + delete_button = server.gui.add_button("Delete", color="red", icon=viser.Icon.TRASH) + go_to_button = server.gui.add_button("Go to") + close_button = server.gui.add_button("Close") @override_fov.on_update def _(_) -> None: @@ -179,10 +179,10 @@ def _(_) -> None: @delete_button.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None - with event.client.add_gui_modal("Confirm") as modal: - event.client.add_gui_markdown("Delete keyframe?") - confirm_button = event.client.add_gui_button("Yes", color="red", icon=viser.Icon.TRASH) - exit_button = event.client.add_gui_button("Cancel") + with event.client.gui.add_modal("Confirm") as modal: + event.client.gui.add_markdown("Delete keyframe?") + confirm_button = event.client.gui.add_button("Yes", color="red", icon=viser.Icon.TRASH) + exit_button = event.client.gui.add_button("Cancel") @confirm_button.on_click def _(_) -> None: @@ -374,7 +374,7 @@ def update_spline(self) -> None: self._spline_nodes.clear() self._spline_nodes.append( - self._server.add_spline_catmull_rom( + self._server.scene.add_spline_catmull_rom( "/render_camera_spline", positions=points_array, color=(220, 220, 220), @@ -384,7 +384,7 @@ def update_spline(self) -> None: ) ) self._spline_nodes.append( - self._server.add_point_cloud( + self._server.scene.add_point_cloud( "/render_camera_spline/points", points=points_array, colors=colors_array, @@ -401,7 +401,7 @@ def make_transition_handle(i: int) -> None: ) ) ) - transition_sphere = self._server.add_icosphere( + transition_sphere = self._server.scene.add_icosphere( f"/render_camera_spline/transition_{i}", radius=0.04, color=(255, 0, 0), @@ -420,16 +420,16 @@ def _(_) -> None: keyframe_index = (i + 1) % len(self._keyframes) keyframe = keyframes[keyframe_index][0] - with server.add_3d_gui_container( + with server.scene.add_3d_gui_container( "/camera_edit_panel", position=transition_pos, ) as camera_edit_panel: self._camera_edit_panel = camera_edit_panel - override_transition_enabled = server.add_gui_checkbox( + override_transition_enabled = server.gui.add_checkbox( "Override transition", initial_value=keyframe.override_transition_enabled, ) - override_transition_sec = server.add_gui_number( + override_transition_sec = server.gui.add_number( "Override transition (sec)", initial_value=keyframe.override_transition_sec if keyframe.override_transition_sec is not None @@ -439,7 +439,7 @@ def _(_) -> None: step=0.001, disabled=not override_transition_enabled.value, ) - close_button = server.add_gui_button("Close") + close_button = server.gui.add_button("Close") @override_transition_enabled.on_update def _(_) -> None: @@ -532,7 +532,7 @@ def populate_render_tab( preview_camera_type="Perspective", ) - fov_degrees = server.add_gui_slider( + fov_degrees = server.gui.add_slider( "Default FOV", initial_value=75.0, min=0.1, @@ -543,7 +543,7 @@ def populate_render_tab( render_time = None if control_panel is not None and control_panel._time_enabled: - render_time = server.add_gui_slider( + render_time = server.gui.add_slider( "Default Time", initial_value=0.0, min=0.0, @@ -568,7 +568,7 @@ def _(_) -> None: camera_path.update_aspect(resolution.value[0] / resolution.value[1]) compute_and_update_preview_camera_state() - resolution = server.add_gui_vector2( + resolution = server.gui.add_vector2( "Resolution", initial_value=(1920, 1080), min=(50, 50), @@ -582,13 +582,13 @@ def _(_) -> None: camera_path.update_aspect(resolution.value[0] / resolution.value[1]) compute_and_update_preview_camera_state() - camera_type = server.add_gui_dropdown( + camera_type = server.gui.add_dropdown( "Camera type", ("Perspective", "Fisheye", "Equirectangular"), initial_value="Perspective", hint="Camera model to render with. This is applied to all keyframes.", ) - add_button = server.add_gui_button( + add_button = server.gui.add_button( "Add Keyframe", icon=viser.Icon.PLUS, hint="Add a new keyframe at the current pose.", @@ -609,7 +609,7 @@ def _(event: viser.GuiEvent) -> None: duration_number.value = camera_path.compute_duration() camera_path.update_spline() - clear_keyframes_button = server.add_gui_button( + clear_keyframes_button = server.gui.add_button( "Clear Keyframes", icon=viser.Icon.TRASH, hint="Remove all keyframes from the render path.", @@ -619,10 +619,10 @@ def _(event: viser.GuiEvent) -> None: def _(event: viser.GuiEvent) -> None: assert event.client_id is not None client = server.get_clients()[event.client_id] - with client.atomic(), client.add_gui_modal("Confirm") as modal: - client.add_gui_markdown("Clear all keyframes?") - confirm_button = client.add_gui_button("Yes", color="red", icon=viser.Icon.TRASH) - exit_button = client.add_gui_button("Cancel") + with client.atomic(), client.gui.add_modal("Confirm") as modal: + client.gui.add_markdown("Clear all keyframes?") + confirm_button = client.gui.add_button("Yes", color="red", icon=viser.Icon.TRASH) + exit_button = client.gui.add_button("Cancel") @confirm_button.on_click def _(_) -> None: @@ -642,14 +642,14 @@ def _(_) -> None: def _(_) -> None: modal.close() - loop = server.add_gui_checkbox("Loop", False, hint="Add a segment between the first and last keyframes.") + loop = server.gui.add_checkbox("Loop", False, hint="Add a segment between the first and last keyframes.") @loop.on_update def _(_) -> None: camera_path.loop = loop.value duration_number.value = camera_path.compute_duration() - tension_slider = server.add_gui_slider( + tension_slider = server.gui.add_slider( "Spline tension", min=0.0, max=1.0, @@ -663,7 +663,7 @@ def _(_) -> None: camera_path.tension = tension_slider.value camera_path.update_spline() - move_checkbox = server.add_gui_checkbox( + move_checkbox = server.gui.add_checkbox( "Move keyframes", initial_value=False, hint="Toggle move handles for keyframes in the scene.", @@ -697,7 +697,7 @@ def _(_) -> None: # Show move handles. assert event.client is not None for keyframe_index, keyframe in camera_path._keyframes.items(): - controls = event.client.add_transform_controls( + controls = event.client.scene.add_transform_controls( f"/keyframe_move/{keyframe_index}", scale=0.4, wxyz=keyframe[0].wxyz, @@ -706,7 +706,7 @@ def _(_) -> None: transform_controls.append(controls) _make_transform_controls_callback(keyframe, controls) - show_keyframe_checkbox = server.add_gui_checkbox( + show_keyframe_checkbox = server.gui.add_checkbox( "Show keyframes", initial_value=True, hint="Show keyframes in the scene.", @@ -716,7 +716,7 @@ def _(_) -> None: def _(_: viser.GuiEvent) -> None: camera_path.set_keyframes_visible(show_keyframe_checkbox.value) - show_spline_checkbox = server.add_gui_checkbox( + show_spline_checkbox = server.gui.add_checkbox( "Show spline", initial_value=True, hint="Show camera path spline in the scene.", @@ -727,16 +727,16 @@ def _(_) -> None: camera_path.show_spline = show_spline_checkbox.value camera_path.update_spline() - playback_folder = server.add_gui_folder("Playback") + playback_folder = server.gui.add_folder("Playback") with playback_folder: - play_button = server.add_gui_button("Play", icon=viser.Icon.PLAYER_PLAY) - pause_button = server.add_gui_button("Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False) - preview_render_button = server.add_gui_button( + play_button = server.gui.add_button("Play", icon=viser.Icon.PLAYER_PLAY) + pause_button = server.gui.add_button("Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False) + preview_render_button = server.gui.add_button( "Preview Render", hint="Show a preview of the render in the viewport." ) - preview_render_stop_button = server.add_gui_button("Exit Render Preview", color="red", visible=False) + preview_render_stop_button = server.gui.add_button("Exit Render Preview", color="red", visible=False) - transition_sec_number = server.add_gui_number( + transition_sec_number = server.gui.add_number( "Transition (sec)", min=0.001, max=30.0, @@ -744,9 +744,9 @@ def _(_) -> None: initial_value=2.0, hint="Time in seconds between each keyframe, which can also be overridden on a per-transition basis.", ) - framerate_number = server.add_gui_number("FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0) - framerate_buttons = server.add_gui_button_group("", ("24", "30", "60")) - duration_number = server.add_gui_number( + framerate_number = server.gui.add_number("FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0) + framerate_buttons = server.gui.add_button_group("", ("24", "30", "60")) + duration_number = server.gui.add_number( "Duration (sec)", min=0.0, max=1e8, @@ -807,7 +807,7 @@ def add_preview_frame_slider() -> Optional[viser.GuiInputHandle[int]]: re-added anytime the `max` value changes.""" with playback_folder: - preview_frame_slider = server.add_gui_slider( + preview_frame_slider = server.gui.add_slider( "Preview frame", min=0, max=get_max_frame_index(), @@ -831,7 +831,7 @@ def _(_) -> None: else: pose, fov_rad = maybe_pose_and_fov_rad - preview_camera_handle = server.add_camera_frustum( + preview_camera_handle = server.scene.add_camera_frustum( "/preview_camera", fov=fov_rad, aspect=resolution.value[0] / resolution.value[1], @@ -867,7 +867,7 @@ def _(_) -> None: del fov # Hide all scene nodes when we're previewing the render. - server.set_global_scene_node_visibility(False) + server.scene.set_global_visibility(False) # Back up and then set camera poses. for client in server.get_clients().values(): @@ -896,7 +896,7 @@ def _(_) -> None: client.flush() # Un-hide scene nodes. - server.set_global_scene_node_visibility(True) + server.scene.set_global_visibility(True) preview_frame_slider = add_preview_frame_slider() @@ -942,7 +942,7 @@ def _(_) -> None: pause_button.visible = False # add button for loading existing path - load_camera_path_button = server.add_gui_button( + load_camera_path_button = server.gui.add_button( "Load Path", icon=viser.Icon.FOLDER_OPEN, hint="Load an existing camera path." ) @@ -954,17 +954,17 @@ def _(event: viser.GuiEvent) -> None: preexisting_camera_paths = list(camera_path_dir.glob("*.json")) preexisting_camera_filenames = [p.name for p in preexisting_camera_paths] - with event.client.add_gui_modal("Load Path") as modal: + with event.client.gui.add_modal("Load Path") as modal: if len(preexisting_camera_filenames) == 0: - event.client.add_gui_markdown("No existing paths found") + event.client.gui.add_markdown("No existing paths found") else: - event.client.add_gui_markdown("Select existing camera path:") - camera_path_dropdown = event.client.add_gui_dropdown( + event.client.gui.add_markdown("Select existing camera path:") + camera_path_dropdown = event.client.gui.add_dropdown( label="Camera Path", options=[str(p) for p in preexisting_camera_filenames], initial_value=str(preexisting_camera_filenames[0]), ) - load_button = event.client.add_gui_button("Load") + load_button = event.client.gui.add_button("Load") @load_button.on_click def _(_) -> None: @@ -1006,7 +1006,7 @@ def _(_) -> None: camera_path.update_spline() modal.close() - cancel_button = event.client.add_gui_button("Cancel") + cancel_button = event.client.gui.add_button("Cancel") @cancel_button.on_click def _(_) -> None: @@ -1014,19 +1014,19 @@ def _(_) -> None: # set the initial value to the current date-time string now = datetime.datetime.now() - render_name_text = server.add_gui_text( + render_name_text = server.gui.add_text( "Render name", initial_value=now.strftime("%Y-%m-%d-%H-%M-%S"), hint="Name of the render", ) - render_button = server.add_gui_button( + render_button = server.gui.add_button( "Generate Command", color="green", icon=viser.Icon.FILE_EXPORT, hint="Generate the ns-render command for rendering the camera path.", ) - reset_up_button = server.add_gui_button( + reset_up_button = server.gui.add_button( "Reset Up Direction", icon=viser.Icon.ARROW_BIG_UP_LINES, color="gray", @@ -1135,7 +1135,7 @@ def _(event: viser.GuiEvent) -> None: with open(json_outfile.absolute(), "w") as outfile: json.dump(json_data, outfile) # now show the command - with event.client.add_gui_modal("Render Command") as modal: + with event.client.gui.add_modal("Render Command") as modal: dataname = datapath.name command = " ".join( [ @@ -1145,7 +1145,7 @@ def _(event: viser.GuiEvent) -> None: f"--output-path renders/{dataname}/{render_name_text.value}.mp4", ] ) - event.client.add_gui_markdown( + event.client.gui.add_markdown( "\n".join( [ "To render the trajectory, run the following from the command line:", @@ -1156,7 +1156,7 @@ def _(event: viser.GuiEvent) -> None: ] ) ) - close_button = event.client.add_gui_button("Close") + close_button = event.client.gui.add_button("Close") @close_button.on_click def _(_) -> None: @@ -1166,6 +1166,7 @@ def _(_) -> None: camera_path = CameraPath(server, duration_number, control_panel._time_enabled) else: camera_path = CameraPath(server, duration_number) + camera_path.tension = tension_slider.value camera_path.default_fov = fov_degrees.value / 180.0 * np.pi camera_path.default_transition_sec = transition_sec_number.value diff --git a/nerfstudio/viewer/render_state_machine.py b/nerfstudio/viewer/render_state_machine.py index c950644648..cc8002ec9f 100644 --- a/nerfstudio/viewer/render_state_machine.py +++ b/nerfstudio/viewer/render_state_machine.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" This file contains the render state machine, which is responsible for deciding when to render the image """ +"""This file contains the render state machine, which is responsible for deciding when to render the image""" + from __future__ import annotations import contextlib @@ -302,7 +303,7 @@ def _send_output_to_viewer(self, outputs: Dict[str, Any], static_render: bool = if self.viewer.render_tab_state.preview_render else 40 ) - self.client.set_background_image( + self.client.scene.set_background_image( selected_output, format=self.viewer.config.image_format, jpeg_quality=jpg_quality, diff --git a/nerfstudio/viewer/server/viewer_elements.py b/nerfstudio/viewer/server/viewer_elements.py index 2689e432b3..597c89c01f 100644 --- a/nerfstudio/viewer/server/viewer_elements.py +++ b/nerfstudio/viewer/server/viewer_elements.py @@ -16,4 +16,5 @@ Resolves issues like: https://github.com/ayaanzhaque/instruct-nerf2nerf/pull/88 """ + from ..viewer_elements import * # noqa diff --git a/nerfstudio/viewer/utils.py b/nerfstudio/viewer/utils.py index 0f5ead912f..2fbb5af14f 100644 --- a/nerfstudio/viewer/utils.py +++ b/nerfstudio/viewer/utils.py @@ -42,6 +42,8 @@ class CameraState: """Type of camera to render.""" time: float = 0.0 """The rendering time of the camera state.""" + idx: int = 0 + """The index of the current camera.""" def get_camera( @@ -78,6 +80,7 @@ def get_camera( camera_type=camera_state.camera_type, camera_to_worlds=camera_state.c2w.to(torch.float32)[None, ...], times=torch.tensor([camera_state.time], dtype=torch.float32), + metadata={"cam_idx": camera_state.idx}, ) return camera diff --git a/nerfstudio/viewer/viewer.py b/nerfstudio/viewer/viewer.py index a5093f0dad..3b01c2e5be 100644 --- a/nerfstudio/viewer/viewer.py +++ b/nerfstudio/viewer/viewer.py @@ -106,6 +106,8 @@ def __init__( self.train_btn_state: Literal["training", "paused", "completed"] = "training" self._prev_train_state: Literal["training", "paused", "completed"] = "training" self.last_move_time = 0 + # track the camera index that last being clicked + self.current_camera_idx = 0 self.viser_server = viser.ViserServer(host=config.websocket_host, port=websocket_port) # Set the name of the URL either to the share link if available, or the localhost @@ -150,7 +152,7 @@ def __init__( href="https://docs.nerf.studio/", ) titlebar_theme = viser.theme.TitlebarConfig(buttons=buttons, image=image) - self.viser_server.configure_theme( + self.viser_server.gui.configure_theme( titlebar_content=titlebar_theme, control_layout="collapsible", dark_mode=True, @@ -162,32 +164,32 @@ def __init__( self.viser_server.on_client_connect(self.handle_new_client) # Populate the header, which includes the pause button, train cam button, and stats - self.pause_train = self.viser_server.add_gui_button( + self.pause_train = self.viser_server.gui.add_button( label="Pause Training", disabled=False, icon=viser.Icon.PLAYER_PAUSE_FILLED ) self.pause_train.on_click(lambda _: self.toggle_pause_button()) self.pause_train.on_click(lambda han: self._toggle_training_state(han)) - self.resume_train = self.viser_server.add_gui_button( + self.resume_train = self.viser_server.gui.add_button( label="Resume Training", disabled=False, icon=viser.Icon.PLAYER_PLAY_FILLED ) self.resume_train.on_click(lambda _: self.toggle_pause_button()) self.resume_train.on_click(lambda han: self._toggle_training_state(han)) self.resume_train.visible = False # Add buttons to toggle training image visibility - self.hide_images = self.viser_server.add_gui_button( + self.hide_images = self.viser_server.gui.add_button( label="Hide Train Cams", disabled=False, icon=viser.Icon.EYE_OFF, color=None ) self.hide_images.on_click(lambda _: self.set_camera_visibility(False)) self.hide_images.on_click(lambda _: self.toggle_cameravis_button()) - self.show_images = self.viser_server.add_gui_button( + self.show_images = self.viser_server.gui.add_button( label="Show Train Cams", disabled=False, icon=viser.Icon.EYE, color=None ) self.show_images.on_click(lambda _: self.set_camera_visibility(True)) self.show_images.on_click(lambda _: self.toggle_cameravis_button()) self.show_images.visible = False mkdown = self.make_stats_markdown(0, "0x0px") - self.stats_markdown = self.viser_server.add_gui_markdown(mkdown) - tabs = self.viser_server.add_gui_tab_group() + self.stats_markdown = self.viser_server.gui.add_markdown(mkdown) + tabs = self.viser_server.gui.add_tab_group() control_tab = tabs.add_tab("Control", viser.Icon.SETTINGS) with control_tab: self.control_panel = ControlPanel( @@ -242,7 +244,7 @@ def nested_folder_install(folder_labels: List[str], prev_labels: List[str], elem # Otherwise, use the existing folder as context manager. folder_path = "/".join(prev_labels + [folder_labels[0]]) if folder_path not in viewer_gui_folders: - viewer_gui_folders[folder_path] = self.viser_server.add_gui_folder(folder_labels[0]) + viewer_gui_folders[folder_path] = self.viser_server.gui.add_folder(folder_labels[0]) with viewer_gui_folders[folder_path]: nested_folder_install(folder_labels[1:], prev_labels + [folder_labels[0]], element) @@ -272,7 +274,7 @@ def nested_folder_install(folder_labels: List[str], prev_labels: List[str], elem # Diagnostics for Gaussian Splatting: where the points are at the start of training. # This is hidden by default, it can be shown from the Viser UI's scene tree table. if isinstance(pipeline.model, SplatfactoModel): - self.viser_server.add_point_cloud( + self.viser_server.scene.add_point_cloud( "/gaussian_splatting_initial_points", points=pipeline.model.means.numpy(force=True) * VISER_NERFSTUDIO_SCALE_RATIO, colors=(255, 0, 0), @@ -325,6 +327,7 @@ def get_camera_state(self, client: viser.ClientHandle) -> CameraState: else CameraType.EQUIRECTANGULAR if camera_type == "Equirectangular" else assert_never(camera_type), + idx=self.current_camera_idx, ) else: camera_state = CameraState( @@ -332,6 +335,7 @@ def get_camera_state(self, client: viser.ClientHandle) -> CameraState: aspect=client.camera.aspect, c2w=c2w, camera_type=CameraType.PERSPECTIVE, + idx=self.current_camera_idx, ) return camera_state @@ -452,7 +456,7 @@ def init_scene( c2w = camera.camera_to_worlds.cpu().numpy() R = vtf.SO3.from_matrix(c2w[:3, :3]) R = R @ vtf.SO3.from_x_radians(np.pi) - camera_handle = self.viser_server.add_camera_frustum( + camera_handle = self.viser_server.scene.add_camera_frustum( name=f"/cameras/camera_{idx:05d}", fov=float(2 * np.arctan(camera.cx / camera.fx[0])), scale=self.config.camera_frustum_scale, @@ -462,11 +466,16 @@ def init_scene( position=c2w[:3, 3] * VISER_NERFSTUDIO_SCALE_RATIO, ) - @camera_handle.on_click - def _(event: viser.SceneNodePointerEvent[viser.CameraFrustumHandle]) -> None: - with event.client.atomic(): - event.client.camera.position = event.target.position - event.client.camera.wxyz = event.target.wxyz + def create_on_click_callback(capture_idx): + def on_click_callback(event: viser.SceneNodePointerEvent[viser.CameraFrustumHandle]) -> None: + with event.client.atomic(): + event.client.camera.position = event.target.position + event.client.camera.wxyz = event.target.wxyz + self.current_camera_idx = capture_idx + + return on_click_callback + + camera_handle.on_click(create_on_click_callback(idx)) self.camera_handles[idx] = camera_handle self.original_c2w[idx] = c2w diff --git a/nerfstudio/viewer/viewer_elements.py b/nerfstudio/viewer/viewer_elements.py index 654503a3c4..1ccfab3b45 100644 --- a/nerfstudio/viewer/viewer_elements.py +++ b/nerfstudio/viewer/viewer_elements.py @@ -13,8 +13,7 @@ # limitations under the License. -""" Viewer GUI elements for the nerfstudio viewer """ - +"""Viewer GUI elements for the nerfstudio viewer""" from __future__ import annotations @@ -172,8 +171,7 @@ def register_pointer_cb( event_type: Literal["click"], cb: Callable[[ViewerClick], None], removed_cb: Optional[Callable[[], None]] = None, - ): - ... + ): ... @overload def register_pointer_cb( @@ -181,8 +179,7 @@ def register_pointer_cb( event_type: Literal["rect-select"], cb: Callable[[ViewerRectSelect], None], removed_cb: Optional[Callable[[], None]] = None, - ): - ... + ): ... def register_pointer_cb( self, @@ -230,7 +227,7 @@ def wrapped_cb(scene_pointer_msg: ScenePointerEvent): cb_overriden = False with warnings.catch_warnings(record=True) as w: # Register the callback with the viser server. - self.viser_server.on_scene_pointer(event_type=event_type)(wrapped_cb) + self.viser_server.scene.on_pointer_event(event_type=event_type)(wrapped_cb) # If there exists a warning, it's because a callback was overriden. cb_overriden = len(w) > 0 @@ -242,7 +239,7 @@ def wrapped_cb(scene_pointer_msg: ScenePointerEvent): # If there exists a cleanup callback after the pointer event is done, register it. if removed_cb is not None: - self.viser_server.on_scene_pointer_removed(removed_cb) + self.viser_server.scene.on_pointer_callback_removed(removed_cb) def unregister_click_cb(self, cb: Optional[Callable] = None): """Deprecated, use unregister_pointer_cb instead. `cb` is ignored.""" @@ -260,7 +257,7 @@ def unregister_pointer_cb(self): Args: cb: The callback to remove """ - self.viser_server.remove_scene_pointer_callback() + self.viser_server.scene.remove_pointer_callback() @property def server(self): @@ -342,7 +339,7 @@ def __init__(self, name: str, cb_hook: Callable[[ViewerButton], Any], disabled: super().__init__(name, disabled=disabled, visible=visible, cb_hook=cb_hook) def _create_gui_handle(self, viser_server: ViserServer) -> None: - self.gui_handle = viser_server.add_gui_button(label=self.name, disabled=self.disabled, visible=self.visible) + self.gui_handle = viser_server.gui.add_button(label=self.name, disabled=self.disabled, visible=self.visible) def install(self, viser_server: ViserServer) -> None: self._create_gui_handle(viser_server) @@ -388,8 +385,7 @@ def install(self, viser_server: ViserServer) -> None: self.gui_handle.on_update(lambda _: self.cb_hook(self)) @abstractmethod - def _create_gui_handle(self, viser_server: ViserServer) -> None: - ... + def _create_gui_handle(self, viser_server: ViserServer) -> None: ... @property def value(self) -> TValue: @@ -445,7 +441,7 @@ def __init__( def _create_gui_handle(self, viser_server: ViserServer) -> None: assert self.gui_handle is None, "gui_handle should be initialized once" - self.gui_handle = viser_server.add_gui_slider( + self.gui_handle = viser_server.gui.add_slider( self.name, self.min, self.max, @@ -484,7 +480,7 @@ def __init__( def _create_gui_handle(self, viser_server: ViserServer) -> None: assert self.gui_handle is None, "gui_handle should be initialized once" - self.gui_handle = viser_server.add_gui_text( + self.gui_handle = viser_server.gui.add_text( self.name, self.default_value, disabled=self.disabled, visible=self.visible, hint=self.hint ) @@ -518,7 +514,7 @@ def __init__( def _create_gui_handle(self, viser_server: ViserServer) -> None: assert self.gui_handle is None, "gui_handle should be initialized once" - self.gui_handle = viser_server.add_gui_number( + self.gui_handle = viser_server.gui.add_number( self.name, self.default_value, disabled=self.disabled, visible=self.visible, hint=self.hint ) @@ -550,7 +546,7 @@ def __init__( def _create_gui_handle(self, viser_server: ViserServer) -> None: assert self.gui_handle is None, "gui_handle should be initialized once" - self.gui_handle = viser_server.add_gui_checkbox( + self.gui_handle = viser_server.gui.add_checkbox( self.name, self.default_value, disabled=self.disabled, visible=self.visible, hint=self.hint ) @@ -590,7 +586,7 @@ def __init__( def _create_gui_handle(self, viser_server: ViserServer) -> None: assert self.gui_handle is None, "gui_handle should be initialized once" - self.gui_handle = viser_server.add_gui_dropdown( + self.gui_handle = viser_server.gui.add_dropdown( self.name, self.options, self.default_value, @@ -636,7 +632,7 @@ def __init__( def _create_gui_handle(self, viser_server: ViserServer) -> None: assert self.gui_handle is None, "gui_handle should be initialized once" - self.gui_handle = viser_server.add_gui_button_group(self.name, self.options, visible=self.visible) + self.gui_handle = viser_server.gui.add_button_group(self.name, self.options, visible=self.visible) def install(self, viser_server: ViserServer) -> None: self._create_gui_handle(viser_server) @@ -672,7 +668,7 @@ def __init__( self.hint = hint def _create_gui_handle(self, viser_server: ViserServer) -> None: - self.gui_handle = viser_server.add_gui_rgb( + self.gui_handle = viser_server.gui.add_rgb( self.name, self.default_value, disabled=self.disabled, visible=self.visible, hint=self.hint ) @@ -707,6 +703,6 @@ def __init__( self.hint = hint def _create_gui_handle(self, viser_server: ViserServer) -> None: - self.gui_handle = viser_server.add_gui_vector3( + self.gui_handle = viser_server.gui.add_vector3( self.name, self.default_value, step=self.step, disabled=self.disabled, visible=self.visible, hint=self.hint ) diff --git a/nerfstudio/viewer_legacy/app/run_deploy.py b/nerfstudio/viewer_legacy/app/run_deploy.py index 8a5fce02d8..6f1f9c38d7 100644 --- a/nerfstudio/viewer_legacy/app/run_deploy.py +++ b/nerfstudio/viewer_legacy/app/run_deploy.py @@ -16,6 +16,7 @@ Code for deploying the built viewer folder to a server and handing versioning. We use the library sshconf (https://github.com/sorend/sshconf) for working with the ssh config file. """ + import json import subprocess from os.path import expanduser diff --git a/nerfstudio/viewer_legacy/server/control_panel.py b/nerfstudio/viewer_legacy/server/control_panel.py index 43fb51b466..f2caa793eb 100644 --- a/nerfstudio/viewer_legacy/server/control_panel.py +++ b/nerfstudio/viewer_legacy/server/control_panel.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Control panel for the viewer """ +"""Control panel for the viewer""" + from collections import defaultdict from typing import Callable, DefaultDict, List, Tuple, get_args diff --git a/nerfstudio/viewer_legacy/server/gui_utils.py b/nerfstudio/viewer_legacy/server/gui_utils.py index cd17172cbb..2bd34e1be5 100644 --- a/nerfstudio/viewer_legacy/server/gui_utils.py +++ b/nerfstudio/viewer_legacy/server/gui_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Utilities for generating custom gui elements in the viewer """ +"""Utilities for generating custom gui elements in the viewer""" from __future__ import annotations diff --git a/nerfstudio/viewer_legacy/server/path.py b/nerfstudio/viewer_legacy/server/path.py index 1dc8bdfeac..12494e5d08 100644 --- a/nerfstudio/viewer_legacy/server/path.py +++ b/nerfstudio/viewer_legacy/server/path.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Path class -""" - +"""Path class""" from typing import Tuple diff --git a/nerfstudio/viewer_legacy/server/render_state_machine.py b/nerfstudio/viewer_legacy/server/render_state_machine.py index a3c0524906..9b27cf361a 100644 --- a/nerfstudio/viewer_legacy/server/render_state_machine.py +++ b/nerfstudio/viewer_legacy/server/render_state_machine.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" This file contains the render state machine, which is responsible for deciding when to render the image """ +"""This file contains the render state machine, which is responsible for deciding when to render the image""" + from __future__ import annotations import contextlib diff --git a/nerfstudio/viewer_legacy/server/utils.py b/nerfstudio/viewer_legacy/server/utils.py index 0b26741dbd..7f49aeec73 100644 --- a/nerfstudio/viewer_legacy/server/utils.py +++ b/nerfstudio/viewer_legacy/server/utils.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Generic utility functions -""" +"""Generic utility functions""" from typing import List, Optional, Tuple, Union diff --git a/nerfstudio/viewer_legacy/server/viewer_elements.py b/nerfstudio/viewer_legacy/server/viewer_elements.py index b8065f8863..a8562d5b89 100644 --- a/nerfstudio/viewer_legacy/server/viewer_elements.py +++ b/nerfstudio/viewer_legacy/server/viewer_elements.py @@ -13,8 +13,7 @@ # limitations under the License. -""" Viewer GUI elements for the nerfstudio viewer """ - +"""Viewer GUI elements for the nerfstudio viewer""" from __future__ import annotations @@ -263,8 +262,7 @@ def install(self, viser_server: ViserServer) -> None: self.gui_handle.on_update(lambda _: self.cb_hook(self)) @abstractmethod - def _create_gui_handle(self, viser_server: ViserServer) -> None: - ... + def _create_gui_handle(self, viser_server: ViserServer) -> None: ... @property def value(self) -> TValue: diff --git a/nerfstudio/viewer_legacy/server/viewer_state.py b/nerfstudio/viewer_legacy/server/viewer_state.py index ee02d2a4b7..cfb3bff7b1 100644 --- a/nerfstudio/viewer_legacy/server/viewer_state.py +++ b/nerfstudio/viewer_legacy/server/viewer_state.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Manage the state of the viewer """ +"""Manage the state of the viewer""" + from __future__ import annotations import threading diff --git a/nerfstudio/viewer_legacy/server/viewer_utils.py b/nerfstudio/viewer_legacy/server/viewer_utils.py index 3a2c016084..6e3b28e868 100644 --- a/nerfstudio/viewer_legacy/server/viewer_utils.py +++ b/nerfstudio/viewer_legacy/server/viewer_utils.py @@ -14,6 +14,7 @@ """Code to interface with the `vis/` (the JS viewer).""" + from __future__ import annotations import os diff --git a/nerfstudio/viewer_legacy/viser/__init__.py b/nerfstudio/viewer_legacy/viser/__init__.py index 7560a4cc80..228d658882 100644 --- a/nerfstudio/viewer_legacy/viser/__init__.py +++ b/nerfstudio/viewer_legacy/viser/__init__.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Viser is used for the nerfstudio viewer backend """ - +"""Viser is used for the nerfstudio viewer backend""" from .message_api import GuiHandle as GuiHandle, GuiSelectHandle as GuiSelectHandle from .messages import NerfstudioMessage as NerfstudioMessage diff --git a/nerfstudio/viewer_legacy/viser/gui.py b/nerfstudio/viewer_legacy/viser/gui.py index 75c18decae..11536db3f7 100644 --- a/nerfstudio/viewer_legacy/viser/gui.py +++ b/nerfstudio/viewer_legacy/viser/gui.py @@ -13,10 +13,11 @@ # limitations under the License. -""" Manages GUI communication. +"""Manages GUI communication. Should be almost identical to: https://github.com/brentyi/viser/blob/main/viser/_gui.py """ + from __future__ import annotations import dataclasses diff --git a/nerfstudio/viewer_legacy/viser/message_api.py b/nerfstudio/viewer_legacy/viser/message_api.py index bb6449cf8e..982c41fcc1 100644 --- a/nerfstudio/viewer_legacy/viser/message_api.py +++ b/nerfstudio/viewer_legacy/viser/message_api.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" This module contains the MessageApi class, which is the interface for sending messages to the Viewer""" - +"""This module contains the MessageApi class, which is the interface for sending messages to the Viewer""" from __future__ import annotations @@ -274,8 +273,7 @@ def add_gui_select( options: List[TLiteralString], initial_value: Optional[TLiteralString] = None, hint: Optional[str] = None, - ) -> GuiSelectHandle[TLiteralString]: - ... + ) -> GuiSelectHandle[TLiteralString]: ... @overload def add_gui_select( @@ -284,8 +282,7 @@ def add_gui_select( options: List[str], initial_value: Optional[str] = None, hint: Optional[str] = None, - ) -> GuiSelectHandle[str]: - ... + ) -> GuiSelectHandle[str]: ... def add_gui_select( self, @@ -325,8 +322,7 @@ def add_gui_button_group( name: str, options: List[TLiteralString], initial_value: Optional[TLiteralString] = None, - ) -> GuiHandle[TLiteralString]: - ... + ) -> GuiHandle[TLiteralString]: ... @overload def add_gui_button_group( @@ -334,8 +330,7 @@ def add_gui_button_group( name: str, options: List[str], initial_value: Optional[str] = None, - ) -> GuiHandle[str]: - ... + ) -> GuiHandle[str]: ... def add_gui_button_group( self, diff --git a/nerfstudio/viewer_legacy/viser/server.py b/nerfstudio/viewer_legacy/viser/server.py index 31b1a78a5b..6984b1c305 100644 --- a/nerfstudio/viewer_legacy/viser/server.py +++ b/nerfstudio/viewer_legacy/viser/server.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Core Viser Server """ - +"""Core Viser Server""" from __future__ import annotations from typing import Callable, Type -import viser.infra from typing_extensions import override +from viser.infra import WebsockServer from .message_api import MessageApi from .messages import GuiUpdateMessage, NerfstudioMessage @@ -44,7 +43,7 @@ def __init__( ): super().__init__() - self._ws_server = viser.infra.Server(host, port, http_server_root=None, verbose=False) + self._ws_server = WebsockServer(host, port, http_server_root=None, verbose=False) self._ws_server.register_handler(GuiUpdateMessage, self._handle_gui_updates) self._ws_server.start() @@ -53,7 +52,7 @@ def _queue(self, message: NerfstudioMessage) -> None: """Implements message enqueue required by MessageApi. Pushes a message onto a broadcast queue.""" - self._ws_server.broadcast(message) + self._ws_server.queue_message(message) def register_handler( self, message_type: Type[NerfstudioMessage], handler: Callable[[NerfstudioMessage], None] diff --git a/pixi.lock b/pixi.lock index 72631ecaa0..d4036678c0 100644 --- a/pixi.lock +++ b/pixi.lock @@ -369,7 +369,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/fd/5b/8f0c4a5bb9fd491c277c21eff7ccae71b47d43c4446c9d0c6cff2fe8c2c4/gitdb-4.0.11-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/bd/cc3a402a6439c15c3d4294333e13042b915bbeab54edc457c723931fed3f/GitPython-3.1.43-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/47/82/5f51b0ac0e670aa6551f351c6c8a479149a36c413dd76db4b98d26dddbea/grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/b4/b2/0c3fe3a11a2e8cdf9216ba92e97172d08f769082181f6f10807517db9295/gsplat-0.1.11-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/53/71/d9bf12b11f608f0ad078fa962a9ab61a2cf28fa9739293a1e842656bc419/gsplat-1.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/94/00/94bf8573e7487b7c37f2b613fc381880d48ec2311f2e859b8a5817deb4df/h5py-3.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/78/d4/e5d7e4f2174f8a4d63c8897d79eb8fe2503f7ecc03282fee1fa2719c2704/httpcore-1.0.5-py3-none-any.whl @@ -2710,9 +2710,9 @@ packages: requires_python: '>=3.8' - kind: pypi name: gsplat - version: 0.1.11 - url: https://files.pythonhosted.org/packages/b4/b2/0c3fe3a11a2e8cdf9216ba92e97172d08f769082181f6f10807517db9295/gsplat-0.1.11-py3-none-any.whl - sha256: 2d47c5d4c245b46d85b7eeae8ca4a96df3f1e2354d8daf254871516cb251b75c + version: 1.0.0 + url: https://files.pythonhosted.org/packages/53/71/d9bf12b11f608f0ad078fa962a9ab61a2cf28fa9739293a1e842656bc419/gsplat-1.0.0-py3-none-any.whl + sha256: a21eead19150e80a0531dd24e5d717c67892cb381657c8411ec8b318b293a032 requires_dist: - jaxtyping - rich >=12 @@ -6085,7 +6085,7 @@ packages: - xatlas - trimesh >=3.20.2 - timm ==0.6.7 - - gsplat >=0.1.11 + - gsplat ==1.0.0 - pytorch-msssim - pathos - packaging diff --git a/pixi.toml b/pixi.toml index 1706a27e3a..ad2eee6aec 100644 --- a/pixi.toml +++ b/pixi.toml @@ -1,6 +1,5 @@ [project] name = "nerfstudio" -version = "1.0.3" description = "All-in-one repository for state-of-the-art NeRFs" channels = ["nvidia/label/cuda-11.8.0", "nvidia", "conda-forge", "pytorch"] platforms = ["linux-64"] diff --git a/pyproject.toml b/pyproject.toml index 62cfbe0913..499815b3eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nerfstudio" -version = "1.1.0" +version = "1.1.3" description = "All-in-one repository for state-of-the-art NeRFs" readme = "README.md" license = { text="Apache 2.0"} @@ -16,7 +16,6 @@ classifiers = [ dependencies = [ "appdirs>=1.4", "av>=9.2.0", - "awscli>=1.31.10", "comet_ml>=3.33.8", "cryptography>=38", "tyro>=0.6.6", @@ -34,7 +33,7 @@ dependencies = [ "msgpack_numpy>=0.4.8", "nerfacc==0.5.2", "open3d>=0.16.0", - "opencv-python==4.8.0.76", + "opencv-python==4.10.0.84", "Pillow>=10.3.0", "plotly>=5.7.0", "protobuf<=3.20.3,!=3.20.0", @@ -57,16 +56,17 @@ dependencies = [ "torchvision>=0.14.1", "torchmetrics[image]>=1.0.1", "typing_extensions>=4.4.0", - "viser==0.1.27", + "viser==0.2.3", "nuscenes-devkit>=1.1.1", "wandb>=0.13.3", "xatlas", "trimesh>=3.20.2", "timm==0.6.7", - "gsplat>=0.1.11", + "gsplat==1.0.0", "pytorch-msssim", "pathos", - "packaging" + "packaging", + "fpsample" ] [project.urls] @@ -91,7 +91,7 @@ dev = [ "pytest==7.1.2", "pytest-xdist==2.5.0", "typeguard==2.13.3", - "ruff==0.1.13", + "ruff>=0.4.8", "sshconf==0.2.5", "pycolmap>=0.3.0", # NOTE: pycolmap==0.3.0 is not available on newer python versions "diffusers==0.16.1", @@ -103,6 +103,7 @@ dev = [ "projectaria-tools>=1.3.1; sys_platform != 'win32'", # pin torch to <=2.1 to fix https://github.com/pytorch/pytorch/issues/118736 "torch>=1.13.1,<2.2", + "awscli==1.33.18" ] # Documentation related packages @@ -164,7 +165,7 @@ pythonPlatform = "Linux" [tool.ruff] line-length = 120 respect-gitignore = false -select = [ +lint.select = [ "E", # pycodestyle errors. "F", # Pyflakes rules. "I", # isort formatting. @@ -172,8 +173,9 @@ select = [ "PLE", # Pylint errors. "PLR", # Pylint refactor recommendations. "PLW", # Pylint warnings. + "NPY201" # NumPY 2.0 migration https://numpy.org/devdocs/numpy_2_0_migration_guide.html#ruff-plugin ] -ignore = [ +lint.ignore = [ "E501", # Line too long. "F722", # Forward annotation false positive from jaxtyping. Should be caught by pyright. "F821", # Forward annotation false positive from jaxtyping. Should be caught by pyright. diff --git a/tests/cameras/test_cameras.py b/tests/cameras/test_cameras.py index 2c7d0b8403..99f79da3f1 100644 --- a/tests/cameras/test_cameras.py +++ b/tests/cameras/test_cameras.py @@ -1,6 +1,7 @@ """ Test the camera classes. """ + import dataclasses from itertools import product diff --git a/tests/field_components/test_embedding.py b/tests/field_components/test_embedding.py index 11ba7d5a98..569773745a 100644 --- a/tests/field_components/test_embedding.py +++ b/tests/field_components/test_embedding.py @@ -1,6 +1,7 @@ """ Embedding tests """ + from nerfstudio.field_components.embedding import Embedding diff --git a/tests/field_components/test_encodings.py b/tests/field_components/test_encodings.py index a7dbb7f2cb..a241fc52a1 100644 --- a/tests/field_components/test_encodings.py +++ b/tests/field_components/test_encodings.py @@ -1,6 +1,7 @@ """ Encoding Tests """ + import pytest import torch diff --git a/tests/field_components/test_field_outputs.py b/tests/field_components/test_field_outputs.py index 08aa236a24..8a585d86cc 100644 --- a/tests/field_components/test_field_outputs.py +++ b/tests/field_components/test_field_outputs.py @@ -1,6 +1,7 @@ """ Field output tests """ + import pytest import torch from torch import nn diff --git a/tests/field_components/test_fields.py b/tests/field_components/test_fields.py index fd5332776e..95d6239f76 100644 --- a/tests/field_components/test_fields.py +++ b/tests/field_components/test_fields.py @@ -1,6 +1,7 @@ """ Test the fields """ + import torch from nerfstudio.cameras.rays import Frustums, RaySamples diff --git a/tests/field_components/test_mlp.py b/tests/field_components/test_mlp.py index b52734654a..92a69d93ef 100644 --- a/tests/field_components/test_mlp.py +++ b/tests/field_components/test_mlp.py @@ -1,6 +1,7 @@ """ MLP Test """ + import torch from torch import nn diff --git a/tests/field_components/test_temporal_distortions.py b/tests/field_components/test_temporal_distortions.py index 8d5abd1914..24934da0d5 100644 --- a/tests/field_components/test_temporal_distortions.py +++ b/tests/field_components/test_temporal_distortions.py @@ -1,6 +1,7 @@ """ Test if temporal distortions run properly """ + import torch from nerfstudio.field_components.temporal_distortions import DNeRFDistortion diff --git a/tests/model_components/test_renderers.py b/tests/model_components/test_renderers.py index a1e0983885..79d5e54a2f 100644 --- a/tests/model_components/test_renderers.py +++ b/tests/model_components/test_renderers.py @@ -1,6 +1,7 @@ """ Test renderers """ + import pytest import torch diff --git a/tests/pipelines/test_vanilla_pipeline.py b/tests/pipelines/test_vanilla_pipeline.py index 3bb73e06af..5e0ef367ec 100644 --- a/tests/pipelines/test_vanilla_pipeline.py +++ b/tests/pipelines/test_vanilla_pipeline.py @@ -1,6 +1,7 @@ """ Test pipeline """ + from pathlib import Path import torch @@ -41,7 +42,7 @@ def test_load_state_dict(): """Test pipeline load_state_dict calls model's load_state_dict""" was_called = False - class MockedModel(Model): # + class MockedModel(Model): """Mocked model""" def __init__(self, *args, **kwargs): diff --git a/tests/plugins/test_registry.py b/tests/plugins/test_registry.py index c17d88e150..ae704be4c5 100644 --- a/tests/plugins/test_registry.py +++ b/tests/plugins/test_registry.py @@ -1,6 +1,7 @@ """ Tests for the nerfstudio.plugins.registry module. """ + import os import sys from dataclasses import dataclass, field diff --git a/tests/process_data/test_process_images.py b/tests/process_data/test_process_images.py index fd506b1fba..b27346c68a 100644 --- a/tests/process_data/test_process_images.py +++ b/tests/process_data/test_process_images.py @@ -1,6 +1,7 @@ """ Process images test """ + import os from pathlib import Path diff --git a/tests/utils/test_aabb_intersection.py b/tests/utils/test_aabb_intersection.py index 625ae6edb6..c0c9169384 100644 --- a/tests/utils/test_aabb_intersection.py +++ b/tests/utils/test_aabb_intersection.py @@ -1,6 +1,7 @@ """ Test AABB intersection """ + import importlib import math import os diff --git a/tests/utils/test_tensor_dataclass.py b/tests/utils/test_tensor_dataclass.py index 10e21adc05..a884d7c051 100644 --- a/tests/utils/test_tensor_dataclass.py +++ b/tests/utils/test_tensor_dataclass.py @@ -1,6 +1,7 @@ """ Test tensor dataclass """ + from dataclasses import dataclass, field from typing import Generic, Optional, TypeVar diff --git a/tests/utils/test_visualization.py b/tests/utils/test_visualization.py index b3d9f53d1d..fa1b31be63 100644 --- a/tests/utils/test_visualization.py +++ b/tests/utils/test_visualization.py @@ -1,6 +1,7 @@ """ Test colormaps """ + import torch from nerfstudio.utils import colormaps, plotly_utils