From 5404a8c7aa8e6af7cb5b33b9e7addb8402d04c9f Mon Sep 17 00:00:00 2001
From: Baizhou Zhang <eddiezhang@pku.edu.cn>
Date: Mon, 28 Aug 2023 15:33:27 +0800
Subject: [PATCH] [shardformer] fix emerged bugs after updating transformers

---
 colossalai/pipeline/schedule/_utils.py      | 5 ++++-
 tests/test_shardformer/test_model/_utils.py | 6 +++++-
 2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py
index 3ed9239272f1..5cd934b76822 100644
--- a/colossalai/pipeline/schedule/_utils.py
+++ b/colossalai/pipeline/schedule/_utils.py
@@ -123,7 +123,10 @@ def merge_batch(data: List[Any]) -> Any:
     merged_data = []
     for elem_batch in zip(*flattened_data):
         if isinstance(elem_batch[0], torch.Tensor):
-            merged_data.append(torch.cat(elem_batch, dim=0))
+            if len(elem_batch[0].shape) == 0:    # set loss to None in pipeline outputs
+                merged_data.append(None)
+            else:
+                merged_data.append(torch.cat(elem_batch, dim=0))
         else:
             merged_data.append(list(elem_batch))
     return tree_unflatten(merged_data, tree_spec)
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index 811471bec3c8..803afc48ac09 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -195,7 +195,11 @@ def check_output_hidden_state(org_output: Tensor,
         sharded_hidden_state = sharded_output.last_hidden_state
 
     if stage_manager and stage_manager.is_last_stage():
-        sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim)
+        pipeline_output = sharded_output['outputs']
+        if isinstance(pipeline_output, List):
+            sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim)
+        else:
+            sharded_hidden_state = pipeline_output.last_hidden_state
 
     assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
         f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"