Skip to content

Commit

Permalink
feat: track number of batches
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed May 10, 2024
1 parent a9ce5ac commit 78755d7
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions dmlcloud/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def train_epoch(self):
loss = self.train_step(batch)
self.optimize(loss)
self.track_reduce(self.loss_metric_name(), loss)
self.track_reduce('misc/total_train_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False)
self.track_reduce('misc/worker_train_batches', torch.tensor(1), reduction=Reduction.SUM, reduce_globally=False, prefixed=False)

for name, scheduler in self.pipeline.schedulers.items():
self.track(f'misc/lr_{name}', scheduler.get_last_lr()[0], prefixed=False)
Expand All @@ -314,6 +316,9 @@ def val_epoch(self):
for batch in self.val_dataset():
loss = self.val_step(batch)
self.track_reduce('loss', loss)
self.track_reduce('misc/total_val_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False)
self.track_reduce('misc/worker_val_batches', torch.tensor(1), reduction=Reduction.SUM, reduce_globally=False, prefixed=False)


def table_columns(self):
columns = super().table_columns()
Expand Down

0 comments on commit 78755d7

Please sign in to comment.