From 5360394e3c46bea4e9cacc5d1c5e34d63db6408e Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Fri, 25 Oct 2024 13:50:21 +0100 Subject: [PATCH] add flushing of val epoch resluts (#256) * add flushing of val epoch resluts * add docstring back base_model.py * log to csv only when end of accumbatch --- pvnet/models/base_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 8665bf3c..83e67e7b 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -666,7 +666,8 @@ def validation_step(self, batch: dict, batch_idx): # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] - self._log_validation_results(batch, y_hat, accum_batch_num) + if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: + self._log_validation_results(batch, y_hat, accum_batch_num) # Expand persistence to be the same shape as y losses = self._calculate_common_losses(y, y_hat) @@ -743,6 +744,7 @@ def on_validation_epoch_end(self): print("Failed to log validation results to wandb") print(e) + self.validation_epoch_results = [] horizon_maes_dict = self._horizon_maes.flush() # Create the horizon accuracy curve