From 2d8437cb45bd2d805a75ef7ac64914fd4a748c40 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Wed, 23 Oct 2024 16:14:13 +0200 Subject: [PATCH] feat: upgraded progress-table to v2; closes #3 --- dmlcloud/stage.py | 17 +++++++---------- dmlcloud/util/logging.py | 18 ++++++++++++++++++ requirements.txt | 2 +- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/dmlcloud/stage.py b/dmlcloud/stage.py index 5f49d80..db5eab1 100644 --- a/dmlcloud/stage.py +++ b/dmlcloud/stage.py @@ -7,7 +7,8 @@ from progress_table import ProgressTable from .metrics import MetricTracker, Reduction -from .util.distributed import is_root, root_only +from .util.distributed import is_root +from .util.logging import DevNullIO, flush_log_handlers class Stage: @@ -139,22 +140,19 @@ def run(self): def _pre_stage(self): self.start_time = datetime.now() - self.table = ProgressTable(file=sys.stdout) + self.table = ProgressTable(file=sys.stdout if is_root else DevNullIO()) self._setup_table() if len(self.pipeline.stages) > 1: self.logger.info(f'\n========== STAGE: {self.name} ==========') self.pre_stage() - for handler in self.logger.handlers: - handler.flush() - if is_root(): - self.table._print_header() + flush_log_handlers(self.logger) + self.pipeline.barrier(self.barrier_timeout) def _post_stage(self): - if is_root(): - self.table.close() + self.table.close() self.post_stage() self.pipeline.barrier(self.barrier_timeout) self.stop_time = datetime.now() @@ -163,6 +161,7 @@ def _post_stage(self): def _pre_epoch(self): self.epoch_start_time = datetime.now() + self.table['Epoch'] = self.current_epoch self.pre_epoch() self.pipeline._pre_epoch() @@ -182,14 +181,12 @@ def _reduce_metrics(self): self.tracker.next_epoch() pass - @root_only def _setup_table(self): for column_dct in self._metrics(): display_name = column_dct.pop('name') column_dct.pop('metric') self.table.add_column(display_name, **column_dct) - @root_only def _update_table(self): self.table.update('Epoch', self.current_epoch) self.table.update('Time/Epoch', (datetime.now() - self.start_time) / self.current_epoch) diff --git a/dmlcloud/util/logging.py b/dmlcloud/util/logging.py index 5e6c56b..c0a9aac 100644 --- a/dmlcloud/util/logging.py +++ b/dmlcloud/util/logging.py @@ -1,3 +1,4 @@ +import io import logging import os import subprocess @@ -80,6 +81,15 @@ def __exit__(self, exc_type, exc_value, traceback): self.uninstall() +class DevNullIO(io.TextIOBase): + """ + Dummy TextIOBase that will simply ignore anything written to it similar to /dev/null + """ + + def write(self, msg): + pass + + def add_log_handlers(logger: logging.Logger): if logger.hasHandlers(): return @@ -98,6 +108,14 @@ def add_log_handlers(logger: logging.Logger): logger.addHandler(stderr_handler) +def flush_log_handlers(logger: logging.Logger): + """ + Flushes all handlers of the given logger. + """ + for handler in logger.handlers: + handler.flush() + + def experiment_header( name: str | None, checkpoint_dir: str | None, diff --git a/requirements.txt b/requirements.txt index 36db973..ff58da4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch numpy xarray -progress_table>=0.1.20,<1.0.0 +progress_table>=2.2.0 omegaconf