diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index b65f22d8..14e4d48d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -17,7 +17,6 @@ from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import HfApi from huggingface_hub.utils._deprecation import _deprecate_positional_args - from ocf_datapipes.utils.consts import BatchKey from ocf_ml_metrics.evaluation.evaluation import evaluation from ocf_ml_metrics.metrics.errors import common_metrics @@ -371,8 +370,8 @@ def _calculate_val_losses(self, y, y_hat): # Take median value for remaining metric calculations y_hat = self._quantiles_to_prediction(y_hat) - common_metrics_each_step = common_metrics(predictions=y_hat, targets=y) - mse_each_step = common_metrics_each_step["rmse"]**2 + common_metrics_each_step = common_metrics(predictions=y_hat, targets=y) + mse_each_step = common_metrics_each_step["rmse"] ** 2 mae_each_step = common_metrics_each_step["mae"] losses.update({f"MSE_horizon/step_{i:02}": m for i, m in enumerate(mse_each_step)})