diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index ef1e9afda447..588d2e5e36a5 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -70,8 +70,6 @@ def evaluate_subset(dataloader: DataLoader): current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) current_rank = dist.get_rank() - print(current_pp_group_ranks) - batch = iter([batch]) outputs = booster.execute_pipeline(batch, @@ -84,11 +82,7 @@ def evaluate_subset(dataloader: DataLoader): if booster.plugin.stage_manager.is_last_stage(): val_loss = outputs["loss"] - #TODO get merged output - #logits = outputs["outputs"].logits - logits = outputs["outputs"][0].logits - logits = logits.repeat((2, 1)) - #### + logits = outputs["outputs"].logits accum_loss.add_(val_loss)