Skip to content

Commit

Permalink
option to ignore first part of forecast in the model
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed May 30, 2024
1 parent 80dad59 commit 8adcc9b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def save_pretrained(


class BaseModel(pl.LightningModule, PVNetModelHubMixin):
"""Abtstract base class for PVNet submodels"""
"""Abstract base class for PVNet submodels"""

def __init__(
self,
Expand All @@ -257,6 +257,7 @@ def __init__(
interval_minutes: int = 30,
timestep_intervals_to_plot: Optional[list[int]] = None,
use_weighted_loss: bool = False,
forecast_minutes_ignore: Optional[int] = 0,
):
"""Abtstract base class for PVNet submodels.
Expand All @@ -270,6 +271,8 @@ def __init__(
interval_minutes: The interval in minutes between each timestep in the data
timestep_intervals_to_plot: Intervals, in timesteps, to plot during training
use_weighted_loss: Whether to use a weighted loss function
forecast_minutes_ignore: Number of forecast minutes to ignore when calculating losses.
For example if set to 60, the model doesnt predict the first 60 minutes
"""
super().__init__()

Expand All @@ -292,10 +295,11 @@ def __init__(
self.forecast_minutes = forecast_minutes
self.output_quantiles = output_quantiles
self.interval_minutes = interval_minutes
self.forecast_minutes_ignore = forecast_minutes_ignore

# Number of timestemps for 30 minutely data
self.history_len = history_minutes // interval_minutes
self.forecast_len = forecast_minutes // interval_minutes
self.forecast_len = (forecast_minutes - forecast_minutes_ignore) // interval_minutes

self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len)

Expand Down Expand Up @@ -334,7 +338,7 @@ def _quantiles_to_prediction(self, y_quantiles):
y_median = y_quantiles[..., idx]
return y_median

def _calculate_qauntile_loss(self, y_quantiles, y):
def _calculate_quantile_loss(self, y_quantiles, y):
"""Calculate quantile loss.
Note:
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,12 @@ def multimodal_weighted_quantile_model(multimodal_model_kwargs):
output_quantiles=[0.1, 0.5, 0.9], **multimodal_model_kwargs, use_weighted_loss=True
)
return model


@pytest.fixture()
def multimodal_quantile_model_ignore_minutes(multimodal_model_kwargs):
""" Only forecsat second half of the 8 hours"""
model = Model(
output_quantiles=[0.1, 0.5, 0.9], **multimodal_model_kwargs, ignore_minutes=240
)
return model
12 changes: 12 additions & 0 deletions tests/models/multimodal/test_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,15 @@ def test_weighted_quantile_model_backward(multimodal_weighted_quantile_model, sa

# Backwards on sum drives sum to zero
y_quantiles.sum().backward()


def test_weighted_quantile_model_forward(multimodal_quantile_model_ignore_minutes, sample_batch):
y_quantiles = multimodal_quantile_model_ignore_minutes(sample_batch)

# check output is the correct shape
# batch size=2, forecast_len=8, num_quantiles=3
assert tuple(y_quantiles.shape) == (2, 8, 3), y_quantiles.shape

# Backwards on sum drives sum to zero
y_quantiles.sum().backward()

0 comments on commit 8adcc9b

Please sign in to comment.