Skip to content

Commit

Permalink
Better error when device mismatches when calling gather() on CUDA (#2180
Browse files Browse the repository at this point in the history
)

* Better err

* Update src/accelerate/utils/operations.py

Co-authored-by: Benjamin Bossan <[email protected]>

---------

Co-authored-by: Benjamin Bossan <[email protected]>
  • Loading branch information
muellerzr and BenjaminBossan authored Nov 29, 2023
1 parent 0ba3e9b commit 1516379
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,13 @@ def _gpu_gather_one(tensor):
if not tensor.is_contiguous():
tensor = tensor.contiguous()

# Check if `tensor` is not on CUDA
if state.device.type == "cuda" and tensor.device.type != "cuda":
raise RuntimeError(
"One or more of the tensors passed to `gather` were not on the GPU while the `Accelerator` is configured for CUDA. "
"Please move it to the GPU before calling `gather`."
)

if state.backend is not None and state.backend != "gloo":
# We use `empty` as `all_gather_into_tensor` slightly
# differs from `all_gather` for better efficiency,
Expand Down

0 comments on commit 1516379

Please sign in to comment.