diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 5cece4fc..29b909b2 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -617,7 +617,7 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): """Append validation results to self.validation_epoch_results""" # get truth values, shape (b, forecast_len) - y = batch[self._target_key][:, -self.forecast_len:, 0] + y = batch[self._target_key][:, -self.forecast_len :, 0] y = y.detach().cpu().numpy() batch_size = y.shape[0] @@ -626,7 +626,7 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): # get time_utc, shape (b, forecast_len) time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] - time_utc = batch[time_utc_key][:, -self.forecast_len:].detach().cpu().numpy() + time_utc = batch[time_utc_key][:, -self.forecast_len :].detach().cpu().numpy() # get target id and change from (b,1) to (b,) id_key = BatchKey[f"{self._target_key_name}_id"]