diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 20358362..ccda4abe 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -354,7 +354,7 @@ def _calculate_qauntile_loss(self, y_quantiles, y): losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1)) losses = 2 * torch.cat(losses, dim=2) if self.use_weighted_loss: - weights = self.weighted_losses.weights.unsqueeze(1).unsqueeze(0) + weights = self.weighted_losses.weights.unsqueeze(1).unsqueeze(0).to(y.device) losses = losses * weights return losses.mean()