Skip to content

Commit

Permalink
Gradient bucketing using a pre-defined bucket size cap
Browse files Browse the repository at this point in the history
  • Loading branch information
amithrm committed Mar 1, 2024
1 parent 3bee0a7 commit fdb0f9e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
_ALLREDUCE_BUCKET_CAP_MB = 50

XLA_LIB = Library("xla", "DEF")
# Default bucket size for all-reduce
_ALLREDUCE_BUCKET_CAP_MB = 50


def _init_world_size_ordinal():
Expand Down Expand Up @@ -1052,7 +1054,6 @@ def reduce_gradients(optimizer, groups=None, pin_layout=True):
gradients = reversed(_fetch_gradients(optimizer))
bucketed_allreduce(gradients)


def optimizer_step(optimizer,
barrier=False,
optimizer_args={},
Expand Down

0 comments on commit fdb0f9e

Please sign in to comment.