Skip to content

Commit

Permalink
Make output end up on all GPUs at the end (#2423)
Browse files Browse the repository at this point in the history
* Make output end up on the cpu at the end

* Rework a bit

* Remove the CPU part

* Update to include a new util to copy tensors across devices

* Update test

* Update doc

* Update docstring

* Make False by default and change if community feedback says yes

* Apply suggestions from code review

Co-authored-by: Marc Sun <[email protected]>

* Update default to False in doc and make a tip

* Update typing

* Defaults

* Explain

---------

Co-authored-by: Marc Sun <[email protected]>
  • Loading branch information
muellerzr and SunMarc authored Feb 9, 2024
1 parent 86228e3 commit 9467a62
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 15 deletions.
11 changes: 9 additions & 2 deletions docs/source/usage_guides/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,20 @@ with torch.no_grad():
output = model(*args)
```

When finished, all the data will be on the last GPU, which you can use the [`PartialState`] to find and extract:
When finished all the data will be on the last process only:

```{python}
from accelerate import PartialState
if PartialState().is_last_process:
print(output)
```

<Tip>

If you pass in `gather_output=True` to [`inference.prepare_pippy`], the output will be sent
across to all the GPUs afterwards without needing the `is_last_process` check. This is
`False` by default as it incurs a communication call.

</Tip>

And that's it! To explore more, please check out the examples in [this repository](https://github.com/muellerzr/pippy-device-map-playground/) and our documentation as we work to improving this integration.
33 changes: 20 additions & 13 deletions src/accelerate/inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
from types import MethodType
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Tuple, Union

from .state import PartialState
from .utils import (
calculate_maximum_sizes,
convert_bytes,
copy_tensor_to_devices,
ignorant_find_batch_size,
infer_auto_device_map,
is_pippy_available,
Expand Down Expand Up @@ -82,7 +83,7 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks):
return stage


def pippy_forward(forward, num_chunks, *args, **kwargs):
def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
state = PartialState()
output = None

Expand All @@ -101,37 +102,43 @@ def pippy_forward(forward, num_chunks, *args, **kwargs):
output = forward()
else:
forward()
if gather_output:
# Each node will get a copy of the full output which is only on the last GPU
output = copy_tensor_to_devices(output)
return output


def prepare_pippy(
model,
split_points="auto",
no_split_module_classes=None,
example_args=(),
split_points: Optional[Union[str, List[str]]] = "auto",
no_split_module_classes: Optional[List[str]] = None,
example_args: Optional[Tuple[Any]] = (),
example_kwargs: Optional[Dict[str, Any]] = None,
num_chunks=None,
num_chunks: Optional[int] = None,
gather_output: Optional[bool] = False,
):
"""
Wraps `model` for PipelineParallelism
Wraps `model` for pipeline parallel inference.
Args:
model (`torch.nn.Module`):
A model we want to split for pipeline-parallel inference
split_points (`str`, defaults to 'auto'):
split_points (`str` or `List[str]`, defaults to 'auto'):
How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
split given any model.
split given any model. Should be a list of layer names in the model to split by otherwise.
no_split_module_classes (`List[str]`):
A list of class names for layers we don't want to be split.
example_args (tuple of `torch.Tensor`):
example_args (tuple of model inputs):
The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
example_kwargs (dict of `torch.Tensor`)
example_kwargs (dict of model inputs)
The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
is true for all cases.
num_chunks (`int`):
num_chunks (`int`, defaults to the number of available GPUs):
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
gather_output (`bool`, defaults to `False`):
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
"""
if not is_pippy_available():
raise ImportError(
Expand All @@ -156,7 +163,7 @@ def prepare_pippy(
model.hf_split_points = split_points

def forward(*args, **kwargs):
return pippy_forward(stage.forward, num_chunks, *args, **kwargs)
return pippy_forward(stage.forward, num_chunks, gather_output, *args, **kwargs)

# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
# Note: creates an infinite recursion loop with `generate`
Expand Down
14 changes: 14 additions & 0 deletions src/accelerate/test_utils/scripts/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from accelerate.utils.operations import (
DistributedOperationException,
broadcast,
copy_tensor_to_devices,
gather,
gather_object,
pad_across_processes,
Expand Down Expand Up @@ -129,6 +130,17 @@ def test_op_checker(state):
state.debug = False


def test_copy_tensor_to_devices(state):
if state.distributed_type not in [DistributedType.MULTI_GPU, DistributedType.TPU]:
return
if state.is_main_process:
tensor = torch.tensor([1, 2, 3], dtype=torch.int).to(state.device)
else:
tensor = None
tensor = copy_tensor_to_devices(tensor)
assert torch.allclose(tensor, torch.tensor([1, 2, 3], dtype=torch.int, device="cuda"))


def _mp_fn(index):
# For xla_spawn (TPUs)
main()
Expand All @@ -153,6 +165,8 @@ def main():
test_reduce_mean(state)
state.print("testing op_checker")
test_op_checker(state)
state.print("testing sending tensors across devices")
test_copy_tensor_to_devices(state)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
concatenate,
convert_outputs_to_fp32,
convert_to_fp32,
copy_tensor_to_devices,
find_batch_size,
find_device,
gather,
Expand Down
58 changes: 58 additions & 0 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,64 @@ def _tpu_broadcast(tensor, src=0, name="broadcast tensor"):
return xm.mesh_reduce(name, tensor, lambda x: x[src])


TENSOR_TYPE_TO_INT = {
torch.float: 1,
torch.double: 2,
torch.half: 3,
torch.bfloat16: 4,
torch.uint8: 5,
torch.int8: 6,
torch.int16: 7,
torch.int32: 8,
torch.int64: 9,
torch.bool: 10,
}

TENSOR_INT_TO_DTYPE = {v: k for k, v in TENSOR_TYPE_TO_INT.items()}


def gather_tensor_shape(tensor):
"""
Grabs the shape of `tensor` only available on one process and returns a tensor of its shape
"""
# Allocate 80 bytes to store the shape
max_tensor_dimension = 2**20
state = PartialState()
base_tensor = torch.empty(max_tensor_dimension, dtype=torch.int, device=state.device)

# Since PyTorch can't just send a tensor to another GPU without
# knowing its size, we store the size of the tensor with data
# in an allocation
if tensor is not None:
shape = tensor.shape
tensor_dtype = TENSOR_TYPE_TO_INT[tensor.dtype]
base_tensor[: len(shape) + 1] = torch.tensor(list(shape) + [tensor_dtype], dtype=int)
# Perform a reduction to copy the size data onto all GPUs
base_tensor = reduce(base_tensor, reduction="sum")
base_tensor = base_tensor[base_tensor.nonzero()]
# The last non-zero data contains the coded dtype the source tensor is
dtype = int(base_tensor[-1:][0])
base_tensor = base_tensor[:-1]
return base_tensor, dtype


def copy_tensor_to_devices(tensor=None) -> torch.Tensor:
"""
Copys a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
each worker doesn't need to know its shape when used (and tensor can be `None`)
Args:
tensor (`torch.tensor`):
The tensor that should be sent to all devices. Must only have it be defined on a single device, the rest
should be `None`.
"""
state = PartialState()
shape, dtype = gather_tensor_shape(tensor)
if tensor is None:
tensor = torch.zeros(shape, dtype=TENSOR_INT_TO_DTYPE[dtype]).to(state.device)
return reduce(tensor, reduction="sum")


@verify_operation
def broadcast(tensor, from_process: int = 0):
"""
Expand Down

0 comments on commit 9467a62

Please sign in to comment.