diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 6d03cab6..c5c40082 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -433,7 +433,7 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat): def training_step(self, batch, batch_idx): """Run training step""" y_hat = self(batch) - y = batch[self._target_key][:,0, -self.forecast_len_30 :] + y = batch[self._target_key][:, 0, -self.forecast_len_30 :] losses = self._calculate_common_losses(y, y_hat) losses = {f"{k}/train": v for k, v in losses.items()} @@ -454,7 +454,7 @@ def validation_step(self, batch: dict, batch_idx): print(f"{batch[self._target_key][:, -self.forecast_len_30 :, 0].shape=}") print(f"{self.forecast_len_30=}") # Sensor seems to be in batch, station, time order - y = batch[self._target_key][:,0, -self.forecast_len_30 :] + y = batch[self._target_key][:, 0, -self.forecast_len_30 :] losses = self._calculate_common_losses(y, y_hat) losses.update(self._calculate_val_losses(y, y_hat))