From d501107e10ecf6a339cd739162f673c9788a88e6 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Fri, 6 Sep 2024 08:45:39 +0100 Subject: [PATCH] fix and add comments --- pvnet/models/base_model.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index d9bcdb3c..5cece4fc 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -616,31 +616,37 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p def _log_validation_results(self, batch, y_hat, accum_batch_num): """Append validation results to self.validation_epoch_results""" - y = batch[self._target_key][:, -self.forecast_len :, 0] - batch_size = y.shape[0] + # get truth values, shape (b, forecast_len) + y = batch[self._target_key][:, -self.forecast_len:, 0] y = y.detach().cpu().numpy() + batch_size = y.shape[0] + + # get truth values, shape (b, forecast_len) y_hat = y_hat.detach().cpu().numpy() + + # get time_utc, shape (b, forecast_len) time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] time_utc = batch[time_utc_key][:, -self.forecast_len:].detach().cpu().numpy() + + # get target id and change from (b,1) to (b,) id_key = BatchKey[f"{self._target_key_name}_id"] target_id = batch[id_key].detach().cpu().numpy() + target_id = target_id.squeeze() for i in range(batch_size): y_i = y[i] y_hat_i = y_hat[i] - time_utc = time_utc[i] - target_id = target_id[i] - if target_id.ndim > 0: - target_id = target_id[0] + time_utc_i = time_utc[i] + target_id_i = target_id[i] results_df = pd.DataFrame( { "y": y_i, "y_hat": y_hat_i, - "time_utc": time_utc, + "time_utc": time_utc_i, } ) - results_df["id"] = target_id + results_df["id"] = target_id_i results_df["batch_idx"] = accum_batch_num results_df["example_idx"] = i