Skip to content

Commit

Permalink
Small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Mar 20, 2024
1 parent 8bf47e4 commit 3c50eda
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ def _calculate_qauntile_loss(self, y_quantiles, y):
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
losses = 2 * torch.cat(losses, dim=2)
if self.use_weighted_loss:
losses = losses * self.weighted_losses.weights
weights = self.weighted_losses.weights.unsqueeze(1).unsqueeze(0)
losses = losses * weights
return losses.mean()

def _calculate_common_losses(self, y, y_hat):
Expand Down

0 comments on commit 3c50eda

Please sign in to comment.