Skip to content

Commit

Permalink
Accommodate FSDP full-precision param_dtype training with PyTorch <…
Browse files Browse the repository at this point in the history
… 2.0 (#18278)

(cherry picked from commit c081b48)
  • Loading branch information
speediedan authored and Borda committed Aug 28, 2023
1 parent 754ae68 commit f97058d
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 15 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed FSDP full-precision `param_dtype` training (`16-mixed`, `bf16-mixed` and `32-true` configurations) to avoid FSDP assertion errors with PyTorch < 2.0 ([#18278](https://github.com/Lightning-AI/lightning/pull/18278))


- Fixed issue where DDP subprocesses that used Hydra would set hydra's working directory to current directory ([#18145](https://github.com/Lightning-AI/lightning/pull/18145))
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))
- Fixed an issue with `Fabric.all_reduce()` not performing an inplace operation for all backends consistently ([#18235](https://github.com/Lightning-AI/lightning/pull/18235))
Expand Down
9 changes: 6 additions & 3 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from lightning.fabric.plugins.precision.amp import MixedPrecision
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0

if TYPE_CHECKING:
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
Expand Down Expand Up @@ -48,11 +48,14 @@ def __init__(
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

# With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision`
# property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to
# `torch.float32` here with PyTorch < 2.0.
if self.precision == "16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.distributed import log
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH

Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process ([#16966](https://github.com/Lightning-AI/lightning/pull/16966))


- Fixed FSDP full-precision `param_dtype` training (`16-mixed`, `bf16-mixed` and `32-true` configurations) to avoid FSDP assertion errors with PyTorch < 2.0 ([#18278](https://github.com/Lightning-AI/lightning/pull/18278))


## [2.0.7] - 2023-08-14

### Added
Expand Down
9 changes: 6 additions & 3 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
from lightning.pytorch.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -57,11 +57,14 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
def mixed_precision_config(self) -> Optional[MixedPrecision]:
assert MixedPrecision is not None

# With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision`
# property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to
# `torch.float32` here with PyTorch < 2.0.
if self.precision == "16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
Expand Down
18 changes: 15 additions & 3 deletions tests/tests_fabric/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,21 @@ def test_fsdp_precision_support(*_):
@pytest.mark.parametrize(
("precision", "expected"),
[
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
# TODO: add 16-true and bf16-true once supported
pytest.param(
"16-mixed", (torch.float32, torch.float16, torch.float16), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0"
),
pytest.param(
"16-mixed", (None, torch.float16, torch.float16), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0"
),
pytest.param(
"bf16-mixed",
(torch.float32, torch.bfloat16, torch.bfloat16),
marks=RunIf(min_torch="2.0"),
id="bf16-mixed-ge2_0",
),
pytest.param(
"bf16-mixed", (None, torch.bfloat16, torch.bfloat16), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0"
),
],
)
def test_fsdp_precision_config(precision, expected):
Expand Down
8 changes: 4 additions & 4 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.trainer.strategy.precision_plugin, FSDPMixedPrecisionPlugin)

if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
Expand Down Expand Up @@ -122,10 +122,10 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.trainer.strategy.precision_plugin, FSDPMixedPrecisionPlugin)

if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
Expand Down

0 comments on commit f97058d

Please sign in to comment.