Skip to content

Commit

Permalink
feat: track step time
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed May 10, 2024
1 parent d7c11f1 commit f8ce8e6
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions dmlcloud/stage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from datetime import datetime
import time
from typing import Any, Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -296,12 +297,16 @@ def train_epoch(self):
train_ds.sampler.set_epoch(self.current_epoch)

for batch in train_ds:
step_start_time = time.perf_counter_ns()
self.zero_grad()
loss = self.train_step(batch)
self.optimize(loss)
step_end_time = time.perf_counter_ns()

self.track_reduce(self.loss_metric_name(), loss)
self.track_reduce('misc/total_train_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False)
self.track_reduce('misc/worker_train_batches', torch.tensor(1), reduction=Reduction.SUM, reduce_globally=False, prefixed=False)
self.track_reduce('misc/step_time_ms', torch.tensor(step_end_time-step_start_time)/1e6, prefixed=False)

for name, scheduler in self.pipeline.schedulers.items():
self.track(f'misc/lr_{name}', scheduler.get_last_lr()[0], prefixed=False)
Expand Down

0 comments on commit f8ce8e6

Please sign in to comment.