From 78755d7890f22525ecc78122e78a9f731be27dfc Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Fri, 10 May 2024 18:28:43 +0200 Subject: [PATCH] feat: track number of batches --- dmlcloud/stage.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dmlcloud/stage.py b/dmlcloud/stage.py index 6aa6ce6..50981d7 100644 --- a/dmlcloud/stage.py +++ b/dmlcloud/stage.py @@ -300,6 +300,8 @@ def train_epoch(self): loss = self.train_step(batch) self.optimize(loss) self.track_reduce(self.loss_metric_name(), loss) + self.track_reduce('misc/total_train_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False) + self.track_reduce('misc/worker_train_batches', torch.tensor(1), reduction=Reduction.SUM, reduce_globally=False, prefixed=False) for name, scheduler in self.pipeline.schedulers.items(): self.track(f'misc/lr_{name}', scheduler.get_last_lr()[0], prefixed=False) @@ -314,6 +316,9 @@ def val_epoch(self): for batch in self.val_dataset(): loss = self.val_step(batch) self.track_reduce('loss', loss) + self.track_reduce('misc/total_val_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False) + self.track_reduce('misc/worker_val_batches', torch.tensor(1), reduction=Reduction.SUM, reduce_globally=False, prefixed=False) + def table_columns(self): columns = super().table_columns()