diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 7d4a61a1..459337b2 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -521,7 +521,7 @@ def validation_step(self, batch: dict, batch_idx): # Store these to make horizon accuracy plot self._horizon_maes.append( - {i: losses[f"MAE_horizon/step_{i:03}"] for i in range(self.forecast_len)} + {i: losses[f"MAE_horizon/step_{i:03}"].cpu().numpy() for i in range(self.forecast_len)} ) logged_losses = {f"{k}/val": v for k, v in losses.items()}