Skip to content

Commit

Permalink
Merge pull request #119 from openclimatefix/jacob/pv-generalize
Browse files Browse the repository at this point in the history
Add option for non-GSP forecast
  • Loading branch information
jacobbieker authored Jan 9, 2024
2 parents cbbcaf7 + 9c6917e commit 2c05bc2
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 13 deletions.
9 changes: 6 additions & 3 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def __init__(
forecast_minutes: int,
optimizer: AbstractOptimizer,
output_quantiles: Optional[list[float]] = None,
target_key: BatchKey = BatchKey.gsp,
):
"""Abtstract base class for PVNet submodels.
Expand All @@ -243,10 +244,12 @@ def __init__(
optimizer (AbstractOptimizer): Optimizer
output_quantiles: A list of float (0.0, 1.0) quantiles to predict values for. If set to
None the output is a single value.
target_key: The key of the target variable in the batch
"""
super().__init__()

self._optimizer = optimizer
self._target_key = target_key

# Model must have lr to allow tuning
# This setting is only used when lr is tuned with callback
Expand Down Expand Up @@ -424,7 +427,7 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
def training_step(self, batch, batch_idx):
"""Run training step"""
y_hat = self(batch)
y = batch[BatchKey.gsp][:, -self.forecast_len_30 :, 0]
y = batch[self._target_key][:, -self.forecast_len_30 :, 0]

losses = self._calculate_common_losses(y, y_hat)
losses = {f"{k}/train": v for k, v in losses.items()}
Expand All @@ -440,7 +443,7 @@ def training_step(self, batch, batch_idx):
def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""
y_hat = self(batch)
y = batch[BatchKey.gsp][:, -self.forecast_len_30 :, 0]
y = batch[self._target_key][:, -self.forecast_len_30 :, 0]

losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))
Expand Down Expand Up @@ -484,7 +487,7 @@ def validation_step(self, batch: dict, batch_idx):
def test_step(self, batch, batch_idx):
"""Run test step"""
y_hat = self(batch)
y = batch[BatchKey.gsp][:, -self.forecast_len_30 :, 0]
y = batch[self._target_key][:, -self.forecast_len_30 :, 0]

losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))
Expand Down
20 changes: 17 additions & 3 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
nwp_history_minutes: Optional[int] = None,
pv_history_minutes: Optional[int] = None,
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
target_key: str = "gsp",
):
"""Neural network which combines information from different sources.
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
pv_history_minutes: Length of recent site-level PV data data used as input. Defaults to
`history_minutes` if not provided.
optimizer: Optimizer factory function used for network.
target_key: The key of the target variable in the batch.
"""

self.include_gsp_yield_history = include_gsp_yield_history
Expand All @@ -103,8 +105,15 @@ def __init__(
self.include_sun = include_sun
self.embedding_dim = embedding_dim
self.add_image_embedding_channel = add_image_embedding_channel

super().__init__(history_minutes, forecast_minutes, optimizer, output_quantiles)
self.target_key_name = target_key

super().__init__(
history_minutes=history_minutes,
forecast_minutes=forecast_minutes,
optimizer=optimizer,
output_quantiles=output_quantiles,
target_key=BatchKey.gsp if target_key == "gsp" else BatchKey.pv,
)

# Number of features expected by the output_network
# Add to this as network pices are constructed
Expand Down Expand Up @@ -228,7 +237,12 @@ def forward(self, x):
# *********************** PV Data *************************************
# Add site-level PV yield
if self.include_pv:
modes["pv"] = self.pv_encoder(x)
if self.target_key_name != "pv":
modes["pv"] = self.pv_encoder(x)
else:
# Target is PV, so only take the history
pv_history = x[BatchKey.pv][:, : self.history_len_30].float()
modes["pv"] = self.pv_encoder(pv_history)

# *********************** GSP Data ************************************
# add gsp yield history
Expand Down
19 changes: 12 additions & 7 deletions pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,19 +242,24 @@ def finish(
wandb.finish()


def plot_batch_forecasts(batch, y_hat, batch_idx=None, quantiles=None):
def plot_batch_forecasts(batch, y_hat, batch_idx=None, quantiles=None, key_to_plot: str = "gsp"):
"""Plot a batch of data and the forecast from that batch"""

def _get_numpy(key):
return batch[key].cpu().numpy().squeeze()

y = batch[BatchKey.gsp].cpu().numpy()
y_key = BatchKey.gsp if key_to_plot == "gsp" else BatchKey.pv
y_id_key = BatchKey.gsp_id if key_to_plot == "gsp" else BatchKey.pv_id
t0_idx_key = BatchKey.gsp_t0_idx if key_to_plot == "gsp" else BatchKey.pv_t0_idx
time_utc_key = BatchKey.gsp_time_utc if key_to_plot == "gsp" else BatchKey.pv_time_utc
plotting_name = key_to_plot.upper()
y = batch[y_key].cpu().numpy()
y_hat = y_hat.cpu().numpy()

gsp_ids = batch[BatchKey.gsp_id].cpu().numpy().squeeze()
batch[BatchKey.gsp_t0_idx]
gsp_ids = batch[y_id_key].cpu().numpy().squeeze()
batch[t0_idx_key]

times_utc = batch[BatchKey.gsp_time_utc].cpu().numpy().squeeze().astype("datetime64[s]")
times_utc = batch[time_utc_key].cpu().numpy().squeeze().astype("datetime64[s]")
times_utc = [pd.to_datetime(t) for t in times_utc]

batch_size = y.shape[0]
Expand Down Expand Up @@ -294,9 +299,9 @@ def _get_numpy(key):
ax.set_xlabel("Time (hour of day)")

if batch_idx is not None:
title = f"Normed GSP output : batch_idx={batch_idx}"
title = f"Normed {plotting_name} output : batch_idx={batch_idx}"
else:
title = "Normed GSP output"
title = f"Normed {plotting_name} output"
plt.suptitle(title)
plt.tight_layout()

Expand Down

0 comments on commit 2c05bc2

Please sign in to comment.