Skip to content

Commit

Permalink
feat(configs): resolve conflicts from merging feat/fstp_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Oct 26, 2023
2 parents 6389984 + aa3840f commit e04a61a
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 9 deletions.
4 changes: 2 additions & 2 deletions configs/13B_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
pack_sample_into_one=True,
total_steps=20,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=10,
empty_cache_and_diag_interval=100,
diag_outlier_ratio=1.1,
)

Expand Down
6 changes: 3 additions & 3 deletions configs/30B_template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
DO_ALERT = False

SEQ_LEN = 4096
JOB_NAME = "7b_train_" + str({micro_bsz}) + "_" + str({sp}) + "_" + str({intern_overlap}) + "_" + str({checkpoint})
JOB_NAME = "30b_train_" + str({micro_bsz}) + "_" + str({sp}) + "_" + str({intern_overlap}) + "_" + str({checkpoint})
HIDDEN_SIZE = 6144
NUM_ATTENTION_HEAD = 48
MLP_RATIO = 8 / 3
Expand Down Expand Up @@ -56,15 +56,15 @@
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
pack_sample_into_one=True,
total_steps=20,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=10,
empty_cache_and_diag_interval=100,
diag_outlier_ratio=1.1,
)

Expand Down
4 changes: 2 additions & 2 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=True,
total_steps=20,
total_steps=50,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
Expand Down Expand Up @@ -163,7 +163,7 @@
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, sp="none", intern_overlap=False),
tensor=dict(size=8, sp="intern", intern_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True),
)

Expand Down
1 change: 1 addition & 0 deletions configs/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

log_name = root_name + "_" + output_file_name[:-3]

print(log_name)
command = f"srun -p llm_s -N 8 -n 64 --ntasks-per-node=8 --gpus-per-task=1 --time=30 python train.py --config {write_file} --profiling 2>&1 | tee ./fstp_logs/{log_name}.log"
process = subprocess.Popen(command, shell=True, executable="/bin/bash")
process.wait()
2 changes: 2 additions & 0 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,8 @@ def broadcast_params(self):
for handle in handles:
handle.wait()

torch.cuda.synchronize()

##################
# FP16 Utilities #
##################
Expand Down
13 changes: 13 additions & 0 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,13 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):

tgs_list = []
tflops_list = []
tflops_list_2 = []


@llm_timeout(func_name="record_current_batch_training_metrics")
def record_current_batch_training_metrics(
get_tflops_func,
get_tflops_func_2,
logger,
writer,
success_update,
Expand Down Expand Up @@ -495,6 +497,7 @@ def record_current_batch_training_metrics(
tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2)

tflops = get_tflops_func((time.time() - start_time))
tflops_2 = get_tflops_func_2((time.time() - start_time))

tgs_origin = round(
num_tokens_in_batch
Expand All @@ -506,6 +509,7 @@ def record_current_batch_training_metrics(

infos = {
"tflops": tflops,
"tflops2": tflops_2,
"step": batch_count,
"loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(),
"tgs (tokens/gpu/second)": tgs_origin,
Expand Down Expand Up @@ -599,16 +603,25 @@ def record_current_batch_training_metrics(
if batch_count >= 5:
tgs_list.append(tgs_origin)
tflops_list.append(tflops)
tflops_list_2.append(tflops_2)
if batch_count == gpc.config.data.total_steps - 1:
print(tgs_list, flush=True)
avg_tgs = sum(tgs_list) / len(tgs_list)
for tgs in tgs_list.copy():
if abs(tgs - avg_tgs) > 400:
tgs_list.remove(tgs)
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)

print(tflops_list, flush=True)
avg_tflops = sum(tflops_list) / len(tflops_list)
for tf in tflops_list.copy():
if abs(tf - avg_tflops) > 10:
tflops_list.remove(tf)
print(f"avg_tflops: {sum(tflops_list)/len(tflops_list)}", flush=True)

print(tflops_list_2, flush=True)
avg_tflops_2 = sum(tflops_list_2) / len(tflops_list_2)
for tf in tflops_list_2.copy():
if abs(tf - avg_tflops_2) > 10:
tflops_list_2.remove(tf)
print(f"avg_tflops_2: {sum(tflops_list_2)/len(tflops_list_2)}", flush=True)
37 changes: 37 additions & 0 deletions internlm/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,43 @@ def get_megatron_flops(
return tflops


def get_megatron_flops_2(
elapsed_time_per_iter,
checkpoint=False,
seq_len=2048,
hidden_size=12,
num_layers=32,
vocab_size=12,
global_batch_size=4,
global_world_size=1,
mlp_ratio=4,
use_swiglu=True,
):
"""
Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf
"""

checkpoint_activations_factor = 4 if checkpoint else 3
flashattn_activations_factor = 4.5 if checkpoint else 3.5

if use_swiglu:
mlp_ratio = mlp_ratio * 3 / 2

flops_per_iteration = (
checkpoint_activations_factor
* (8 + mlp_ratio * 4)
* global_batch_size
* seq_len
* hidden_size**2
* num_layers
+ 4 * global_batch_size * seq_len**2 * hidden_size * num_layers * flashattn_activations_factor
+ 6 * global_batch_size * seq_len * hidden_size * vocab_size
)

tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
return tflops


class DummyProfile:
"""
Dummy Profile.
Expand Down
19 changes: 17 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from internlm.utils.common import (
BatchSkipper,
get_megatron_flops,
get_megatron_flops_2,
launch_time,
parse_args,
)
Expand Down Expand Up @@ -111,6 +112,18 @@ def main(args):
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.MLP_RATIO,
)

get_tflops_func_2 = partial(
get_megatron_flops_2,
checkpoint=gpc.config.model.checkpoint,
seq_len=gpc.config.SEQ_LEN,
hidden_size=gpc.config.model.hidden_size,
num_layers=gpc.config.model.num_layers,
vocab_size=gpc.config.model.vocab_size,
global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.MLP_RATIO,
)

# get and broadcast current time
current_time = launch_time()
Expand Down Expand Up @@ -271,6 +284,7 @@ def main(args):
# calculate and record the training metrics, eg. loss, accuracy and so on.
record_current_batch_training_metrics(
get_tflops_func=get_tflops_func,
get_tflops_func_2=get_tflops_func_2,
logger=logger,
writer=writer,
success_update=success_update,
Expand Down Expand Up @@ -309,8 +323,9 @@ def main(args):

if memory_profiler is not None:
memory_profiler.step()

prof.step()

if batch_count % 2 == 0:
prof.step()

if gpc.fstp_handler is not None:
gpc.fstp_handler.clear_memory_pool()
Expand Down

0 comments on commit e04a61a

Please sign in to comment.