From 9a0027172225362c0ee02d78aef9c150bca190b0 Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Mon, 15 Jan 2024 13:51:02 -0500 Subject: [PATCH] tweak logging --- zoobot/pytorch/estimators/define_model.py | 24 ++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index 53331422..b5ba1e57 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -120,12 +120,13 @@ def on_train_epoch_end(self) -> None: # called *after* on_validation_epoch_end, confusingly # do NOT log_all_metrics here. # logging a metric resets it, and on_validation_epoch_end just logged and reset everything, so you will only log nans - pass + self.log_all_metrics(subset='train') def on_validation_epoch_end(self) -> None: - # raise ValueError('val epoch end') - # called at end of val epoch, but BEFORE on_train_epoch_end - self.log_all_metrics() # logs all metrics, so can do in val only + self.log_all_metrics(subset='validation') + + def on_test_epoch_end(self) -> None: + self.log_all_metrics(subset='test') def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name): raise NotImplementedError('Must be subclassed') @@ -133,11 +134,16 @@ def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name) def update_other_metrics(self, outputs, step_name): raise NotImplementedError('Must be subclassed') - def log_all_metrics(self): - - self.log_dict(self.loss_metrics, on_epoch=True, on_step=False, prog_bar=True, logger=True) - self.log_dict(self.question_loss_metrics, on_step=False, on_epoch=True, logger=True) - self.log_dict(self.campaign_loss_metrics, on_step=False, on_epoch=True, logger=True) + def log_all_metrics(self, subset=None): + if subset is not None: + for name, metric in self.loss_metrics.items(): + if subset in name: + print('logging', name) + self.log(name, metric, on_epoch=True, on_step=False, prog_bar=True, logger=True) + else: # just log everything + self.log_dict(self.loss_metrics, on_epoch=True, on_step=False, prog_bar=True, logger=True) + self.log_dict(self.question_loss_metrics, on_step=False, on_epoch=True, logger=True) + self.log_dict(self.campaign_loss_metrics, on_step=False, on_epoch=True, logger=True)