From 19f3c318ec6db2be8496c993ef05fe54619ab4b7 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Wed, 20 Mar 2024 11:02:26 +0000 Subject: [PATCH] Fix device --- pvnet/models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 20358362..ccda4abe 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -354,7 +354,7 @@ def _calculate_qauntile_loss(self, y_quantiles, y): losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1)) losses = 2 * torch.cat(losses, dim=2) if self.use_weighted_loss: - weights = self.weighted_losses.weights.unsqueeze(1).unsqueeze(0) + weights = self.weighted_losses.weights.unsqueeze(1).unsqueeze(0).to(y.device) losses = losses * weights return losses.mean()