Skip to content

Commit

Permalink
add pearson corr coeff loss for depth
Browse files Browse the repository at this point in the history
  • Loading branch information
mattstrong-stanford committed Jun 29, 2024
1 parent b7c682d commit 8801f9a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
22 changes: 22 additions & 0 deletions nerfstudio/model_components/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from jaxtyping import Bool, Float
from torch import Tensor, nn
from torchmetrics.functional.regression import pearson_corrcoef

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.field_components.field_heads import FieldHeadNames
Expand All @@ -44,6 +45,7 @@ class DepthLossType(Enum):
URF = 2
SPARSENERF_RANKING = 3
MSE = 4
PEARSON_LOSS = 5


FORCE_PSEUDODEPTH_LOSS = False
Expand Down Expand Up @@ -240,6 +242,26 @@ def mse_depth_loss(
expected_depth_loss = expected_depth_loss * depth_mask
return torch.mean(expected_depth_loss)


def pearson_correlation_depth_loss(
termination_depth,
predicted_depth,
)-> Float[Tensor, "*batch 1"]:
"""Pearson correlation depth loss.
Args:
termination_depth: Ground truth depth of rays.
predicted_depth: Rendered depth from the radiance field
Returns:
Depth loss scalar.
"""
termination_depth = termination_depth.reshape(-1, 1)
predicted_depth = predicted_depth.reshape(-1, 1)

loss = (1 - pearson_corrcoef( predicted_depth, termination_depth))
return torch.mean(loss)


def ds_nerf_depth_loss(
weights: Float[Tensor, "*batch num_samples 1"],
termination_depth: Float[Tensor, "*batch 1"],
Expand Down
24 changes: 18 additions & 6 deletions nerfstudio/models/depth_splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


from nerfstudio.model_components import losses
from nerfstudio.model_components.losses import DepthLossType, mse_depth_loss, depth_ranking_loss
from nerfstudio.model_components.losses import DepthLossType, mse_depth_loss, depth_ranking_loss, pearson_correlation_depth_loss


from nerfstudio.models.splatfacto import SplatfactoModel, SplatfactoModelConfig
Expand All @@ -54,9 +54,17 @@ class DepthSplatfactoModel(SplatfactoModel):
"""

config: DepthSplatfactoModelConfig

def reshape_termination_depth(self, termination_depth, output_depth_shape):
termination_depth = F.interpolate(termination_depth.permute(2, 0, 1).unsqueeze(0), size=(output_depth_shape[0], output_depth_shape[1]), mode='bilinear', align_corners=False)
# Remove the extra dimensions added by unsqueeze and permute
termination_depth = termination_depth.squeeze(0).permute(1, 2, 0)
return termination_depth

def get_metrics_dict(self, outputs, batch):
metrics_dict = super().get_metrics_dict(outputs, batch)
output_depth_shape = outputs["depth"].shape[:2]

if self.training:
if (
losses.FORCE_PSEUDODEPTH_LOSS
Expand All @@ -68,15 +76,19 @@ def get_metrics_dict(self, outputs, batch):
if self.config.depth_loss_type in (DepthLossType.MSE,):
metrics_dict["depth_loss"] = torch.Tensor([0.0]).to(self.device)
termination_depth = batch["depth_image"].to(self.device)

output_depth_shape = outputs["depth"].shape[:2]
termination_depth = F.interpolate(termination_depth.permute(2, 0, 1).unsqueeze(0), size=(output_depth_shape[0], output_depth_shape[1]), mode='bilinear', align_corners=False)
# Remove the extra dimensions added by unsqueeze and permute
termination_depth = termination_depth.squeeze(0).permute(1, 2, 0)
termination_depth = self.reshape_termination_depth(termination_depth, output_depth_shape)

metrics_dict["depth_loss"] = mse_depth_loss(
termination_depth, outputs["depth"])

elif self.config.depth_loss_type in (DepthLossType.PEARSON_LOSS,):
metrics_dict["depth_loss"] = torch.Tensor([0.0]).to(self.device)
termination_depth = batch["depth_image"].to(self.device)
termination_depth = self.reshape_termination_depth(termination_depth, output_depth_shape)

metrics_dict["depth_loss"] = pearson_correlation_depth_loss(
termination_depth, outputs["depth"])

elif self.config.depth_loss_type in (DepthLossType.SPARSENERF_RANKING,):
metrics_dict["depth_ranking"] = depth_ranking_loss(
outputs["depth"], batch["depth_image"].to(self.device)
Expand Down

0 comments on commit 8801f9a

Please sign in to comment.