From 275e393d60ab278f5f38e15789318e09ab4276b2 Mon Sep 17 00:00:00 2001 From: guangtai Date: Thu, 9 May 2024 17:03:09 -0700 Subject: [PATCH] use bf16 allgather --- torch_xla/distributed/zero_redundancy_optimizer.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 7e83ef5ccde..63fb7ad8df2 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -411,8 +411,7 @@ def allgather_weights_and_update_full_parameter(self): if param.grad is not None or (self.use_grad_acc_hook and hasattr(shard, 'main_grad')): shard_data = shard.data - if not self.higher_cc_precision: - shard_data = shard_data.to(dtype=param.dtype) + shard_data = shard_data.to(dtype=param.dtype) if self.coalesce_cc: sharded_data.append(shard_data) else: @@ -422,8 +421,6 @@ def allgather_weights_and_update_full_parameter(self): pin_layout=self.pin_layout, groups=self.sharding_groups, ) - if padded_param.dtype != param.dtype: - padded_param = padded_param.to(dtype=param.dtype) param.data.copy_(padded_param.data[:param.size(0)]) if self.coalesce_cc: @@ -441,8 +438,6 @@ def allgather_weights_and_update_full_parameter(self): if param.grad is not None or (self.use_grad_acc_hook and hasattr(shard, 'main_grad')): padded_param = padded_params[index] - if padded_param.dtype != param.dtype: - padded_param = padded_params[index].to(dtype=param.dtype) param.data.copy_(padded_param.data[:param.size(0)]) index += 1