From bd79763043c64eb2289309bb5cd0866045493c93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Nov 2023 11:32:05 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/base_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index b65f22d8..14e4d48d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -17,7 +17,6 @@ from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import HfApi from huggingface_hub.utils._deprecation import _deprecate_positional_args - from ocf_datapipes.utils.consts import BatchKey from ocf_ml_metrics.evaluation.evaluation import evaluation from ocf_ml_metrics.metrics.errors import common_metrics @@ -371,8 +370,8 @@ def _calculate_val_losses(self, y, y_hat): # Take median value for remaining metric calculations y_hat = self._quantiles_to_prediction(y_hat) - common_metrics_each_step = common_metrics(predictions=y_hat, targets=y) - mse_each_step = common_metrics_each_step["rmse"]**2 + common_metrics_each_step = common_metrics(predictions=y_hat, targets=y) + mse_each_step = common_metrics_each_step["rmse"] ** 2 mae_each_step = common_metrics_each_step["mae"] losses.update({f"MSE_horizon/step_{i:02}": m for i, m in enumerate(mse_each_step)})