Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OOM observer with memory visualizations #2958

Merged
merged 34 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
bb75f65
add oomobserver
cli99 Feb 1, 2024
0054b52
update docstring
cli99 Feb 1, 2024
2e96d9c
Merge branch 'dev' into oom-observer
cli99 Feb 1, 2024
92faf4f
Update composer/callbacks/oom_observer.py
cli99 Feb 1, 2024
a0c6696
use pyskip
cli99 Feb 1, 2024
99395a8
call trainer fit
cli99 Feb 1, 2024
875546a
fix ci
cli99 Feb 1, 2024
66d093c
Update composer/callbacks/oom_observer.py
cli99 Feb 2, 2024
f2f94d3
addresss comments
cli99 Feb 2, 2024
43118ca
Merge branch 'dev' into oom-observer
cli99 Feb 2, 2024
5a74d34
Update composer/callbacks/oom_observer.py
cli99 Feb 2, 2024
ba6c859
add test wiht snapshot
cli99 Feb 2, 2024
637208d
update doc
cli99 Feb 2, 2024
cc23887
fix typo
cli99 Feb 2, 2024
c314a5c
use log info
cli99 Feb 2, 2024
1d0553b
fix format
cli99 Feb 2, 2024
1f2bf43
fix format
cli99 Feb 2, 2024
95104b5
fix ci
cli99 Feb 2, 2024
1faf75b
fix cpu test
cli99 Feb 2, 2024
bddca6c
Merge branch 'dev' into oom-observer
cli99 Feb 2, 2024
9d4e02d
fix ci
cli99 Feb 2, 2024
1e8c98f
Update tests/callbacks/test_oom_observer.py
cli99 Feb 2, 2024
f5d6db7
Update composer/callbacks/oom_observer.py
cli99 Feb 2, 2024
b48a720
Update composer/callbacks/oom_observer.py
cli99 Feb 2, 2024
b860bc0
Update composer/callbacks/oom_observer.py
cli99 Feb 2, 2024
07a8bec
update test
cli99 Feb 2, 2024
f91b854
Update tests/callbacks/test_oom_observer.py
cli99 Feb 2, 2024
7b7f30c
Update tests/callbacks/test_oom_observer.py
cli99 Feb 2, 2024
78bce44
Update tests/callbacks/test_oom_observer.py
cli99 Feb 2, 2024
818f772
Update composer/callbacks/oom_observer.py
cli99 Feb 2, 2024
c0ca7aa
Update composer/callbacks/oom_observer.py
cli99 Feb 2, 2024
5a07ae4
use warnings
cli99 Feb 2, 2024
74c66ce
add pytest filter user warnings in cpu callback tests
cli99 Feb 2, 2024
fe3dd2c
fix typo
cli99 Feb 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from composer.callbacks.memory_snapshot import MemorySnapshot
from composer.callbacks.mlperf import MLPerfCallback
from composer.callbacks.nan_monitor import NaNMonitor
from composer.callbacks.oom_observer import OOMObserver
from composer.callbacks.optimizer_monitor import OptimizerMonitor
from composer.callbacks.runtime_estimator import RuntimeEstimator
from composer.callbacks.speed_monitor import SpeedMonitor
Expand All @@ -42,4 +43,5 @@
'Generate',
'FreeOutputs',
'MemorySnapshot',
'OOMObserver',
]
177 changes: 177 additions & 0 deletions composer/callbacks/oom_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Generate a memory snapshot during an OutOfMemory exception."""

import logging
import os
import pickle
from typing import Optional

import torch.cuda
from packaging import version

from composer import State
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri

log = logging.getLogger(__name__)

__all__ = ['OOMObserver']


class OOMObserver(Callback):
"""Generate visualizations of the state of allocated memory during an OutOfMemory exception.

This callback registers an observer with the allocator that will be called everytime it is about to raise an OutOfMemoryError before any memory has been release while unwinding the exception. OOMObserver is attached to the Trainer at init stage. The visualizations include a snapshot of the memory state, a trace plot, a segment plot, a segment flamegraph, and a memory flamegraph.

Example:
.. doctest::

>>> from composer import Trainer
>>> from composer.callbacks import OOMObserver
>>> # constructing trainer object with this callback
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... optimizers=optimizer,
... max_duration="1ep",
... callbacks=[OOMObserver()],
... )

.. note::
OOMObserver is only supported for GPU devices.

Args:
max_entries (int, optional): Maximum number of memory alloc/free events to record. Defaults to 100000.
folder (str, optional): A format string describing the folder containing the memory visualization files.
Defaults to ``'{{run_name}}/torch_traces'``.
filename (str, optional): A format string describing the prefix used to name the memory visualization files.
cli99 marked this conversation as resolved.
Show resolved Hide resolved
Defaults to ``'rank{{rank}}_oom'``.
remote_file_name (str, optional): A format string describing the prefix for the memory visualization remote file name.
Defaults to ``'{{run_name}}/oom_traces/rank{{rank}}_oom'``.

Whenever a trace file is saved, it is also uploaded as a file according to this format string.
The same format variables as for ``filename`` are available.

.. seealso:: :doc:`Uploading Files</trainer/file_uploading>` for notes for file uploading.

Leading slashes (``'/'``) will be stripped.

To disable uploading trace files, set this parameter to ``None``.
overwrite (bool, optional): Whether to override existing memory snapshots. Defaults to False.

If False, then the trace folder as determined by ``folder`` must be empty.
"""

def __init__(
self,
max_entries: int = 100000,
folder: str = '{run_name}/torch_traces',
filename: str = 'rank{rank}_oom',
remote_file_name: Optional[str] = '{run_name}/oom_traces/rank{rank}_oom',
overwrite: bool = False,
) -> None:
self.max_entries = max_entries
self.folder = folder
self.folder_name = None
self.filename = filename
self.remote_file_name = remote_file_name
self.overwrite = overwrite
if remote_file_name:
self.remote_file_name = remote_file_name
_, _, self.remote_path_in_bucket = parse_uri(remote_file_name)
else:
self.remote_path_in_bucket = None

if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.1.0'): # type: ignore
# OOMObserver is only supported in torch v2.1.0-rc1 or higher
cli99 marked this conversation as resolved.
Show resolved Hide resolved
self._enabled = True
else:
self._enabled = False
log.warning('OOMObserver is supported after PyTorch 2.1.0. Skipping oom observer callback.')
cli99 marked this conversation as resolved.
Show resolved Hide resolved
cli99 marked this conversation as resolved.
Show resolved Hide resolved

def init(self, state: State, logger: Logger) -> None:
if not self._enabled:
return
# Not relying on `torch.cuda.is_available()` since the model could be on CPU.
model_device = next(state.model.parameters()).device
j316chuck marked this conversation as resolved.
Show resolved Hide resolved

if model_device.type not in ('cuda', 'meta'):
log.warning(
cli99 marked this conversation as resolved.
Show resolved Hide resolved
f'OOMObserver only works on CUDA devices, but the model is on {model_device.type}. OOMObserver is disabled'
cli99 marked this conversation as resolved.
Show resolved Hide resolved
)
self._enabled = False
else:
self.folder_name = format_name_with_dist(self.folder, state.run_name)
os.makedirs(self.folder_name, exist_ok=True)
if not self.overwrite:
ensure_folder_is_empty(self.folder_name)

def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int):
# Snapshot right after an OOM happened
log.warn('Out Of Memory (OOM) observed')
cli99 marked this conversation as resolved.
Show resolved Hide resolved

assert self.filename
assert self.folder_name, 'folder_name must be set in init'
filename = os.path.join(
self.folder_name,
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=state.timestamp))

try:
snapshot_file = filename + '_snapshot.pickle'
trace_plot_file = filename + '_trace_plot.html'
segment_plot_file = filename + '_segment_plot.html'
segment_flamegraph_file = filename + '_segment_flamegraph.svg'
memory_flamegraph_file = filename + '_memory_flamegraph.svg'
log.info(f'Dumping OOMObserver visualizations')

snapshot = torch.cuda.memory._snapshot()
# No data was recorded - avoids a `ValueError` in `trace_plot`
if all(len(t) == 0 for t in snapshot['device_traces']):
log.info(f'No allocation is recorded in memory snapshot)')
return

with open(snapshot_file, 'wb') as fd:
pickle.dump(snapshot, fd)

with open(trace_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.trace_plot(snapshot)) # type: ignore

with open(segment_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segment_plot(snapshot)) # type: ignore

with open(segment_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segments(snapshot)) # type: ignore

with open(memory_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.memory(snapshot)) # type: ignore

log.info(f'Saved memory visualizations to local files with prefix = {filename} during OOM')

if self.remote_path_in_bucket is not None:
for f in [
snapshot_file, trace_plot_file, segment_plot_file, segment_flamegraph_file,
memory_flamegraph_file
]:
remote_file_name = (self.remote_path_in_bucket + os.path.basename(f)).lstrip('/')
log.info(f'Uploading memory visualization to remote: {remote_file_name} from {f}')
try:
logger.upload_file(remote_file_name=remote_file_name, file_path=f, overwrite=self.overwrite)
except FileExistsError as e:
raise FileExistsError(
f'Uploading memory visualizations failed with error: {e}. overwrite was set to {self.overwrite}. To overwrite memory visualizations with Trainer, set save_overwrite to True.'
) from e

except Exception as e:
log.error(f'Failed to capture memory snapshot {e}')

if self._enabled:
torch.cuda.memory._record_memory_history(
True, # type: ignore
trace_alloc_max_entries=self.max_entries,
trace_alloc_record_context=True)
torch._C._cuda_attach_out_of_memory_observer(oom_observer) # type: ignore
log.info('OOMObserver is enabled and registered')
6 changes: 3 additions & 3 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from torch.utils.data import DataLoader, DistributedSampler
from torchmetrics import Metric

from composer.callbacks import CheckpointSaver, MemorySnapshot, OptimizerMonitor
from composer.callbacks import CheckpointSaver, MemorySnapshot, OOMObserver, OptimizerMonitor
from composer.core import (Algorithm, AlgorithmPass, Batch, Callback, DataSpec, Engine, Evaluator, Event, Precision,
State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec, ensure_evaluator,
ensure_time, get_precision_context, validate_eval_automicrobatching)
Expand Down Expand Up @@ -1072,9 +1072,9 @@ def __init__(
loggers.append(remote_ud)
self.state.profiler.bind_to_state(self.state)

# MemorySnapshot
# MemorySnapshot, OOMObserver
for cb in self.state.callbacks:
if isinstance(cb, MemorySnapshot):
if isinstance(cb, MemorySnapshot) or isinstance(cb, OOMObserver):
if cb.remote_file_name:
remote_ud = maybe_create_remote_uploader_downloader_from_uri(uri=cb.remote_file_name,
loggers=loggers)
Expand Down
1 change: 1 addition & 0 deletions docs/source/trainer/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ components of training.
~optimizer_monitor.OptimizerMonitor
~memory_monitor.MemoryMonitor
~memory_snapshot.MemorySnapshot
~oom_observer.OOMObserver
~nan_monitor.NaNMonitor
~image_visualizer.ImageVisualizer
~mlperf.MLPerfCallback
Expand Down
8 changes: 6 additions & 2 deletions tests/callbacks/callback_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import composer.profiler
from composer import Callback
from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, FreeOutputs, Generate, ImageVisualizer,
MemoryMonitor, MemorySnapshot, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor,
ThresholdStopper)
MemoryMonitor, MemorySnapshot, MLPerfCallback, OOMObserver, SpeedMonitor,
SystemMetricsMonitor, ThresholdStopper)
from composer.loggers import (CometMLLogger, ConsoleLogger, LoggerDestination, MLFlowLogger, ProgressBarLogger,
RemoteUploaderDownloader, TensorboardLogger, WandBLogger)
from composer.models.base import ComposerModel
Expand Down Expand Up @@ -132,6 +132,10 @@
pytest.mark.filterwarnings(
r'ignore:The memory snapshot only works on CUDA devices, but the model is on cpu:UserWarning')
],
OOMObserver: [
pytest.mark.filterwarnings(
r'ignore:The oom observer only works on CUDA devices, but the model is on cpu:UserWarning')
],
MLPerfCallback: [pytest.mark.skipif(not _MLPERF_INSTALLED, reason='MLPerf is optional')],
WandBLogger: [
pytest.mark.filterwarnings(r'ignore:unclosed file:ResourceWarning'),
Expand Down
97 changes: 97 additions & 0 deletions tests/callbacks/test_oom_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import pathlib

import pytest
import torch
from packaging import version
from torch.utils.data import DataLoader

from composer import State, Trainer
from composer.callbacks import MemorySnapshot, OOMObserver
from composer.loggers import LoggerDestination
from composer.trainer import Trainer
from tests.common import RandomClassificationDataset, SimpleModel, device


@device('cpu', 'gpu')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
reason='OOM Observer requires PyTorch 2.1 or higher')
def test_oom_observer_warnings_on_cpu_models(device: str):
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved

# Error if the user sets device=cpu even when cuda is available
del device # unused. always using cpu
ob = OOMObserver()
Trainer(
model=SimpleModel(),
callbacks=ob,
device='cpu',
train_dataloader=DataLoader(RandomClassificationDataset()),
max_duration='1ba',
)
assert ob._enabled is False
cli99 marked this conversation as resolved.
Show resolved Hide resolved


class FileUploaderTracker(LoggerDestination):

def __init__(self) -> None:
self.uploaded_files = []

def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Path, *, overwrite: bool):
del state, overwrite # unused
self.uploaded_files.append((remote_file_name, file_path))


@pytest.mark.gpu
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
reason='OOM Observer requires PyTorch 2.1 or higher')
def test_oom_observer():

cli99 marked this conversation as resolved.
Show resolved Hide resolved
# Construct the callbacks
oom_observer = OOMObserver()

simple_model = SimpleModel()

file_tracker_destination = FileUploaderTracker()
cli99 marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(torch.cuda.OutOfMemoryError):
trainer = Trainer(
model=simple_model,
loggers=file_tracker_destination,
callbacks=oom_observer,
train_dataloader=DataLoader(RandomClassificationDataset()),
max_duration='2ba',
)

# trigger OOM
torch.empty(1024 * 1024 * 1024 * 1024, device='cuda')

trainer.fit()

assert len(file_tracker_destination.uploaded_files) == 5


@pytest.mark.gpu
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
reason='OOM Observer requires PyTorch 2.1 or higher')
def test_oom_observer_with_memory_snapshot():

# Construct the callbacks
oom_observer = OOMObserver()
memory_snapshot = MemorySnapshot(skip_batches=0, interval='1ba')

simple_model = SimpleModel()

file_tracker_destination = FileUploaderTracker()

trainer = Trainer(
cli99 marked this conversation as resolved.
Show resolved Hide resolved
model=simple_model,
loggers=file_tracker_destination,
callbacks=[oom_observer, memory_snapshot],
train_dataloader=DataLoader(RandomClassificationDataset()),
max_duration='2ba',
)

trainer.fit()
assert len(file_tracker_destination.uploaded_files) == 1
Loading