Skip to content

Commit

Permalink
[shardformer] gpt2 tests fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Aug 10, 2023
1 parent b2766ca commit 00a3634
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ def check_weight(org_model: Module,

if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [
torch.zeros([*sharded_weight.shape]).to(sharded_weight.dtype).to('cuda')
for _ in range(dist.get_world_size(tp_group))
torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group))
]
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
Expand All @@ -234,10 +233,7 @@ def check_grad(org_model: Module,
shard_weight = getattr_(sharded_model, suffix).weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [
torch.zeros([*shard_grad.shape]).to(shard_grad.dtype).to('cuda')
for _ in range(dist.get_world_size(tp_group))
]
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)

Expand Down

0 comments on commit 00a3634

Please sign in to comment.