diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 5f88e54e3c9..73ada66544f 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1670,9 +1670,11 @@ def load_state_dict(checkpoint_file, device_map=None): if len(set(device_map.values())) == 1: device = list(device_map.values())[0] target_device = device - if is_xpu_available(): - if isinstance(device, int): + if isinstance(device, int): + if is_xpu_available(): target_device = f"xpu:{device}" + elif is_npu_available(): + target_device = f"npu:{device}" return safe_load_file(checkpoint_file, device=target_device) @@ -1704,9 +1706,11 @@ def load_state_dict(checkpoint_file, device_map=None): progress_bar = None for device in devices: target_device = device - if is_xpu_available(): - if isinstance(device, int): + if isinstance(device, int): + if is_xpu_available(): target_device = f"xpu:{device}" + elif is_npu_available(): + target_device = f"npu:{device}" with safe_open(checkpoint_file, framework="pt", device=target_device) as f: for key in device_weights[device]: