diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index 4e6ab7ef8ec..c29075ac996 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -451,6 +451,8 @@ def wrapper(*args, **kwargs): model.to = add_warning(model.to, model) if is_npu_available(): model.npu = add_warning(model.npu, model) + elif is_xpu_available(): + model.xpu = add_warning(model.xpu, model) else: model.cuda = add_warning(model.cuda, model) @@ -459,6 +461,8 @@ def wrapper(*args, **kwargs): # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). if is_npu_available() and isinstance(device, int): device = f"npu:{device}" + elif is_xpu_available() and isinstance(device, int): + device = f"xpu:{device}" if device != "disk": model.to(device) else: diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index ebaa56250e6..86da0517677 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -367,6 +367,8 @@ def set_module_tensor_to_device( # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). if is_npu_available() and isinstance(device, int): device = f"npu:{device}" + if is_xpu_available() and isinstance(device, int): + device = f"xpu:{device}" if value is None: new_value = old_value.to(device) if dtype is not None and device in ["meta", torch.device("meta")]: @@ -427,6 +429,8 @@ def set_module_tensor_to_device( # clean pre and post foward hook if is_npu_available(): torch.npu.empty_cache() + elif is_xpu_available(): + torch.xpu.empty_cache() else: torch.cuda.empty_cache() @@ -1351,7 +1355,12 @@ def load_state_dict(checkpoint_file, device_map=None): else: progress_bar = None for device in devices: - with safe_open(checkpoint_file, framework="pt", device=device) as f: + target_device = device + + if is_xpu_available() and isinstance(device, int): + target_device = f"xpu:{device}" + + with safe_open(checkpoint_file, framework="pt", device=target_device) as f: for key in device_weights[device]: if progress_bar is not None: progress_bar.set_postfix(dev=device, refresh=False) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 4d4dedff447..e42becf5b4a 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -26,7 +26,7 @@ from ..state import PartialState from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES from .dataclasses import DistributedType, TensorInformation -from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available +from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available, is_xpu_available if is_tpu_available(check_device=False): @@ -171,6 +171,9 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None): # `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)). elif device == torch.device("npu"): device = "npu:0" + elif is_xpu_available(): + if isinstance(device, int): + device = f"xpu:{device}" try: return tensor.to(device, non_blocking=non_blocking) except TypeError: # .to() doesn't accept non_blocking as kwarg