Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue/rm nowcasting utils #100

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,15 @@
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils._deprecation import _deprecate_positional_args
from nowcasting_utils.models.loss import WeightedLosses
from nowcasting_utils.models.metrics import (
mae_each_forecast_horizon,
mse_each_forecast_horizon,
)
from ocf_datapipes.utils.consts import BatchKey
from ocf_ml_metrics.evaluation.evaluation import evaluation
from ocf_ml_metrics.metrics.errors import common_metrics

from pvnet.models.utils import (
BatchAccumulator,
MetricAccumulator,
PredAccumulator,
WeightedLosses,
)
from pvnet.optimizers import AbstractOptimizer
from pvnet.utils import construct_ocf_ml_metrics_batch_df, plot_batch_forecasts
Expand Down Expand Up @@ -373,8 +370,9 @@
# Take median value for remaining metric calculations
y_hat = self._quantiles_to_prediction(y_hat)

mse_each_step = mse_each_forecast_horizon(output=y_hat, target=y)
mae_each_step = mae_each_forecast_horizon(output=y_hat, target=y)
common_metrics_each_step = common_metrics(predictions=y_hat, targets=y)
mse_each_step = common_metrics_each_step["rmse"] ** 2
mae_each_step = common_metrics_each_step["mae"]

Check warning on line 375 in pvnet/models/base_model.py

View check run for this annotation

Codecov / codecov/patch

pvnet/models/base_model.py#L373-L375

Added lines #L373 - L375 were not covered by tests

losses.update({f"MSE_horizon/step_{i:02}": m for i, m in enumerate(mse_each_step)})
losses.update({f"MAE_horizon/step_{i:02}": m for i, m in enumerate(mae_each_step)})
Expand Down
61 changes: 61 additions & 0 deletions pvnet/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""Utility functions"""

import logging
import math
from typing import Optional

import numpy as np
import torch
from ocf_datapipes.utils.consts import BatchKey

logger = logging.getLogger(__name__)


class PredAccumulator:
"""A class for accumulating y-predictions using grad accumulation and small batch size.
Expand Down Expand Up @@ -110,3 +116,58 @@ def flush(self) -> dict[BatchKey, list[torch.Tensor]]:
batch[k] = torch.cat(v, dim=0)
self._batches = {}
return batch


class WeightedLosses:
"""Class: Weighted loss depending on the forecast horizon."""

def __init__(self, decay_rate: Optional[int] = None, forecast_length: int = 6):
"""
Want to set up the MSE loss function so the weights only have to be calculated once.

Args:
decay_rate: The weights exponentially decay depending on the 'decay_rate'.
forecast_length: The forecast length is needed to make sure the weights sum to 1
"""
self.decay_rate = decay_rate
self.forecast_length = forecast_length

logger.debug(
f"Setting up weights with decay rate {decay_rate} and of length {forecast_length}"
)

# set default rate of ln(2) if not set
if self.decay_rate is None:
self.decay_rate = math.log(2)

# make weights from decay rate
weights = torch.FloatTensor(
[math.exp(-self.decay_rate * i) for i in range(0, self.forecast_length)]
)

# normalized the weights, so there mean is 1.
# To calculate the loss, we times the weights by the differences between truth
# and predictions and then take the mean across all forecast horizons and the batch
self.weights = weights / weights.sum() * len(weights)

# move weights to gpu is needed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.weights = self.weights.to(device)

def get_mse_exp(self, output, target):
"""Loss function weighted MSE"""

# get the differences weighted by the forecast horizon weights
diff_with_weights = self.weights * ((output - target) ** 2)

# average across batches
return torch.mean(diff_with_weights)

def get_mae_exp(self, output, target):
"""Loss function weighted MAE"""

# get the differences weighted by the forecast horizon weights
diff_with_weights = self.weights * torch.abs(output - target)

# average across batches
return torch.mean(diff_with_weights)
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
ocf_datapipes>=2.2.5
nowcasting_utils
ocf_ml_metrics
numpy
pandas
Expand Down
69 changes: 69 additions & 0 deletions tests/models/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
import torch

from pvnet.models.utils import WeightedLosses


def test_weight_losses_weights():
"""Test weighted loss"""
forecast_length = 2
w = WeightedLosses(forecast_length=forecast_length)

assert w.weights.cpu().numpy()[0] == pytest.approx(4 / 3)
assert w.weights.cpu().numpy()[1] == pytest.approx(2 / 3)


def test_mae_exp():
"""Test MAE exp with weighted loss"""
forecast_length = 2
w = WeightedLosses(forecast_length=forecast_length)

output = torch.Tensor([[1, 3], [1, 3]])
target = torch.Tensor([[1, 5], [1, 9]])

loss = w.get_mae_exp(output=output, target=target)

# 0.5((1-1)*2/3 + (5-3)*1/3) + 0.5((1-1)*2/3 + (9-3)*1/3) = 1/3 + 3/3
assert loss == pytest.approx(4 / 3)


def test_mse_exp():
"""Test MSE exp with weighted loss"""
forecast_length = 2
w = WeightedLosses(forecast_length=forecast_length)

output = torch.Tensor([[1, 3], [1, 3]])
target = torch.Tensor([[1, 5], [1, 9]])

loss = w.get_mse_exp(output=output, target=target)

# 0.5((1-1)^2*2/3 + (5-3)^2*1/3) + 0.5((1-1)^2*2/3 + (9-3)^2*1/3) = 2/3 + 18/3
assert loss == pytest.approx(20 / 3)


def test_mae_exp_rand():
"""Test MAE exp with weighted loss with random tensors"""
forecast_length = 6
batch_size = 32

w = WeightedLosses(forecast_length=6)

output = torch.randn(batch_size, forecast_length)
target = torch.randn(batch_size, forecast_length)

loss = w.get_mae_exp(output=output, target=target)
assert loss > 0


def test_mse_exp_rand():
"""Test MSE exp with weighted loss with random tensors"""
forecast_length = 6
batch_size = 32

w = WeightedLosses(forecast_length=6)

output = torch.randn(batch_size, forecast_length)
target = torch.randn(batch_size, forecast_length)

loss = w.get_mse_exp(output=output, target=target)
assert loss > 0
Loading