From c0b2e41d286b0fd04e4f6ab26cb7deb171a7ff8e Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Mon, 15 Jan 2024 13:24:57 -0500 Subject: [PATCH] add test metrics --- zoobot/pytorch/estimators/define_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index 6e958819..53331422 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -64,11 +64,12 @@ def setup_metrics(self, nan_strategy='error'): # may sometimes want to ignore n self.loss_metrics = torch.nn.ModuleDict({ 'train/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy), 'validation/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy), + 'test/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy), }) # TODO handle when schema doesn't exist question_metric_dict = {} - for step_name in ['train', 'validation']: # TODO test + for step_name in ['train', 'validation', 'test']: question_metric_dict.update({ step_name + '/question_loss/' + question.text: torchmetrics.MeanMetric(nan_strategy='ignore') for question in self.schema.questions @@ -77,7 +78,7 @@ def setup_metrics(self, nan_strategy='error'): # may sometimes want to ignore n campaigns = schema_to_campaigns(self.schema) campaign_metric_dict = {} - for step_name in ['train', 'validation']: + for step_name in ['train', 'validation', 'test']: campaign_metric_dict.update({ step_name + '/campaign_loss/' + campaign: torchmetrics.MeanMetric(nan_strategy='ignore') for campaign in campaigns