From fdb0f9e62250cf39c6a73494e6b7a8df9ec9c76d Mon Sep 17 00:00:00 2001 From: Amithrajith Mamidala Date: Tue, 16 Nov 2021 02:29:06 +0000 Subject: [PATCH] Gradient bucketing using a pre-defined bucket size cap --- torch_xla/core/xla_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 91c2bf22b3b..284ea919563 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -41,6 +41,8 @@ _ALLREDUCE_BUCKET_CAP_MB = 50 XLA_LIB = Library("xla", "DEF") +# Default bucket size for all-reduce +_ALLREDUCE_BUCKET_CAP_MB = 50 def _init_world_size_ordinal(): @@ -1052,7 +1054,6 @@ def reduce_gradients(optimizer, groups=None, pin_layout=True): gradients = reversed(_fetch_gradients(optimizer)) bucketed_allreduce(gradients) - def optimizer_step(optimizer, barrier=False, optimizer_args={},