-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient bucketing using a pre-defined bucket size cap #6417
Changes from all commits
e6b6122
d2989c6
3a5f92f
3a93868
5ad177d
9dbc2a7
3915b77
7c9b15d
f5ffcd4
05e2367
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import os | ||
import sys | ||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
import torch_xla.distributed.xla_backend | ||
import torch.distributed as dist | ||
|
||
|
||
def _mp_fn(index): | ||
device = xm.xla_device() | ||
if xm.xla_device_hw(device) in ('TPU', 'CUDA'): | ||
world_size = xm.xrt_world_size() | ||
rank = xm.get_ordinal() | ||
|
||
dist.init_process_group('xla', world_size=world_size, rank=rank) | ||
|
||
tensor_list = [ | ||
torch.empty((i, i), device=device) for i in range(1, 1000, 101) | ||
] | ||
for j, t in enumerate(tensor_list): | ||
t.fill_(float(j)) | ||
dist.bucketed_allreduce(tensor_list) | ||
for j, t in enumerate(tensor_list): | ||
assert torch.all(torch.eq(t.cpu(), | ||
float(j) * world_size)) == torch.tensor(True) | ||
else: | ||
print( | ||
'Default device {} is not a TPU or GPU device'.format(device), | ||
file=sys.stderr) | ||
|
||
|
||
if __name__ == '__main__': | ||
xmp.spawn(_mp_fn, args=()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
import torch_xla.debug.metrics_saver as ms | ||
import torch_xla.utils.utils as xu | ||
import torch_xla.utils.closures as xc | ||
import os | ||
|
||
_DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices()) | ||
|
||
|
@@ -1123,6 +1124,36 @@ def wait_device_ops(devices=[]): | |
torch_xla._XLAC._xla_wait_device_ops(devices=devices) | ||
|
||
|
||
def bucketed_allreduce(gradients): | ||
total = 0 | ||
tensor_bucket = [] | ||
|
||
for grad in gradients: | ||
grad_bytes = grad.numel() * grad.element_size() | ||
|
||
# Bucketize till the total spills over | ||
total += grad_bytes | ||
if total > bucket_cap and len(tensor_bucket) > 0: | ||
all_reduce( | ||
REDUCE_SUM, | ||
tensor_bucket, | ||
scale=1.0 / count, | ||
groups=groups, | ||
pin_layout=pin_layout) | ||
total = grad_bytes | ||
tensor_bucket = [] | ||
tensor_bucket.append(grad) | ||
|
||
# Flush the last remaining bucket | ||
if len(tensor_bucket): | ||
all_reduce( | ||
REDUCE_SUM, | ||
tensor_bucket, | ||
scale=1.0 / count, | ||
groups=groups, | ||
pin_layout=pin_layout) | ||
|
||
|
||
def reduce_gradients(optimizer, groups=None, pin_layout=True): | ||
"""Reduces all the gradients handled by an optimizer. | ||
|
||
|
@@ -1140,12 +1171,20 @@ def reduce_gradients(optimizer, groups=None, pin_layout=True): | |
count = xrt_world_size() | ||
if count > 1: | ||
gradients = _fetch_gradients(optimizer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we keep the original behavior? And maybe use a flag to turn this feature on? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK..let me work on that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should introduce an argument "bucket_cap_mb" that turns this on, instead of environmental variable? bucket_cap_mb=0 turns off bucketing and is the default? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
all_reduce( | ||
REDUCE_SUM, | ||
gradients, | ||
scale=1.0 / count, | ||
groups=groups, | ||
pin_layout=pin_layout) | ||
bucket_cap = int(os.getenv('ALLREDUCE_BUCKET_SIZE_MB', 0)) * 1024 * 1024 | ||
# Reverse the gradients list so that we start allreduce from the last layer | ||
# onwards. This allows allreduce to trigger as soon as the bucket fills up and | ||
# overlap with backward pass. | ||
if bucket_cap > 0: | ||
gradients = reversed(gradients) | ||
bucketed_allreduce(gradients) | ||
else: | ||
all_reduce( | ||
REDUCE_SUM, | ||
gradients, | ||
scale=1.0 / count, | ||
groups=groups, | ||
pin_layout=pin_layout) | ||
|
||
|
||
def optimizer_step(optimizer, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe name it similar to the original function all_reduce? How about all_reduce_bucketized?
Also, do you need to pass "groups" and "pin_layout" also?