From 7c56ad3bd3c9ef394e45504c972e43bcec8ac115 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 9 Jan 2024 10:56:41 +0000 Subject: [PATCH 1/5] Add option for non-GSP forecast --- pvnet/models/base_model.py | 8 +++++--- pvnet/models/multimodal/multimodal.py | 20 +++++++++++++++++--- pvnet/utils.py | 19 ++++++++++++------- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index b056a90c..5cdae844 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -234,6 +234,7 @@ def __init__( forecast_minutes: int, optimizer: AbstractOptimizer, output_quantiles: Optional[list[float]] = None, + target_key: BatchKey = BatchKey.gsp, ): """Abtstract base class for PVNet submodels. @@ -247,6 +248,7 @@ def __init__( super().__init__() self._optimizer = optimizer + self._target_key = target_key # Model must have lr to allow tuning # This setting is only used when lr is tuned with callback @@ -424,7 +426,7 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat): def training_step(self, batch, batch_idx): """Run training step""" y_hat = self(batch) - y = batch[BatchKey.gsp][:, -self.forecast_len_30 :, 0] + y = batch[self._target_key][:, -self.forecast_len_30 :, 0] losses = self._calculate_common_losses(y, y_hat) losses = {f"{k}/train": v for k, v in losses.items()} @@ -440,7 +442,7 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch: dict, batch_idx): """Run validation step""" y_hat = self(batch) - y = batch[BatchKey.gsp][:, -self.forecast_len_30 :, 0] + y = batch[self._target_key][:, -self.forecast_len_30 :, 0] losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) @@ -484,7 +486,7 @@ def validation_step(self, batch: dict, batch_idx): def test_step(self, batch, batch_idx): """Run test step""" y_hat = self(batch) - y = batch[BatchKey.gsp][:, -self.forecast_len_30 :, 0] + y = batch[self._target_key][:, -self.forecast_len_30 :, 0] losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index e711c9af..bdf2b339 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -54,6 +54,7 @@ def __init__( nwp_history_minutes: Optional[int] = None, pv_history_minutes: Optional[int] = None, optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), + target_key: str = "gsp", ): """Neural network which combines information from different sources. @@ -94,6 +95,7 @@ def __init__( pv_history_minutes: Length of recent site-level PV data data used as input. Defaults to `history_minutes` if not provided. optimizer: Optimizer factory function used for network. + target_key: The key of the target variable in the batch. """ self.include_gsp_yield_history = include_gsp_yield_history @@ -103,8 +105,15 @@ def __init__( self.include_sun = include_sun self.embedding_dim = embedding_dim self.add_image_embedding_channel = add_image_embedding_channel - - super().__init__(history_minutes, forecast_minutes, optimizer, output_quantiles) + self.target_key_name = target_key + + super().__init__( + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + optimizer=optimizer, + output_quantiles=output_quantiles, + target_key=BatchKey.gsp if target_key == "gsp" else BatchKey.pv, + ) # Number of features expected by the output_network # Add to this as network pices are constructed @@ -228,7 +237,12 @@ def forward(self, x): # *********************** PV Data ************************************* # Add site-level PV yield if self.include_pv: - modes["pv"] = self.pv_encoder(x) + if self.target_key_name != "pv": + modes["pv"] = self.pv_encoder(x) + else: + # Target is PV, so only take the history + pv_history = x[BatchKey.pv][:, : self.history_len_30].float() + modes["pv"] = self.pv_encoder(pv_history) # *********************** GSP Data ************************************ # add gsp yield history diff --git a/pvnet/utils.py b/pvnet/utils.py index 33e26372..efe58fd8 100644 --- a/pvnet/utils.py +++ b/pvnet/utils.py @@ -242,19 +242,24 @@ def finish( wandb.finish() -def plot_batch_forecasts(batch, y_hat, batch_idx=None, quantiles=None): +def plot_batch_forecasts(batch, y_hat, batch_idx=None, quantiles=None, key_to_plot: str = "gsp"): """Plot a batch of data and the forecast from that batch""" def _get_numpy(key): return batch[key].cpu().numpy().squeeze() - y = batch[BatchKey.gsp].cpu().numpy() + y_key = BatchKey.gsp if key_to_plot == "gsp" else BatchKey.wind + y_id_key = BatchKey.gsp_id if key_to_plot == "gsp" else BatchKey.wind_id + t0_idx_key = BatchKey.gsp_t0_idx if key_to_plot == "gsp" else BatchKey.wind_t0_idx + time_utc_key = BatchKey.gsp_time_utc if key_to_plot == "gsp" else BatchKey.wind_time_utc + plotting_name = key_to_plot.upper() + y = batch[y_key].cpu().numpy() y_hat = y_hat.cpu().numpy() - gsp_ids = batch[BatchKey.gsp_id].cpu().numpy().squeeze() - t0_idx = batch[BatchKey.gsp_t0_idx] + gsp_ids = batch[y_id_key].cpu().numpy().squeeze() + t0_idx = batch[t0_idx_key] - times_utc = batch[BatchKey.gsp_time_utc].cpu().numpy().squeeze().astype("datetime64[s]") + times_utc = batch[time_utc_key].cpu().numpy().squeeze().astype("datetime64[s]") times_utc = [pd.to_datetime(t) for t in times_utc] len(times_utc[0]) - t0_idx - 1 @@ -295,9 +300,9 @@ def _get_numpy(key): ax.set_xlabel("Time (hour of day)") if batch_idx is not None: - title = f"Normed GSP output : batch_idx={batch_idx}" + title = f"Normed {plotting_name} output : batch_idx={batch_idx}" else: - title = "Normed GSP output" + title = f"Normed {plotting_name} output" plt.suptitle(title) plt.tight_layout() From b8695a6b673bb2909d2e5540d8cbb2aa9dc6c5c6 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 9 Jan 2024 10:56:41 +0000 Subject: [PATCH 2/5] Add option for non-GSP forecast --- pvnet/models/base_model.py | 8 +++++--- pvnet/models/multimodal/multimodal.py | 20 +++++++++++++++++--- pvnet/utils.py | 19 ++++++++++++------- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index b056a90c..5cdae844 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -234,6 +234,7 @@ def __init__( forecast_minutes: int, optimizer: AbstractOptimizer, output_quantiles: Optional[list[float]] = None, + target_key: BatchKey = BatchKey.gsp, ): """Abtstract base class for PVNet submodels. @@ -247,6 +248,7 @@ def __init__( super().__init__() self._optimizer = optimizer + self._target_key = target_key # Model must have lr to allow tuning # This setting is only used when lr is tuned with callback @@ -424,7 +426,7 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat): def training_step(self, batch, batch_idx): """Run training step""" y_hat = self(batch) - y = batch[BatchKey.gsp][:, -self.forecast_len_30 :, 0] + y = batch[self._target_key][:, -self.forecast_len_30 :, 0] losses = self._calculate_common_losses(y, y_hat) losses = {f"{k}/train": v for k, v in losses.items()} @@ -440,7 +442,7 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch: dict, batch_idx): """Run validation step""" y_hat = self(batch) - y = batch[BatchKey.gsp][:, -self.forecast_len_30 :, 0] + y = batch[self._target_key][:, -self.forecast_len_30 :, 0] losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) @@ -484,7 +486,7 @@ def validation_step(self, batch: dict, batch_idx): def test_step(self, batch, batch_idx): """Run test step""" y_hat = self(batch) - y = batch[BatchKey.gsp][:, -self.forecast_len_30 :, 0] + y = batch[self._target_key][:, -self.forecast_len_30 :, 0] losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat)) diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index e711c9af..bdf2b339 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -54,6 +54,7 @@ def __init__( nwp_history_minutes: Optional[int] = None, pv_history_minutes: Optional[int] = None, optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), + target_key: str = "gsp", ): """Neural network which combines information from different sources. @@ -94,6 +95,7 @@ def __init__( pv_history_minutes: Length of recent site-level PV data data used as input. Defaults to `history_minutes` if not provided. optimizer: Optimizer factory function used for network. + target_key: The key of the target variable in the batch. """ self.include_gsp_yield_history = include_gsp_yield_history @@ -103,8 +105,15 @@ def __init__( self.include_sun = include_sun self.embedding_dim = embedding_dim self.add_image_embedding_channel = add_image_embedding_channel - - super().__init__(history_minutes, forecast_minutes, optimizer, output_quantiles) + self.target_key_name = target_key + + super().__init__( + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + optimizer=optimizer, + output_quantiles=output_quantiles, + target_key=BatchKey.gsp if target_key == "gsp" else BatchKey.pv, + ) # Number of features expected by the output_network # Add to this as network pices are constructed @@ -228,7 +237,12 @@ def forward(self, x): # *********************** PV Data ************************************* # Add site-level PV yield if self.include_pv: - modes["pv"] = self.pv_encoder(x) + if self.target_key_name != "pv": + modes["pv"] = self.pv_encoder(x) + else: + # Target is PV, so only take the history + pv_history = x[BatchKey.pv][:, : self.history_len_30].float() + modes["pv"] = self.pv_encoder(pv_history) # *********************** GSP Data ************************************ # add gsp yield history diff --git a/pvnet/utils.py b/pvnet/utils.py index 33e26372..efe58fd8 100644 --- a/pvnet/utils.py +++ b/pvnet/utils.py @@ -242,19 +242,24 @@ def finish( wandb.finish() -def plot_batch_forecasts(batch, y_hat, batch_idx=None, quantiles=None): +def plot_batch_forecasts(batch, y_hat, batch_idx=None, quantiles=None, key_to_plot: str = "gsp"): """Plot a batch of data and the forecast from that batch""" def _get_numpy(key): return batch[key].cpu().numpy().squeeze() - y = batch[BatchKey.gsp].cpu().numpy() + y_key = BatchKey.gsp if key_to_plot == "gsp" else BatchKey.wind + y_id_key = BatchKey.gsp_id if key_to_plot == "gsp" else BatchKey.wind_id + t0_idx_key = BatchKey.gsp_t0_idx if key_to_plot == "gsp" else BatchKey.wind_t0_idx + time_utc_key = BatchKey.gsp_time_utc if key_to_plot == "gsp" else BatchKey.wind_time_utc + plotting_name = key_to_plot.upper() + y = batch[y_key].cpu().numpy() y_hat = y_hat.cpu().numpy() - gsp_ids = batch[BatchKey.gsp_id].cpu().numpy().squeeze() - t0_idx = batch[BatchKey.gsp_t0_idx] + gsp_ids = batch[y_id_key].cpu().numpy().squeeze() + t0_idx = batch[t0_idx_key] - times_utc = batch[BatchKey.gsp_time_utc].cpu().numpy().squeeze().astype("datetime64[s]") + times_utc = batch[time_utc_key].cpu().numpy().squeeze().astype("datetime64[s]") times_utc = [pd.to_datetime(t) for t in times_utc] len(times_utc[0]) - t0_idx - 1 @@ -295,9 +300,9 @@ def _get_numpy(key): ax.set_xlabel("Time (hour of day)") if batch_idx is not None: - title = f"Normed GSP output : batch_idx={batch_idx}" + title = f"Normed {plotting_name} output : batch_idx={batch_idx}" else: - title = "Normed GSP output" + title = f"Normed {plotting_name} output" plt.suptitle(title) plt.tight_layout() From 6f4320be117dba7b9252681fc3722b57bc1b1efc Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 9 Jan 2024 11:02:26 +0000 Subject: [PATCH 3/5] Update PV keys --- pvnet/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pvnet/utils.py b/pvnet/utils.py index efe58fd8..7ce26d50 100644 --- a/pvnet/utils.py +++ b/pvnet/utils.py @@ -248,10 +248,10 @@ def plot_batch_forecasts(batch, y_hat, batch_idx=None, quantiles=None, key_to_pl def _get_numpy(key): return batch[key].cpu().numpy().squeeze() - y_key = BatchKey.gsp if key_to_plot == "gsp" else BatchKey.wind - y_id_key = BatchKey.gsp_id if key_to_plot == "gsp" else BatchKey.wind_id - t0_idx_key = BatchKey.gsp_t0_idx if key_to_plot == "gsp" else BatchKey.wind_t0_idx - time_utc_key = BatchKey.gsp_time_utc if key_to_plot == "gsp" else BatchKey.wind_time_utc + y_key = BatchKey.gsp if key_to_plot == "gsp" else BatchKey.pv + y_id_key = BatchKey.gsp_id if key_to_plot == "gsp" else BatchKey.pv_id + t0_idx_key = BatchKey.gsp_t0_idx if key_to_plot == "gsp" else BatchKey.pv_t0_idx + time_utc_key = BatchKey.gsp_time_utc if key_to_plot == "gsp" else BatchKey.pv_time_utc plotting_name = key_to_plot.upper() y = batch[y_key].cpu().numpy() y_hat = y_hat.cpu().numpy() From e86ea9f87a5b8d3ab006193d5cd3fca8c2140f57 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 9 Jan 2024 11:04:26 +0000 Subject: [PATCH 4/5] Lint fix --- pvnet/models/base_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 5cdae844..4c665a50 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -244,6 +244,7 @@ def __init__( optimizer (AbstractOptimizer): Optimizer output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to None the output is a single value. + target_key: The key of the target variable in the batch """ super().__init__() From 9c6917e3ef75a2e1306681c8f0f07a6f4a975d59 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Jan 2024 12:06:46 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pvnet/utils.py b/pvnet/utils.py index cbcf04ec..a88fae08 100644 --- a/pvnet/utils.py +++ b/pvnet/utils.py @@ -257,7 +257,7 @@ def _get_numpy(key): y_hat = y_hat.cpu().numpy() gsp_ids = batch[y_id_key].cpu().numpy().squeeze() - t0_idx = batch[t0_idx_key] + batch[t0_idx_key] times_utc = batch[time_utc_key].cpu().numpy().squeeze().astype("datetime64[s]") times_utc = [pd.to_datetime(t) for t in times_utc]