Skip to content

Commit

Permalink
Fix device
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Mar 20, 2024
1 parent b0c7639 commit 19f3c31
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def _calculate_qauntile_loss(self, y_quantiles, y):
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
losses = 2 * torch.cat(losses, dim=2)
if self.use_weighted_loss:
weights = self.weighted_losses.weights.unsqueeze(1).unsqueeze(0)
weights = self.weighted_losses.weights.unsqueeze(1).unsqueeze(0).to(y.device)
losses = losses * weights
return losses.mean()

Expand Down

0 comments on commit 19f3c31

Please sign in to comment.