diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 7f9f54f0..d9ebb482 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -623,18 +623,13 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num): y_i = y[i].detach().cpu().numpy() y_hat_i = y_hat[i].detach().cpu().numpy() - print(BatchKey._member_map_) - try: - time_utc_key = BatchKey[f"{self._target_key}_time_utc"] - except Exception as e: - raise Exception( - f"Failed to find time_utc key for {self._target_key}, {BatchKey._member_map_}, {e}" - ) - + time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"] time_utc = batch[time_utc_key][i, -self.forecast_len :].detach().cpu().numpy() - id_key = BatchKey[f"{self._target_key}_id"] + id_key = BatchKey[f"{self._target_key_name}_id"] target_id = batch[id_key][i].detach().cpu().numpy() + if target_id.ndim > 0: + target_id = target_id[0] results_df = pd.DataFrame( {