diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index d1d667a2..14e4d48d 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -17,18 +17,15 @@ from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import HfApi from huggingface_hub.utils._deprecation import _deprecate_positional_args -from nowcasting_utils.models.loss import WeightedLosses -from nowcasting_utils.models.metrics import ( - mae_each_forecast_horizon, - mse_each_forecast_horizon, -) from ocf_datapipes.utils.consts import BatchKey from ocf_ml_metrics.evaluation.evaluation import evaluation +from ocf_ml_metrics.metrics.errors import common_metrics from pvnet.models.utils import ( BatchAccumulator, MetricAccumulator, PredAccumulator, + WeightedLosses, ) from pvnet.optimizers import AbstractOptimizer from pvnet.utils import construct_ocf_ml_metrics_batch_df, plot_batch_forecasts @@ -373,8 +370,9 @@ def _calculate_val_losses(self, y, y_hat): # Take median value for remaining metric calculations y_hat = self._quantiles_to_prediction(y_hat) - mse_each_step = mse_each_forecast_horizon(output=y_hat, target=y) - mae_each_step = mae_each_forecast_horizon(output=y_hat, target=y) + common_metrics_each_step = common_metrics(predictions=y_hat, targets=y) + mse_each_step = common_metrics_each_step["rmse"] ** 2 + mae_each_step = common_metrics_each_step["mae"] losses.update({f"MSE_horizon/step_{i:02}": m for i, m in enumerate(mse_each_step)}) losses.update({f"MAE_horizon/step_{i:02}": m for i, m in enumerate(mae_each_step)}) diff --git a/pvnet/models/utils.py b/pvnet/models/utils.py index 89dffe6a..c652735e 100644 --- a/pvnet/models/utils.py +++ b/pvnet/models/utils.py @@ -1,9 +1,15 @@ """Utility functions""" +import logging +import math +from typing import Optional + import numpy as np import torch from ocf_datapipes.utils.consts import BatchKey +logger = logging.getLogger(__name__) + class PredAccumulator: """A class for accumulating y-predictions using grad accumulation and small batch size. @@ -110,3 +116,58 @@ def flush(self) -> dict[BatchKey, list[torch.Tensor]]: batch[k] = torch.cat(v, dim=0) self._batches = {} return batch + + +class WeightedLosses: + """Class: Weighted loss depending on the forecast horizon.""" + + def __init__(self, decay_rate: Optional[int] = None, forecast_length: int = 6): + """ + Want to set up the MSE loss function so the weights only have to be calculated once. + + Args: + decay_rate: The weights exponentially decay depending on the 'decay_rate'. + forecast_length: The forecast length is needed to make sure the weights sum to 1 + """ + self.decay_rate = decay_rate + self.forecast_length = forecast_length + + logger.debug( + f"Setting up weights with decay rate {decay_rate} and of length {forecast_length}" + ) + + # set default rate of ln(2) if not set + if self.decay_rate is None: + self.decay_rate = math.log(2) + + # make weights from decay rate + weights = torch.FloatTensor( + [math.exp(-self.decay_rate * i) for i in range(0, self.forecast_length)] + ) + + # normalized the weights, so there mean is 1. + # To calculate the loss, we times the weights by the differences between truth + # and predictions and then take the mean across all forecast horizons and the batch + self.weights = weights / weights.sum() * len(weights) + + # move weights to gpu is needed + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.weights = self.weights.to(device) + + def get_mse_exp(self, output, target): + """Loss function weighted MSE""" + + # get the differences weighted by the forecast horizon weights + diff_with_weights = self.weights * ((output - target) ** 2) + + # average across batches + return torch.mean(diff_with_weights) + + def get_mae_exp(self, output, target): + """Loss function weighted MAE""" + + # get the differences weighted by the forecast horizon weights + diff_with_weights = self.weights * torch.abs(output - target) + + # average across batches + return torch.mean(diff_with_weights) diff --git a/requirements.txt b/requirements.txt index 96f7b893..005f9ec1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ ocf_datapipes>=2.2.5 -nowcasting_utils ocf_ml_metrics numpy pandas diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py new file mode 100644 index 00000000..15a26a58 --- /dev/null +++ b/tests/models/test_utils.py @@ -0,0 +1,69 @@ +import pytest +import torch + +from pvnet.models.utils import WeightedLosses + + +def test_weight_losses_weights(): + """Test weighted loss""" + forecast_length = 2 + w = WeightedLosses(forecast_length=forecast_length) + + assert w.weights.cpu().numpy()[0] == pytest.approx(4 / 3) + assert w.weights.cpu().numpy()[1] == pytest.approx(2 / 3) + + +def test_mae_exp(): + """Test MAE exp with weighted loss""" + forecast_length = 2 + w = WeightedLosses(forecast_length=forecast_length) + + output = torch.Tensor([[1, 3], [1, 3]]) + target = torch.Tensor([[1, 5], [1, 9]]) + + loss = w.get_mae_exp(output=output, target=target) + + # 0.5((1-1)*2/3 + (5-3)*1/3) + 0.5((1-1)*2/3 + (9-3)*1/3) = 1/3 + 3/3 + assert loss == pytest.approx(4 / 3) + + +def test_mse_exp(): + """Test MSE exp with weighted loss""" + forecast_length = 2 + w = WeightedLosses(forecast_length=forecast_length) + + output = torch.Tensor([[1, 3], [1, 3]]) + target = torch.Tensor([[1, 5], [1, 9]]) + + loss = w.get_mse_exp(output=output, target=target) + + # 0.5((1-1)^2*2/3 + (5-3)^2*1/3) + 0.5((1-1)^2*2/3 + (9-3)^2*1/3) = 2/3 + 18/3 + assert loss == pytest.approx(20 / 3) + + +def test_mae_exp_rand(): + """Test MAE exp with weighted loss with random tensors""" + forecast_length = 6 + batch_size = 32 + + w = WeightedLosses(forecast_length=6) + + output = torch.randn(batch_size, forecast_length) + target = torch.randn(batch_size, forecast_length) + + loss = w.get_mae_exp(output=output, target=target) + assert loss > 0 + + +def test_mse_exp_rand(): + """Test MSE exp with weighted loss with random tensors""" + forecast_length = 6 + batch_size = 32 + + w = WeightedLosses(forecast_length=6) + + output = torch.randn(batch_size, forecast_length) + target = torch.randn(batch_size, forecast_length) + + loss = w.get_mse_exp(output=output, target=target) + assert loss > 0