From 00a3634fc2ff97e6521dbe5a804d5a132bc72ced Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 10 Aug 2023 16:35:35 +0800 Subject: [PATCH] [shardformer] gpt2 tests fix --- tests/test_shardformer/test_model/_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index e4755256190c..bea0587e646b 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -206,8 +206,7 @@ def check_weight(org_model: Module, if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): sharded_weight_list = [ - torch.zeros([*sharded_weight.shape]).to(sharded_weight.dtype).to('cuda') - for _ in range(dist.get_world_size(tp_group)) + torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group)) ] dist.all_gather(sharded_weight_list, sharded_weight, tp_group) sharded_weight = torch.cat(sharded_weight_list, dim=dim) @@ -234,10 +233,7 @@ def check_grad(org_model: Module, shard_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [ - torch.zeros([*shard_grad.shape]).to(shard_grad.dtype).to('cuda') - for _ in range(dist.get_world_size(tp_group)) - ] + shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) shard_grad = torch.cat(shard_grad_list, dim=dim)