Skip to content

Commit

Permalink
Add all-gather/reduce-scatter coalescee for FSDP.
Browse files Browse the repository at this point in the history
Allow using reduce-scatter's scale param in FSDP.
  • Loading branch information
hjm-aws authored and jeffhataws committed Oct 12, 2023
1 parent ed7879d commit 78fc6d3
Show file tree
Hide file tree
Showing 15 changed files with 670 additions and 275 deletions.
84 changes: 74 additions & 10 deletions test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ def test_allgather(self):
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
hlo_matches(hlo, all_gather_pattern)

def test_allgather_coalesced(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank()
pg_xla = get_process_group_xla(rank=3, size=8)
output_tensors = [torch.zeros_like(tensor)] * 8
output_tensors2 = [torch.zeros_like(tensor2)] * 8
# because we set os.environ[xenv.WORLD_SIZE] = '1', here the outputs'
# shapes will be same as the inputs' shapes.
all_gather_pattern = (
r'%all-gather\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) '
r'all-gather\(s64\[2]\{0} %.+\.\d+, s64\[5]\{0} %.+\.\d+, '
r's64\[] %.+\.\d+\)')
pg_xla.allgather_coalesced([output_tensors, output_tensors2],
[tensor, tensor2])
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
hlo_matches(hlo, all_gather_pattern)
# purge all computations attached the device.
xm.mark_step()

def test_broadcast(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
Expand All @@ -102,9 +122,40 @@ def test_reduce_scatter(self):
input_list = [tensor]
output = torch.zeros_like(tensor)
reduce_scatter_pattern = r'%reduce\-scatter\.\d+ = .+ reduce\-scatter\('
dist.reduce_scatter(output, input_list)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
hlo_matches(hlo, reduce_scatter_pattern)
with self.assertRaises(RuntimeError) as cm:
dist.reduce_scatter(output, input_list)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
hlo_matches(hlo, reduce_scatter_pattern)
# purge all computations attached the device.
xm.mark_step()
assert 'UNIMPLEMENTED: ReduceScatter is not implemented on CPU.' in str(
cm.exception), str(cm.exception)
# reset token to clean up the mess after the RuntimeError.
xm.set_replication(device, [])

def test_reduce_scatter_coalesced(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank()
input_tensors_list = [[tensor, tensor], [tensor2, tensor2]]
output_list = [torch.zeros_like(tensor), torch.zeros_like(tensor2)]
pg_xla = get_process_group_xla(rank=0, size=len(input_tensors_list[0]))
opts = dist.ReduceScatterOptions()
opts.reduceOp = dist.ReduceOp.SUM
reduce_scatter_pattern = (
r'%reduce-scatter\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) '
r'reduce-scatter\(s64\[4]\{0} %.+\.\d+, s64\[10]\{0} %.+\.\d+, '
r's64\[] %.+\.\d+\)')
with self.assertRaises(RuntimeError) as cm:
pg_xla.reduce_scatter_coalesced(output_list, input_tensors_list, opts)
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list)
hlo_matches(hlo, reduce_scatter_pattern)
# purge all computations attached the device.
xm.mark_step()
assert 'UNIMPLEMENTED: ReduceScatter is not implemented on CPU.' in str(
cm.exception), str(cm.exception)
# reset token to clean up the mess after the RuntimeError.
xm.set_replication(device, [])

@patch_world(0, 6)
def test_send(self):
Expand All @@ -120,9 +171,16 @@ def test_send(self):

send_pattern = r'%send\.\d+ = .+ send\(.+\), channel_id=2'
senddone_pattern = r'%send\-done\.\d+ = .+ send\-done\(.+\), channel_id=2'
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
hlo_matches(hlo, send_pattern)
hlo_matches(hlo, senddone_pattern)
# seeing 'Send is not implemented on CPU' means we have successfully
# generated `send` in the HLO.
with self.assertRaises(RuntimeError) as cm:
pg_xla.send(input_list, 1)
hlo = torch_xla._XLAC._get_xla_tensors_hlo(input_list)
hlo_matches(hlo, send_pattern)
hlo_matches(hlo, senddone_pattern)
xm.mark_step()
assert 'UNIMPLEMENTED: Send is not implemented on CPU.' in str(
cm.exception), str(cm.exception)

# Don't try to run Send on CPU because it's not implemented
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
Expand All @@ -140,9 +198,16 @@ def test_recv(self):

recv_pattern = r'%recv\.\d+ = .+ recv\(.+\), channel_id=3'
recvdone_pattern = r'%recv\-done\.\d+ = .+ recv\-done\(.+\), channel_id=3'
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
hlo_matches(hlo, recv_pattern)
hlo_matches(hlo, recvdone_pattern)
# seeing 'recv is not implemented on CPU' means we have successfully
# generated `recv` in the HLO.
with self.assertRaises(RuntimeError) as cm:
pg_xla.recv(output_list, 1)
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list)
hlo_matches(hlo, recv_pattern)
hlo_matches(hlo, recvdone_pattern)
xm.mark_step()
assert 'UNIMPLEMENTED: Recv is not implemented on CPU.' in str(
cm.exception), str(cm.exception)

# Don't try to run Recv on CPU because it's not implemented
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
Expand Down Expand Up @@ -291,7 +356,6 @@ def test_barrier(self):

@parameterized.parameters(
'reduce',
'allgather_coalesced',
'allreduce_coalesced',
'alltoall',
'alltoall_base',
Expand Down
93 changes: 64 additions & 29 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
"""Performs an all-gather operation along a given dimension.
Args:
value (torch.Tensor): The input tensor.
value (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then
it will also be the output.
dim (int): The gather dimension.
Default: 0
groups (list, optional): A list of list, representing the replica groups for
Expand Down Expand Up @@ -560,17 +561,29 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
shard_count = xrt_world_size()

token, devctx = _get_all_reduce_token()
if output != None:
# Call the out of place version of the all_gather
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or
[], pin_layout)
return result
if isinstance(value, torch.Tensor):
if output != None:
# Call the out of place version of the all_gather
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or
[], pin_layout)
return result

# Now the input should be a list of Tensors.
if not isinstance(value, list) or any(
not isinstance(v, torch.Tensor) for v in value):
raise TypeError("`value` needs to be a Tensor or a list of Tensors, but "
f"given {type(value)}.")
result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]


def all_to_all(value,
Expand Down Expand Up @@ -716,16 +729,18 @@ def reduce_scatter(reduce_type,
reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``,
``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and
``xm.REDUCE_MAX``.
input: A single `torch.Tensor` all reduce + scatter op to.
input: (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then
it will also be the output.
scale (float): A default scaling value to be applied after the reduce.
scatter_dim (int): Dimension number to which apply scatter operation.
shard_count (int): The number of ways to split up the scatter_dim in.
groups (list): A list of list, representing the replica groups for
the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
the `reduce_scatter()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
output: Optional output tensor
output: Optional output tensor if `input` is a torch.Tensor or a list of
torch.Tensor if `input` is a list of torch.Tensor.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
Expand All @@ -738,21 +753,41 @@ def reduce_scatter(reduce_type,
the same as the input.
"""
token, devctx = _get_all_reduce_token()
if isinstance(input, torch.Tensor):
if output != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_out(
reduce_type, output, input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token,
scale, scatter_dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]

# Now the input should be a list of Tensors.
if not isinstance(input, list) or any(
not isinstance(v, torch.Tensor) for v in input):
raise TypeError("`input` needs to be a Tensor or a list of Tensors, but "
f"given {type(input)}.")
if output != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_out(reduce_type, output,
input, token, scale,
scatter_dim,
shard_count, groups or
[], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale,
scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]
if not isinstance(output, list) or any(
not isinstance(v, torch.Tensor) for v in output):
raise TypeError(
f"`output` needs to be a list of Tensors, but given {type(output)}.")
if len(output) != len(input):
raise ValueError("`output` length doesn't match `input` length: "
f"{len(output)} vs {len(input)}.")

result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
reduce_type, output or [], input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]


def add_step_closure(closure, args=(), run_async=False):
Expand Down
Loading

0 comments on commit 78fc6d3

Please sign in to comment.