Skip to content

Commit

Permalink
Add gradient scaling option to more methods (#2555)
Browse files Browse the repository at this point in the history
add gradient scaling to instant-ngp, mipnerf, tensorf, vanilla-nerf
  • Loading branch information
kobejean authored Oct 26, 2023
1 parent 8a44e46 commit e4a0050
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
6 changes: 5 additions & 1 deletion nerfstudio/models/instant_ngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.field_components.spatial_distortions import SceneContraction
from nerfstudio.fields.nerfacto_field import NerfactoField
from nerfstudio.model_components.losses import MSELoss
from nerfstudio.model_components.losses import MSELoss, scale_gradients_by_distance_squared
from nerfstudio.model_components.ray_samplers import VolumetricSampler
from nerfstudio.model_components.renderers import (
AccumulationRenderer,
Expand Down Expand Up @@ -78,6 +78,8 @@ class InstantNGPModelConfig(ModelConfig):
"""How far along ray to start sampling."""
far_plane: float = 1e3
"""How far along ray to stop sampling."""
use_gradient_scaling: bool = False
"""Use gradient scaler where the gradients are lower for points closer to the camera."""
use_appearance_embedding: bool = False
"""Whether to use an appearance embedding."""
background_color: Literal["random", "black", "white"] = "random"
Expand Down Expand Up @@ -187,6 +189,8 @@ def get_outputs(self, ray_bundle: RayBundle):
)

field_outputs = self.field(ray_samples)
if self.config.use_gradient_scaling:
field_outputs = scale_gradients_by_distance_squared(field_outputs, ray_samples)

# accumulation
packed_info = nerfacc.pack_info(ray_indices, num_rays)
Expand Down
6 changes: 5 additions & 1 deletion nerfstudio/models/mipnerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from nerfstudio.field_components.encodings import NeRFEncoding
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.fields.vanilla_nerf_field import NeRFField
from nerfstudio.model_components.losses import MSELoss
from nerfstudio.model_components.losses import MSELoss, scale_gradients_by_distance_squared
from nerfstudio.model_components.ray_samplers import PDFSampler, UniformSampler
from nerfstudio.model_components.renderers import (
AccumulationRenderer,
Expand Down Expand Up @@ -109,6 +109,8 @@ def get_outputs(self, ray_bundle: RayBundle):

# First pass:
field_outputs_coarse = self.field.forward(ray_samples_uniform)
if self.config.use_gradient_scaling:
field_outputs_coarse = scale_gradients_by_distance_squared(field_outputs_coarse, ray_samples_uniform)
weights_coarse = ray_samples_uniform.get_weights(field_outputs_coarse[FieldHeadNames.DENSITY])
rgb_coarse = self.renderer_rgb(
rgb=field_outputs_coarse[FieldHeadNames.RGB],
Expand All @@ -122,6 +124,8 @@ def get_outputs(self, ray_bundle: RayBundle):

# Second pass:
field_outputs_fine = self.field.forward(ray_samples_pdf)
if self.config.use_gradient_scaling:
field_outputs_fine = scale_gradients_by_distance_squared(field_outputs_fine, ray_samples_pdf)
weights_fine = ray_samples_pdf.get_weights(field_outputs_fine[FieldHeadNames.DENSITY])
rgb_fine = self.renderer_rgb(
rgb=field_outputs_fine[FieldHeadNames.RGB],
Expand Down
6 changes: 5 additions & 1 deletion nerfstudio/models/tensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.fields.tensorf_field import TensoRFField
from nerfstudio.model_components.losses import MSELoss, tv_loss
from nerfstudio.model_components.losses import MSELoss, tv_loss, scale_gradients_by_distance_squared
from nerfstudio.model_components.ray_samplers import PDFSampler, UniformSampler
from nerfstudio.model_components.renderers import (
AccumulationRenderer,
Expand Down Expand Up @@ -92,6 +92,8 @@ class TensoRFModelConfig(ModelConfig):
"""Regularization method used in tensorf paper"""
camera_optimizer: CameraOptimizerConfig = CameraOptimizerConfig(mode="SO3xR3")
"""Config of the camera optimizer to use"""
use_gradient_scaling: bool = False
"""Use gradient scaler where the gradients are lower for points closer to the camera."""
background_color: Literal["random", "last_sample", "black", "white"] = "white"
"""Whether to randomize the background color."""

Expand Down Expand Up @@ -296,6 +298,8 @@ def get_outputs(self, ray_bundle: RayBundle):
field_outputs_fine = self.field.forward(
ray_samples_pdf, mask=acc_mask, bg_color=colors.WHITE.to(weights.device)
)
if self.config.use_gradient_scaling:
field_outputs_fine = scale_gradients_by_distance_squared(field_outputs_fine, ray_samples_pdf)

weights_fine = ray_samples_pdf.get_weights(field_outputs_fine[FieldHeadNames.DENSITY])

Expand Down
8 changes: 7 additions & 1 deletion nerfstudio/models/vanilla_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.field_components.temporal_distortions import TemporalDistortionKind
from nerfstudio.fields.vanilla_nerf_field import NeRFField
from nerfstudio.model_components.losses import MSELoss
from nerfstudio.model_components.losses import MSELoss, scale_gradients_by_distance_squared
from nerfstudio.model_components.ray_samplers import PDFSampler, UniformSampler
from nerfstudio.model_components.renderers import (
AccumulationRenderer,
Expand All @@ -58,6 +58,8 @@ class VanillaModelConfig(ModelConfig):
"""Specifies whether or not to include ray warping based on time."""
temporal_distortion_params: Dict[str, Any] = to_immutable_dict({"kind": TemporalDistortionKind.DNERF})
"""Parameters to instantiate temporal distortion with"""
use_gradient_scaling: bool = False
"""Use gradient scaler where the gradients are lower for points closer to the camera."""
background_color: Literal["random", "last_sample", "black", "white"] = "white"
"""Whether to randomize the background color."""

Expand Down Expand Up @@ -154,6 +156,8 @@ def get_outputs(self, ray_bundle: RayBundle):

# coarse field:
field_outputs_coarse = self.field_coarse.forward(ray_samples_uniform)
if self.config.use_gradient_scaling:
field_outputs_coarse = scale_gradients_by_distance_squared(field_outputs_coarse, ray_samples_uniform)
weights_coarse = ray_samples_uniform.get_weights(field_outputs_coarse[FieldHeadNames.DENSITY])
rgb_coarse = self.renderer_rgb(
rgb=field_outputs_coarse[FieldHeadNames.RGB],
Expand All @@ -172,6 +176,8 @@ def get_outputs(self, ray_bundle: RayBundle):

# fine field:
field_outputs_fine = self.field_fine.forward(ray_samples_pdf)
if self.config.use_gradient_scaling:
field_outputs_fine = scale_gradients_by_distance_squared(field_outputs_fine, ray_samples_pdf)
weights_fine = ray_samples_pdf.get_weights(field_outputs_fine[FieldHeadNames.DENSITY])
rgb_fine = self.renderer_rgb(
rgb=field_outputs_fine[FieldHeadNames.RGB],
Expand Down

0 comments on commit e4a0050

Please sign in to comment.