Skip to content

Commit

Permalink
Add allgather check for xpu (#2199)
Browse files Browse the repository at this point in the history
* add  allgather check for xpu

* style fix

* fix test

* fix test and review
  • Loading branch information
abhilash1910 authored Dec 5, 2023
1 parent 47c1445 commit 47e6c36
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,6 @@ def _gpu_gather_one(tensor):
if not tensor.is_contiguous():
tensor = tensor.contiguous()

# Check if `tensor` is not on CUDA
if state.device.type == "cuda" and tensor.device.type != "cuda":
raise RuntimeError(
"One or more of the tensors passed to `gather` were not on the GPU while the `Accelerator` is configured for CUDA. "
"Please move it to the GPU before calling `gather`."
)

if state.backend is not None and state.backend != "gloo":
# We use `empty` as `all_gather_into_tensor` slightly
# differs from `all_gather` for better efficiency,
Expand Down Expand Up @@ -351,6 +344,11 @@ def wrapper(*args, **kwargs):
tensor = kwargs["tensor"]
else:
tensor = args[0]
if PartialState().device.type != find_device(tensor).type:
raise DistributedOperationException(
f"One or more of the tensors passed to {operation} were not on the {tensor.device.type} while the `Accelerator` is configured for {PartialState().device.type}. "
f"Please move it to the {PartialState().device.type} before calling {operation}."
)
shapes = get_shape(tensor)
output = gather_object([shapes])
if output[0] is not None:
Expand Down

0 comments on commit 47e6c36

Please sign in to comment.