Skip to content

Commit

Permalink
Update _utils.py to fix merge batch
Browse files Browse the repository at this point in the history
  • Loading branch information
CjhHa1 authored Aug 28, 2023
1 parent 80e9e96 commit 5f4984f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions colossalai/pipeline/schedule/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ def merge_batch(data: List[Any]) -> Any:
tree_spec = None
for d in data:
elems, tree_spec = tree_flatten(d)
flattened_data.append(elems)
flattened_data.append(elems[0])
merged_data = []
for elem_batch in zip(*flattened_data):
for elem_batch in zip(*(fd.items() for fd in flattened_data)):
if isinstance(elem_batch[0], torch.Tensor):
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
merged_data.append(None)
else:
merged_data.append(torch.cat(elem_batch, dim=0))
else:
merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec)
return tree_unflatten([merged_data], tree_spec)

0 comments on commit 5f4984f

Please sign in to comment.