From f8ce8e63e4080078bcb84b753b02dbea520ac215 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Fri, 10 May 2024 20:13:00 +0200 Subject: [PATCH] feat: track step time --- dmlcloud/stage.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dmlcloud/stage.py b/dmlcloud/stage.py index 50981d7..41079d6 100644 --- a/dmlcloud/stage.py +++ b/dmlcloud/stage.py @@ -1,5 +1,6 @@ import sys from datetime import datetime +import time from typing import Any, Dict, List, Optional, Union import torch @@ -296,12 +297,16 @@ def train_epoch(self): train_ds.sampler.set_epoch(self.current_epoch) for batch in train_ds: + step_start_time = time.perf_counter_ns() self.zero_grad() loss = self.train_step(batch) self.optimize(loss) + step_end_time = time.perf_counter_ns() + self.track_reduce(self.loss_metric_name(), loss) self.track_reduce('misc/total_train_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False) self.track_reduce('misc/worker_train_batches', torch.tensor(1), reduction=Reduction.SUM, reduce_globally=False, prefixed=False) + self.track_reduce('misc/step_time_ms', torch.tensor(step_end_time-step_start_time)/1e6, prefixed=False) for name, scheduler in self.pipeline.schedulers.items(): self.track(f'misc/lr_{name}', scheduler.get_last_lr()[0], prefixed=False)