diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index f934b12639..61d9eda19f 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -314,7 +314,10 @@ def colors(self): @property def shs_0(self): - return self.features_dc + if self.config.sh_degree > 0: + return self.features_dc + else: + return RGB2SH(torch.sigmoid(self.features_dc)) @property def shs_rest(self): diff --git a/nerfstudio/scripts/exporter.py b/nerfstudio/scripts/exporter.py index 5ae6037009..970b5a9c7a 100644 --- a/nerfstudio/scripts/exporter.py +++ b/nerfstudio/scripts/exporter.py @@ -485,6 +485,9 @@ class ExportGaussianSplat(Exporter): """Rotation of the oriented bounding box. Expressed as RPY Euler angles in radians""" obb_scale: Optional[Tuple[float, float, float]] = None """Scale of the oriented bounding box along each axis.""" + ply_color_mode: Literal["sh_coeffs", "rgb"] = "sh_coeffs" + """If "rgb", export colors as red/green/blue fields. Otherwise, export colors as + spherical harmonics coefficients.""" @staticmethod def write_ply( @@ -504,7 +507,7 @@ def write_ply( """ # Ensure count matches the length of all tensors - if not all(len(tensor) == count for tensor in map_to_tensors.values()): + if not all(tensor.size == count for tensor in map_to_tensors.values()): raise ValueError("Count does not match the length of all tensors") # Type check for numpy arrays of type float or uint8 and non-empty @@ -552,7 +555,6 @@ def main(self) -> None: filename = self.output_dir / "splat.ply" - count = 0 map_to_tensors = OrderedDict() with torch.no_grad(): @@ -566,19 +568,28 @@ def main(self) -> None: map_to_tensors["ny"] = np.zeros(n, dtype=np.float32) map_to_tensors["nz"] = np.zeros(n, dtype=np.float32) - if model.config.sh_degree > 0: + if self.ply_color_mode == "rgb": + colors = torch.clamp(model.colors.clone(), 0.0, 1.0).data.cpu().numpy() + colors = (colors * 255).astype(np.uint8) + map_to_tensors["red"] = colors[:, 0] + map_to_tensors["green"] = colors[:, 1] + map_to_tensors["blue"] = colors[:, 2] + elif self.ply_color_mode == "sh_coeffs": shs_0 = model.shs_0.contiguous().cpu().numpy() for i in range(shs_0.shape[1]): map_to_tensors[f"f_dc_{i}"] = shs_0[:, i, None] - # transpose(1, 2) was needed to match the sh order in Inria version - shs_rest = model.shs_rest.transpose(1, 2).contiguous().cpu().numpy() - shs_rest = shs_rest.reshape((n, -1)) - for i in range(shs_rest.shape[-1]): - map_to_tensors[f"f_rest_{i}"] = shs_rest[:, i, None] - else: - colors = torch.clamp(model.colors.clone(), 0.0, 1.0).data.cpu().numpy() - map_to_tensors["colors"] = (colors * 255).astype(np.uint8) + if model.config.sh_degree > 0: + if self.ply_color_mode == "rgb": + CONSOLE.print( + "Warning: model has higher level of spherical harmonics, ignoring them and only export rgb." + ) + elif self.ply_color_mode == "sh_coeffs": + # transpose(1, 2) was needed to match the sh order in Inria version + shs_rest = model.shs_rest.transpose(1, 2).contiguous().cpu().numpy() + shs_rest = shs_rest.reshape((n, -1)) + for i in range(shs_rest.shape[-1]): + map_to_tensors[f"f_rest_{i}"] = shs_rest[:, i, None] map_to_tensors["opacity"] = model.opacities.data.cpu().numpy()