Skip to content

Commit

Permalink
fix and add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Sep 6, 2024
1 parent ffe9b15 commit d501107
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,31 +616,37 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p
def _log_validation_results(self, batch, y_hat, accum_batch_num):
"""Append validation results to self.validation_epoch_results"""

y = batch[self._target_key][:, -self.forecast_len :, 0]
batch_size = y.shape[0]
# get truth values, shape (b, forecast_len)
y = batch[self._target_key][:, -self.forecast_len:, 0]
y = y.detach().cpu().numpy()
batch_size = y.shape[0]

# get truth values, shape (b, forecast_len)
y_hat = y_hat.detach().cpu().numpy()

# get time_utc, shape (b, forecast_len)
time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"]
time_utc = batch[time_utc_key][:, -self.forecast_len:].detach().cpu().numpy()

# get target id and change from (b,1) to (b,)
id_key = BatchKey[f"{self._target_key_name}_id"]
target_id = batch[id_key].detach().cpu().numpy()
target_id = target_id.squeeze()

for i in range(batch_size):
y_i = y[i]
y_hat_i = y_hat[i]
time_utc = time_utc[i]
target_id = target_id[i]
if target_id.ndim > 0:
target_id = target_id[0]
time_utc_i = time_utc[i]
target_id_i = target_id[i]

results_df = pd.DataFrame(
{
"y": y_i,
"y_hat": y_hat_i,
"time_utc": time_utc,
"time_utc": time_utc_i,
}
)
results_df["id"] = target_id
results_df["id"] = target_id_i
results_df["batch_idx"] = accum_batch_num
results_df["example_idx"] = i

Expand Down

0 comments on commit d501107

Please sign in to comment.