diff --git a/torch_cluster/testing.py b/torch_cluster/testing.py index f60291b..68949fa 100644 --- a/torch_cluster/testing.py +++ b/torch_cluster/testing.py @@ -6,7 +6,10 @@ torch.half, torch.bfloat16, torch.float, torch.double, torch.int, torch.long ] -grad_dtypes = [torch.half, torch.float, torch.double] +if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + grad_dtypes = [torch.float, torch.double] +else: + grad_dtypes = [torch.half, torch.float, torch.double] floating_dtypes = grad_dtypes + [torch.bfloat16] devices = [torch.device('cpu')]