diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index aa56e9dc..20358362 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -354,7 +354,8 @@ def _calculate_qauntile_loss(self, y_quantiles, y): losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1)) losses = 2 * torch.cat(losses, dim=2) if self.use_weighted_loss: - losses = losses * self.weighted_losses.weights + weights = self.weighted_losses.weights.unsqueeze(1).unsqueeze(0) + losses = losses * weights return losses.mean() def _calculate_common_losses(self, y, y_hat):