diff --git a/nerfstudio/field_components/encodings.py b/nerfstudio/field_components/encodings.py index 8d63cd9279..b5f8bf4f0e 100644 --- a/nerfstudio/field_components/encodings.py +++ b/nerfstudio/field_components/encodings.py @@ -28,8 +28,9 @@ from nerfstudio.field_components.base_field_component import FieldComponent from nerfstudio.utils.external import TCNN_EXISTS, tcnn -from nerfstudio.utils.math import components_from_spherical_harmonics, expected_sin, generate_polyhedron_basis +from nerfstudio.utils.math import expected_sin, generate_polyhedron_basis from nerfstudio.utils.printing import print_tcnn_speed_warning +from nerfstudio.utils.spherical_harmonics import MAX_SH_DEGREE, components_from_spherical_harmonics class Encoding(FieldComponent): @@ -756,14 +757,16 @@ class SHEncoding(Encoding): """Spherical harmonic encoding Args: - levels: Number of spherical harmonic levels to encode. + levels: Number of spherical harmonic levels to encode. (level = sh degree + 1) """ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "torch") -> None: super().__init__(in_dim=3) - if levels <= 0 or levels > 4: - raise ValueError(f"Spherical harmonic encoding only supports 1 to 4 levels, requested {levels}") + if levels <= 0 or levels > MAX_SH_DEGREE + 1: + raise ValueError( + f"Spherical harmonic encoding only supports 1 to {MAX_SH_DEGREE + 1} levels, requested {levels}" + ) self.levels = levels @@ -778,7 +781,7 @@ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = " ) @classmethod - def get_tcnn_encoding_config(cls, levels) -> dict: + def get_tcnn_encoding_config(cls, levels: int) -> dict: """Get the encoding configuration for tcnn if implemented""" encoding_config = { "otype": "SphericalHarmonics", @@ -792,7 +795,7 @@ def get_out_dim(self) -> int: @torch.no_grad() def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: """Forward pass using pytorch. Significantly slower than TCNN implementation.""" - return components_from_spherical_harmonics(levels=self.levels, directions=in_tensor) + return components_from_spherical_harmonics(degree=self.levels - 1, directions=in_tensor) def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: if self.tcnn_encoding is not None: diff --git a/nerfstudio/model_components/renderers.py b/nerfstudio/model_components/renderers.py index 1f3ee17499..99c14ca7d0 100644 --- a/nerfstudio/model_components/renderers.py +++ b/nerfstudio/model_components/renderers.py @@ -38,7 +38,8 @@ from nerfstudio.cameras.rays import RaySamples from nerfstudio.utils import colors -from nerfstudio.utils.math import components_from_spherical_harmonics, safe_normalize +from nerfstudio.utils.math import safe_normalize +from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics BackgroundColor = Union[Literal["random", "last_sample", "black", "white"], Float[Tensor, "3"], Float[Tensor, "*bs 3"]] BACKGROUND_COLOR_OVERRIDE: Optional[Float[Tensor, "3"]] = None @@ -268,7 +269,7 @@ def forward( sh = sh.view(*sh.shape[:-1], 3, sh.shape[-1] // 3) levels = int(math.sqrt(sh.shape[-1])) - components = components_from_spherical_harmonics(levels=levels, directions=directions) + components = components_from_spherical_harmonics(degree=levels - 1, directions=directions) rgb = sh * components[..., None, :] # [..., num_samples, 3, sh_components] rgb = torch.sum(rgb, dim=-1) # [..., num_samples, 3] diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 28b8f0a1de..a87c5518a2 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -19,11 +19,9 @@ from __future__ import annotations -import math from dataclasses import dataclass, field from typing import Dict, List, Literal, Optional, Tuple, Type, Union -import numpy as np import torch from gsplat.strategy import DefaultStrategy @@ -42,70 +40,10 @@ from nerfstudio.model_components.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils.colors import get_color +from nerfstudio.utils.math import k_nearest_sklearn, random_quat_tensor from nerfstudio.utils.misc import torch_compile from nerfstudio.utils.rich_utils import CONSOLE - - -def num_sh_bases(degree: int) -> int: - """ - Returns the number of spherical harmonic bases for a given degree. - """ - assert degree <= 4, "We don't support degree greater than 4." - return (degree + 1) ** 2 - - -def quat_to_rotmat(quat): - assert quat.shape[-1] == 4, quat.shape - w, x, y, z = torch.unbind(quat, dim=-1) - mat = torch.stack( - [ - 1 - 2 * (y**2 + z**2), - 2 * (x * y - w * z), - 2 * (x * z + w * y), - 2 * (x * y + w * z), - 1 - 2 * (x**2 + z**2), - 2 * (y * z - w * x), - 2 * (x * z - w * y), - 2 * (y * z + w * x), - 1 - 2 * (x**2 + y**2), - ], - dim=-1, - ) - return mat.reshape(quat.shape[:-1] + (3, 3)) - - -def random_quat_tensor(N): - """ - Defines a random quaternion tensor of shape (N, 4) - """ - u = torch.rand(N) - v = torch.rand(N) - w = torch.rand(N) - return torch.stack( - [ - torch.sqrt(1 - u) * torch.sin(2 * math.pi * v), - torch.sqrt(1 - u) * torch.cos(2 * math.pi * v), - torch.sqrt(u) * torch.sin(2 * math.pi * w), - torch.sqrt(u) * torch.cos(2 * math.pi * w), - ], - dim=-1, - ) - - -def RGB2SH(rgb): - """ - Converts from RGB values [0,1] to the 0th spherical harmonic coefficient - """ - C0 = 0.28209479177387814 - return (rgb - 0.5) / C0 - - -def SH2RGB(sh): - """ - Converts from the 0th spherical harmonic coefficient to RGB values [0,1] - """ - C0 = 0.28209479177387814 - return sh * C0 + 0.5 +from nerfstudio.utils.spherical_harmonics import RGB2SH, SH2RGB, num_sh_bases def resize_image(image: torch.Tensor, d: int): @@ -243,8 +181,7 @@ def populate_modules(self): means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color) else: means = torch.nn.Parameter((torch.rand((self.config.num_random, 3)) - 0.5) * self.config.random_scale) - distances, _ = self.k_nearest_sklearn(means.data, 3) - distances = torch.from_numpy(distances) + distances, _ = k_nearest_sklearn(means.data, 3) # find the average of the three nearest neighbors for each point and use that as the scale avg_dist = distances.mean(dim=-1, keepdim=True) scales = torch.nn.Parameter(torch.log(avg_dist.repeat(1, 3))) @@ -392,26 +329,6 @@ def load_state_dict(self, dict, **kwargs): # type: ignore self.gauss_params[name] = torch.nn.Parameter(torch.zeros(new_shape, device=self.device)) super().load_state_dict(dict, **kwargs) - def k_nearest_sklearn(self, x: torch.Tensor, k: int): - """ - Find k-nearest neighbors using sklearn's NearestNeighbors. - x: The data tensor of shape [num_samples, num_features] - k: The number of neighbors to retrieve - """ - # Convert tensor to numpy array - x_np = x.cpu().numpy() - - # Build the nearest neighbors model - from sklearn.neighbors import NearestNeighbors - - nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric="euclidean").fit(x_np) - - # Find the k-nearest neighbors - distances, indices = nn_model.kneighbors(x_np) - - # Exclude the point itself from the result and return - return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32) - def set_crop(self, crop_box: Optional[OrientedBox]): self.crop_box = crop_box diff --git a/nerfstudio/utils/math.py b/nerfstudio/utils/math.py index d71907bee3..7a80b9acc2 100644 --- a/nerfstudio/utils/math.py +++ b/nerfstudio/utils/math.py @@ -20,78 +20,12 @@ from typing import Literal, Tuple import torch -from jaxtyping import Bool, Float +from jaxtyping import Bool, Float, Int from torch import Tensor from nerfstudio.data.scene_box import OrientedBox -def components_from_spherical_harmonics( - levels: int, directions: Float[Tensor, "*batch 3"] -) -> Float[Tensor, "*batch components"]: - """ - Returns value for each component of spherical harmonics. - - Args: - levels: Number of spherical harmonic levels to compute. - directions: Spherical harmonic coefficients - """ - num_components = levels**2 - components = torch.zeros((*directions.shape[:-1], num_components), device=directions.device) - - assert 1 <= levels <= 5, f"SH levels must be in [1,4], got {levels}" - assert directions.shape[-1] == 3, f"Direction input should have three dimensions. Got {directions.shape[-1]}" - - x = directions[..., 0] - y = directions[..., 1] - z = directions[..., 2] - - xx = x**2 - yy = y**2 - zz = z**2 - - # l0 - components[..., 0] = 0.28209479177387814 - - # l1 - if levels > 1: - components[..., 1] = 0.4886025119029199 * y - components[..., 2] = 0.4886025119029199 * z - components[..., 3] = 0.4886025119029199 * x - - # l2 - if levels > 2: - components[..., 4] = 1.0925484305920792 * x * y - components[..., 5] = 1.0925484305920792 * y * z - components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999 - components[..., 7] = 1.0925484305920792 * x * z - components[..., 8] = 0.5462742152960396 * (xx - yy) - - # l3 - if levels > 3: - components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy) - components[..., 10] = 2.890611442640554 * x * y * z - components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1) - components[..., 12] = 0.3731763325901154 * z * (5 * zz - 3) - components[..., 13] = 0.4570457994644658 * x * (5 * zz - 1) - components[..., 14] = 1.445305721320277 * z * (xx - yy) - components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy) - - # l4 - if levels > 4: - components[..., 16] = 2.5033429417967046 * x * y * (xx - yy) - components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy) - components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1) - components[..., 19] = 0.6690465435572892 * y * z * (7 * zz - 3) - components[..., 20] = 0.10578554691520431 * (35 * zz * zz - 30 * zz + 3) - components[..., 21] = 0.6690465435572892 * x * z * (7 * zz - 3) - components[..., 22] = 0.47308734787878004 * (xx - yy) * (7 * zz - 1) - components[..., 23] = 1.7701307697799304 * x * z * (xx - 3 * yy) - components[..., 24] = 0.6258357354491761 * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) - - return components - - @dataclass class Gaussians: """Stores Gaussians @@ -323,7 +257,9 @@ def masked_reduction( def normalized_depth_scale_and_shift( - prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"] + prediction: Float[Tensor, "1 32 mult"], + target: Float[Tensor, "1 32 mult"], + mask: Bool[Tensor, "1 32 mult"], ): """ More info here: https://arxiv.org/pdf/2206.00665.pdf supplementary section A2 Depth Consistency Loss @@ -405,7 +341,10 @@ def _compute_tesselation_weights(v: int) -> Tensor: def _tesselate_geodesic( - vertices: Float[Tensor, "N 3"], faces: Float[Tensor, "M 3"], v: int, eps: float = 1e-4 + vertices: Float[Tensor, "N 3"], + faces: Float[Tensor, "M 3"], + v: int, + eps: float = 1e-4, ) -> Tensor: """Tesselate the vertices of a geodesic polyhedron. @@ -518,3 +457,58 @@ def generate_polyhedron_basis( basis = verts.flip(-1) return basis + + +def random_quat_tensor(N: int) -> Float[Tensor, "*batch 4"]: + """ + Defines a random quaternion tensor. + + Args: + N: Number of quaternions to generate + + Returns: + a random quaternion tensor of shape (N, 4) + + """ + u = torch.rand(N) + v = torch.rand(N) + w = torch.rand(N) + return torch.stack( + [ + torch.sqrt(1 - u) * torch.sin(2 * math.pi * v), + torch.sqrt(1 - u) * torch.cos(2 * math.pi * v), + torch.sqrt(u) * torch.sin(2 * math.pi * w), + torch.sqrt(u) * torch.cos(2 * math.pi * w), + ], + dim=-1, + ) + + +def k_nearest_sklearn( + x: torch.Tensor, k: int, metric: str = "euclidean" +) -> Tuple[Float[Tensor, "*batch k"], Int[Tensor, "*batch k"]]: + """ + Find k-nearest neighbors using sklearn's NearestNeighbors. + + Args: + x: input tensor + k: number of neighbors to find + metric: metric to use for distance computation + + Returns: + distances: distances to the k-nearest neighbors + indices: indices of the k-nearest neighbors + """ + # Convert tensor to numpy array + x_np = x.cpu().numpy() + + # Build the nearest neighbors model + from sklearn.neighbors import NearestNeighbors + + nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric=metric).fit(x_np) + + # Find the k-nearest neighbors + distances, indices = nn_model.kneighbors(x_np) + + # Exclude the point itself from the result and return + return torch.tensor(distances[:, 1:], dtype=torch.float32), torch.tensor(indices[:, 1:], dtype=torch.int64) diff --git a/nerfstudio/utils/spherical_harmonics.py b/nerfstudio/utils/spherical_harmonics.py new file mode 100644 index 0000000000..07936281cc --- /dev/null +++ b/nerfstudio/utils/spherical_harmonics.py @@ -0,0 +1,111 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sphecal Harmonics utils.""" + +import torch +from jaxtyping import Float +from torch import Tensor + +MAX_SH_DEGREE = 4 + + +def components_from_spherical_harmonics( + degree: int, directions: Float[Tensor, "*batch 3"] +) -> Float[Tensor, "*batch components"]: + """ + Returns value for each component of spherical harmonics. + + Args: + degree: Number of spherical harmonic degrees to compute. + directions: Spherical harmonic coefficients + """ + num_components = num_sh_bases(degree) + components = torch.zeros((*directions.shape[:-1], num_components), device=directions.device) + + assert 0 <= degree <= MAX_SH_DEGREE, f"SH degree must be in [0, {MAX_SH_DEGREE}], got {degree}" + assert directions.shape[-1] == 3, f"Direction input should have three dimensions. Got {directions.shape[-1]}" + + x = directions[..., 0] + y = directions[..., 1] + z = directions[..., 2] + + xx = x**2 + yy = y**2 + zz = z**2 + + # l0 + components[..., 0] = 0.28209479177387814 + + # l1 + if degree > 0: + components[..., 1] = 0.4886025119029199 * y + components[..., 2] = 0.4886025119029199 * z + components[..., 3] = 0.4886025119029199 * x + + # l2 + if degree > 1: + components[..., 4] = 1.0925484305920792 * x * y + components[..., 5] = 1.0925484305920792 * y * z + components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999 + components[..., 7] = 1.0925484305920792 * x * z + components[..., 8] = 0.5462742152960396 * (xx - yy) + + # l3 + if degree > 2: + components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy) + components[..., 10] = 2.890611442640554 * x * y * z + components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1) + components[..., 12] = 0.3731763325901154 * z * (5 * zz - 3) + components[..., 13] = 0.4570457994644658 * x * (5 * zz - 1) + components[..., 14] = 1.445305721320277 * z * (xx - yy) + components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy) + + # l4 + if degree > 3: + components[..., 16] = 2.5033429417967046 * x * y * (xx - yy) + components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy) + components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1) + components[..., 19] = 0.6690465435572892 * y * z * (7 * zz - 3) + components[..., 20] = 0.10578554691520431 * (35 * zz * zz - 30 * zz + 3) + components[..., 21] = 0.6690465435572892 * x * z * (7 * zz - 3) + components[..., 22] = 0.47308734787878004 * (xx - yy) * (7 * zz - 1) + components[..., 23] = 1.7701307697799304 * x * z * (xx - 3 * yy) + components[..., 24] = 0.6258357354491761 * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) + + return components + + +def num_sh_bases(degree: int) -> int: + """ + Returns the number of spherical harmonic bases for a given degree. + """ + assert degree <= MAX_SH_DEGREE, f"We don't support degree greater than {MAX_SH_DEGREE}." + return (degree + 1) ** 2 + + +def RGB2SH(rgb): + """ + Converts from RGB values [0,1] to the 0th spherical harmonic coefficient + """ + C0 = 0.28209479177387814 + return (rgb - 0.5) / C0 + + +def SH2RGB(sh): + """ + Converts from the 0th spherical harmonic coefficient to RGB values [0,1] + """ + C0 = 0.28209479177387814 + return sh * C0 + 0.5 diff --git a/tests/field_components/test_encodings.py b/tests/field_components/test_encodings.py index a241fc52a1..63dc9a0261 100644 --- a/tests/field_components/test_encodings.py +++ b/tests/field_components/test_encodings.py @@ -125,11 +125,11 @@ def test_tensor_cp_encoder(): def test_tensor_sh_encoder(): """Test Spherical Harmonic encoder""" - levels = 4 + levels = 5 out_dim = levels**2 with pytest.raises(ValueError): - encoder = encodings.SHEncoding(levels=5) + encoder = encodings.SHEncoding(levels=6) encoder = encodings.SHEncoding(levels=levels) assert encoder.get_out_dim() == out_dim diff --git a/tests/utils/test_math.py b/tests/utils/test_math.py deleted file mode 100644 index 952c6a1be5..0000000000 --- a/tests/utils/test_math.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest -import torch - -from nerfstudio.utils.math import components_from_spherical_harmonics - - -@pytest.mark.parametrize("components", list(range(1, 5 + 1))) -def test_spherical_harmonics(components): - torch.manual_seed(0) - N = 1000000 - - dx = torch.normal(0, 1, size=(N, 3)) - dx = dx / torch.linalg.norm(dx, dim=-1, keepdim=True) - sh = components_from_spherical_harmonics(components, dx) - matrix = (sh.T @ sh) / N * 4 * torch.pi - torch.testing.assert_close(matrix, torch.eye(components**2), rtol=0, atol=1.5e-2) diff --git a/tests/utils/test_spherical_harmonics.py b/tests/utils/test_spherical_harmonics.py new file mode 100644 index 0000000000..a8949891a1 --- /dev/null +++ b/tests/utils/test_spherical_harmonics.py @@ -0,0 +1,16 @@ +import pytest +import torch + +from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics, num_sh_bases + + +@pytest.mark.parametrize("degree", list(range(0, 5))) +def test_spherical_harmonics(degree): + torch.manual_seed(0) + N = 1000000 + + dx = torch.normal(0, 1, size=(N, 3)) + dx = dx / torch.linalg.norm(dx, dim=-1, keepdim=True) + sh = components_from_spherical_harmonics(degree, dx) + matrix = (sh.T @ sh) / N * 4 * torch.pi + torch.testing.assert_close(matrix, torch.eye(num_sh_bases(degree)), rtol=0, atol=1.5e-2)