Skip to content

Commit

Permalink
[shardformer] fix emerged bugs after updating transformers (#4526)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 authored Aug 29, 2023
1 parent c554b7f commit 0387a47
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
5 changes: 4 additions & 1 deletion colossalai/pipeline/schedule/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def merge_batch(data: List[Any]) -> Any:
merged_data = []
for elem_batch in zip(*flattened_data):
if isinstance(elem_batch[0], torch.Tensor):
merged_data.append(torch.cat(elem_batch, dim=0))
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
merged_data.append(None)
else:
merged_data.append(torch.cat(elem_batch, dim=0))
else:
merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec)
6 changes: 5 additions & 1 deletion tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ def check_output_hidden_state(org_output: Tensor,
sharded_hidden_state = sharded_output.last_hidden_state

if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim)
pipeline_output = sharded_output['outputs']
if isinstance(pipeline_output, List):
sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim)
else:
sharded_hidden_state = pipeline_output.last_hidden_state

assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
Expand Down

0 comments on commit 0387a47

Please sign in to comment.