Skip to content

Commit

Permalink
add flash tflops
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Oct 26, 2023
1 parent 4d83e10 commit 8aefb74
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
13 changes: 13 additions & 0 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,13 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):

tgs_list = []
tflops_list = []
tflops_list_2 = []


@llm_timeout(func_name="record_current_batch_training_metrics")
def record_current_batch_training_metrics(
get_tflops_func,
get_tflops_func_2,
logger,
writer,
success_update,
Expand Down Expand Up @@ -495,6 +497,7 @@ def record_current_batch_training_metrics(
tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2)

tflops = get_tflops_func((time.time() - start_time))
tflops_2 = get_tflops_func_2((time.time() - start_time))

tgs_origin = round(
num_tokens_in_batch
Expand All @@ -506,6 +509,7 @@ def record_current_batch_training_metrics(

infos = {
"tflops": tflops,
"tflops2": tflops_2,
"step": batch_count,
"loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(),
"tgs (tokens/gpu/second)": tgs_origin,
Expand Down Expand Up @@ -599,16 +603,25 @@ def record_current_batch_training_metrics(
if batch_count >= 5:
tgs_list.append(tgs_origin)
tflops_list.append(tflops)
tflops_list_2.append(tflops_2)
if batch_count == gpc.config.data.total_steps - 1:
print(tgs_list, flush=True)
avg_tgs = sum(tgs_list) / len(tgs_list)
for tgs in tgs_list.copy():
if abs(tgs - avg_tgs) > 400:
tgs_list.remove(tgs)
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)

print(tflops_list, flush=True)
avg_tflops = sum(tflops_list) / len(tflops_list)
for tf in tflops_list.copy():
if abs(tf - avg_tflops) > 10:
tflops_list.remove(tf)
print(f"avg_tflops: {sum(tflops_list)/len(tflops_list)}", flush=True)

print(tflops_list_2, flush=True)
avg_tflops_2 = sum(tflops_list_2) / len(tflops_list_2)
for tf in tflops_list_2.copy():
if abs(tf - avg_tflops_2) > 10:
tflops_list_2.remove(tf)
print(f"avg_tflops: {sum(tflops_list_2)/len(tflops_list_2)}", flush=True)
14 changes: 14 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from internlm.utils.common import (
BatchSkipper,
get_megatron_flops,
get_megatron_flops_2,
launch_time,
parse_args,
)
Expand Down Expand Up @@ -111,6 +112,18 @@ def main(args):
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.MLP_RATIO,
)

get_tflops_func_2 = partial(
get_megatron_flops_2,
checkpoint=gpc.config.model.checkpoint,
seq_len=gpc.config.SEQ_LEN,
hidden_size=gpc.config.model.hidden_size,
num_layers=gpc.config.model.num_layers,
vocab_size=gpc.config.model.vocab_size,
global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.MLP_RATIO,
)

# get and broadcast current time
current_time = launch_time()
Expand Down Expand Up @@ -271,6 +284,7 @@ def main(args):
# calculate and record the training metrics, eg. loss, accuracy and so on.
record_current_batch_training_metrics(
get_tflops_func=get_tflops_func,
get_tflops_func_2=get_tflops_func_2,
logger=logger,
writer=writer,
success_update=success_update,
Expand Down

0 comments on commit 8aefb74

Please sign in to comment.