diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index d9bcdb3c..a9fca650 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -621,7 +621,7 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y = y.detach().cpu().numpy() y_hat = y_hat.detach().cpu().numpy() time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] - time_utc = batch[time_utc_key][:, -self.forecast_len:].detach().cpu().numpy() + time_utc = batch[time_utc_key][:, -self.forecast_len :].detach().cpu().numpy() id_key = BatchKey[f"{self._target_key_name}_id"] target_id = batch[id_key].detach().cpu().numpy()