Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ji-huazhong committed Dec 4, 2024
1 parent 1b2ad57 commit 0a923ff
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,9 +1706,11 @@ def load_state_dict(checkpoint_file, device_map=None):
progress_bar = None
for device in devices:
target_device = device
if is_xpu_available():
if isinstance(device, int):
if isinstance(device, int):
if is_xpu_available():
target_device = f"xpu:{device}"
elif is_npu_available():
target_device = f"npu:{device}"

with safe_open(checkpoint_file, framework="pt", device=target_device) as f:
for key in device_weights[device]:
Expand Down

0 comments on commit 0a923ff

Please sign in to comment.