Skip to content

Commit

Permalink
Update exporter.py to export sh_degree 0 case nerfstudio-project#3371 (
Browse files Browse the repository at this point in the history
…nerfstudio-project#3374)

* Update exporter.py for sh_degree 0

Change to write sh coefficients instead of color values

* Add flag for use_sh0_renderer

Add sh0 renderer case for model.config.sh_degree == 0

* fix ruff

* add warning if use_sh0_renderer is used when higher order of SH is available

* fix rgb export for color-only training

* use ply_color_mode

* better handling ply_color_mode=='rgb' when sh_degree>0

* clean RGB2SH

* fix issues

* update description

---------

Co-authored-by: bell-one <[email protected]>
Co-authored-by: Jianbo Ye <[email protected]>
Co-authored-by: Brent Yi <[email protected]>
  • Loading branch information
4 people authored Aug 28, 2024
1 parent 22dae34 commit 96ca8b0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
5 changes: 4 additions & 1 deletion nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,10 @@ def colors(self):

@property
def shs_0(self):
return self.features_dc
if self.config.sh_degree > 0:
return self.features_dc
else:
return RGB2SH(torch.sigmoid(self.features_dc))

@property
def shs_rest(self):
Expand Down
33 changes: 22 additions & 11 deletions nerfstudio/scripts/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ class ExportGaussianSplat(Exporter):
"""Rotation of the oriented bounding box. Expressed as RPY Euler angles in radians"""
obb_scale: Optional[Tuple[float, float, float]] = None
"""Scale of the oriented bounding box along each axis."""
ply_color_mode: Literal["sh_coeffs", "rgb"] = "sh_coeffs"
"""If "rgb", export colors as red/green/blue fields. Otherwise, export colors as
spherical harmonics coefficients."""

@staticmethod
def write_ply(
Expand All @@ -504,7 +507,7 @@ def write_ply(
"""

# Ensure count matches the length of all tensors
if not all(len(tensor) == count for tensor in map_to_tensors.values()):
if not all(tensor.size == count for tensor in map_to_tensors.values()):
raise ValueError("Count does not match the length of all tensors")

# Type check for numpy arrays of type float or uint8 and non-empty
Expand Down Expand Up @@ -552,7 +555,6 @@ def main(self) -> None:

filename = self.output_dir / "splat.ply"

count = 0
map_to_tensors = OrderedDict()

with torch.no_grad():
Expand All @@ -566,19 +568,28 @@ def main(self) -> None:
map_to_tensors["ny"] = np.zeros(n, dtype=np.float32)
map_to_tensors["nz"] = np.zeros(n, dtype=np.float32)

if model.config.sh_degree > 0:
if self.ply_color_mode == "rgb":
colors = torch.clamp(model.colors.clone(), 0.0, 1.0).data.cpu().numpy()
colors = (colors * 255).astype(np.uint8)
map_to_tensors["red"] = colors[:, 0]
map_to_tensors["green"] = colors[:, 1]
map_to_tensors["blue"] = colors[:, 2]
elif self.ply_color_mode == "sh_coeffs":
shs_0 = model.shs_0.contiguous().cpu().numpy()
for i in range(shs_0.shape[1]):
map_to_tensors[f"f_dc_{i}"] = shs_0[:, i, None]

# transpose(1, 2) was needed to match the sh order in Inria version
shs_rest = model.shs_rest.transpose(1, 2).contiguous().cpu().numpy()
shs_rest = shs_rest.reshape((n, -1))
for i in range(shs_rest.shape[-1]):
map_to_tensors[f"f_rest_{i}"] = shs_rest[:, i, None]
else:
colors = torch.clamp(model.colors.clone(), 0.0, 1.0).data.cpu().numpy()
map_to_tensors["colors"] = (colors * 255).astype(np.uint8)
if model.config.sh_degree > 0:
if self.ply_color_mode == "rgb":
CONSOLE.print(
"Warning: model has higher level of spherical harmonics, ignoring them and only export rgb."
)
elif self.ply_color_mode == "sh_coeffs":
# transpose(1, 2) was needed to match the sh order in Inria version
shs_rest = model.shs_rest.transpose(1, 2).contiguous().cpu().numpy()
shs_rest = shs_rest.reshape((n, -1))
for i in range(shs_rest.shape[-1]):
map_to_tensors[f"f_rest_{i}"] = shs_rest[:, i, None]

map_to_tensors["opacity"] = model.opacities.data.cpu().numpy()

Expand Down

0 comments on commit 96ca8b0

Please sign in to comment.