Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed May 28, 2024
1 parent 91a8083 commit f71df15
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions torch_xla/distributed/zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,7 @@ def allgather_weights_and_update_full_parameter(self):
if param.grad is not None or (self.use_grad_acc_hook and
hasattr(shard, 'main_grad')):
shard_data = shard.data
if not self.higher_cc_precision:
shard_data = shard_data.to(dtype=param.dtype)
shard_data = shard_data.to(dtype=param.dtype)
if self.coalesce_cc_all_gather:
sharded_data.append(shard_data)
else:
Expand All @@ -430,8 +429,6 @@ def allgather_weights_and_update_full_parameter(self):
pin_layout=self.pin_layout,
groups=self.sharding_groups,
)
if padded_param.dtype != param.dtype:
padded_param = padded_param.to(dtype=param.dtype)
param.data.copy_(padded_param.data[:param.size(0)])

if self.coalesce_cc_all_gather:
Expand All @@ -450,8 +447,6 @@ def allgather_weights_and_update_full_parameter(self):
if param.grad is not None or (self.use_grad_acc_hook and
hasattr(shard, 'main_grad')):
padded_param = padded_params[index]
if padded_param.dtype != param.dtype:
padded_param = padded_params[index].to(dtype=param.dtype)
param.data.copy_(padded_param.data[:param.size(0)])
index += 1

Expand Down

0 comments on commit f71df15

Please sign in to comment.