From 3cf4eb415f75a22c0345db8fcfd297a045a69600 Mon Sep 17 00:00:00 2001 From: Travis Driver Date: Tue, 17 Sep 2024 16:11:51 -0400 Subject: [PATCH 1/2] AstroVision working for instant-ngp --- .vscode/settings.json | 3 +- compute_metrics.ipynb | 75 +++ nerfstudio/configs/dataparser_configs.py | 46 +- .../dataparsers/astrovision_dataparser.py | 524 ++++++++++++++++++ nerfstudio/data/utils/data_utils.py | 4 + nerfstudio/models/nerfacto.py | 35 +- nerfstudio/models/splatfacto.py | 21 +- 7 files changed, 669 insertions(+), 39 deletions(-) create mode 100644 compute_metrics.ipynb create mode 100644 nerfstudio/data/dataparsers/astrovision_dataparser.py diff --git a/.vscode/settings.json b/.vscode/settings.json index c365ae07b4..ea0491df4e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -35,7 +35,7 @@ "python.envFile": "${workspaceFolder}/.env", "python.formatting.provider": "none", "python.linting.pylintEnabled": false, - "python.linting.flake8Enabled": false, + "python.linting.flake8Enabled": true, "python.linting.enabled": true, "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, @@ -123,4 +123,5 @@ "python.analysis.typeCheckingMode": "basic", "python.analysis.diagnosticMode": "workspace", "eslint.packageManager": "yarn", + "python.linting.mypyEnabled": false, } diff --git a/compute_metrics.ipynb b/compute_metrics.ipynb new file mode 100644 index 0000000000..b756bfb573 --- /dev/null +++ b/compute_metrics.ipynb @@ -0,0 +1,75 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from PIL import Image\n", + "\n", + "\n", + "plt.rcParams\n", + "\n", + "\n", + "RENDER_ROOT = \"/home/tdriver6/Documents/nerfstudio/renders/train/rgb\"\n", + "GT_ROOT = \"/home/tdriver6/Documents/nerfstudio/data/nerfstudio/aspct_ahunamons/images\"\n", + "\n", + "render_paths = os.listdir(RENDER_ROOT)\n", + "psnrs = []\n", + "for pth in render_paths:\n", + " if not pth.endswith(\".png\"):\n", + " continue\n", + " print(os.path.join(RENDER_ROOT, pth))\n", + " rend = np.asarray(Image.open(os.path.join(RENDER_ROOT, pth)).convert(\"L\"))\n", + " gt = np.asarray(Image.open(os.path.join(GT_ROOT, pth)).convert(\"L\"))\n", + "\n", + " plt.imshow((gt - rend) ** 2, cmap=\"jet\")\n", + " plt.colorbar()\n", + " plt.show()\n", + "\n", + " plt.imshow(rend, cmap=\"gray\")\n", + " plt.colorbar()\n", + " plt.show()\n", + "\n", + "\n", + " mse = np.mean((gt - rend) ** 2)\n", + " psnrs.append(10 * np.log10(gt.max() ** 2 / mse))\n", + " print(psnrs[-1])\n", + "\n", + "print(\"Mean PSNR: %.2f\" % np.mean(psnrs))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.19 ('nerfstudio')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "d1bb290f7ef6eb618af297cbac6e2293d0f49bb9d3bac6de88ea6e7f4e095596" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nerfstudio/configs/dataparser_configs.py b/nerfstudio/configs/dataparser_configs.py index d9e11b3045..2d3c14b4c8 100644 --- a/nerfstudio/configs/dataparser_configs.py +++ b/nerfstudio/configs/dataparser_configs.py @@ -19,27 +19,43 @@ from typing import TYPE_CHECKING import tyro - -from nerfstudio.data.dataparsers.arkitscenes_dataparser import ARKitScenesDataParserConfig +from nerfstudio.data.dataparsers.arkitscenes_dataparser import \ + ARKitScenesDataParserConfig +from nerfstudio.data.dataparsers.astrovision_dataparser import \ + AstroVisionDataParserConfig from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig -from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig -from nerfstudio.data.dataparsers.colmap_dataparser import ColmapDataParserConfig +from nerfstudio.data.dataparsers.blender_dataparser import \ + BlenderDataParserConfig +from nerfstudio.data.dataparsers.colmap_dataparser import \ + ColmapDataParserConfig from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig -from nerfstudio.data.dataparsers.dycheck_dataparser import DycheckDataParserConfig -from nerfstudio.data.dataparsers.instant_ngp_dataparser import InstantNGPDataParserConfig -from nerfstudio.data.dataparsers.minimal_dataparser import MinimalDataParserConfig -from nerfstudio.data.dataparsers.nerfosr_dataparser import NeRFOSRDataParserConfig -from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig -from nerfstudio.data.dataparsers.nuscenes_dataparser import NuScenesDataParserConfig -from nerfstudio.data.dataparsers.phototourism_dataparser import PhototourismDataParserConfig -from nerfstudio.data.dataparsers.scannet_dataparser import ScanNetDataParserConfig -from nerfstudio.data.dataparsers.scannetpp_dataparser import ScanNetppDataParserConfig -from nerfstudio.data.dataparsers.sdfstudio_dataparser import SDFStudioDataParserConfig -from nerfstudio.data.dataparsers.sitcoms3d_dataparser import Sitcoms3DDataParserConfig +from nerfstudio.data.dataparsers.dycheck_dataparser import \ + DycheckDataParserConfig +from nerfstudio.data.dataparsers.instant_ngp_dataparser import \ + InstantNGPDataParserConfig +from nerfstudio.data.dataparsers.minimal_dataparser import \ + MinimalDataParserConfig +from nerfstudio.data.dataparsers.nerfosr_dataparser import \ + NeRFOSRDataParserConfig +from nerfstudio.data.dataparsers.nerfstudio_dataparser import \ + NerfstudioDataParserConfig +from nerfstudio.data.dataparsers.nuscenes_dataparser import \ + NuScenesDataParserConfig +from nerfstudio.data.dataparsers.phototourism_dataparser import \ + PhototourismDataParserConfig +from nerfstudio.data.dataparsers.scannet_dataparser import \ + ScanNetDataParserConfig +from nerfstudio.data.dataparsers.scannetpp_dataparser import \ + ScanNetppDataParserConfig +from nerfstudio.data.dataparsers.sdfstudio_dataparser import \ + SDFStudioDataParserConfig +from nerfstudio.data.dataparsers.sitcoms3d_dataparser import \ + Sitcoms3DDataParserConfig from nerfstudio.plugins.registry_dataparser import discover_dataparsers dataparsers = { "nerfstudio-data": NerfstudioDataParserConfig(), + "astrovision-data": AstroVisionDataParserConfig(), "minimal-parser": MinimalDataParserConfig(), "arkit-data": ARKitScenesDataParserConfig(), "blender-data": BlenderDataParserConfig(), diff --git a/nerfstudio/data/dataparsers/astrovision_dataparser.py b/nerfstudio/data/dataparsers/astrovision_dataparser.py new file mode 100644 index 0000000000..d0993764d1 --- /dev/null +++ b/nerfstudio/data/dataparsers/astrovision_dataparser.py @@ -0,0 +1,524 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data parser for AstroVision datasets.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Literal, Optional, Tuple, Type + +import numpy as np +import open3d as o3d +import torch +from nerfstudio.cameras import camera_utils +from nerfstudio.cameras.cameras import (CAMERA_MODEL_TO_TYPE, Cameras, + CameraType) +from nerfstudio.data.dataparsers.base_dataparser import (DataParser, + DataParserConfig, + DataparserOutputs) +from nerfstudio.data.scene_box import SceneBox +from nerfstudio.data.utils.dataparsers_utils import ( + get_train_eval_split_all, get_train_eval_split_filename, + get_train_eval_split_fraction, get_train_eval_split_interval) +from nerfstudio.utils.io import load_from_json +from nerfstudio.utils.rich_utils import CONSOLE +from PIL import Image + +MAX_AUTO_RESOLUTION = 1600 + + +@dataclass +class AstroVisionDataParserConfig(DataParserConfig): + """Nerfstudio dataset config""" + + _target: Type = field(default_factory=lambda: AstroVision) + """target class to instantiate""" + data: Path = Path() + """Directory or explicit json file path specifying location of data.""" + #scale_factor: float = 0.01 / 2 + scale_factor: float = 0.25 + """How much to scale the camera origins by.""" + downscale_factor: Optional[int] = None + """How much to downscale images. If not set, images are chosen such that the max dimension is <1600px.""" + scene_scale: float = 1.0 + """How much to scale the region of interest by.""" + orientation_method: Literal["pca", "up", "vertical", "none"] = "vertical" + """The method to use for orientation.""" + center_method: Literal["poses", "points", "focus", "none"] = "points" + """The method to use to center the poses.""" + auto_scale_poses: bool = False + """Whether to automatically scale the poses to fit in +/- 1 bounding box.""" + eval_mode: Literal["fraction", "filename", "interval", "all"] = "fraction" + """ + The method to use for splitting the dataset into train and eval. + Fraction splits based on a percentage for train and the remaining for eval. + Filename splits based on filenames containing train/eval. + Interval uses every nth frame for eval. + All uses all the images for any split. + """ + train_split_fraction: float = 0.9 + """The percentage of the dataset to use for training. Only used when eval_mode is train-split-fraction.""" + eval_interval: int = 8 + """The interval between frames to use for eval. Only used when eval_mode is eval-interval.""" + depth_unit_scale_factor: float = 1e-3 + """Scales the depth values to meters. Default value is 0.001 for a millimeter to meter conversion.""" + mask_color: Optional[Tuple[float, float, float]] = None + """Replace the unknown pixels with this color. Relevant if you have a mask but still sample everywhere.""" + load_3D_points: bool = False + """Whether to load the 3D points from the colmap reconstruction.""" + + +@dataclass +class AstroVision(DataParser): + """Nerfstudio DatasetParser""" + + config: AstroVisionDataParserConfig + downscale_factor: Optional[int] = None + + def _generate_dataparser_outputs(self, split="train"): + assert self.config.data.exists(), f"Data directory {self.config.data} does not exist." + + if self.config.data.suffix == ".json": + meta = load_from_json(self.config.data) + data_dir = self.config.data.parent + else: + meta = load_from_json(self.config.data / "transforms.json") + data_dir = self.config.data + + image_filenames = [] + mask_filenames = [] + depth_filenames = [] + poses = [] + + fx_fixed = "fl_x" in meta + fy_fixed = "fl_y" in meta + cx_fixed = "cx" in meta + cy_fixed = "cy" in meta + height_fixed = "h" in meta + width_fixed = "w" in meta + distort_fixed = False + for distort_key in ["k1", "k2", "k3", "p1", "p2", "distortion_params"]: + if distort_key in meta: + distort_fixed = True + break + fisheye_crop_radius = meta.get("fisheye_crop_radius", None) + fx = [] + fy = [] + cx = [] + cy = [] + height = [] + width = [] + distort = [] + + # sort the frames by fname + fnames = [] + for frame in meta["frames"]: + filepath = Path(frame["file_path"]) + fname = self._get_fname(filepath, data_dir) + fnames.append(fname) + inds = np.argsort(fnames) + frames = [meta["frames"][ind] for ind in inds] + + for frame in frames: + filepath = Path(frame["file_path"]) + fname = self._get_fname(filepath, data_dir) + + if not fx_fixed: + assert "fl_x" in frame, "fx not specified in frame" + fx.append(float(frame["fl_x"])) + if not fy_fixed: + assert "fl_y" in frame, "fy not specified in frame" + fy.append(float(frame["fl_y"])) + if not cx_fixed: + assert "cx" in frame, "cx not specified in frame" + cx.append(float(frame["cx"])) + if not cy_fixed: + assert "cy" in frame, "cy not specified in frame" + cy.append(float(frame["cy"])) + if not height_fixed: + assert "h" in frame, "height not specified in frame" + height.append(int(frame["h"])) + if not width_fixed: + assert "w" in frame, "width not specified in frame" + width.append(int(frame["w"])) + if not distort_fixed: + distort.append( + torch.tensor(frame["distortion_params"], dtype=torch.float32) + if "distortion_params" in frame + else camera_utils.get_distortion_params( + k1=float(frame["k1"]) if "k1" in frame else 0.0, + k2=float(frame["k2"]) if "k2" in frame else 0.0, + k3=float(frame["k3"]) if "k3" in frame else 0.0, + k4=float(frame["k4"]) if "k4" in frame else 0.0, + p1=float(frame["p1"]) if "p1" in frame else 0.0, + p2=float(frame["p2"]) if "p2" in frame else 0.0, + ) + ) + + image_filenames.append(fname) + poses.append(np.array(frame["transform_matrix"])) + if "mask_path" in frame: + mask_filepath = Path(frame["mask_path"]) + mask_fname = self._get_fname( + mask_filepath, + data_dir, + downsample_folder_prefix="masks_", + ) + mask_filenames.append(mask_fname) + + if "depth_file_path" in frame: + depth_filepath = Path(frame["depth_file_path"]) + depth_fname = self._get_fname(depth_filepath, data_dir, downsample_folder_prefix="depths_") + depth_filenames.append(depth_fname) + + assert len(mask_filenames) == 0 or (len(mask_filenames) == len(image_filenames)), """ + Different number of image and mask filenames. + You should check that mask_path is specified for every frame (or zero frames) in transforms.json. + """ + assert len(depth_filenames) == 0 or (len(depth_filenames) == len(image_filenames)), """ + Different number of image and depth filenames. + You should check that depth_file_path is specified for every frame (or zero frames) in transforms.json. + """ + + has_split_files_spec = any(f"{split}_filenames" in meta for split in ("train", "val", "test")) + if f"{split}_filenames" in meta: + # Validate split first + split_filenames = set(self._get_fname(Path(x), data_dir) for x in meta[f"{split}_filenames"]) + unmatched_filenames = split_filenames.difference(image_filenames) + if unmatched_filenames: + raise RuntimeError(f"Some filenames for split {split} were not found: {unmatched_filenames}.") + + indices = [i for i, path in enumerate(image_filenames) if path in split_filenames] + CONSOLE.log(f"[yellow] Dataset is overriding {split}_indices to {indices}") + indices = np.array(indices, dtype=np.int32) + elif has_split_files_spec: + raise RuntimeError(f"The dataset's list of filenames for split {split} is missing.") + else: + # find train and eval indices based on the eval_mode specified + if self.config.eval_mode == "fraction": + i_train, i_eval = get_train_eval_split_fraction(image_filenames, self.config.train_split_fraction) + elif self.config.eval_mode == "filename": + i_train, i_eval = get_train_eval_split_filename(image_filenames) + elif self.config.eval_mode == "interval": + i_train, i_eval = get_train_eval_split_interval(image_filenames, self.config.eval_interval) + elif self.config.eval_mode == "all": + CONSOLE.log( + "[yellow] Be careful with '--eval-mode=all'. If using camera optimization, the cameras may diverge in the current implementation, giving unpredictable results." + ) + i_train, i_eval = get_train_eval_split_all(image_filenames) + else: + raise ValueError(f"Unknown eval mode {self.config.eval_mode}") + + if split == "train": + indices = i_train + elif split in ["val", "test"]: + indices = i_eval + else: + raise ValueError(f"Unknown dataparser split {split}") + + if "orientation_override" in meta: + orientation_method = meta["orientation_override"] + CONSOLE.log(f"[yellow] Dataset is overriding orientation method to {orientation_method}") + else: + orientation_method = self.config.orientation_method + + # Center poses. + poses = torch.from_numpy(np.array(poses).astype(np.float32)) + if self.config.center_method == "points": + poses, transform_matrix = camera_utils.auto_orient_and_center_poses( + poses, + method=orientation_method, + center_method="none", + ) + pcd = np.asarray(o3d.io.read_point_cloud(os.path.join(self.config.data, "sparse_pc.ply")).points) + vertices = torch.from_numpy(pcd).type(torch.float32) # (N, 3) + vertices = (transform_matrix[:, :3] @ vertices.T + transform_matrix[:, 3].reshape((3, 1))).T + translation = torch.mean(vertices, axis=0) + poses[:, :3, 3] -= translation + vertices -= translation + print("center =") + print(translation) + transform_matrix[:, 3] = -translation + else: + poses, transform_matrix = camera_utils.auto_orient_and_center_poses( + poses, + method=orientation_method, + center_method=self.config.center_method, + ) + + # Scale poses + scale_factor = 1.0 + if self.config.center_method == "points": + min_vertices, _ = vertices.min(axis=0) + max_vertices, _ = vertices.max(axis=0) + print(min_vertices) + print(max_vertices) + scale_factor *= 2.0 / (torch.max(max_vertices - min_vertices)) + aabb_scale_x, aabb_scale_y, aabb_scale_z = (max_vertices - min_vertices) * scale_factor / 2 + print(aabb_scale_x, aabb_scale_y, aabb_scale_z) + else: + if self.config.auto_scale_poses: + scale_factor /= float(torch.max(torch.abs(poses[:, :3, 3]))) + scale_factor *= self.config.scale_factor + print(scale_factor) + poses[:, :3, 3] *= scale_factor + + # Choose image_filenames and poses based on split, but after auto orient and scaling the poses. + image_filenames = [image_filenames[i] for i in indices] + mask_filenames = [mask_filenames[i] for i in indices] if len(mask_filenames) > 0 else [] + depth_filenames = [depth_filenames[i] for i in indices] if len(depth_filenames) > 0 else [] + + idx_tensor = torch.tensor(indices, dtype=torch.long) + poses = poses[idx_tensor] + + # in x,y,z order + # assumes that the scene is centered at the origin + aabb_scale = self.config.scene_scale + scene_box = SceneBox( + aabb=torch.tensor( + [ + [-aabb_scale, -aabb_scale, -aabb_scale], + [ aabb_scale, aabb_scale, aabb_scale] + ], + dtype=torch.float32, + ) + ) + + if "camera_model" in meta: + camera_type = CAMERA_MODEL_TO_TYPE[meta["camera_model"]] + else: + camera_type = CameraType.PERSPECTIVE + + fx = float(meta["fl_x"]) if fx_fixed else torch.tensor(fx, dtype=torch.float32)[idx_tensor] + fy = float(meta["fl_y"]) if fy_fixed else torch.tensor(fy, dtype=torch.float32)[idx_tensor] + cx = float(meta["cx"]) if cx_fixed else torch.tensor(cx, dtype=torch.float32)[idx_tensor] + cy = float(meta["cy"]) if cy_fixed else torch.tensor(cy, dtype=torch.float32)[idx_tensor] + height = int(meta["h"]) if height_fixed else torch.tensor(height, dtype=torch.int32)[idx_tensor] + width = int(meta["w"]) if width_fixed else torch.tensor(width, dtype=torch.int32)[idx_tensor] + if distort_fixed: + distortion_params = ( + torch.tensor(meta["distortion_params"], dtype=torch.float32) + if "distortion_params" in meta + else camera_utils.get_distortion_params( + k1=float(meta["k1"]) if "k1" in meta else 0.0, + k2=float(meta["k2"]) if "k2" in meta else 0.0, + k3=float(meta["k3"]) if "k3" in meta else 0.0, + k4=float(meta["k4"]) if "k4" in meta else 0.0, + p1=float(meta["p1"]) if "p1" in meta else 0.0, + p2=float(meta["p2"]) if "p2" in meta else 0.0, + ) + ) + else: + distortion_params = torch.stack(distort, dim=0)[idx_tensor] + + # Only add fisheye crop radius parameter if the images are actually fisheye, to allow the same config to be used + # for both fisheye and non-fisheye datasets. + metadata = {} + if (camera_type in [CameraType.FISHEYE, CameraType.FISHEYE624]) and (fisheye_crop_radius is not None): + metadata["fisheye_crop_radius"] = fisheye_crop_radius + + cameras = Cameras( + fx=fx, + fy=fy, + cx=cx, + cy=cy, + distortion_params=distortion_params, + height=height, + width=width, + camera_to_worlds=poses[:, :3, :4], + camera_type=camera_type, + metadata=metadata, + ) + + assert self.downscale_factor is not None + cameras.rescale_output_resolution(scaling_factor=1.0 / self.downscale_factor) + + # The naming is somewhat confusing, but: + # - transform_matrix contains the transformation to dataparser output coordinates from saved coordinates. + # - dataparser_transform_matrix contains the transformation to dataparser output coordinates from original data coordinates. + # - applied_transform contains the transformation to saved coordinates from original data coordinates. + applied_transform = None + colmap_path = self.config.data / "colmap/sparse/0" + if "applied_transform" in meta: + applied_transform = torch.tensor(meta["applied_transform"], dtype=transform_matrix.dtype) + elif colmap_path.exists(): + # For converting from colmap, this was the effective value of applied_transform that was being + # used before we added the applied_transform field to the output dataformat. + meta["applied_transform"] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, -1, 0]] + applied_transform = torch.tensor(meta["applied_transform"], dtype=transform_matrix.dtype) + + if applied_transform is not None: + dataparser_transform_matrix = transform_matrix @ torch.cat( + [applied_transform, torch.tensor([[0, 0, 0, 1]], dtype=transform_matrix.dtype)], 0 + ) + else: + dataparser_transform_matrix = transform_matrix + print("dataparser_transform_matrix =") + print(dataparser_transform_matrix) + + if "applied_scale" in meta: + applied_scale = float(meta["applied_scale"]) + scale_factor *= applied_scale + + # reinitialize metadata for dataparser_outputs + metadata = {} + + # _generate_dataparser_outputs might be called more than once so we check if we already loaded the point cloud + try: + self.prompted_user + except AttributeError: + self.prompted_user = False + + # Load 3D points + if self.config.load_3D_points: + if "ply_file_path" in meta: + ply_file_path = data_dir / meta["ply_file_path"] + + elif colmap_path.exists(): + from rich.prompt import Confirm + + # check if user wants to make a point cloud from colmap points + if not self.prompted_user: + self.create_pc = Confirm.ask( + "load_3D_points is true, but the dataset was processed with an outdated ns-process-data that didn't convert colmap points to .ply! Update the colmap dataset automatically?" + ) + + if self.create_pc: + import json + + from nerfstudio.process_data.colmap_utils import \ + create_ply_from_colmap + + with open(self.config.data / "transforms.json") as f: + transforms = json.load(f) + + # Update dataset if missing the applied_transform field. + if "applied_transform" not in transforms: + transforms["applied_transform"] = meta["applied_transform"] + + ply_filename = "sparse_pc.ply" + create_ply_from_colmap( + filename=ply_filename, + recon_dir=colmap_path, + output_dir=self.config.data, + applied_transform=applied_transform, + ) + ply_file_path = data_dir / ply_filename + transforms["ply_file_path"] = ply_filename + + # This was the applied_transform value + with open(self.config.data / "transforms.json", "w", encoding="utf-8") as f: + json.dump(transforms, f, indent=4) + else: + ply_file_path = None + else: + if not self.prompted_user: + CONSOLE.print( + "[bold yellow]Warning: load_3D_points set to true but no point cloud found. splatfacto will use random point cloud initialization." + ) + ply_file_path = None + + if ply_file_path: + sparse_points = self._load_3D_points(ply_file_path, transform_matrix, scale_factor) + if sparse_points is not None: + metadata.update(sparse_points) + self.prompted_user = True + + dataparser_outputs = DataparserOutputs( + image_filenames=image_filenames, + cameras=cameras, + scene_box=scene_box, + mask_filenames=mask_filenames if len(mask_filenames) > 0 else None, + dataparser_scale=scale_factor, + dataparser_transform=dataparser_transform_matrix, + metadata={ + "depth_filenames": depth_filenames if len(depth_filenames) > 0 else None, + "depth_unit_scale_factor": self.config.depth_unit_scale_factor, + "mask_color": self.config.mask_color, + **metadata, + }, + ) + return dataparser_outputs + + def _load_3D_points(self, ply_file_path: Path, transform_matrix: torch.Tensor, scale_factor: float): + """Loads point clouds positions and colors from .ply + + Args: + ply_file_path: Path to .ply file + transform_matrix: Matrix to transform world coordinates + scale_factor: How much to scale the camera origins by. + + Returns: + A dictionary of points: points3D_xyz and colors: points3D_rgb + """ + import open3d as o3d # Importing open3d is slow, so we only do it if we need it. + + pcd = o3d.io.read_point_cloud(str(ply_file_path)) + + # if no points found don't read in an initial point cloud + if len(pcd.points) == 0: + return None + + points3D = torch.from_numpy(np.asarray(pcd.points, dtype=np.float32)) + points3D = ( + torch.cat( + ( + points3D, + torch.ones_like(points3D[..., :1]), + ), + -1, + ) + @ transform_matrix.T + ) + points3D *= scale_factor + points3D_rgb = torch.from_numpy((np.asarray(pcd.colors) * 255).astype(np.uint8)) + + out = { + "points3D_xyz": points3D, + "points3D_rgb": points3D_rgb, + } + return out + + def _get_fname(self, filepath: Path, data_dir: Path, downsample_folder_prefix="images_") -> Path: + """Get the filename of the image file. + downsample_folder_prefix can be used to point to auxiliary image data, e.g. masks + + filepath: the base file name of the transformations. + data_dir: the directory of the data that contains the transform file + downsample_folder_prefix: prefix of the newly generated downsampled images + """ + + if self.downscale_factor is None: + if self.config.downscale_factor is None: + test_img = Image.open(data_dir / filepath) + h, w = test_img.size + max_res = max(h, w) + df = 0 + while True: + if (max_res / 2 ** (df)) <= MAX_AUTO_RESOLUTION: + break + if not (data_dir / f"{downsample_folder_prefix}{2**(df+1)}" / filepath.name).exists(): + break + df += 1 + + self.downscale_factor = 2**df + CONSOLE.log(f"Auto image downscale factor of {self.downscale_factor}") + else: + self.downscale_factor = self.config.downscale_factor + + if self.downscale_factor > 1: + return data_dir / f"{downsample_folder_prefix}{self.downscale_factor}" / filepath.name + return data_dir / filepath diff --git a/nerfstudio/data/utils/data_utils.py b/nerfstudio/data/utils/data_utils.py index 11ce74d9da..16438ff5b9 100644 --- a/nerfstudio/data/utils/data_utils.py +++ b/nerfstudio/data/utils/data_utils.py @@ -33,6 +33,10 @@ def get_image_mask_tensor_from_path(filepath: Path, scale_factor: float = 1.0) - newsize = (int(width * scale_factor), int(height * scale_factor)) pil_mask = pil_mask.resize(newsize, resample=Image.Resampling.NEAREST) mask_tensor = torch.from_numpy(np.array(pil_mask)).unsqueeze(-1).bool() + print(filepath, torch.sum(mask_tensor)) + if torch.sum(mask_tensor) == 0: + mask_tensor[0, 0, 0] = True + print(filepath, torch.sum(mask_tensor)) if len(mask_tensor.shape) != 3: raise ValueError("The mask image should have 1 channel") return mask_tensor diff --git a/nerfstudio/models/nerfacto.py b/nerfstudio/models/nerfacto.py index bfccfd8797..5408fd4fe8 100644 --- a/nerfstudio/models/nerfacto.py +++ b/nerfstudio/models/nerfacto.py @@ -23,29 +23,31 @@ import numpy as np import torch -from torch.nn import Parameter - -from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig +from nerfstudio.cameras.camera_optimizers import (CameraOptimizer, + CameraOptimizerConfig) from nerfstudio.cameras.rays import RayBundle, RaySamples -from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation +from nerfstudio.engine.callbacks import (TrainingCallback, + TrainingCallbackAttributes, + TrainingCallbackLocation) from nerfstudio.field_components.field_heads import FieldHeadNames from nerfstudio.field_components.spatial_distortions import SceneContraction from nerfstudio.fields.density_fields import HashMLPDensityField from nerfstudio.fields.nerfacto_field import NerfactoField from nerfstudio.model_components.losses import ( - MSELoss, - distortion_loss, - interlevel_loss, - orientation_loss, - pred_normal_loss, - scale_gradients_by_distance_squared, -) -from nerfstudio.model_components.ray_samplers import ProposalNetworkSampler, UniformSampler -from nerfstudio.model_components.renderers import AccumulationRenderer, DepthRenderer, NormalsRenderer, RGBRenderer -from nerfstudio.model_components.scene_colliders import NearFarCollider + MSELoss, distortion_loss, interlevel_loss, orientation_loss, + pred_normal_loss, scale_gradients_by_distance_squared) +from nerfstudio.model_components.ray_samplers import (ProposalNetworkSampler, + UniformSampler) +from nerfstudio.model_components.renderers import (AccumulationRenderer, + DepthRenderer, + NormalsRenderer, + RGBRenderer) +from nerfstudio.model_components.scene_colliders import (AABBBoxCollider, + NearFarCollider) from nerfstudio.model_components.shaders import NormalsShader from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils import colormaps +from torch.nn import Parameter @dataclass @@ -228,6 +230,8 @@ def update_schedule(step): # Collider self.collider = NearFarCollider(near_plane=self.config.near_plane, far_plane=self.config.far_plane) + #self.collider = AABBBoxCollider(scene_box=self.scene_box) + print("self.scene_box", self.scene_box) # renderers self.renderer_rgb = RGBRenderer(background_color=self.config.background_color) @@ -245,7 +249,8 @@ def update_schedule(step): # metrics from torchmetrics.functional import structural_similarity_index_measure from torchmetrics.image import PeakSignalNoiseRatio - from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + from torchmetrics.image.lpip import \ + LearnedPerceptualImagePatchSimilarity self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = structural_similarity_index_measure diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 61d9eda19f..43e4e388f5 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -31,19 +31,23 @@ except ImportError: print("Please install gsplat>=1.0.0") from gsplat.cuda_legacy._wrapper import num_sh_bases -from pytorch_msssim import SSIM -from torch.nn import Parameter - -from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig +from nerfstudio.cameras.camera_optimizers import (CameraOptimizer, + CameraOptimizerConfig) from nerfstudio.cameras.cameras import Cameras from nerfstudio.data.scene_box import OrientedBox -from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation +from nerfstudio.engine.callbacks import (TrainingCallback, + TrainingCallbackAttributes, + TrainingCallbackLocation) from nerfstudio.engine.optimizers import Optimizers -from nerfstudio.model_components.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss +from nerfstudio.model_components.lib_bilagrid import (BilateralGrid, + color_correct, slice, + total_variation_loss) from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils.colors import get_color from nerfstudio.utils.misc import torch_compile from nerfstudio.utils.rich_utils import CONSOLE +from pytorch_msssim import SSIM +from torch.nn import Parameter def quat_to_rotmat(quat): @@ -116,7 +120,7 @@ def resize_image(image: torch.Tensor, d: int): return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0) -@torch_compile() +# @torch_compile() def get_viewmat(optimized_camera_to_world): """ function that converts c2w to gsplat world2camera matrix, using compile for some speed @@ -283,7 +287,8 @@ def populate_modules(self): # metrics from torchmetrics.image import PeakSignalNoiseRatio - from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + from torchmetrics.image.lpip import \ + LearnedPerceptualImagePatchSimilarity self.psnr = PeakSignalNoiseRatio(data_range=1.0) self.ssim = SSIM(data_range=1.0, size_average=True, channel=3) From 5d62aa0234dc6f1ed9ea8a3ea1902fe972ebb9fe Mon Sep 17 00:00:00 2001 From: Travis Driver Date: Tue, 17 Sep 2024 16:31:37 -0400 Subject: [PATCH 2/2] Cleanup --- .vscode/settings.json | 3 +- nerfstudio/configs/dataparser_configs.py | 45 ++++++++---------------- 2 files changed, 16 insertions(+), 32 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index ea0491df4e..c365ae07b4 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -35,7 +35,7 @@ "python.envFile": "${workspaceFolder}/.env", "python.formatting.provider": "none", "python.linting.pylintEnabled": false, - "python.linting.flake8Enabled": true, + "python.linting.flake8Enabled": false, "python.linting.enabled": true, "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, @@ -123,5 +123,4 @@ "python.analysis.typeCheckingMode": "basic", "python.analysis.diagnosticMode": "workspace", "eslint.packageManager": "yarn", - "python.linting.mypyEnabled": false, } diff --git a/nerfstudio/configs/dataparser_configs.py b/nerfstudio/configs/dataparser_configs.py index 2d3c14b4c8..8e26c2b507 100644 --- a/nerfstudio/configs/dataparser_configs.py +++ b/nerfstudio/configs/dataparser_configs.py @@ -19,38 +19,23 @@ from typing import TYPE_CHECKING import tyro -from nerfstudio.data.dataparsers.arkitscenes_dataparser import \ - ARKitScenesDataParserConfig -from nerfstudio.data.dataparsers.astrovision_dataparser import \ - AstroVisionDataParserConfig +from nerfstudio.data.dataparsers.arkitscenes_dataparser import ARKitScenesDataParserConfig +from nerfstudio.data.dataparsers.astrovision_dataparser import AstroVisionDataParserConfig from nerfstudio.data.dataparsers.base_dataparser import DataParserConfig -from nerfstudio.data.dataparsers.blender_dataparser import \ - BlenderDataParserConfig -from nerfstudio.data.dataparsers.colmap_dataparser import \ - ColmapDataParserConfig +from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig +from nerfstudio.data.dataparsers.colmap_dataparser import ColmapDataParserConfig from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig -from nerfstudio.data.dataparsers.dycheck_dataparser import \ - DycheckDataParserConfig -from nerfstudio.data.dataparsers.instant_ngp_dataparser import \ - InstantNGPDataParserConfig -from nerfstudio.data.dataparsers.minimal_dataparser import \ - MinimalDataParserConfig -from nerfstudio.data.dataparsers.nerfosr_dataparser import \ - NeRFOSRDataParserConfig -from nerfstudio.data.dataparsers.nerfstudio_dataparser import \ - NerfstudioDataParserConfig -from nerfstudio.data.dataparsers.nuscenes_dataparser import \ - NuScenesDataParserConfig -from nerfstudio.data.dataparsers.phototourism_dataparser import \ - PhototourismDataParserConfig -from nerfstudio.data.dataparsers.scannet_dataparser import \ - ScanNetDataParserConfig -from nerfstudio.data.dataparsers.scannetpp_dataparser import \ - ScanNetppDataParserConfig -from nerfstudio.data.dataparsers.sdfstudio_dataparser import \ - SDFStudioDataParserConfig -from nerfstudio.data.dataparsers.sitcoms3d_dataparser import \ - Sitcoms3DDataParserConfig +from nerfstudio.data.dataparsers.dycheck_dataparser import DycheckDataParserConfig +from nerfstudio.data.dataparsers.instant_ngp_dataparser import InstantNGPDataParserConfig +from nerfstudio.data.dataparsers.minimal_dataparser import MinimalDataParserConfig +from nerfstudio.data.dataparsers.nerfosr_dataparser import NeRFOSRDataParserConfig +from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig +from nerfstudio.data.dataparsers.nuscenes_dataparser import NuScenesDataParserConfig +from nerfstudio.data.dataparsers.phototourism_dataparser import PhototourismDataParserConfig +from nerfstudio.data.dataparsers.scannet_dataparser import ScanNetDataParserConfig +from nerfstudio.data.dataparsers.scannetpp_dataparser import ScanNetppDataParserConfig +from nerfstudio.data.dataparsers.sdfstudio_dataparser import SDFStudioDataParserConfig +from nerfstudio.data.dataparsers.sitcoms3d_dataparser import Sitcoms3DDataParserConfig from nerfstudio.plugins.registry_dataparser import discover_dataparsers dataparsers = {