Skip to content

Commit

Permalink
Add torch 2.3 support (#3209)
Browse files Browse the repository at this point in the history
* update tests

* remove 2.0.1 monkeypatches

* lint

* update images

* lint ignore

* fix lint

* remove import

* rework typing

* add type check

* add type check

* type ignore

* fix docs build

* gating

* test named wrong

* lint

* fix test names
  • Loading branch information
mvpatel2000 authored and Chuck Tang committed May 16, 2024
1 parent b5fda07 commit b4e692c
Show file tree
Hide file tree
Showing 16 changed files with 167 additions and 793 deletions.
23 changes: 4 additions & 19 deletions .github/workflows/daily.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ jobs:
strategy:
matrix:
include:
- name: cpu-3.10-2.0
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: cpu-3.10-2.1
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not doctest
Expand All @@ -34,7 +29,7 @@ jobs:
pytest_command: coverage run -m pytest
composer_package_name: composer
- name: cpu-3.11-2.2
container: mosaicml/pytorch:2.2.0_cpu-python3.11-ubuntu20.04
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
Expand All @@ -43,11 +38,6 @@ jobs:
markers: not daily and (remote or not remote) and not gpu and doctest
pytest_command: coverage run -m pytest tests/test_docs.py
composer_package_name: mosaicml
- name: daily-cpu-3.10-2.0
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: daily-cpu-3.10-2.1
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not doctest
Expand All @@ -59,12 +49,12 @@ jobs:
pytest_command: coverage run -m pytest
composer_package_name: composer
- name: daily-cpu-3.11-2.2
container: mosaicml/pytorch:2.2.0_cpu-python3.11-ubuntu20.04
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: daily-cpu-doctest
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and doctest
pytest_command: coverage run -m pytest tests/test_docs.py
composer_package_name: mosaicml
Expand Down Expand Up @@ -106,17 +96,12 @@ jobs:
# Unlike CPU tests, we run daily tests together with GPU tests to minimize launch time
# on MCLOUD and not eat up all GPUs at once
include:
- name: "gpu-3.10-2.0"
container: mosaicml/pytorch_vision:2.0.1_cu117-python3.10-ubuntu20.04
markers: "(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)"
pytest_command: "coverage run -m pytest"
composer_package_name: "mosaicml"
- name: "gpu-3.10-2.1"
container: mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04
markers: "(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)"
pytest_command: "coverage run -m pytest"
composer_package_name: "mosaicml"
- name: "gpu-3.10-2.2"
- name: "gpu-3.11-2.2"
container: mosaicml/pytorch:2.2.0_cu121-python3.11-ubuntu20.04
markers: "(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)"
pytest_command: "coverage run -m pytest"
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ jobs:
strategy:
matrix:
include:
- name: cpu-3.10-2.0
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-3.10-2.1
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-3.11-2.2
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-doctest
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: not daily and not remote and not gpu and doctest
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ jobs:
strategy:
matrix:
include:
- name: gpu-3.10-2.2
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
- name: gpu-3.11-2.2
container: mosaicml/pytorch:2.2.1_cu121-python3.10-ubuntu20.04
markers: not daily and not remote and gpu and (doctest or not doctest)
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
Expand Down
2 changes: 1 addition & 1 deletion composer/algorithms/utils/augmentation_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def image_as_type(image: _InputImgT, typ: Type[_OutputImgT]) -> _OutputImgT:
raise TypeError(f'Only typ={{torch.Tensor, Image}} is supported; got {typ}')

if typ is torch.Tensor:
return cast(_OutputImgT, torchvision.transforms.functional.to_tensor(image)) # PIL -> Tensor
return cast(_OutputImgT, torchvision.transforms.functional.to_tensor(image)) # type: ignore PIL -> Tensor
return cast(_OutputImgT, torchvision.transforms.functional.to_pil_image(image)) # Tensor -> PIL


Expand Down
19 changes: 12 additions & 7 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import GradScaler # type: ignore
else:
from torch.cuda.amp.grad_scaler import GradScaler # type: ignore

from composer.core.data_spec import DataSpec
from composer.core.event import Event
from composer.core.precision import Precision
Expand Down Expand Up @@ -242,7 +247,7 @@ class State(Serializable):
train the model. Multiple optimizers are not currently supported.
schedulers (LRScheduler | Sequence[LRScheduler], optional):
The learning rate scheduler (can also be a list or tuple of schedulers).
scaler (torch.cuda.amp.GradScaler, optional): The gradient scaler in use for mixed precision training.
scaler (torch.amp.GradScaler, optional): The gradient scaler in use for mixed precision training.
save_metrics (bool, optional): Whether to save metrics in state_dict.
algorithms (Algorithm | Sequence[Algorithm], optional): The algorithms used for training.
callbacks (Callback | Sequence[Callback], optional): The callbacks used for training.
Expand Down Expand Up @@ -326,7 +331,7 @@ class State(Serializable):
profiler (Profiler): The profiler (if profiling is enabled), or ``None`` if not profiling.
rank_zero_seed (int): The seed of the rank zero process.
run_name (str): The name for this training run.
scaler (torch.cuda.amp.GradScaler): The gradient scaler if using mixed-precision training, or
scaler (torch.amp.GradScaler): The gradient scaler if using mixed-precision training, or
``None`` if not using mixed-precision training.
serialized_attributes (List[str]): The names of the attribute which are serialized in a checkpoint.
Expand Down Expand Up @@ -404,7 +409,7 @@ def __init__(
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,

# scaler
scaler: Optional[torch.cuda.amp.grad_scaler.GradScaler] = None,
scaler: Optional[GradScaler] = None,

# state_dict
save_metrics: bool = False,
Expand Down Expand Up @@ -868,7 +873,7 @@ def get_model_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the model.
"""
if version.parse(torch.__version__) > version.parse('2.2.9'):
if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand Down Expand Up @@ -906,7 +911,7 @@ def get_optim_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the optimizer.
"""
if version.parse(torch.__version__) > version.parse('2.2.9'):
if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand Down Expand Up @@ -1228,7 +1233,7 @@ def load_model_state(
model_on_rank = state_dict['model'] is not None

if model_on_rank:
if version.parse(torch.__version__) > version.parse('2.2.9'):
if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
set_model_state_dict(
model=self.model,
Expand Down Expand Up @@ -1292,7 +1297,7 @@ def load_optim_state(self, state_dict: Dict[str, Any], strict: bool = True):
strict (bool): Whether the keys (i.e., optimizer parameter names) in the optimizer
state dict should perfectly match the keys in the optimizer instance.
"""
if version.parse(torch.__version__) > version.parse('2.2.9'):
if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict
optimizer = self.optimizers[0]
set_optimizer_state_dict(
Expand Down
10 changes: 8 additions & 2 deletions composer/trainer/_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from typing import Optional, Union

import torch
from torch.cuda.amp.grad_scaler import GradScaler, OptState, _refresh_per_optimizer_state
from packaging import version
from torch.cuda.amp.grad_scaler import GradScaler, OptState
from torch.optim import Optimizer

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore
else:
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore

from composer.utils import dist

__all__ = ['ClosureGradScaler']
Expand Down Expand Up @@ -78,7 +84,7 @@ def _amp_closure(**kwargs):
return optimizer.step(closure=_amp_closure) # type: ignore

# Mostly copied from original grad_scaler implementation
# See: https://pytorch.org/docs/stable/_modules/torch/cuda/amp/grad_scaler.html#GradScaler
# See: https://pytorch.org/docs/stable/_modules/torch/amp/grad_scaler.html#GradScaler
def update(self, new_scale: Optional[Union[float, torch.FloatTensor]] = None):
"""Updates the scale factor.
Expand Down
49 changes: 3 additions & 46 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,11 @@
import torch
from packaging import version
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed.fsdp import FullyShardedDataParallel


def patch_pytorch():
"""Monkey patches pytorch functions based on pytorch version."""
if version.parse(torch.__version__) < version.parse('2.0.2'):
# Monkey patch for torch == 2.0.1

# Monkey patch __init__ where __init__ calls the custom _auto_wrap fn
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p0p1

FullyShardedDataParallel.__init__ = init_fn_t2p0p1 # type: ignore

# Monkey patch sharding method
from composer.trainer.mosaic_fsdp_utils import build_metadata

ChunkShardingSpec.build_metadata = build_metadata

elif version.parse(torch.__version__) < version.parse('2.1.1'):
if version.parse(torch.__version__) < version.parse('2.1.1'):
# Monkey patch for torch < 2.1.1 ie torch == 2.1.0

# Monkey patch sharding method
Expand Down Expand Up @@ -61,8 +47,8 @@ def patch_pytorch():
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

elif version.parse(torch.__version__) < version.parse('2.2.9'):
# Monkey patch for torch < 2.3.0 ie torch == 2.2.1/2.2.2 currently
elif version.parse(torch.__version__) < version.parse('2.2.3'):
# Monkey patch for torch < 2.2.3 ie torch == 2.2.1/2.2.2 currently

# Fix memory leak for FSDP.optim_state_dict_to_load
# https://github.com/pytorch/pytorch/issues/116553
Expand All @@ -73,35 +59,6 @@ def patch_pytorch():

elif version.parse(torch.__version__) < version.parse('2.3.1'):
# Monkey patch for torch < 2.3.1 ie torch == 2.3.0
# Note: this is the same patch as 2.2.0, we are just making a new if branch
# for clarity and modularity of changes.

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Monkeypatch state_dict
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p3p0
FullyShardedDataParallel.__init__ = init_fn_t2p3p0

# Monkeypatch state_dict
from torch.distributed.checkpoint import state_dict # type: ignore

from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p3p0
state_dict._verify_options = _verify_options_t2p3p0

# Monkeypatch sharding optim state
from torch.distributed.fsdp import _optim_utils

from composer.trainer.mosaic_fsdp_utils import _shard_orig_param_state
_optim_utils._shard_orig_param_state = _shard_orig_param_state

# Monkeypatch checkpointing full state dict
from torch.distributed.fsdp import _state_dict_utils

from composer.trainer.mosaic_fsdp_utils import _full_pre_state_dict_hook, _set_use_dtensor
_state_dict_utils._full_pre_state_dict_hook = _full_pre_state_dict_hook
_state_dict_utils._set_use_dtensor = _set_use_dtensor

# Monkeypatch _flat_param.py to fix 2D with SHARD_GRAD_OP
# Issue: https://github.com/pytorch/pytorch/issues/123272
Expand Down
Loading

0 comments on commit b4e692c

Please sign in to comment.