From 4e99a7fdbc88e398255d63a9b22854b5ded5deb3 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Tue, 17 Oct 2023 11:30:44 +0800 Subject: [PATCH] feat(train/training_internlm.py): remove abnormal tgs when calculating avg tgs --- internlm/train/training_internlm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 24040a02..cc310a21 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -576,4 +576,8 @@ def record_current_batch_training_metrics( tgs_list.append(tgs_origin) if batch_count == gpc.config.data.total_steps - 1: print(tgs_list, flush=True) + avg_tgs = sum(tgs_list) / len(tgs_list) + for tgs in tgs_list.copy(): + if abs(tgs - avg_tgs) > 1000: + tgs_list.remove(tgs) print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)