diff --git a/street_gaussians_ns/sgn_splatfacto.py b/street_gaussians_ns/sgn_splatfacto.py index 1459fb1..7e2a36a 100644 --- a/street_gaussians_ns/sgn_splatfacto.py +++ b/street_gaussians_ns/sgn_splatfacto.py @@ -8,10 +8,8 @@ from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Type, Union -from gsplat._torch_impl import quat_to_rotmat -from gsplat.project_gaussians import project_gaussians -from gsplat.rasterize import rasterize_gaussians -from gsplat.sh import num_sh_bases, spherical_harmonics +from gsplat.cuda_legacy._torch_impl import quat_to_rotmat +from gsplat import project_gaussians, rasterize_gaussians, spherical_harmonics from pytorch_msssim import SSIM from torch.nn import Parameter from typing_extensions import Literal @@ -36,6 +34,18 @@ from street_gaussians_ns.data.utils.data_utils import SemanticType +def num_sh_bases(degree: int): + if degree == 0: + return 1 + if degree == 1: + return 4 + if degree == 2: + return 9 + if degree == 3: + return 16 + return 25 + + def random_quat_tensor(N): """ Defines a random quaternion tensor of shape (N, 4) diff --git a/street_gaussians_ns/sgn_splatfacto_scene_graph.py b/street_gaussians_ns/sgn_splatfacto_scene_graph.py index 18526fc..8396c34 100644 --- a/street_gaussians_ns/sgn_splatfacto_scene_graph.py +++ b/street_gaussians_ns/sgn_splatfacto_scene_graph.py @@ -5,7 +5,7 @@ import copy import math -from gsplat.sh import spherical_harmonics +from gsplat import spherical_harmonics from pytorch3d.transforms import quaternion_multiply from torch.nn import Parameter import mediapy as media