Skip to content

Commit

Permalink
ZeRO1: Add bucketting logic to control the size of tensors for all-ga…
Browse files Browse the repository at this point in the history
…ther/reduce-scatter (#6025)

Co-authored-by: Rahul Solanki <[email protected]>
Co-authored-by: guangtai <[email protected]>
Co-authored-by: Amithrajith Mamidala <[email protected]>
  • Loading branch information
4 people authored Mar 22, 2024
1 parent 782f05d commit e75677f
Show file tree
Hide file tree
Showing 4 changed files with 400 additions and 26 deletions.
60 changes: 60 additions & 0 deletions test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,66 @@ def _mp_fn(index):
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input (Bucketized)
# TODO: add support for list input with pin_layout=True and output=None
result_list = xm.all_gather_bucketized(
ordinal_tensors, dim=0, pin_layout=False)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather_bucketized() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input and output!=None (out-of-place) (Bucketized)
# Reuse ordinal_tensors from previous test
output_tensors = [
torch.zeros([world_size], dtype=torch.float).to(device)
for i in range(input_list_size)
]
# TODO: add support for list input with pin_layout=True and output!=None
result_list = xm.all_gather_bucketized(
ordinal_tensors, dim=0, output=output_tensors, pin_layout=False)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input and output!=None (out-of-place) (Bucketized, zero bucket size)
# Reuse ordinal_tensors from previous test
output_tensors = [
torch.zeros([world_size], dtype=torch.float).to(device)
for i in range(input_list_size)
]
# TODO: add support for list input with pin_layout=True and output!=None
result_list = xm.all_gather_bucketized(
ordinal_tensors,
dim=0,
output=output_tensors,
pin_layout=False,
bucket_cap_mb=0)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# TODO: add test for torch.compile when support for list input is ready

else:
Expand Down
87 changes: 87 additions & 0 deletions test/test_mp_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,33 @@ def _mp_fn(index):

xm.rendezvous('test_reduce_scatter_list_input')

# Testing reduce-scatter with list input bucketized
rand_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xrand_list = [rand.to(device) for rand in rand_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
pin_layout=False)

for i, res in enumerate(res_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_bucketized')

# Testing reduce-scatter with list input and output
output_list = [
torch.rand((32, shard_size * world_size, 32))
Expand Down Expand Up @@ -83,6 +110,66 @@ def _mp_fn(index):
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_output')

# Testing reduce-scatter with list input and output (buckettized)
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
pin_layout=False)

assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous('test_reduce_scatter_list_input_output_bucketized')

# Testing reduce-scatter with list input and output (buckettized, but zero bucket size)
output_list = [
torch.rand((32, shard_size * world_size, 32))
for _ in range(input_list_size)
]
xoutput_list = [output.to(device) for output in output_list]

# TODO: fix the broken case with pin_layout=True
res_list = xm.reduce_scatter_bucketized(
xm.REDUCE_SUM,
xrand_list,
scale,
scatter_dim,
world_size,
output=xoutput_list,
bucket_cap_mb=0,
pin_layout=False)

assert (xoutput_list == res_list)
for i, res in enumerate(xoutput_list):
expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale)
xm.mark_step()

slice_idx = torch.tensor(
list(range(index * shard_size, (index + 1) * shard_size)))
expected = expected_world.cpu().index_select(scatter_dim, slice_idx)
assert res.cpu().allclose(expected)

xm.rendezvous(
'test_reduce_scatter_list_input_output_bucketized, zero bucket size')
else:
print(
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
Expand Down
148 changes: 148 additions & 0 deletions torch_xla/core/xla_model.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,110 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
f"given {type(value)}.")


class CoalescingBuckets(object):

def __init__(self, func, input_list, output_list=None, bucket_cap_mb=160):
if not isinstance(input_list, list) or any(
not isinstance(v, torch.Tensor) for v in input_list):
raise TypeError(
f"`input_list` needs to be a list of Tensors, but given {type(input_list)}."
)
if output_list != None:
if not isinstance(output_list, list) or any(
not isinstance(v, torch.Tensor) for v in output_list):
raise TypeError(
f"`output_list` needs to be a list of Tensors, but given {type(output_list)}."
)
if len(output_list) != len(input_list):
raise ValueError(
"`output_list` length doesn't match `input_list` length: "
f"{len(output_list)} vs {len(input_list)}.")
self._func = func
self._input_list = input_list
self._output_list = output_list
self._total = 0
self._tensor_bucket = []
self._output_bucket = [] if output_list else None
self._bucket_cap = bucket_cap_mb * 1024 * 1024
self._out_tensors = []

def flush(self):
if len(self._tensor_bucket) == 1:
# Use non-coalesced CCOp if its just one tensor
output = self._output_bucket[0] if self._output_bucket else None
self._out_tensors.append(self._func(self._tensor_bucket[0], output))
elif len(self._tensor_bucket):
self._out_tensors.extend(
self._func(self._tensor_bucket, self._output_bucket))
self._total = 0
self._tensor_bucket = []
self._output_bucket = [] if self._output_list else None

def add(self, tensor, idx):
self._total += tensor.numel() * tensor.element_size()
self._tensor_bucket.append(tensor)
if self._output_list != None:
self._output_bucket.append(self._output_list[idx])

def __call__(self):
for idx, tensor in enumerate(self._input_list):
tensor_bytes = tensor.numel() * tensor.element_size()

# Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content
# is small (1/2 cap) but don't combine if combined total is over 2x cap
total_new = self._total + tensor_bytes
if tensor_bytes > self._bucket_cap and self._total < 0.5 * self._bucket_cap and total_new <= 2 * self._bucket_cap:
self.add(tensor, idx)
self.flush()
else:
# Bucketize till the total spills over
if total_new > self._bucket_cap:
self.flush()
self.add(tensor, idx)

# Flush the last remaining bucket
self.flush()

assert len(self._out_tensors) == len(self._input_list)

return self._out_tensors


def all_gather_bucketized(input_list,
dim=0,
groups=None,
output=None,
pin_layout=False,
bucket_cap_mb=160):
"""Performs an all-gather operation along a given dimension, with bucketization.
Args:
See all_gather for the args: dim, groups, output, pin_layout
input_list: List of input tensors
bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather.
Returns:
A list of tensors each of which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
# sanity checks
if pin_layout:
raise RuntimeError(
"For xm.all_gather_bucketized, pin_layout=True is not yet supported.")

def _all_gather_coalesced(_input_list, _output_list=None):
return all_gather(
value=_input_list,
dim=dim,
groups=groups,
output=_output_list,
pin_layout=pin_layout)

buckets = CoalescingBuckets(
_all_gather_coalesced, input_list, output, bucket_cap_mb=bucket_cap_mb)
return buckets()


def all_to_all(value,
split_dimension,
concat_dimension,
Expand Down Expand Up @@ -847,6 +951,50 @@ def reduce_scatter(reduce_type,
f"given {type(input)}.")


def reduce_scatter_bucketized(reduce_type,
input_list,
scale,
scatter_dim,
shard_count,
groups=None,
output=None,
pin_layout=False,
bucket_cap_mb=160):
"""Performs a XLA `ReduceScatter()` operation on a list of tensors (bucketized).
See: https://www.tensorflow.org/xla/operation_semantics#reducescatter
Args:
see reduce_scatter for reduce_type, scale, scatter_dim, shard_count, groups, pin_layout
input_list: List of input tensors
output: Optional list of output torch.Tensor
bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather.
Returns:
A list of `torch.Tensors` with all the values reduced across replicas. Each process
gets a shard split along the `scatter_dim`. All other dimensions are
the same as the input.
"""

def _reduce_scatter_coalesced(_input_list, _output_list=None):
return reduce_scatter(
reduce_type=reduce_type,
input=_input_list,
scale=scale,
scatter_dim=scatter_dim,
shard_count=shard_count,
groups=groups,
output=_output_list,
pin_layout=pin_layout)

buckets = CoalescingBuckets(
_reduce_scatter_coalesced,
input_list,
output,
bucket_cap_mb=bucket_cap_mb)
return buckets()


def add_step_closure(closure, args=(), run_async=False):
"""Adds a closure to the list of the ones to be run at the end of the step.
Expand Down
Loading

0 comments on commit e75677f

Please sign in to comment.