diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 24040a02..cc310a21 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -576,4 +576,8 @@ def record_current_batch_training_metrics( tgs_list.append(tgs_origin) if batch_count == gpc.config.data.total_steps - 1: print(tgs_list, flush=True) + avg_tgs = sum(tgs_list) / len(tgs_list) + for tgs in tgs_list.copy(): + if abs(tgs - avg_tgs) > 1000: + tgs_list.remove(tgs) print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)