Skip to content

Commit

Permalink
Fix XPU inference
Browse files Browse the repository at this point in the history
Though it will complain about "Device 0 is not recognized, available devices are integers(for GPU/XPU),
'mps', 'cpu' and 'disk'", but you cannot just put 0 as device, or it will treat 0 as CUDA device, then complains
again that torch is not compiled with CUDA enabled.

You will need safetensors >= 0.4.2 if using safetensors files.
  • Loading branch information
notsyncing committed Jan 28, 2024
1 parent 7aafa25 commit 56d0a23
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
11 changes: 10 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def set_module_tensor_to_device(
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if is_npu_available() and isinstance(device, int):
device = f"npu:{device}"
if is_xpu_available() and isinstance(device, int):
device = f"xpu:{device}"
if value is None:
new_value = old_value.to(device)
if dtype is not None and device in ["meta", torch.device("meta")]:
Expand Down Expand Up @@ -427,6 +429,8 @@ def set_module_tensor_to_device(
# clean pre and post foward hook
if is_npu_available():
torch.npu.empty_cache()
elif is_xpu_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

Expand Down Expand Up @@ -1351,7 +1355,12 @@ def load_state_dict(checkpoint_file, device_map=None):
else:
progress_bar = None
for device in devices:
with safe_open(checkpoint_file, framework="pt", device=device) as f:
target_device = device

if is_xpu_available() and isinstance(device, int):
target_device = f"xpu:{device}"

with safe_open(checkpoint_file, framework="pt", device=target_device) as f:
for key in device_weights[device]:
if progress_bar is not None:
progress_bar.set_postfix(dev=device, refresh=False)
Expand Down
4 changes: 3 additions & 1 deletion src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..state import PartialState
from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
from .dataclasses import DistributedType, TensorInformation
from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available
from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available, is_xpu_available


if is_tpu_available(check_device=False):
Expand Down Expand Up @@ -171,6 +171,8 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
# `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)).
elif device == torch.device("npu"):
device = "npu:0"
elif is_xpu_available() and isinstance(device, int):
device = f"xpu:{device}"
try:
return tensor.to(device, non_blocking=non_blocking)
except TypeError: # .to() doesn't accept non_blocking as kwarg
Expand Down

0 comments on commit 56d0a23

Please sign in to comment.