diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index fa1f2fcf5df..7225a6136e4 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -42,6 +42,7 @@ is_torch_available, is_torch_cuda_available, is_torch_mlu_available, + is_torch_mps_available, is_torch_npu_available, is_torch_xpu_available, logging, @@ -860,6 +861,8 @@ def __init__( self.device = torch.device(f"npu:{device}") elif is_torch_xpu_available(check_device=True): self.device = torch.device(f"xpu:{device}") + elif is_torch_mps_available(): + self.device = torch.device(f"mps:{device}") else: raise ValueError(f"{device} unrecognized or not available.") else: