diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 46706c00..b4d42f22 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -245,7 +245,7 @@ def save_pretrained( class BaseModel(pl.LightningModule, PVNetModelHubMixin): - """Abtstract base class for PVNet submodels""" + """Abstract base class for PVNet submodels""" def __init__( self, @@ -257,6 +257,7 @@ def __init__( interval_minutes: int = 30, timestep_intervals_to_plot: Optional[list[int]] = None, use_weighted_loss: bool = False, + forecast_minutes_ignore: Optional[int] = 0, ): """Abtstract base class for PVNet submodels. @@ -270,6 +271,8 @@ def __init__( interval_minutes: The interval in minutes between each timestep in the data timestep_intervals_to_plot: Intervals, in timesteps, to plot during training use_weighted_loss: Whether to use a weighted loss function + forecast_minutes_ignore: Number of forecast minutes to ignore when calculating losses. + For example if set to 60, the model doesnt predict the first 60 minutes """ super().__init__() @@ -292,10 +295,11 @@ def __init__( self.forecast_minutes = forecast_minutes self.output_quantiles = output_quantiles self.interval_minutes = interval_minutes + self.forecast_minutes_ignore = forecast_minutes_ignore # Number of timestemps for 30 minutely data self.history_len = history_minutes // interval_minutes - self.forecast_len = forecast_minutes // interval_minutes + self.forecast_len = (forecast_minutes - forecast_minutes_ignore) // interval_minutes self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len) @@ -334,7 +338,7 @@ def _quantiles_to_prediction(self, y_quantiles): y_median = y_quantiles[..., idx] return y_median - def _calculate_qauntile_loss(self, y_quantiles, y): + def _calculate_quantile_loss(self, y_quantiles, y): """Calculate quantile loss. Note: diff --git a/tests/conftest.py b/tests/conftest.py index 4ea9e3f1..990f8e3c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -303,3 +303,12 @@ def multimodal_weighted_quantile_model(multimodal_model_kwargs): output_quantiles=[0.1, 0.5, 0.9], **multimodal_model_kwargs, use_weighted_loss=True ) return model + + +@pytest.fixture() +def multimodal_quantile_model_ignore_minutes(multimodal_model_kwargs): + """ Only forecsat second half of the 8 hours""" + model = Model( + output_quantiles=[0.1, 0.5, 0.9], **multimodal_model_kwargs, ignore_minutes=240 + ) + return model diff --git a/tests/models/multimodal/test_multimodal.py b/tests/models/multimodal/test_multimodal.py index 0cc9a7df..c112e6a3 100644 --- a/tests/models/multimodal/test_multimodal.py +++ b/tests/models/multimodal/test_multimodal.py @@ -51,3 +51,15 @@ def test_weighted_quantile_model_backward(multimodal_weighted_quantile_model, sa # Backwards on sum drives sum to zero y_quantiles.sum().backward() + + +def test_weighted_quantile_model_forward(multimodal_quantile_model_ignore_minutes, sample_batch): + y_quantiles = multimodal_quantile_model_ignore_minutes(sample_batch) + + # check output is the correct shape + # batch size=2, forecast_len=8, num_quantiles=3 + assert tuple(y_quantiles.shape) == (2, 8, 3), y_quantiles.shape + + # Backwards on sum drives sum to zero + y_quantiles.sum().backward() +