Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient bucketing using a pre-defined bucket size cap #6417

Closed
wants to merge 10 commits into from
13 changes: 7 additions & 6 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,13 @@ function run_mp_op_tests {
run_test "$CDIR/test_fsdp_auto_wrap.py"
run_torchrun "$CDIR/test_mp_early_exit.py"
run_pt_xla_debug "$CDIR/debug_tool/test_mp_pt_xla_debug.py"
run_test "$CDIR/torch_distributed/test_torch_distributed_all_gather_xla_backend.py"
run_test "$CDIR/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py"
run_test "$CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py"
run_test "$CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py"
run_test "$CDIR/torch_distributed/test_ddp.py"
run_test "$CDIR/torch_distributed/test_torch_distributed_fsdp_meta.py"
run_xla_backend_mp "$CDIR/test_torch_distributed_all_gather_xla_backend.py"
run_xla_backend_mp "$CDIR/test_torch_distributed_all_reduce_xla_backend.py"
run_xla_backend_mp "$CDIR/test_torch_distributed_bucketed_all_reduce_xla_backend.py"
run_xla_backend_mp "$CDIR/test_torch_distributed_multi_all_reduce_xla_backend.py"
run_xla_backend_mp "$CDIR/test_torch_distributed_reduce_scatter_xla_backend.py"
run_xla_backend_mp "$CDIR/test_ddp.py"
run_xla_backend_mp "$CDIR/test_torch_distributed_fsdp_meta.py"
}

function run_tests {
Expand Down
35 changes: 35 additions & 0 deletions test/test_torch_distributed_bucketed_all_reduce_xla_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
import sys
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.xla_backend
import torch.distributed as dist


def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

dist.init_process_group('xla', world_size=world_size, rank=rank)

tensor_list = [
torch.empty((i, i), device=device) for i in range(1, 1000, 101)
]
for j, t in enumerate(tensor_list):
t.fill_(float(j))
dist.bucketed_allreduce(tensor_list)
for j, t in enumerate(tensor_list):
assert torch.all(torch.eq(t.cpu(),
float(j) * world_size)) == torch.tensor(True)
else:
print(
'Default device {} is not a TPU or GPU device'.format(device),
file=sys.stderr)


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
51 changes: 45 additions & 6 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch_xla.debug.metrics_saver as ms
import torch_xla.utils.utils as xu
import torch_xla.utils.closures as xc
import os

_DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())

Expand Down Expand Up @@ -1123,6 +1124,36 @@ def wait_device_ops(devices=[]):
torch_xla._XLAC._xla_wait_device_ops(devices=devices)


def bucketed_allreduce(gradients):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe name it similar to the original function all_reduce? How about all_reduce_bucketized?

Also, do you need to pass "groups" and "pin_layout" also?

total = 0
tensor_bucket = []

for grad in gradients:
grad_bytes = grad.numel() * grad.element_size()

# Bucketize till the total spills over
total += grad_bytes
if total > bucket_cap and len(tensor_bucket) > 0:
all_reduce(
REDUCE_SUM,
tensor_bucket,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout)
total = grad_bytes
tensor_bucket = []
tensor_bucket.append(grad)

# Flush the last remaining bucket
if len(tensor_bucket):
all_reduce(
REDUCE_SUM,
tensor_bucket,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout)


def reduce_gradients(optimizer, groups=None, pin_layout=True):
"""Reduces all the gradients handled by an optimizer.

Expand All @@ -1140,12 +1171,20 @@ def reduce_gradients(optimizer, groups=None, pin_layout=True):
count = xrt_world_size()
if count > 1:
gradients = _fetch_gradients(optimizer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep the original behavior? And maybe use a flag to turn this feature on?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK..let me work on that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should introduce an argument "bucket_cap_mb" that turns this on, instead of environmental variable? bucket_cap_mb=0 turns off bucketing and is the default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

all_reduce(
REDUCE_SUM,
gradients,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout)
bucket_cap = int(os.getenv('ALLREDUCE_BUCKET_SIZE_MB', 0)) * 1024 * 1024
# Reverse the gradients list so that we start allreduce from the last layer
# onwards. This allows allreduce to trigger as soon as the bucket fills up and
# overlap with backward pass.
if bucket_cap > 0:
gradients = reversed(gradients)
bucketed_allreduce(gradients)
else:
all_reduce(
REDUCE_SUM,
gradients,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout)


def optimizer_step(optimizer,
Expand Down
Loading