From 8aefb74e02d6083d308a15b4d90309a24e1a093b Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Thu, 26 Oct 2023 20:33:12 +0800 Subject: [PATCH] add flash tflops --- internlm/train/training_internlm.py | 13 +++++++++++++ train.py | 14 ++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index df3fa88d..a4b2e598 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -406,11 +406,13 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None): tgs_list = [] tflops_list = [] +tflops_list_2 = [] @llm_timeout(func_name="record_current_batch_training_metrics") def record_current_batch_training_metrics( get_tflops_func, + get_tflops_func_2, logger, writer, success_update, @@ -495,6 +497,7 @@ def record_current_batch_training_metrics( tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2) tflops = get_tflops_func((time.time() - start_time)) + tflops_2 = get_tflops_func_2((time.time() - start_time)) tgs_origin = round( num_tokens_in_batch @@ -506,6 +509,7 @@ def record_current_batch_training_metrics( infos = { "tflops": tflops, + "tflops2": tflops_2, "step": batch_count, "loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(), "tgs (tokens/gpu/second)": tgs_origin, @@ -599,6 +603,7 @@ def record_current_batch_training_metrics( if batch_count >= 5: tgs_list.append(tgs_origin) tflops_list.append(tflops) + tflops_list_2.append(tflops_2) if batch_count == gpc.config.data.total_steps - 1: print(tgs_list, flush=True) avg_tgs = sum(tgs_list) / len(tgs_list) @@ -606,9 +611,17 @@ def record_current_batch_training_metrics( if abs(tgs - avg_tgs) > 400: tgs_list.remove(tgs) print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True) + print(tflops_list, flush=True) avg_tflops = sum(tflops_list) / len(tflops_list) for tf in tflops_list.copy(): if abs(tf - avg_tflops) > 10: tflops_list.remove(tf) print(f"avg_tflops: {sum(tflops_list)/len(tflops_list)}", flush=True) + + print(tflops_list_2, flush=True) + avg_tflops_2 = sum(tflops_list_2) / len(tflops_list_2) + for tf in tflops_list_2.copy(): + if abs(tf - avg_tflops_2) > 10: + tflops_list_2.remove(tf) + print(f"avg_tflops: {sum(tflops_list_2)/len(tflops_list_2)}", flush=True) diff --git a/train.py b/train.py index f4195964..45117623 100644 --- a/train.py +++ b/train.py @@ -33,6 +33,7 @@ from internlm.utils.common import ( BatchSkipper, get_megatron_flops, + get_megatron_flops_2, launch_time, parse_args, ) @@ -111,6 +112,18 @@ def main(args): global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), mlp_ratio=gpc.config.MLP_RATIO, ) + + get_tflops_func_2 = partial( + get_megatron_flops_2, + checkpoint=gpc.config.model.checkpoint, + seq_len=gpc.config.SEQ_LEN, + hidden_size=gpc.config.model.hidden_size, + num_layers=gpc.config.model.num_layers, + vocab_size=gpc.config.model.vocab_size, + global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA), + global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), + mlp_ratio=gpc.config.MLP_RATIO, + ) # get and broadcast current time current_time = launch_time() @@ -271,6 +284,7 @@ def main(args): # calculate and record the training metrics, eg. loss, accuracy and so on. record_current_batch_training_metrics( get_tflops_func=get_tflops_func, + get_tflops_func_2=get_tflops_func_2, logger=logger, writer=writer, success_update=success_update,