diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 720bace8698..8f28ab8a499 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -38,8 +38,6 @@ _ORDINAL = None XLA_LIB = Library("xla", "DEF") -# Default bucket size for all-reduce -_ALLREDUCE_BUCKET_CAP_MB = 50 def _init_world_size_ordinal(): @@ -1171,13 +1169,21 @@ def reduce_gradients(optimizer, groups=None, pin_layout=True): """ count = xrt_world_size() if count > 1: - bucket_cap = int(os.getenv('BUCKET_CAP_MB', - _ALLREDUCE_BUCKET_CAP_MB)) * 1024 * 1024 + gradients = _fetch_gradients(optimizer) + bucket_cap = int(os.getenv('ALLREDUCE_BUCKET_SIZE_MB', 0)) * 1024 * 1024 # Reverse the gradients list so that we start allreduce from the last layer # onwards. This allows allreduce to trigger as soon as the bucket fills up and # overlap with backward pass. - gradients = reversed(_fetch_gradients(optimizer)) - bucketed_allreduce(gradients) + if bucket_cap > 0: + gradients = reversed(gradients) + bucketed_allreduce(gradients) + else: + all_reduce( + REDUCE_SUM, + gradients, + scale=1.0 / count, + groups=groups, + pin_layout=pin_layout) def optimizer_step(optimizer,