Skip to content

Commit

Permalink
move WeightLoss class + test, use common_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Nov 27, 2023
1 parent 842a8c3 commit 5c5c9f1
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 7 deletions.
13 changes: 6 additions & 7 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils._deprecation import _deprecate_positional_args
from nowcasting_utils.models.loss import WeightedLosses
from nowcasting_utils.models.metrics import (
mae_each_forecast_horizon,
mse_each_forecast_horizon,
)

from ocf_datapipes.utils.consts import BatchKey
from ocf_ml_metrics.evaluation.evaluation import evaluation
from ocf_ml_metrics.metrics.errors import common_metrics

from pvnet.models.utils import (
BatchAccumulator,
MetricAccumulator,
PredAccumulator,
WeightedLosses,
)
from pvnet.optimizers import AbstractOptimizer
from pvnet.utils import construct_ocf_ml_metrics_batch_df, plot_batch_forecasts
Expand Down Expand Up @@ -373,8 +371,9 @@ def _calculate_val_losses(self, y, y_hat):
# Take median value for remaining metric calculations
y_hat = self._quantiles_to_prediction(y_hat)

mse_each_step = mse_each_forecast_horizon(output=y_hat, target=y)
mae_each_step = mae_each_forecast_horizon(output=y_hat, target=y)
common_metrics_each_step = common_metrics(predictions=y_hat, targets=y)
mse_each_step = common_metrics_each_step["rmse"]**2
mae_each_step = common_metrics_each_step["mae"]

Check warning on line 376 in pvnet/models/base_model.py

View check run for this annotation

Codecov / codecov/patch

pvnet/models/base_model.py#L374-L376

Added lines #L374 - L376 were not covered by tests

losses.update({f"MSE_horizon/step_{i:02}": m for i, m in enumerate(mse_each_step)})
losses.update({f"MAE_horizon/step_{i:02}": m for i, m in enumerate(mae_each_step)})
Expand Down
61 changes: 61 additions & 0 deletions pvnet/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""Utility functions"""

import logging
import math
from typing import Optional

import numpy as np
import torch
from ocf_datapipes.utils.consts import BatchKey

logger = logging.getLogger(__name__)


class PredAccumulator:
"""A class for accumulating y-predictions using grad accumulation and small batch size.
Expand Down Expand Up @@ -110,3 +116,58 @@ def flush(self) -> dict[BatchKey, list[torch.Tensor]]:
batch[k] = torch.cat(v, dim=0)
self._batches = {}
return batch


class WeightedLosses:
"""Class: Weighted loss depending on the forecast horizon."""

def __init__(self, decay_rate: Optional[int] = None, forecast_length: int = 6):
"""
Want to set up the MSE loss function so the weights only have to be calculated once.
Args:
decay_rate: The weights exponentially decay depending on the 'decay_rate'.
forecast_length: The forecast length is needed to make sure the weights sum to 1
"""
self.decay_rate = decay_rate
self.forecast_length = forecast_length

logger.debug(
f"Setting up weights with decay rate {decay_rate} and of length {forecast_length}"
)

# set default rate of ln(2) if not set
if self.decay_rate is None:
self.decay_rate = math.log(2)

# make weights from decay rate
weights = torch.FloatTensor(
[math.exp(-self.decay_rate * i) for i in range(0, self.forecast_length)]
)

# normalized the weights, so there mean is 1.
# To calculate the loss, we times the weights by the differences between truth
# and predictions and then take the mean across all forecast horizons and the batch
self.weights = weights / weights.sum() * len(weights)

# move weights to gpu is needed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.weights = self.weights.to(device)

def get_mse_exp(self, output, target):
"""Loss function weighted MSE"""

# get the differences weighted by the forecast horizon weights
diff_with_weights = self.weights * ((output - target) ** 2)

# average across batches
return torch.mean(diff_with_weights)

def get_mae_exp(self, output, target):
"""Loss function weighted MAE"""

# get the differences weighted by the forecast horizon weights
diff_with_weights = self.weights * torch.abs(output - target)

# average across batches
return torch.mean(diff_with_weights)
69 changes: 69 additions & 0 deletions tests/models/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
import torch

from pvnet.models.utils import WeightedLosses


def test_weight_losses_weights():
"""Test weighted loss"""
forecast_length = 2
w = WeightedLosses(forecast_length=forecast_length)

assert w.weights.cpu().numpy()[0] == pytest.approx(4 / 3)
assert w.weights.cpu().numpy()[1] == pytest.approx(2 / 3)


def test_mae_exp():
"""Test MAE exp with weighted loss"""
forecast_length = 2
w = WeightedLosses(forecast_length=forecast_length)

output = torch.Tensor([[1, 3], [1, 3]])
target = torch.Tensor([[1, 5], [1, 9]])

loss = w.get_mae_exp(output=output, target=target)

# 0.5((1-1)*2/3 + (5-3)*1/3) + 0.5((1-1)*2/3 + (9-3)*1/3) = 1/3 + 3/3
assert loss == pytest.approx(4 / 3)


def test_mse_exp():
"""Test MSE exp with weighted loss"""
forecast_length = 2
w = WeightedLosses(forecast_length=forecast_length)

output = torch.Tensor([[1, 3], [1, 3]])
target = torch.Tensor([[1, 5], [1, 9]])

loss = w.get_mse_exp(output=output, target=target)

# 0.5((1-1)^2*2/3 + (5-3)^2*1/3) + 0.5((1-1)^2*2/3 + (9-3)^2*1/3) = 2/3 + 18/3
assert loss == pytest.approx(20 / 3)


def test_mae_exp_rand():
"""Test MAE exp with weighted loss with random tensors"""
forecast_length = 6
batch_size = 32

w = WeightedLosses(forecast_length=6)

output = torch.randn(batch_size, forecast_length)
target = torch.randn(batch_size, forecast_length)

loss = w.get_mae_exp(output=output, target=target)
assert loss > 0


def test_mse_exp_rand():
"""Test MSE exp with weighted loss with random tensors"""
forecast_length = 6
batch_size = 32

w = WeightedLosses(forecast_length=6)

output = torch.randn(batch_size, forecast_length)
target = torch.randn(batch_size, forecast_length)

loss = w.get_mse_exp(output=output, target=target)
assert loss > 0

0 comments on commit 5c5c9f1

Please sign in to comment.