Skip to content

Commit

Permalink
Add OOM observer with memory visualizations (mosaicml#2958)
Browse files Browse the repository at this point in the history
* add oomobserver

* update docstring

* Update composer/callbacks/oom_observer.py

Co-authored-by: Mihir Patel <[email protected]>

* use pyskip

* call trainer fit

* fix ci

* Update composer/callbacks/oom_observer.py

Co-authored-by: Charles Tang <[email protected]>

* addresss comments

* Update composer/callbacks/oom_observer.py

Co-authored-by: Mihir Patel <[email protected]>

* add test wiht snapshot

* update doc

* fix typo

* use log info

* fix format

* fix format

* fix ci

* fix cpu test

* fix ci

* Update tests/callbacks/test_oom_observer.py

Co-authored-by: Mihir Patel <[email protected]>

* Update composer/callbacks/oom_observer.py

Co-authored-by: Mihir Patel <[email protected]>

* Update composer/callbacks/oom_observer.py

Co-authored-by: Mihir Patel <[email protected]>

* Update composer/callbacks/oom_observer.py

Co-authored-by: Mihir Patel <[email protected]>

* update test

* Update tests/callbacks/test_oom_observer.py

Co-authored-by: Mihir Patel <[email protected]>

* Update tests/callbacks/test_oom_observer.py

Co-authored-by: Mihir Patel <[email protected]>

* Update tests/callbacks/test_oom_observer.py

Co-authored-by: Mihir Patel <[email protected]>

* Update composer/callbacks/oom_observer.py

Co-authored-by: Mihir Patel <[email protected]>

* Update composer/callbacks/oom_observer.py

Co-authored-by: Mihir Patel <[email protected]>

* use warnings

* add pytest filter user warnings in cpu callback tests

* fix typo

---------

Co-authored-by: Mihir Patel <[email protected]>
Co-authored-by: Charles Tang <[email protected]>
  • Loading branch information
3 people authored Feb 2, 2024
1 parent e4ee99e commit 21bc3db
Show file tree
Hide file tree
Showing 10 changed files with 287 additions and 5 deletions.
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from composer.callbacks.memory_snapshot import MemorySnapshot
from composer.callbacks.mlperf import MLPerfCallback
from composer.callbacks.nan_monitor import NaNMonitor
from composer.callbacks.oom_observer import OOMObserver
from composer.callbacks.optimizer_monitor import OptimizerMonitor
from composer.callbacks.runtime_estimator import RuntimeEstimator
from composer.callbacks.speed_monitor import SpeedMonitor
Expand All @@ -42,4 +43,5 @@
'Generate',
'FreeOutputs',
'MemorySnapshot',
'OOMObserver',
]
178 changes: 178 additions & 0 deletions composer/callbacks/oom_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Generate a memory snapshot during an OutOfMemory exception."""

import logging
import os
import pickle
import warnings
from typing import Optional

import torch.cuda
from packaging import version

from composer import State
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri

log = logging.getLogger(__name__)

__all__ = ['OOMObserver']


class OOMObserver(Callback):
"""Generate visualizations of the state of allocated memory during an OutOfMemory exception.
This callback registers an observer with the allocator that will be called everytime it is about to raise an OutOfMemoryError before any memory has been release while unwinding the exception. OOMObserver is attached to the Trainer at init stage. The visualizations include a snapshot of the memory state, a trace plot, a segment plot, a segment flamegraph, and a memory flamegraph.
Example:
.. doctest::
>>> from composer import Trainer
>>> from composer.callbacks import OOMObserver
>>> # constructing trainer object with this callback
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... optimizers=optimizer,
... max_duration="1ep",
... callbacks=[OOMObserver()],
... )
.. note::
OOMObserver is only supported for GPU devices.
Args:
max_entries (int, optional): Maximum number of memory alloc/free events to record. Defaults to 100000.
folder (str, optional): A format string describing the folder containing the memory visualization files.
Defaults to ``'{{run_name}}/torch_traces'``.
filename (str, optional): A format string describing the prefix used to name the memory visualization files.
Defaults to ``'rank{{rank}}_oom'``.
remote_file_name (str, optional): A format string describing the prefix for the memory visualization remote file name.
Defaults to ``'{{run_name}}/oom_traces/rank{{rank}}_oom'``.
Whenever a trace file is saved, it is also uploaded as a file according to this format string.
The same format variables as for ``filename`` are available.
.. seealso:: :doc:`Uploading Files</trainer/file_uploading>` for notes for file uploading.
Leading slashes (``'/'``) will be stripped.
To disable uploading trace files, set this parameter to ``None``.
overwrite (bool, optional): Whether to override existing memory snapshots. Defaults to False.
If False, then the trace folder as determined by ``folder`` must be empty.
"""

def __init__(
self,
max_entries: int = 100000,
folder: str = '{run_name}/torch_traces',
filename: str = 'rank{rank}_oom',
remote_file_name: Optional[str] = '{run_name}/oom_traces/rank{rank}_oom',
overwrite: bool = False,
) -> None:
self.max_entries = max_entries
self.folder = folder
self.folder_name = None
self.filename = filename
self.remote_file_name = remote_file_name
self.overwrite = overwrite
if remote_file_name:
self.remote_file_name = remote_file_name
_, _, self.remote_path_in_bucket = parse_uri(remote_file_name)
else:
self.remote_path_in_bucket = None

if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.1.0'): # type: ignore
# OOMObserver is only supported in torch v2.1.0 or higher
self._enabled = True
else:
self._enabled = False
warnings.warn('OOMObserver is supported after PyTorch 2.1.0. Disabling OOMObserver callback.')

def init(self, state: State, logger: Logger) -> None:
if not self._enabled:
return
# Not relying on `torch.cuda.is_available()` since the model could be on CPU.
model_device = next(state.model.parameters()).device

if model_device.type not in ('cuda', 'meta'):
warnings.warn(
f'OOMObserver only works on CUDA devices, but the model is on {model_device.type}. Disabling OOMObserver.'
)
self._enabled = False
else:
self.folder_name = format_name_with_dist(self.folder, state.run_name)
os.makedirs(self.folder_name, exist_ok=True)
if not self.overwrite:
ensure_folder_is_empty(self.folder_name)

def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int):
# Snapshot right after an OOM happened
log.warning('Out Of Memory (OOM) observed')

assert self.filename
assert self.folder_name, 'folder_name must be set in init'
filename = os.path.join(
self.folder_name,
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=state.timestamp))

try:
snapshot_file = filename + '_snapshot.pickle'
trace_plot_file = filename + '_trace_plot.html'
segment_plot_file = filename + '_segment_plot.html'
segment_flamegraph_file = filename + '_segment_flamegraph.svg'
memory_flamegraph_file = filename + '_memory_flamegraph.svg'
log.info(f'Dumping OOMObserver visualizations')

snapshot = torch.cuda.memory._snapshot()
# No data was recorded - avoids a `ValueError` in `trace_plot`
if all(len(t) == 0 for t in snapshot['device_traces']):
log.info(f'No allocation is recorded in memory snapshot)')
return

with open(snapshot_file, 'wb') as fd:
pickle.dump(snapshot, fd)

with open(trace_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.trace_plot(snapshot)) # type: ignore

with open(segment_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segment_plot(snapshot)) # type: ignore

with open(segment_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segments(snapshot)) # type: ignore

with open(memory_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.memory(snapshot)) # type: ignore

log.info(f'Saved memory visualizations to local files with prefix = {filename} during OOM')

if self.remote_path_in_bucket is not None:
for f in [
snapshot_file, trace_plot_file, segment_plot_file, segment_flamegraph_file,
memory_flamegraph_file
]:
remote_file_name = (self.remote_path_in_bucket + os.path.basename(f)).lstrip('/')
log.info(f'Uploading memory visualization to remote: {remote_file_name} from {f}')
try:
logger.upload_file(remote_file_name=remote_file_name, file_path=f, overwrite=self.overwrite)
except FileExistsError as e:
raise FileExistsError(
f'Uploading memory visualizations failed with error: {e}. overwrite was set to {self.overwrite}. To overwrite memory visualizations with Trainer, set save_overwrite to True.'
) from e

except Exception as e:
log.error(f'Failed to capture memory snapshot {e}')

if self._enabled:
torch.cuda.memory._record_memory_history(
True, # type: ignore
trace_alloc_max_entries=self.max_entries,
trace_alloc_record_context=True)
torch._C._cuda_attach_out_of_memory_observer(oom_observer) # type: ignore
log.info('OOMObserver is enabled and registered')
6 changes: 3 additions & 3 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from torch.utils.data import DataLoader, DistributedSampler
from torchmetrics import Metric

from composer.callbacks import CheckpointSaver, MemorySnapshot, OptimizerMonitor
from composer.callbacks import CheckpointSaver, MemorySnapshot, OOMObserver, OptimizerMonitor
from composer.core import (Algorithm, AlgorithmPass, Batch, Callback, DataSpec, Engine, Evaluator, Event, Precision,
State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec, ensure_evaluator,
ensure_time, get_precision_context, validate_eval_automicrobatching)
Expand Down Expand Up @@ -1072,9 +1072,9 @@ def __init__(
loggers.append(remote_ud)
self.state.profiler.bind_to_state(self.state)

# MemorySnapshot
# MemorySnapshot, OOMObserver
for cb in self.state.callbacks:
if isinstance(cb, MemorySnapshot):
if isinstance(cb, MemorySnapshot) or isinstance(cb, OOMObserver):
if cb.remote_file_name:
remote_ud = maybe_create_remote_uploader_downloader_from_uri(uri=cb.remote_file_name,
loggers=loggers)
Expand Down
1 change: 1 addition & 0 deletions docs/source/trainer/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ components of training.
~optimizer_monitor.OptimizerMonitor
~memory_monitor.MemoryMonitor
~memory_snapshot.MemorySnapshot
~oom_observer.OOMObserver
~nan_monitor.NaNMonitor
~image_visualizer.ImageVisualizer
~mlperf.MLPerfCallback
Expand Down
8 changes: 6 additions & 2 deletions tests/callbacks/callback_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import composer.profiler
from composer import Callback
from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, FreeOutputs, Generate, ImageVisualizer,
MemoryMonitor, MemorySnapshot, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor,
ThresholdStopper)
MemoryMonitor, MemorySnapshot, MLPerfCallback, OOMObserver, SpeedMonitor,
SystemMetricsMonitor, ThresholdStopper)
from composer.loggers import (CometMLLogger, ConsoleLogger, LoggerDestination, MLFlowLogger, ProgressBarLogger,
RemoteUploaderDownloader, TensorboardLogger, WandBLogger)
from composer.models.base import ComposerModel
Expand Down Expand Up @@ -132,6 +132,10 @@
pytest.mark.filterwarnings(
r'ignore:The memory snapshot only works on CUDA devices, but the model is on cpu:UserWarning')
],
OOMObserver: [
pytest.mark.filterwarnings(
r'ignore:The oom observer only works on CUDA devices, but the model is on cpu:UserWarning')
],
MLPerfCallback: [pytest.mark.skipif(not _MLPERF_INSTALLED, reason='MLPerf is optional')],
WandBLogger: [
pytest.mark.filterwarnings(r'ignore:unclosed file:ResourceWarning'),
Expand Down
6 changes: 6 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ class TestCallbacks:
def setup_class(cls):
pytest.importorskip('wandb', reason='WandB is optional.')

@pytest.mark.filterwarnings('ignore::UserWarning')
def test_callback_is_constructable(self, cb_cls: Type[Callback]):
cb_kwargs = get_cb_kwargs(cb_cls)
cb = cb_cls(**cb_kwargs)
assert isinstance(cb_cls, type)
assert isinstance(cb, cb_cls)

@pytest.mark.filterwarnings('ignore::UserWarning')
def test_multiple_fit_start_and_end(self, cb_cls: Type[Callback], dummy_state: State):
"""Test that callbacks do not crash when Event.FIT_START and Event.FIT_END is called multiple times."""
cb_kwargs = get_cb_kwargs(cb_cls)
Expand All @@ -69,6 +71,7 @@ def test_multiple_fit_start_and_end(self, cb_cls: Type[Callback], dummy_state: S
engine.run_event(Event.FIT_START)
engine.run_event(Event.FIT_END)

@pytest.mark.filterwarnings('ignore::UserWarning')
def test_idempotent_close(self, cb_cls: Type[Callback], dummy_state: State):
"""Test that callbacks do not crash when .close() and .post_close() are called multiple times."""
cb_kwargs = get_cb_kwargs(cb_cls)
Expand All @@ -85,6 +88,7 @@ def test_idempotent_close(self, cb_cls: Type[Callback], dummy_state: State):
engine.close()
engine.close()

@pytest.mark.filterwarnings('ignore::UserWarning')
def test_multiple_init_and_close(self, cb_cls: Type[Callback], dummy_state: State):
"""Test that callbacks do not crash when INIT/.close()/.post_close() are called multiple times in that order."""
cb_kwargs = get_cb_kwargs(cb_cls)
Expand Down Expand Up @@ -136,13 +140,15 @@ def _get_trainer(self, cb: Callback, device_train_microbatch_size: int):
torch_prof_memory_filename=None),
)

@pytest.mark.filterwarnings('ignore::UserWarning')
def test_trains(self, cb_cls: Type[Callback], device_train_microbatch_size: int, _remote: bool):
del _remote # unused. `_remote` must be passed through to parameterize the test markers.
cb_kwargs = get_cb_kwargs(cb_cls)
cb = cb_cls(**cb_kwargs)
trainer = self._get_trainer(cb, device_train_microbatch_size)
trainer.fit()

@pytest.mark.filterwarnings('ignore::UserWarning')
def test_trains_multiple_calls(self, cb_cls: Type[Callback], device_train_microbatch_size: int, _remote: bool):
"""
Tests that training with multiple fits complete.
Expand Down
1 change: 1 addition & 0 deletions tests/callbacks/test_loggers_across_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

@pytest.mark.parametrize('logger_cls', get_cbs_and_marks(loggers=True))
@pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True))
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_loggers_on_callbacks(logger_cls: Type[LoggerDestination], callback_cls: Type[Callback]):
if logger_cls in [ProgressBarLogger, ConsoleLogger, SlackLogger]:
pytest.skip()
Expand Down
88 changes: 88 additions & 0 deletions tests/callbacks/test_oom_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import pathlib

import pytest
import torch
from packaging import version
from torch.utils.data import DataLoader

from composer import State, Trainer
from composer.callbacks import MemorySnapshot, OOMObserver
from composer.loggers import LoggerDestination
from composer.trainer import Trainer
from tests.common import RandomClassificationDataset, SimpleModel


@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
reason='OOM Observer requires PyTorch 2.1 or higher')
def test_oom_observer_warnings_on_cpu_models():
ob = OOMObserver()
with pytest.warns(UserWarning):
Trainer(
model=SimpleModel(),
callbacks=ob,
train_dataloader=DataLoader(RandomClassificationDataset()),
max_duration='1ba',
device='cpu',
)
assert ob._enabled is False


class FileUploaderTracker(LoggerDestination):

def __init__(self) -> None:
self.uploaded_files = []

def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Path, *, overwrite: bool):
del state, overwrite # unused
self.uploaded_files.append((remote_file_name, file_path))


@pytest.mark.gpu
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
reason='OOM Observer requires PyTorch 2.1 or higher')
def test_oom_observer():
# Construct the callbacks
oom_observer = OOMObserver()
simple_model = SimpleModel()
file_tracker_destination = FileUploaderTracker()

with pytest.raises(torch.cuda.OutOfMemoryError):
trainer = Trainer(
model=simple_model,
loggers=file_tracker_destination,
callbacks=oom_observer,
train_dataloader=DataLoader(RandomClassificationDataset()),
max_duration='2ba',
)

# trigger OOM
torch.empty(1024 * 1024 * 1024 * 1024, device='cuda')

trainer.fit()

assert len(file_tracker_destination.uploaded_files) == 5


@pytest.mark.gpu
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
reason='OOM Observer requires PyTorch 2.1 or higher')
def test_oom_observer_with_memory_snapshot():
# Construct the callbacks
oom_observer = OOMObserver()
memory_snapshot = MemorySnapshot(skip_batches=0, interval='1ba')
simple_model = SimpleModel()
file_tracker_destination = FileUploaderTracker()

trainer = Trainer(
model=simple_model,
loggers=file_tracker_destination,
callbacks=[oom_observer, memory_snapshot],
train_dataloader=DataLoader(RandomClassificationDataset()),
max_duration='2ba',
)

trainer.fit()
assert len(file_tracker_destination.uploaded_files) == 1
1 change: 1 addition & 0 deletions tests/loggers/test_mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_format_data_to_json_serializable():

@pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True))
@world_size(1, 2)
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_logged_data_is_json_serializable(monkeypatch, callback_cls: Type[Callback], world_size):
"""Test that all logged data is json serializable, which is a requirement to use MAPI."""

Expand Down
Loading

0 comments on commit 21bc3db

Please sign in to comment.