Skip to content

Commit

Permalink
Added ALLREDUCE_BUCKET_SIZE_MB to turn on bucketing for allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
amithrm committed May 29, 2024
1 parent 7c9b15d commit f5ffcd4
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
_ORDINAL = None

XLA_LIB = Library("xla", "DEF")
# Default bucket size for all-reduce
_ALLREDUCE_BUCKET_CAP_MB = 50


def _init_world_size_ordinal():
Expand Down Expand Up @@ -1171,13 +1169,21 @@ def reduce_gradients(optimizer, groups=None, pin_layout=True):
"""
count = xrt_world_size()
if count > 1:
bucket_cap = int(os.getenv('BUCKET_CAP_MB',
_ALLREDUCE_BUCKET_CAP_MB)) * 1024 * 1024
gradients = _fetch_gradients(optimizer)
bucket_cap = int(os.getenv('ALLREDUCE_BUCKET_SIZE_MB', 0)) * 1024 * 1024
# Reverse the gradients list so that we start allreduce from the last layer
# onwards. This allows allreduce to trigger as soon as the bucket fills up and
# overlap with backward pass.
gradients = reversed(_fetch_gradients(optimizer))
bucketed_allreduce(gradients)
if bucket_cap > 0:
gradients = reversed(gradients)
bucketed_allreduce(gradients)
else:
all_reduce(
REDUCE_SUM,
gradients,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout)


def optimizer_step(optimizer,
Expand Down

0 comments on commit f5ffcd4

Please sign in to comment.