From 73324420b6cd9194b5ba1028a85ca4d50e724513 Mon Sep 17 00:00:00 2001 From: DESKTOP-42S1K65 Date: Wed, 22 Nov 2023 10:21:15 -0500 Subject: [PATCH] Better err --- src/accelerate/utils/operations.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 5d1df1d995c..15d7d52e54e 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -298,6 +298,13 @@ def _gpu_gather_one(tensor): if not tensor.is_contiguous(): tensor = tensor.contiguous() + # Check if `tensor` is not on CUDA + if state.device.type == "cuda" and tensor.device.type != "cuda": + raise RuntimeError( + "One or more of the tensors passed to `gather` was not on the GPU while the `Accelerator` is configured for CUDA. " + "Please move it to the GPU before calling `gather`." + ) + if state.backend is not None and state.backend != "gloo": # We use `empty` as `all_gather_into_tensor` slightly # differs from `all_gather` for better efficiency,