diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 954850a2df6..8638a36d41e 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -32,7 +32,7 @@ from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_STATE_DICT_TYPE from .environment import str_to_bool -from .imports import is_xpu_available +from .imports import is_xpu_available, is_npu_available, is_cuda_available from .versions import compare_versions @@ -932,7 +932,14 @@ def __post_init__(self): self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1 if self.sync_module_states: - device = torch.cuda.current_device() if not is_xpu_available() else torch.xpu.current_device() + if is_npu_available(): + device = torch.npu.current_device() + elif is_cuda_available(): + device = torch.cuda.current_device() + elif is_xpu_available(): + device = torch.xpu.current_device() + else: + raise RuntimeError("There are currently no available device found in ['XPU', 'CUDA', 'NPU']!") self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False) @staticmethod