Skip to content

Commit

Permalink
feat: upgraded progress-table to v2; closes #3
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Oct 23, 2024
1 parent a78a471 commit 2d8437c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
17 changes: 7 additions & 10 deletions dmlcloud/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from progress_table import ProgressTable

from .metrics import MetricTracker, Reduction
from .util.distributed import is_root, root_only
from .util.distributed import is_root
from .util.logging import DevNullIO, flush_log_handlers


class Stage:
Expand Down Expand Up @@ -139,22 +140,19 @@ def run(self):

def _pre_stage(self):
self.start_time = datetime.now()
self.table = ProgressTable(file=sys.stdout)
self.table = ProgressTable(file=sys.stdout if is_root else DevNullIO())
self._setup_table()
if len(self.pipeline.stages) > 1:
self.logger.info(f'\n========== STAGE: {self.name} ==========')

self.pre_stage()

for handler in self.logger.handlers:
handler.flush()
if is_root():
self.table._print_header()
flush_log_handlers(self.logger)

self.pipeline.barrier(self.barrier_timeout)

def _post_stage(self):
if is_root():
self.table.close()
self.table.close()
self.post_stage()
self.pipeline.barrier(self.barrier_timeout)
self.stop_time = datetime.now()
Expand All @@ -163,6 +161,7 @@ def _post_stage(self):

def _pre_epoch(self):
self.epoch_start_time = datetime.now()
self.table['Epoch'] = self.current_epoch
self.pre_epoch()
self.pipeline._pre_epoch()

Expand All @@ -182,14 +181,12 @@ def _reduce_metrics(self):
self.tracker.next_epoch()
pass

@root_only
def _setup_table(self):
for column_dct in self._metrics():
display_name = column_dct.pop('name')
column_dct.pop('metric')
self.table.add_column(display_name, **column_dct)

@root_only
def _update_table(self):
self.table.update('Epoch', self.current_epoch)
self.table.update('Time/Epoch', (datetime.now() - self.start_time) / self.current_epoch)
Expand Down
18 changes: 18 additions & 0 deletions dmlcloud/util/logging.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import logging
import os
import subprocess
Expand Down Expand Up @@ -80,6 +81,15 @@ def __exit__(self, exc_type, exc_value, traceback):
self.uninstall()


class DevNullIO(io.TextIOBase):
"""
Dummy TextIOBase that will simply ignore anything written to it similar to /dev/null
"""

def write(self, msg):
pass


def add_log_handlers(logger: logging.Logger):
if logger.hasHandlers():
return
Expand All @@ -98,6 +108,14 @@ def add_log_handlers(logger: logging.Logger):
logger.addHandler(stderr_handler)


def flush_log_handlers(logger: logging.Logger):
"""
Flushes all handlers of the given logger.
"""
for handler in logger.handlers:
handler.flush()


def experiment_header(
name: str | None,
checkpoint_dir: str | None,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch
numpy
xarray
progress_table>=0.1.20,<1.0.0
progress_table>=2.2.0
omegaconf

0 comments on commit 2d8437c

Please sign in to comment.