diff --git a/README.md b/README.md index a3cd0f4..0bd38e1 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ pip install dmlcloud Alternatively, you can install the latest development version directly from Github: ``` -pip install git+https://github.com/tangentlabs/django-oscar-paypal.git@issue/34/oscar-0.6 +pip install git+https://github.com/sehoffmann/dmlcloud.git ``` ## Minimal Example diff --git a/dmlcloud/core/metrics.py b/dmlcloud/core/metrics.py index 3d9dc5c..b054465 100644 --- a/dmlcloud/core/metrics.py +++ b/dmlcloud/core/metrics.py @@ -1,326 +1,214 @@ +from collections import namedtuple from enum import Enum +from typing import Any, Union +import numpy as np import torch -import torch.distributed as dist +import torchmetrics +from numpy.typing import ArrayLike __all__ = [ - 'Reduction', - 'reduce_tensor', - 'MetricReducer', - 'MetricTracker', + 'TrainingHistory', + 'Tracker', ] -class Reduction(Enum): +class TrainingHistory: """ - Reduction operation - """ - - MEAN = 'MEAN' - SUM = 'SUM' - MIN = 'MIN' - MAX = 'MAX' - - def as_torch(self): - """ - Returns the corresponding torch.distribution.ReduceOp - """ + Stores the training history of a model. - if self == Reduction.SUM: - return dist.ReduceOp.SUM - elif self == Reduction.MIN: - return dist.ReduceOp.MIN - elif self == Reduction.MAX: - return dist.ReduceOp.MAX - else: - raise ValueError(f'Reduction {self} is not supported by torch') + Metrics can either be ArrayLike objects or any pickleable object. + Usage: + history = TrainingHistory() + history.append_metric('loss', 0.5) + history.append_metric('accuracy', 0.99) + history.next_step() -def reduce_tensor(tensor: torch.Tensor, reduction: Reduction, dim=None): + for metric in history: + print(f'{metric}': history[metric]) """ - Reduces tensor along dim with the given reduction. - """ - - if not isinstance(tensor, torch.Tensor): - raise ValueError('tensor must be a torch.Tensor') - # required because dim=None is not supported by torch - if dim is None: - dim = list(range(tensor.dim())) + max_return_type = namedtuple('Max', ['value', 'step']) + min_return_type = namedtuple('Min', ['value', 'step']) - if reduction is Reduction.MEAN: - return tensor.mean(dim) - elif reduction is Reduction.SUM: - return tensor.sum(dim) - elif reduction is Reduction.MIN: - return tensor.amin(dim) - elif reduction is Reduction.MAX: - return tensor.amax(dim) - else: - raise ValueError(f'Unknown reduction {reduction}') - - -class MetricReducer: - """ - Stores a list of tensors and reduces them at the end of an epoch. - The dim argument specifies the dimensions to reduce over. If None, every dimension is completely reduced. - Notice that the list of individual tensors stored in this obcect, is ALWAYS reduced, both locally and distributed. - Hence, dimension 0 refers to the first dimension of individual tensors, which is usually the batch dimension. - """ - - def __init__(self, reduction=Reduction.MEAN, dim=None, globally=True): - if reduction not in [Reduction.MEAN, Reduction.SUM, Reduction.MIN, Reduction.MAX]: - raise ValueError(f'Unknown reduction {self.reduction}') - - self.values = [] - self.reduction = reduction - self.globally = globally - if isinstance(dim, int): - self.dim = [dim] - elif dim is not None: - self.dim = list(dim) - else: - self.dim = None - - def append(self, value): - """ - Appends a value to the list of values. - If the value is a tensor, it is detached and moved to the cpu to avoid growing memory consumption. - """ - value = torch.as_tensor(value) - value = value.detach().cpu() - self.values.append(value) - - def extend(self, values): - for value in values: - self.append(value) + def __init__(self): + self.num_steps = 0 + self._current_values = {} + self._metrics = {} + self._dtypes = {} - def __iadd__(self, value): - self.append(value) - return self + def __getitem__(self, name: str): + if name not in self._metrics: + raise KeyError(f'Metric {name} does not exist') - def __setitem__(self, idx, value): - value = torch.as_tensor(value) - value = value.detach().cpu() - self.values[idx] = value + return np.stack(self._metrics[name], axis=0, dtype=self._dtypes[name]) - def __getitem__(self, idx): - return self.values[idx] + def __delattr__(self, name): + del self._metrics[name] - def __delitem__(self, idx): - del self.values[idx] + def __contains__(self, name: str): + return name in self._metrics def __len__(self): - return len(self.values) + return len(self._metrics) def __iter__(self): - return iter(self.values) + return iter(self._metrics) - def clear(self): - self.values.clear() - - def reduce_and_append(self, value): - value = reduce_tensor(value, self.reduction, dim=self.dim) - self.values.append(value) - - def reduce_locally(self): - if len(self.values) == 0: - return None - - if isinstance(self.dim, list): - dim = [0] + [d + 1 for d in self.dim] - elif isinstance(self.dim, int): - dim = [0, self.dim + 1] - else: - dim = None - tensor = torch.stack(self.values) - tensor = reduce_tensor(tensor, reduction=self.reduction, dim=dim) - return tensor - - def reduce_globally(self, group=None): - # if the list of values is empty, the result is None - if self.globally: - empty_workers = [None] * dist.get_world_size(group) - dist.all_gather_object(empty_workers, len(self.values) == 0, group=group) - if any(empty_workers): - if len(empty_workers) > 1 and not all(empty_workers): - raise ValueError('Some workers tracked values this epoch and some did not. This is likely a bug.') - else: - return None - elif len(self.values) == 0: - return None - - tensor = self.reduce_locally() - if self.globally: - if self.reduction == Reduction.MEAN: - dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) - tensor /= dist.get_world_size(group) - else: - dist.all_reduce(tensor, op=self.reduction.as_torch(), group=group) - return tensor + def keys(self): + return self._metrics.keys() - def state_dict(self): - return { - 'reduction': self.reduction, - 'dim': self.dim, - 'globally': self.globally, - 'values': self.values, - } + def values(self): + return [self[name] for name in self._metrics] - def load_state_dict(self, state): - self.reduction = state['reduction'] - self.dim = state['dim'] - self.globally = state['globally'] - self.values = state['values'] + def items(self): + return [(name, self[name]) for name in self._metrics] + def append_metric(self, name: str, value: Union[ArrayLike, Any]): + """ + Adds a value for a metric at the current step. -class MetricTracker: - """ - This class keeps track of multiple metrics and their history. + Args: + name (str): The name of the metric. + value (ArrayLike, Any): The value of the metric. Must be a ArrayLike or pickleable object. + """ + if name in self._current_values: + raise ValueError(f'Metric {name} already has a value for step {self.num_steps}') - Usage: - tracker = MetricTracker() - tracker.register_metric('loss', reduction=Reduction.MEAN) - tracker.track('loss', torch.randn(10, 1)) - tracker.next_epoch() + if name not in self._metrics and self.num_steps > 0: + raise ValueError(f'Cannot add metric {name} after the first step') - print(tracker['loss'].last()) - """ + if isinstance(value, torch.Tensor): + value = value.detach().to('cpu', non_blocking=True) - def __init__(self): - self.histories = {} - self.reducers = {} - self.epoch = 1 + self._current_values[name] = value - def __getitem__(self, name): + def append_metrics(self, **metrics): """ - Returns the history of a metric up to the current epoch. - Values for the current epoch that have been reduced already are not included. + Adds multiple metrics at the current step. + + Args: + **metrics: The metrics to add. """ - if name not in self: - raise ValueError(f'Metric {name} does not exist') - return list(self.histories[name])[: self.epoch - 1] + for name, value in metrics.items(): + self.append_metric(name, value) - def __contains__(self, name): - return name in self.histories + def next_step(self): + """ + Advances the step counter. + """ - def __len__(self): - return len(self.histories) + for name in self._metrics: + if name not in self._current_values: + raise ValueError(f'Metric {name} does not have a value for step {self.num_steps}') - def __iter__(self): - return iter(self.histories) + for name, value in self._current_values.items(): + if type(value) == ArrayLike: + value = np.as_array(value) - def current_value(self, name): - """ - If the metric already has an reduced value for the current epoch, it is returned. Otherwise, None is returned. + if name not in self._metrics: + self._metrics[name] = [value] + self._dtypes[name] = value.dtype if type(value) == ArrayLike else object + else: + self._metrics[name].append(value) + + self._current_values = {} + self.num_steps += 1 + + def last(self) -> dict[str, Any]: """ - if name not in self: - raise ValueError(f'Metric {name} does not exist') - if self.has_value(name): - return self.histories[name][-1] - else: - return None - - def is_reduced_metric(self, name): + Returns the last value for each metric. + + Returns: + dict[str, Any]: The last value for each metric. """ - Returns True if the metric gets (all)reduced at the end of each epoch. + + return {name: values[-1] for name, values in self._metrics.items()} + + def current(self) -> dict[str, Any]: """ - if name not in self: - raise ValueError(f'Metric {name} does not exist') - return name in self.reducers + Returns the current, but not yet saved, value for each metric. - def has_value(self, name): + Returns: + dict[str, Any]: The current value for each metric. """ - Returns True if the metric has a final value for the current epoch. + + return {name: self._current_values[name] for name in self._current_values} + + def min(self) -> dict[str, min_return_type]: """ - if name not in self: - raise ValueError(f'Metric {name} does not exist') - return len(self.histories[name]) >= self.epoch + Returns a namedtuple (value, step) containing the minimum value and the corresponding step for each metric across all steps. - def register_metric(self, name, reduction=None, dim=None, globally=True): - if name in self: - raise ValueError(f'Metric {name} already exists') + Returns: + dict[str, namedtuple]: The minimum value and the corresponding step for each metric. + """ + argmin = {name: np.argmin(values, axis=0) for name, values in self._metrics.items()} + return {name: self.min_return_type(self._metrics[name][idx], idx) for name, idx in argmin.items()} - if dim is not None and reduction is None: - raise ValueError('If dim is specified, reduction must be specified as well') + def max(self) -> dict[str, max_return_type]: + """ + Returns a namedtuple (value, step) containing the maximum value and the corresponding step for each metric across all steps. - self.histories[name] = [] + [None] * (self.epoch - 1) - if reduction is not None: - self.reducers[name] = MetricReducer(reduction=reduction, dim=dim, globally=globally) + Returns: + dict[str, namedtuple]: The maximum value and the corresponding step for each metric. + """ + argmax = {name: np.argmax(values, axis=0) for name, values in self._metrics.items()} + return {name: self.max_return_type(self._metrics[name][idx], idx) for name, idx in argmax.items()} - def track(self, name, value): - if isinstance(value, torch.Tensor): - value = value.detach().to('cpu', non_blocking=True) - if name not in self: - raise ValueError(f'Metric {name} does not exist') +class Tracker(torch.nn.Module): + """ + Keeps track of multiple metrics and reduces them at the end of each epoch. + """ - if self.has_value(name): - raise ValueError(f'History for {name} already has a value for epoch {self.epoch}') + def __init__(self): + super().__init__() - history = self.histories[name] - reducer = self.reducers.get(name) - if reducer is not None: - reducer.append(value) - else: - history.append(value) + self.metrics = torch.nn.ModuleDict() + self.external_metrics = torch.nn.ModuleDict() - def reduce_all(self, prefix=None, strict=True): - """ - Reduces all metrics and appends their reduced values to the history. - If prefix is specified, only metrics with the specified prefix are reduced. - If strict is True, an error is raised if a metric has already been reduced for the current epoch. + def add_metric(self, name: str, metric: torchmetrics.Metric): + if name in self.external_metrics or name in self.metrics: + raise ValueError(f'Metric {name} already exists') - After this method has been called, no more values for the reduced metrics can be tracked for the current epoch, - and next_epoch() must be called to be able to track new values. - """ - for name, history in self.histories.items(): - if prefix is not None and not name.startswith(prefix): - continue - - if self.has_value(name): - if strict: - raise ValueError(f'History for {name} has already been reduced for epoch {self.epoch}') - else: - continue - - reducer = self.reducers.get(name) - if reducer is not None: - history.append(reducer.reduce_globally()) - reducer.clear() - else: - history.append(None) + self.external_metrics[name] = metric + + def log(self, name: str, value: Any, reduction: str = 'mean'): + if reduction not in ['mean', 'sum', 'min', 'max']: + raise ValueError(f'Invalid reduction {reduction}. Must be one of mean, sum, min, max') + + if name in self.external_metrics: + raise ValueError(f'Metric {name} is a external metric. Please use the .update() method yourself.') + + if name not in self.metrics: + if reduction == 'mean': + metric = torchmetrics.MeanMetric() + elif reduction == 'sum': + metric = torchmetrics.SumMetric() + elif reduction == 'min': + metric = torchmetrics.MinMetric() + elif reduction == 'max': + metric = torchmetrics.MaxMetric() + self.metrics[name] = metric.to(value.device) + + self.metrics[name].update(value) + + def reduce(self): + values = {} + for name, metric in self.metrics.items(): + values[name] = metric.compute() + metric.reset() + for name, metric in self.external_metrics.items(): + values[name] = metric.compute() + metric.reset() + return values - def next_epoch(self): - """ - Reduces all metrics (if not already reduced) and advances the epoch counter. - """ - self.reduce_all(strict=False) - self.epoch += 1 - - def state_dict(self): - state = { - 'epoch': self.epoch, - 'histories': dict(self.histories), - 'reducers': {name: reducer.state_dict() for name, reducer in self.reducers.items()}, - } - return state - - def load_state_dict(self, state): - self.epoch = state['epoch'] - self.histories = state['histories'] - self.reducers = {} - for name, reducer_state in state['reducers'].items(): - self.reducers[name] = MetricReducer() - self.reducers[name].load_state_dict(reducer_state) - - def __str__(self): - s = 'MetricTracker(' - for name, history in self.histories.items(): - s += f'\n {name}: {history}' - if len(self.histories) > 0: - s += '\n)' - else: - s += ')' - return s + def clear(self): + for metric in self.metrics.values(): + metric.reset() + for metric in self.external_metrics.values(): + metric.reset() + self.metrics.clear() + self.external_metrics.clear() diff --git a/dmlcloud/core/pipeline.py b/dmlcloud/core/pipeline.py index c09ec05..023c3c8 100644 --- a/dmlcloud/core/pipeline.py +++ b/dmlcloud/core/pipeline.py @@ -15,7 +15,6 @@ from . import logging as dml_logging from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path from .distributed import all_gather_object, broadcast_object, init, is_root, local_rank, root_only -from .metrics import MetricTracker, Reduction from .stage import Stage @@ -39,7 +38,6 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option self.gloo_group = None self.io_redirector = None self.resumed = None - self.tracker = MetricTracker() self.start_time = None self.stop_time = None self.current_stage = None @@ -139,31 +137,6 @@ def initializer(): self._wandb_initalizer = initializer self.wandb = True - def track_reduce( - self, - name: str, - value: torch.Tensor, - step: Optional[int] = None, - reduction: Reduction = Reduction.MEAN, - dim: Optional[List[int]] = None, - reduce_globally: bool = True, - ): - if name not in self.tracker: - self.tracker.register_metric(name, reduction, dim, reduce_globally) - - self.tracker.track(name, value) - - def track( - self, - name: str, - value: Any, - step: Optional[int] = None, - ): - if name not in self.tracker: - self.tracker.register_metric(name) - - self.tracker.track(name, value) - def barrier(self, timeout=None): if self.gloo_group is None: dist.barrier() @@ -266,14 +239,6 @@ def _post_run(self): dml_logging.info(f'Outputs have been saved to {self.checkpoint_dir}') self.post_run() - def _pre_epoch(self): - pass - - def _post_epoch(self): - if self.wandb and is_root(): - metrics = {name: self.tracker[name][-1] for name in self.tracker} - wandb.log(metrics) - def _cleanup(self, exc_type, exc_value, traceback): """ Called by _RunGuard to ensure that the pipeline is properly cleaned up diff --git a/dmlcloud/core/stage.py b/dmlcloud/core/stage.py index 71d7b08..f9efb91 100644 --- a/dmlcloud/core/stage.py +++ b/dmlcloud/core/stage.py @@ -1,19 +1,16 @@ import sys -import time from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional -import torch from progress_table import ProgressTable from ..util.logging import DevNullIO from . import logging as dml_logging from .distributed import is_root -from .metrics import MetricTracker, Reduction +from .metrics import Tracker, TrainingHistory __all__ = [ 'Stage', - 'TrainValStage', ] @@ -31,12 +28,13 @@ def __init__(self): self.max_epochs = None # set by the pipeline self.name = None # set by the pipeline + self.history = TrainingHistory() + self.tracker = Tracker() + self.start_time = None self.stop_time = None self.epoch_start_time = None self.epoch_stop_time = None - self.current_epoch = 1 - self._stop_requested = False self.metric_prefix = None self.barrier_timeout = None @@ -44,10 +42,6 @@ def __init__(self): self.table = None self.columns = {} - @property - def tracker(self) -> MetricTracker: - return self.pipeline.tracker - @property def device(self): return self.pipeline.device @@ -56,24 +50,19 @@ def device(self): def config(self): return self.pipeline.config - def track_reduce( - self, - name: str, - value: torch.Tensor, - step: Optional[int] = None, - reduction: Reduction = Reduction.MEAN, - dim: Optional[List[int]] = None, - reduce_globally: bool = True, - prefixed: bool = True, - ): - if prefixed and self.metric_prefix: - name = f'{self.metric_prefix}/{name}' - self.pipeline.track_reduce(name, value, step, reduction, dim, reduce_globally) + @property + def current_epoch(self): + return self.history.num_steps - def track(self, name: str, value, step: Optional[int] = None, prefixed: bool = True): - if prefixed and self.metric_prefix: + def log(self, name: str, value: Any, reduction: str = 'mean', prefixed: bool = True): + if prefixed: name = f'{self.metric_prefix}/{name}' - self.pipeline.track(name, value, step) + self.tracker.log(name, value, reduction) + + def add_metric(self, name, metric): + metric = metric.to(self.device) + self.tracker.add_metric(name, metric) + return metric def add_column( self, @@ -86,9 +75,6 @@ def add_column( self.columns[name] = metric self.table.add_column(name, width=width, color=color, alignment=alignment) - def stop_stage(self): - self._stop_requested = True - def pre_stage(self): """ Executed before the stage starts. @@ -126,12 +112,10 @@ def run(self): Runs this stage. Either until max_epochs are reached, or until stop_stage() is called. """ self._pre_stage() - while self.max_epochs is None or self.current_epoch <= self.max_epochs: + while self.max_epochs is None or self.current_epoch < self.max_epochs: self._pre_epoch() self.run_epoch() self._post_epoch() - if self._stop_requested: - break self._post_stage() def _pre_stage(self): @@ -140,7 +124,7 @@ def _pre_stage(self): dml_logging.info(f'\n========== STAGE: {self.name} ==========') self.table = ProgressTable(file=sys.stdout if is_root() else DevNullIO()) - self.add_column('Epoch', 'misc/epoch', color='bright', width=5) + self.add_column('Epoch', None, color='bright', width=5) self.add_column('Took', None, width=7) self.add_column('ETA', None, width=7) @@ -162,27 +146,21 @@ def _pre_epoch(self): self.epoch_start_time = datetime.now() self.table['Epoch'] = self.current_epoch self.pre_epoch() - self.pipeline._pre_epoch() def _post_epoch(self): self.epoch_stop_time = datetime.now() self._reduce_metrics() self.post_epoch() - self.pipeline._post_epoch() self._update_table() - self.current_epoch += 1 def _reduce_metrics(self): - self.track(name='misc/epoch', value=self.current_epoch, prefixed=False) - self.track( - name='misc/epoch_time', value=(self.epoch_stop_time - self.epoch_start_time).total_seconds(), prefixed=False - ) - self.tracker.next_epoch() - pass + # self.log('misc/epoch', self.current_epoch, prefixed=False) + # self.log('misc/epoch_time', (self.epoch_stop_time - self.epoch_start_time).total_seconds()) + metrics = self.tracker.reduce() + self.history.append_metrics(**metrics) + self.history.next_step() def _update_table(self): - self.table.update('Epoch', self.current_epoch) - time = datetime.now() - self.epoch_start_time self.table.update('Took', str(time - timedelta(microseconds=time.microseconds))) @@ -190,129 +168,9 @@ def _update_table(self): eta = per_epoch * (self.max_epochs - self.current_epoch) self.table.update('ETA', str(eta - timedelta(microseconds=eta.microseconds))) + last_metrics = self.history.last() for name, metric in self.columns.items(): if metric is not None: - self.table.update(name, self.tracker[metric][-1]) + self.table.update(name, last_metrics[metric]) self.table.next_row() - - -class TrainValStage(Stage): - def __init__(self): - super().__init__() - self.is_train = True - - def train_dataset(self): - train_ds = self.pipeline.datasets.get('train') - if train_ds is None: - raise ValueError( - 'No "train" dataset found in pipeline. Use register_dataset("train", ...) to register a dataset.' - ) - return train_ds - - def val_dataset(self): - val_ds = self.pipeline.datasets.get('val') - if val_ds is None: - raise ValueError( - 'No "val" dataset found in pipeline. Use register_dataset("val", ...) to register a dataset.' - ) - return val_ds - - def optimizers(self): - return self.pipeline.optimizers.values() - - def loss_metric_name(self): - return 'loss' - - def train_metric_prefix(self): - return 'train' - - def val_metric_prefix(self): - return 'val' - - def gradient_clip(self): - return 0.0 - - def run_epoch(self): - self.train_epoch() - self.val_epoch() - - def step(self, batch) -> torch.Tensor: - raise NotImplementedError() - - def train_step(self, batch): - return self.step(batch) - - def val_step(self, batch): - return self.step(batch) - - def zero_grad(self): - for optimizer in self.optimizers(): - optimizer.zero_grad() - - def clip_gradients(self): - for optimizer in self.optimizers(): - for group in optimizer.param_groups: - torch.nn.utils.clip_grad_norm_(group['params'], self.gradient_clip()) - - def optimize(self, loss): - loss.backward() - - if self.gradient_clip(): - self.clip_gradients() - - for optimizer in self.optimizers(): - optimizer.step() - - def train_epoch(self): - self.is_train = True - self.metric_prefix = self.train_metric_prefix() - - train_ds = self.train_dataset() - if hasattr(train_ds, 'sampler') and hasattr(train_ds.sampler, 'set_epoch'): - train_ds.sampler.set_epoch(self.current_epoch) - - for batch in train_ds: - step_start_time = time.perf_counter_ns() - self.zero_grad() - loss = self.train_step(batch) - self.optimize(loss) - step_end_time = time.perf_counter_ns() - - self.track_reduce(self.loss_metric_name(), loss) - self.track_reduce('misc/total_train_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False) - self.track_reduce( - 'misc/worker_train_batches', - torch.tensor(1), - reduction=Reduction.SUM, - reduce_globally=False, - prefixed=False, - ) - self.track_reduce('misc/step_time_ms', torch.tensor(step_end_time - step_start_time) / 1e6, prefixed=False) - - for name, scheduler in self.pipeline.schedulers.items(): - self.track(f'misc/lr_{name}', scheduler.get_last_lr()[0], prefixed=False) - scheduler.step() - - @torch.no_grad() - def val_epoch(self): - self.is_train = False - self.metric_prefix = self.val_metric_prefix() - - for batch in self.val_dataset(): - loss = self.val_step(batch) - self.track_reduce(self.loss_metric_name(), loss) - self.track_reduce('misc/total_val_batches', torch.tensor(1), reduction=Reduction.SUM, prefixed=False) - self.track_reduce( - 'misc/worker_val_batches', - torch.tensor(1), - reduction=Reduction.SUM, - reduce_globally=False, - prefixed=False, - ) - - def table_columns(self): - columns = super().table_columns() - columns.insert(1, {'name': '[Train] Loss', 'metric': f'{self.train_metric_prefix()}/{self.loss_metric_name()}'}) - columns.insert(2, {'name': '[Val] Loss', 'metric': f'{self.val_metric_prefix()}/{self.loss_metric_name()}'}) - return columns diff --git a/examples/barebone_mnist.py b/examples/barebone_mnist.py index 942f335..949fadb 100644 --- a/examples/barebone_mnist.py +++ b/examples/barebone_mnist.py @@ -4,6 +4,7 @@ import dmlcloud as dml import torch +import torchmetrics from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms @@ -56,6 +57,9 @@ def pre_stage(self): self.add_column('[Val] Loss', 'val/loss', color='blue') self.add_column('[Val] Acc.', 'val/accuracy', color='blue') + self.train_acc = self.add_metric('train/accuracy', torchmetrics.Accuracy('multiclass', num_classes=10)) + self.val_acc = self.add_metric('val/accuracy', torchmetrics.Accuracy('multiclass', num_classes=10)) + # The run_epoch method is called once per epoch def run_epoch(self): self._train_epoch() @@ -75,7 +79,9 @@ def _train_epoch(self): loss.backward() self.optimizer.step() - self._log_metrics(img, target, output, loss) + self.log('loss', loss) + # self.log('accuracy', (output.argmax(1) == target).float().mean()) + self.train_acc(output, target) @torch.no_grad() def _val_epoch(self): @@ -88,11 +94,9 @@ def _val_epoch(self): output = self.model(img) loss = self.loss(output, target) - self._log_metrics(img, target, output, loss) - - def _log_metrics(self, img, target, output, loss): - self.track_reduce('loss', loss) - self.track_reduce('accuracy', (output.argmax(1) == target).float().mean()) + self.log('loss', loss) + # self.log('accuracy', (output.argmax(1) == target).float().mean()) + self.val_acc(output, target) def main(): diff --git a/examples/mnist.py b/examples/mnist.py deleted file mode 100644 index 11b93f6..0000000 --- a/examples/mnist.py +++ /dev/null @@ -1,67 +0,0 @@ -import sys - -sys.path.insert(0, './') - -import dmlcloud as dml -import torch -from torch import nn -from torch.utils.data import DataLoader -from torchvision import datasets, transforms - - -class MNISTStage(dml.TrainValStage): - def pre_stage(self): - transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) - - with dml.root_first(): - train_dataset = datasets.MNIST(root='data', train=True, download=dml.is_root(), transform=transform) - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - self.pipeline.register_dataset('train', DataLoader(train_dataset, batch_size=32, sampler=train_sampler)) - - val_dataset = datasets.MNIST(root='data', train=False, download=dml.is_root(), transform=transform) - val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) - self.pipeline.register_dataset('val', DataLoader(val_dataset, batch_size=32, sampler=val_sampler)) - - model = nn.Sequential( - nn.Conv2d(1, 16, 3, padding=1), - nn.ReLU(), - nn.MaxPool2d(2), - nn.Conv2d(16, 16, 3, padding=1), - nn.ReLU(), - nn.MaxPool2d(2), - nn.Flatten(), - nn.Linear(784, 10), - ) - self.pipeline.register_model('cnn', model) - - self.pipeline.register_optimizer('adam', torch.optim.Adam(model.parameters(), lr=1e-3)) - - self.loss = nn.CrossEntropyLoss() - - def step(self, batch) -> torch.Tensor: - img, target = batch - img, target = img.to(self.device), target.to(self.device) - - output = self.pipeline.models['cnn'](img) - loss = self.loss(output, target) - - self.track_reduce('accuracy', (output.argmax(1) == target).float().mean()) - return loss - - def table_columns(self): - columns = super().table_columns() - columns.insert(-2, {'name': '[Val] Acc.', 'metric': 'val/accuracy'}) - columns.insert(-2, {'name': '[Train] Acc.', 'metric': 'train/accuracy'}) - return columns - - -def main(): - pipeline = dml.TrainingPipeline(name='mnist') - pipeline.enable_checkpointing('checkpoints', resume=False) - pipeline.enable_wandb() - pipeline.append_stage(MNISTStage(), max_epochs=3) - pipeline.run() - - -if __name__ == '__main__': - main() diff --git a/requirements.txt b/requirements.txt index ff58da4..d7ee9fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ numpy xarray progress_table>=2.2.0 omegaconf +torchmetrics