Skip to content

Commit

Permalink
Fix XPU inference (#2383)
Browse files Browse the repository at this point in the history
Though it will complain about "Device xpu is not recognized, available devices are integers(for GPU/XPU),
'mps', 'cpu' and 'disk'", but you cannot just put 0 as device, or it will treat 0 as CUDA device, then complains
again that torch is not compiled with CUDA enabled.

You will need safetensors >= 0.4.2 if using safetensors files.
  • Loading branch information
notsyncing authored Feb 2, 2024
1 parent cd7ff5e commit 46f1391
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
infer_auto_device_map,
is_npu_available,
is_torch_version,
is_xpu_available,
load_checkpoint_in_model,
offload_state_dict,
parse_flag_from_env,
Expand Down Expand Up @@ -451,6 +452,8 @@ def wrapper(*args, **kwargs):
model.to = add_warning(model.to, model)
if is_npu_available():
model.npu = add_warning(model.npu, model)
elif is_xpu_available():
model.xpu = add_warning(model.xpu, model)
else:
model.cuda = add_warning(model.cuda, model)

Expand All @@ -459,6 +462,8 @@ def wrapper(*args, **kwargs):
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if is_npu_available() and isinstance(device, int):
device = f"npu:{device}"
elif is_xpu_available() and isinstance(device, int):
device = f"xpu:{device}"
if device != "disk":
model.to(device)
else:
Expand Down
22 changes: 21 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import contextlib
import gc
import importlib
import inspect
import json
import logging
Expand All @@ -24,6 +25,7 @@
from collections import OrderedDict, defaultdict
from typing import Dict, List, Optional, Tuple, Union

import packaging
import torch
import torch.nn as nn

Expand All @@ -33,6 +35,7 @@
from .imports import is_mps_available, is_npu_available, is_peft_available, is_xpu_available
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm
from .versions import compare_versions


if is_npu_available(check_device=False):
Expand Down Expand Up @@ -367,6 +370,8 @@ def set_module_tensor_to_device(
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if is_npu_available() and isinstance(device, int):
device = f"npu:{device}"
if is_xpu_available() and isinstance(device, int):
device = f"xpu:{device}"
if value is None:
new_value = old_value.to(device)
if dtype is not None and device in ["meta", torch.device("meta")]:
Expand Down Expand Up @@ -427,6 +432,8 @@ def set_module_tensor_to_device(
# clean pre and post foward hook
if is_npu_available():
torch.npu.empty_cache()
elif is_xpu_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

Expand Down Expand Up @@ -1351,7 +1358,20 @@ def load_state_dict(checkpoint_file, device_map=None):
else:
progress_bar = None
for device in devices:
with safe_open(checkpoint_file, framework="pt", device=device) as f:
target_device = device

if is_xpu_available():
current_safetensors_version = packaging.version.parse(importlib.metadata.version("safetensors"))

if compare_versions(current_safetensors_version, "<", "0.4.2"):
raise ModuleNotFoundError(
f"You need at least safetensors 0.4.2 for Intel GPU, while you have {current_safetensors_version}"
)

if isinstance(device, int):
target_device = f"xpu:{device}"

with safe_open(checkpoint_file, framework="pt", device=target_device) as f:
for key in device_weights[device]:
if progress_bar is not None:
progress_bar.set_postfix(dev=device, refresh=False)
Expand Down
13 changes: 12 additions & 1 deletion src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from ..state import PartialState
from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
from .dataclasses import DistributedType, TensorInformation
from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available
from .imports import (
is_npu_available,
is_torch_distributed_available,
is_torch_version,
is_tpu_available,
is_xpu_available,
)


if is_tpu_available(check_device=False):
Expand Down Expand Up @@ -171,6 +177,11 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
# `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)).
elif device == torch.device("npu"):
device = "npu:0"
elif is_xpu_available():
if isinstance(device, int):
device = f"xpu:{device}"
elif device == torch.device("xpu"):
device = "xpu:0"
try:
return tensor.to(device, non_blocking=non_blocking)
except TypeError: # .to() doesn't accept non_blocking as kwarg
Expand Down

0 comments on commit 46f1391

Please sign in to comment.