Skip to content

Commit

Permalink
Merge branch 'dev' into composer_lora
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jan 18, 2024
2 parents e9d4c4c + 1b393b6 commit 794fc7c
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 411 deletions.
2 changes: 1 addition & 1 deletion composer/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The Composer Version."""

__version__ = '0.17.2'
__version__ = '0.18.0'
10 changes: 9 additions & 1 deletion composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import fnmatch
import os
import pathlib
import textwrap
Expand Down Expand Up @@ -54,6 +55,7 @@ class MLFlowLogger(LoggerDestination):
synchronously to the MLflow backend. If ``False``, Mlflow will log asynchronously. (default: ``False``)
log_system_metrics (bool, optional): Whether to log system metrics. If ``True``, Mlflow will
log system metrics (CPU/GPU/memory/network usage) during training. (default: ``True``)
ignore_metrics (List[str], optional): A list of glob patterns for metrics to ignore when logging. (default: ``None``)
"""

def __init__(
Expand All @@ -68,6 +70,7 @@ def __init__(
model_registry_uri: Optional[str] = None,
synchronous: bool = False,
log_system_metrics: bool = True,
ignore_metrics: Optional[List[str]] = None,
) -> None:
try:
import mlflow
Expand All @@ -85,6 +88,7 @@ def __init__(
self.model_registry_uri = model_registry_uri
self.synchronous = synchronous
self.log_system_metrics = log_system_metrics
self.ignore_metrics = [] if ignore_metrics is None else ignore_metrics
if self.model_registry_uri == 'databricks-uc':
if len(self.model_registry_prefix.split('.')) != 2:
raise ValueError(f'When registering to Unity Catalog, model_registry_prefix must be in the format ' +
Expand Down Expand Up @@ -198,7 +202,11 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> No
from mlflow import log_metrics
if self._enabled:
# Convert all metrics to floats to placate mlflow.
metrics = {k: float(v) for k, v in metrics.items()}
metrics = {
k: float(v)
for k, v in metrics.items()
if not any(fnmatch.fnmatch(k, pattern) for pattern in self.ignore_metrics)
}
log_metrics(
metrics=metrics,
step=step,
Expand Down
25 changes: 12 additions & 13 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,19 +661,18 @@ def _check_fn(module: torch.nn.Module) -> bool:

# Print FSDP wrapped model and FSDP config if `verbose=True`
if fsdp_config['verbose']:
print(f'FSDP: Wrapped Model:')
print(model)
print(f'FSDP: Using sharding_strategy={sharding_strategy}')
print(f'FSDP: Using cpu_offload={cpu_offload}')
print(f'FSDP: Using mixed_precision={mixed_precision}')
print(f'FSDP: Using backward_prefetch={backward_prefetch}')
print(f'FSDP: Using activation_checkpointing={activation_checkpointing}')
print(f'FSDP: Using activation_cpu_offload={activation_cpu_offload}')
print(f'FSDP: Using sync_module_states={sync_module_states}')
print(f'FSDP: Using forward_prefetch={forward_prefetch}')
print(f'FSDP: Using limit_all_gathers={limit_all_gathers}')
print(f'FSDP: Using state_dict_type={state_dict_type}')
print(f'FSDP: Using sharded_ckpt_prefix_dir={sharded_ckpt_prefix_dir}')
log.info(f'FSDP: Wrapped model: {model}')
log.info(f'FSDP: Using sharding_strategy={sharding_strategy}')
log.info(f'FSDP: Using cpu_offload={cpu_offload}')
log.info(f'FSDP: Using mixed_precision={mixed_precision}')
log.info(f'FSDP: Using backward_prefetch={backward_prefetch}')
log.info(f'FSDP: Using activation_checkpointing={activation_checkpointing}')
log.info(f'FSDP: Using activation_cpu_offload={activation_cpu_offload}')
log.info(f'FSDP: Using sync_module_states={sync_module_states}')
log.info(f'FSDP: Using forward_prefetch={forward_prefetch}')
log.info(f'FSDP: Using limit_all_gathers={limit_all_gathers}')
log.info(f'FSDP: Using state_dict_type={state_dict_type}')
log.info(f'FSDP: Using sharded_ckpt_prefix_dir={sharded_ckpt_prefix_dir}')

# Rebuild optimizer now that parameters are sharded
if optimizers:
Expand Down
28 changes: 0 additions & 28 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,31 +62,13 @@ def patch_pytorch():
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# # Better overlap communication and computation
# from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p1,
# _wait_for_computation_stream, forward)
# _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p1
# _runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
# _runtime_utils._root_pre_forward = _root_pre_forward
# FullyShardedDataParallel.forward = forward

elif version.parse(torch.__version__) < version.parse('2.2.1'):
# Monkey patch for torch < 2.2.1 ie torch == 2.2.0

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# # Better overlap communication and computation
# from torch.distributed.fsdp import _runtime_utils

# from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p2,
# _wait_for_computation_stream, forward)
# _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2
# _runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
# _runtime_utils._root_pre_forward = _root_pre_forward
# FullyShardedDataParallel.forward = forward

# Monkeypatch dtensor support
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0
FullyShardedDataParallel.__init__ = init_fn_t2p2p0 # type: ignore
Expand All @@ -106,16 +88,6 @@ def patch_pytorch():
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# # Better overlap communication and computation
# from torch.distributed.fsdp import _runtime_utils

# from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p2,
# _wait_for_computation_stream, forward)
# _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2
# _runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
# _runtime_utils._root_pre_forward = _root_pre_forward
# FullyShardedDataParallel.forward = forward

# Monkeypath state_dict
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0
FullyShardedDataParallel.__init__ = init_fn_t2p2p0
Expand Down
Loading

0 comments on commit 794fc7c

Please sign in to comment.