diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 3d26174f..b056a90c 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -370,7 +370,9 @@ def _calculate_val_losses(self, y, y_hat): # Take median value for remaining metric calculations y_hat = self._quantiles_to_prediction(y_hat) - common_metrics_each_step = common_metrics(predictions=y_hat.numpy(), target=y.numpy()) + common_metrics_each_step = common_metrics( + predictions=y_hat.cpu().numpy(), target=y.cpu().numpy() + ) mse_each_step = common_metrics_each_step["rmse"] ** 2 mae_each_step = common_metrics_each_step["mae"] diff --git a/pvnet/training.py b/pvnet/training.py index 08b40354..5f28215e 100644 --- a/pvnet/training.py +++ b/pvnet/training.py @@ -103,6 +103,8 @@ def train(config: DictConfig) -> Optional[float]: for callback in callbacks: log.info(f"{callback}") if isinstance(callback, ModelCheckpoint): + # Need to call the .experiment property to initialise the logger + wandb_logger.experiment callback.dirpath = "/".join( callback.dirpath.split("/")[:-1] + [wandb_logger.version] )