Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch 2.3 support #3209

Merged
merged 16 commits into from
Apr 24, 2024
23 changes: 4 additions & 19 deletions .github/workflows/daily.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ jobs:
strategy:
matrix:
include:
- name: cpu-3.10-2.0
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: cpu-3.10-2.1
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not doctest
Expand All @@ -34,7 +29,7 @@ jobs:
pytest_command: coverage run -m pytest
composer_package_name: composer
- name: cpu-3.11-2.2
container: mosaicml/pytorch:2.2.0_cpu-python3.11-ubuntu20.04
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
Expand All @@ -43,11 +38,6 @@ jobs:
markers: not daily and (remote or not remote) and not gpu and doctest
pytest_command: coverage run -m pytest tests/test_docs.py
composer_package_name: mosaicml
- name: daily-cpu-3.10-2.0
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: daily-cpu-3.10-2.1
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not doctest
Expand All @@ -59,12 +49,12 @@ jobs:
pytest_command: coverage run -m pytest
composer_package_name: composer
- name: daily-cpu-3.11-2.2
container: mosaicml/pytorch:2.2.0_cpu-python3.11-ubuntu20.04
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: daily-cpu-doctest
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and doctest
pytest_command: coverage run -m pytest tests/test_docs.py
composer_package_name: mosaicml
Expand Down Expand Up @@ -106,17 +96,12 @@ jobs:
# Unlike CPU tests, we run daily tests together with GPU tests to minimize launch time
# on MCLOUD and not eat up all GPUs at once
include:
- name: "gpu-3.10-2.0"
container: mosaicml/pytorch_vision:2.0.1_cu117-python3.10-ubuntu20.04
markers: "(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)"
pytest_command: "coverage run -m pytest"
composer_package_name: "mosaicml"
- name: "gpu-3.10-2.1"
container: mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04
markers: "(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)"
pytest_command: "coverage run -m pytest"
composer_package_name: "mosaicml"
- name: "gpu-3.10-2.2"
- name: "gpu-3.11-2.2"
container: mosaicml/pytorch:2.2.0_cu121-python3.11-ubuntu20.04
markers: "(daily or not daily) and (remote or not remote) and gpu and (doctest or not doctest)"
pytest_command: "coverage run -m pytest"
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ jobs:
strategy:
matrix:
include:
- name: cpu-3.10-2.0
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-3.10-2.1
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-3.11-2.2
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-doctest
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: not daily and not remote and not gpu and doctest
pytest_command: coverage run -m pytest tests/test_docs.py
name: ${{ matrix.name }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ jobs:
strategy:
matrix:
include:
- name: gpu-3.10-2.1
container: mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04
- name: gpu-3.11-2.2
container: mosaicml/pytorch:2.2.1_cu121-python3.10-ubuntu20.04
markers: not daily and not remote and gpu and (doctest or not doctest)
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
Expand Down
2 changes: 1 addition & 1 deletion composer/algorithms/utils/augmentation_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def image_as_type(image: _InputImgT, typ: Type[_OutputImgT]) -> _OutputImgT:
raise TypeError(f'Only typ={{torch.Tensor, Image}} is supported; got {typ}')

if typ is torch.Tensor:
return cast(_OutputImgT, torchvision.transforms.functional.to_tensor(image)) # PIL -> Tensor
return cast(_OutputImgT, torchvision.transforms.functional.to_tensor(image)) # type: ignore PIL -> Tensor
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
return cast(_OutputImgT, torchvision.transforms.functional.to_pil_image(image)) # Tensor -> PIL


Expand Down
19 changes: 12 additions & 7 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import GradScaler # type: ignore
else:
from torch.cuda.amp.grad_scaler import GradScaler # type: ignore

from composer.core.data_spec import DataSpec
from composer.core.event import Event
from composer.core.precision import Precision
Expand Down Expand Up @@ -242,7 +247,7 @@ class State(Serializable):
train the model. Multiple optimizers are not currently supported.
schedulers (LRScheduler | Sequence[LRScheduler], optional):
The learning rate scheduler (can also be a list or tuple of schedulers).
scaler (torch.cuda.amp.GradScaler, optional): The gradient scaler in use for mixed precision training.
scaler (torch.amp.GradScaler, optional): The gradient scaler in use for mixed precision training.
save_metrics (bool, optional): Whether to save metrics in state_dict.
algorithms (Algorithm | Sequence[Algorithm], optional): The algorithms used for training.
callbacks (Callback | Sequence[Callback], optional): The callbacks used for training.
Expand Down Expand Up @@ -326,7 +331,7 @@ class State(Serializable):
profiler (Profiler): The profiler (if profiling is enabled), or ``None`` if not profiling.
rank_zero_seed (int): The seed of the rank zero process.
run_name (str): The name for this training run.
scaler (torch.cuda.amp.GradScaler): The gradient scaler if using mixed-precision training, or
scaler (torch.amp.GradScaler): The gradient scaler if using mixed-precision training, or
``None`` if not using mixed-precision training.
serialized_attributes (List[str]): The names of the attribute which are serialized in a checkpoint.

Expand Down Expand Up @@ -404,7 +409,7 @@ def __init__(
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,

# scaler
scaler: Optional[torch.cuda.amp.grad_scaler.GradScaler] = None,
scaler: Optional[GradScaler] = None,

# state_dict
save_metrics: bool = False,
Expand Down Expand Up @@ -868,7 +873,7 @@ def get_model_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the model.
"""
if version.parse(torch.__version__) > version.parse('2.2.9'):
if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand Down Expand Up @@ -906,7 +911,7 @@ def get_optim_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the optimizer.
"""
if version.parse(torch.__version__) > version.parse('2.2.9'):
if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand Down Expand Up @@ -1228,7 +1233,7 @@ def load_model_state(
model_on_rank = state_dict['model'] is not None

if model_on_rank:
if version.parse(torch.__version__) > version.parse('2.2.9'):
if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
set_model_state_dict(
model=self.model,
Expand Down Expand Up @@ -1292,7 +1297,7 @@ def load_optim_state(self, state_dict: Dict[str, Any], strict: bool = True):
strict (bool): Whether the keys (i.e., optimizer parameter names) in the optimizer
state dict should perfectly match the keys in the optimizer instance.
"""
if version.parse(torch.__version__) > version.parse('2.2.9'):
if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict
optimizer = self.optimizers[0]
set_optimizer_state_dict(
Expand Down
10 changes: 8 additions & 2 deletions composer/trainer/_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from typing import Optional, Union

import torch
from torch.cuda.amp.grad_scaler import GradScaler, OptState, _refresh_per_optimizer_state
from packaging import version
from torch.cuda.amp.grad_scaler import GradScaler, OptState
from torch.optim import Optimizer

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore
else:
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore

from composer.utils import dist

__all__ = ['ClosureGradScaler']
Expand Down Expand Up @@ -78,7 +84,7 @@ def _amp_closure(**kwargs):
return optimizer.step(closure=_amp_closure) # type: ignore

# Mostly copied from original grad_scaler implementation
# See: https://pytorch.org/docs/stable/_modules/torch/cuda/amp/grad_scaler.html#GradScaler
# See: https://pytorch.org/docs/stable/_modules/torch/amp/grad_scaler.html#GradScaler
def update(self, new_scale: Optional[Union[float, torch.FloatTensor]] = None):
"""Updates the scale factor.

Expand Down
49 changes: 3 additions & 46 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,11 @@
import torch
from packaging import version
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed.fsdp import FullyShardedDataParallel


def patch_pytorch():
"""Monkey patches pytorch functions based on pytorch version."""
if version.parse(torch.__version__) < version.parse('2.0.2'):
# Monkey patch for torch == 2.0.1

# Monkey patch __init__ where __init__ calls the custom _auto_wrap fn
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p0p1

FullyShardedDataParallel.__init__ = init_fn_t2p0p1 # type: ignore

# Monkey patch sharding method
from composer.trainer.mosaic_fsdp_utils import build_metadata

ChunkShardingSpec.build_metadata = build_metadata

elif version.parse(torch.__version__) < version.parse('2.1.1'):
if version.parse(torch.__version__) < version.parse('2.1.1'):
# Monkey patch for torch < 2.1.1 ie torch == 2.1.0

# Monkey patch sharding method
Expand Down Expand Up @@ -61,8 +47,8 @@ def patch_pytorch():
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

elif version.parse(torch.__version__) < version.parse('2.2.9'):
# Monkey patch for torch < 2.3.0 ie torch == 2.2.1/2.2.2 currently
elif version.parse(torch.__version__) < version.parse('2.2.3'):
# Monkey patch for torch < 2.2.3 ie torch == 2.2.1/2.2.2 currently

# Fix memory leak for FSDP.optim_state_dict_to_load
# https://github.com/pytorch/pytorch/issues/116553
Expand All @@ -73,35 +59,6 @@ def patch_pytorch():

elif version.parse(torch.__version__) < version.parse('2.3.1'):
# Monkey patch for torch < 2.3.1 ie torch == 2.3.0
# Note: this is the same patch as 2.2.0, we are just making a new if branch
# for clarity and modularity of changes.

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Monkeypatch state_dict
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p3p0
FullyShardedDataParallel.__init__ = init_fn_t2p3p0

# Monkeypatch state_dict
from torch.distributed.checkpoint import state_dict # type: ignore

from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p3p0
state_dict._verify_options = _verify_options_t2p3p0

# Monkeypatch sharding optim state
from torch.distributed.fsdp import _optim_utils

from composer.trainer.mosaic_fsdp_utils import _shard_orig_param_state
_optim_utils._shard_orig_param_state = _shard_orig_param_state

# Monkeypatch checkpointing full state dict
from torch.distributed.fsdp import _state_dict_utils

from composer.trainer.mosaic_fsdp_utils import _full_pre_state_dict_hook, _set_use_dtensor
_state_dict_utils._full_pre_state_dict_hook = _full_pre_state_dict_hook
_state_dict_utils._set_use_dtensor = _set_use_dtensor

# Monkeypatch _flat_param.py to fix 2D with SHARD_GRAD_OP
# Issue: https://github.com/pytorch/pytorch/issues/123272
Expand Down
Loading
Loading