diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 8665bf3c..83e67e7b 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -666,7 +666,8 @@ def validation_step(self, batch: dict, batch_idx): # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] - self._log_validation_results(batch, y_hat, accum_batch_num) + if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: + self._log_validation_results(batch, y_hat, accum_batch_num) # Expand persistence to be the same shape as y losses = self._calculate_common_losses(y, y_hat) @@ -743,6 +744,7 @@ def on_validation_epoch_end(self): print("Failed to log validation results to wandb") print(e) + self.validation_epoch_results = [] horizon_maes_dict = self._horizon_maes.flush() # Create the horizon accuracy curve