diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index 79aa551a..3471d668 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -139,10 +139,11 @@ def update_other_metrics(self, outputs, step_name): def log_all_metrics(self, subset=None): if subset is not None: for metric_collection in (self.loss_metrics, self.question_loss_metrics, self.campaign_loss_metrics): + prog_bar = metric_collection == self.loss_metrics for name, metric in metric_collection.items(): if subset in name: logging.info(name) - self.log(name, metric, on_epoch=True, on_step=False, prog_bar=True, logger=True) + self.log(name, metric, on_epoch=True, on_step=False, prog_bar=prog_bar, logger=True) else: # just log everything self.log_dict(self.loss_metrics, on_epoch=True, on_step=False, prog_bar=True, logger=True) self.log_dict(self.question_loss_metrics, on_step=False, on_epoch=True, logger=True)