Skip to content

Commit

Permalink
Fix linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
amithrm committed May 29, 2024
1 parent 3915b77 commit 7c9b15d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ def _mp_fn(index):

dist.init_process_group('xla', world_size=world_size, rank=rank)

tensor_list = [torch.empty((i, i), device=device) for i in range(1, 1000, 101)]
tensor_list = [
torch.empty((i, i), device=device) for i in range(1, 1000, 101)
]
for j, t in enumerate(tensor_list):
t.fill_(float(j))
dist.bucketed_allreduce(tensor_list)
for j, t in enumerate(tensor_list):
assert torch.all(torch.eq(t.cpu(), float(j)*world_size)) == torch.tensor(True)
assert torch.all(torch.eq(t.cpu(),
float(j) * world_size)) == torch.tensor(True)
else:
print(
'Default device {} is not a TPU or GPU device'.format(device),
Expand Down
41 changes: 21 additions & 20 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,33 +1126,33 @@ def wait_device_ops(devices=[]):


def bucketed_allreduce(gradients):
total = 0
tensor_bucket = []

for grad in gradients:
grad_bytes = grad.numel() * grad.element_size()

# Bucketize till the total spills over
total += grad_bytes
if total > bucket_cap and len(tensor_bucket) > 0:
all_reduce(
REDUCE_SUM,
tensor_bucket,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout)
total = grad_bytes
tensor_bucket = []
tensor_bucket.append(grad)
total = 0
tensor_bucket = []

# Flush the last remaining bucket
if len(tensor_bucket):
for grad in gradients:
grad_bytes = grad.numel() * grad.element_size()

# Bucketize till the total spills over
total += grad_bytes
if total > bucket_cap and len(tensor_bucket) > 0:
all_reduce(
REDUCE_SUM,
tensor_bucket,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout)
total = grad_bytes
tensor_bucket = []
tensor_bucket.append(grad)

# Flush the last remaining bucket
if len(tensor_bucket):
all_reduce(
REDUCE_SUM,
tensor_bucket,
scale=1.0 / count,
groups=groups,
pin_layout=pin_layout)


def reduce_gradients(optimizer, groups=None, pin_layout=True):
Expand All @@ -1179,6 +1179,7 @@ def reduce_gradients(optimizer, groups=None, pin_layout=True):
gradients = reversed(_fetch_gradients(optimizer))
bucketed_allreduce(gradients)


def optimizer_step(optimizer,
barrier=False,
optimizer_args={},
Expand Down

0 comments on commit 7c9b15d

Please sign in to comment.