Skip to content

Commit

Permalink
Sync states for npu fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
jq460494839 committed Oct 27, 2023
1 parent 5440387 commit 8915d4f
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_STATE_DICT_TYPE
from .environment import str_to_bool
from .imports import is_xpu_available
from .imports import is_xpu_available, is_npu_available, is_cuda_available
from .versions import compare_versions


Expand Down Expand Up @@ -932,7 +932,14 @@ def __post_init__(self):
self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1

if self.sync_module_states:
device = torch.cuda.current_device() if not is_xpu_available() else torch.xpu.current_device()
if is_npu_available():
device = torch.npu.current_device()
elif is_cuda_available():
device = torch.cuda.current_device()
elif is_xpu_available():
device = torch.xpu.current_device()
else:
raise RuntimeError("There are currently no available device found in ['XPU', 'CUDA', 'NPU']!")
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)

@staticmethod
Expand Down

0 comments on commit 8915d4f

Please sign in to comment.