Skip to content

Commit

Permalink
Merge branch 'dev' into anna/watchdog
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl authored Feb 27, 2024
2 parents 21ca379 + 0814e01 commit 7e43f68
Show file tree
Hide file tree
Showing 30 changed files with 884 additions and 431 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.282
rev: v0.2.2
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
2 changes: 1 addition & 1 deletion composer/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The Composer Version."""

__version__ = '0.19.1'
__version__ = '0.20.0'
52 changes: 51 additions & 1 deletion composer/callbacks/memory_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import warnings
from typing import Dict, Optional, Union

import torch
import torch.cuda
from torch import distributed

from composer.core import Callback, State
from composer.loggers import Logger
Expand All @@ -17,6 +19,37 @@
__all__ = ['MemoryMonitor']


def reduce_value(
value: Union[int, float],
model_device: torch.device,
reduce_op: str = 'mean',
):
"""Reduce a value across distributed processes.
Args:
value (Union[int, float]): The value to reduce.
model_device (torch.device): The device on which the model is located.
reduce_op (str, optional): The reduction operation to perform. One of 'mean', 'avg', 'sum', 'min', 'max'.
Defaults to 'mean'.
"""
tensor_value = torch.tensor(value, device=model_device)

if reduce_op in ['mean', 'avg', 'sum']:
op = distributed.ReduceOp.SUM
elif reduce_op == 'min':
op = distributed.ReduceOp.MIN
elif reduce_op == 'max':
op = distributed.ReduceOp.MAX
else:
raise ValueError(f'{reduce_op=} not supported.')

distributed.all_reduce(tensor_value, op=op)
if reduce_op in ['mean', 'avg']:
tensor_value = tensor_value / distributed.get_world_size()

return tensor_value.item()


class MemoryMonitor(Callback):
"""Logs the memory usage of the model.
Expand Down Expand Up @@ -73,6 +106,9 @@ class MemoryMonitor(Callback):
| alloc_retries | Number of failed cudaMalloc calls that result in a cache flush and retry. |
+------------------------+-------------------------------------------------------------------------------------------+
Additionally, if `dist_aggregate_batch_interval` is enabled, the `avg`, `min`, and `max` of the
aformentioned statistics are also logged.
.. note::
Memory usage monitoring is only supported for GPU devices.
Expand All @@ -81,10 +117,17 @@ class MemoryMonitor(Callback):
are the names of memory statistics to log from `torch.cuda.memory_stats()`, and values
are the names they will be logged under. If not provided, the above statistics are
logged. Defaults to None.
dist_aggregate_batch_interval (int, optional): interval for aggregating memory stats across
all nodes. Defaults to None (by default the functionality is disabled).
"""

def __init__(self, memory_keys: Optional[Dict[str, str]] = None) -> None:
def __init__(
self,
memory_keys: Optional[Dict[str, str]] = None,
dist_aggregate_batch_interval: Optional[int] = None,
) -> None:
self.memory_keys = memory_keys
self.dist_aggregate_batch_interval = dist_aggregate_batch_interval

def init(self, state: State, logger: Logger) -> None:
# Not relying on `torch.cuda.is_available()` since the model could be on CPU.
Expand All @@ -101,6 +144,13 @@ def after_train_batch(self, state: State, logger: Logger):
return

memory_report = _get_memory_report(self.memory_keys)
if self.dist_aggregate_batch_interval is not None and state.timestamp.batch.value % self.dist_aggregate_batch_interval == 0:
dist_memory_report = {}
for (mem_stat, val) in memory_report.items():
dist_memory_report[mem_stat + '_avg'] = reduce_value(val, model_device, 'avg')
dist_memory_report[mem_stat + '_min'] = reduce_value(val, model_device, 'min')
dist_memory_report[mem_stat + '_max'] = reduce_value(val, model_device, 'max')
memory_report.update(dist_memory_report)

logger.log_metrics({f'memory/{mem_stat}': val for (mem_stat, val) in memory_report.items()})

Expand Down
13 changes: 7 additions & 6 deletions composer/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import torch

import composer
from composer.loggers.mosaicml_logger import MOSAICML_LOG_DIR_ENV_VAR, MOSAICML_PLATFORM_ENV_VAR
from composer.loggers.mosaicml_logger import (MOSAICML_GPU_LOG_FILE_PREFIX_ENV_VAR, MOSAICML_LOG_DIR_ENV_VAR,
MOSAICML_PLATFORM_ENV_VAR)
from composer.utils import get_free_tcp_port

CLEANUP_TIMEOUT = datetime.timedelta(seconds=30)
Expand Down Expand Up @@ -470,7 +471,7 @@ def main():
args = _parse_args()

logging.basicConfig()
log.setLevel(logging.INFO if args.verbose else logging.WARN)
log.setLevel(logging.INFO if args.verbose else logging.WARNING)

processes = {}

Expand All @@ -481,11 +482,11 @@ def main():
args.stderr = f'{log_tmpdir.name}/rank{{rank}}.stderr.txt'

# If running on the Mosaic platform, log all gpu ranks' stderr and stdout to Mosaic platform
if os.environ.get(
MOSAICML_PLATFORM_ENV_VAR,
'false').lower() == 'true' and str(os.environ.get(MOSAICML_LOG_DIR_ENV_VAR, 'false')).lower() != 'false':
if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'true' and str(
os.environ.get(MOSAICML_LOG_DIR_ENV_VAR, 'false')).lower() != 'false' and os.environ.get(
MOSAICML_GPU_LOG_FILE_PREFIX_ENV_VAR, 'false').lower() != 'false':
log.info('Logging all GPU ranks to Mosaic Platform.')
log_file_format = f'{os.environ.get(MOSAICML_LOG_DIR_ENV_VAR)}/gpu_{{rank}}.txt'
log_file_format = f'{os.environ.get(MOSAICML_LOG_DIR_ENV_VAR)}/{os.environ.get(MOSAICML_GPU_LOG_FILE_PREFIX_ENV_VAR)}{{local_rank}}.txt'
if args.stderr is not None or args.stdout is not None:
warnings.warn(
'Logging to Mosaic Platform. Ignoring provided stdout and stderr args. To use provided stdout and stderr, set MOSAICML_LOG_DIR=false.'
Expand Down
Loading

0 comments on commit 7e43f68

Please sign in to comment.