-
Notifications
You must be signed in to change notification settings - Fork 132
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
90 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,6 +101,16 @@ def launch_slurm_job(launch_file_contents, *args): | |
logs_path=f"/fsx/elie_bakouch/nanotron/debug", | ||
conda_path="/fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh", | ||
conda_env_path="/fsx/elie_bakouch/miniconda3/envs/smollm", | ||
exclude_nodes=["ip-26-0-161-138", "ip-26-0-161-178"], | ||
torchrun_args={ | ||
"rdzv_backend": "etcd-v2", | ||
"rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", | ||
"rdzv_id": "$SLURM_JOB_ID" | ||
}, | ||
qos="normal", | ||
mail_type="FAIL", | ||
mail_user="[email protected]", | ||
begin="now+0minutes" | ||
) | ||
|
||
model_config = LlamaConfig( | ||
|
@@ -135,8 +145,6 @@ def launch_slurm_job(launch_file_contents, *args): | |
) | ||
).replace(".", "p") | ||
|
||
print(f"🏋️ Model has {num_params} parameters") | ||
|
||
# Do we have a SLURM task ID? | ||
# You can SLURM_ARRAY_TASK_ID to run multiple runs with predefined HP | ||
task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", -1)) | ||
|
@@ -197,7 +205,6 @@ def launch_slurm_job(launch_file_contents, *args): | |
tp_linear_async_communication=True, | ||
) | ||
#Add sanity check for the number of GPUs and the number of nodes ? | ||
print(f"🤖 {slurm.nodes} Nodes | {parallelism.dp*parallelism.pp*parallelism.tp} GPUs | 3D Config : DP {parallelism.dp} / PP {parallelism.pp} / TP {parallelism.tp}") | ||
|
||
tokens = TokensArgs( | ||
batch_accumulation_per_replica=8, | ||
|
@@ -211,9 +218,6 @@ def launch_slurm_job(launch_file_contents, *args): | |
|
||
total_tokens = tokens.train_steps * GBS | ||
total_tokens_billions = total_tokens / 1e9 | ||
print(f"📙 Number of tokens: {total_tokens_billions:.2f} billion") | ||
|
||
|
||
|
||
model = ModelArgs( | ||
model_config=model_config, | ||
|
@@ -248,15 +252,6 @@ def launch_slurm_job(launch_file_contents, *args): | |
lr_decay_start = learning_rate_scheduler.lr_decay_starting_step | ||
lr_decay_style = learning_rate_scheduler.lr_decay_style | ||
|
||
print(f"📊 Learning Rate Schedule:") | ||
print(f" Initial LR: {lr_initial:.2e}") | ||
print(f" Warmup: {learning_rate_scheduler.lr_warmup_style} increase over {lr_warmup_steps} steps") | ||
if lr_decay_start != lr_warmup_steps: | ||
print(f" Constant LR until step {lr_decay_start}") | ||
print(f" {lr_decay_style.capitalize()} decay from step {lr_decay_start} to {lr_decay_start + lr_decay_steps}") | ||
print(f" Final LR: {lr_min:.2e}") | ||
|
||
print(f"🚚 Global Batch Size: {GBS:,} tokens") | ||
optimizer = OptimizerArgs( | ||
zero_stage=0, | ||
weight_decay=0.01, | ||
|
@@ -311,6 +306,54 @@ def launch_slurm_job(launch_file_contents, *args): | |
lighteval=lighteval, | ||
slurm=slurm, | ||
) | ||
|
||
print(f""" | ||
🏋️ Model Parameters: | ||
┌───────────────────────┬───────────────────────────┐ | ||
│ Total Parameters │ {num_params:>25} │ | ||
│ Layers │ {model_config.num_hidden_layers:>25d} │ | ||
│ Attention Heads │ {model_config.num_attention_heads:>25d} │ | ||
│ Hidden Size │ {model_config.hidden_size:>25d} │ | ||
│ Intermediate Size │ {model_config.intermediate_size:>25d} │ | ||
│ Context Length │ {model_config.max_position_embeddings:>25d} │ | ||
│ Tokenizer │ {tokenizer.tokenizer_name_or_path[:25]:>25} │ | ||
│ Vocab Size │ {model_config.vocab_size:>25d} │ | ||
└───────────────────────┴───────────────────────────┘ | ||
""") | ||
|
||
print(f""" | ||
🤖 Parallelism Configuration: | ||
┌───────────────────────┬───────────────────┐ | ||
│ Nodes │ {slurm.nodes:>17d} │ | ||
│ Total GPUs │ {parallelism.dp*parallelism.pp*parallelism.tp:>17d} │ | ||
│ Data Parallel (DP) │ {parallelism.dp:>17d} │ | ||
│ Pipeline Parallel (PP)│ {parallelism.pp:>17d} │ | ||
│ Tensor Parallel (TP) │ {parallelism.tp:>17d} │ | ||
└───────────────────────┴───────────────────┘ | ||
""") | ||
|
||
print(f""" | ||
📙 Training Configuration: | ||
┌───────────────────────┬───────────────────┐ | ||
│ Total Tokens │ {total_tokens_billions:>16.2f}B │ | ||
│ Global Batch Size │ {GBS:>17,d} │ | ||
│ Batch Size (per GPU) │ {BS:>17,d} │ | ||
└───────────────────────┴───────────────────┘ | ||
""") | ||
|
||
print(f""" | ||
📊 Learning Rate Schedule: | ||
┌───────────────────────┬───────────────────┐ | ||
│ Initial LR │ {lr_initial:>17.2e} │ | ||
│ Warmup Style │ {learning_rate_scheduler.lr_warmup_style[:17]:>17} │ | ||
│ Warmup Steps │ {lr_warmup_steps:>17d} │ | ||
│ Decay Style │ {lr_decay_style[:17]:>17} │ | ||
│ Decay Start Step │ {lr_decay_start:>17d} │ | ||
│ Decay Steps │ {lr_decay_steps:>17d} │ | ||
│ Final LR │ {lr_min:>17.2e} │ | ||
└───────────────────────┴───────────────────┘ | ||
""") | ||
|
||
if slurm is not None: | ||
dir = os.path.dirname(__file__) | ||
|
||
|
@@ -321,37 +364,40 @@ def launch_slurm_job(launch_file_contents, *args): | |
|
||
os.makedirs(f"{config.slurm.slurm_logs_path}/", exist_ok=True) | ||
|
||
#SBATCH --job-name={slurm.job_name} | ||
#SBATCH --nodes={slurm.nodes} | ||
#SBATCH --ntasks-per-node={slurm.n_tasks_per_node} # crucial - only 1 task per dist per node! | ||
#SBATCH --cpus-per-task={slurm.cpus_per_task} | ||
#SBATCH --gres=gpu:{slurm.gpu_per_node} | ||
#SBATCH --partition={slurm.gpu_partition} | ||
#SBATCH --output={slurm.slurm_logs_path}/train-{timestamp}-%x-%j.out | ||
#SBATCH --array={slurm.array} | ||
#SBATCH --qos={slurm.qos} | ||
#SBATCH --begin=now+0minutes | ||
#SBATCH --mail-type=ALL | ||
#SBATCH --mail-user={slurm.mail} | ||
#SBATCH --requeue | ||
def format_sbatch_option(option, value): | ||
return f"#SBATCH --{option}={value}" if value is not None else "" | ||
|
||
torchrun_args = "" | ||
if hasattr(slurm, 'torchrun_args') and slurm.torchrun_args: | ||
torchrun_args = " ".join([f"--{k} {v}" for k, v in slurm.torchrun_args.items()]) | ||
|
||
sbatch_script = f"""#!/bin/bash | ||
#SBATCH --job-name=test | ||
#SBATCH --nodes=2 | ||
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! | ||
#SBATCH --cpus-per-task=32 | ||
#SBATCH --gres=gpu:8 | ||
#SBATCH --partition=hopper-prod | ||
#SBATCH --output=/fsx/elie_bakouch/nanotron/debug/main/train-{timestamp}-%x-%j.out | ||
#SBATCH --qos=high | ||
#SBATCH --begin=now+0minutes | ||
#SBATCH --mail-type=ALL | ||
{format_sbatch_option("job-name", slurm.job_name)} | ||
{format_sbatch_option("nodes", slurm.nodes)} | ||
{format_sbatch_option("ntasks-per-node", slurm.n_tasks_per_node)} | ||
{format_sbatch_option("cpus-per-task", slurm.cpus_per_task)} | ||
{format_sbatch_option("gres", f"gpu:{slurm.gpu_per_node}")} | ||
{format_sbatch_option("partition", slurm.gpu_partition)} | ||
{format_sbatch_option("output", f"{slurm.slurm_logs_path}/train-{timestamp}-%x-%j.out")} | ||
{format_sbatch_option("array", slurm.array)} | ||
{format_sbatch_option("qos", slurm.qos)} | ||
{format_sbatch_option("mail-type", slurm.mail_type)} | ||
{format_sbatch_option("mail-user", slurm.mail_user)} | ||
{format_sbatch_option("exclude", ",".join(slurm.exclude_nodes) if slurm.exclude_nodes else None)} | ||
{format_sbatch_option("time", slurm.time)} | ||
{format_sbatch_option("mem", slurm.mem)} | ||
{format_sbatch_option("constraint", slurm.constraint)} | ||
{format_sbatch_option("account", slurm.account)} | ||
{format_sbatch_option("reservation", slurm.reservation)} | ||
{format_sbatch_option("begin", slurm.begin)} | ||
set -x -e | ||
TRAINER_PYTHON_FILE=/fsx/elie_bakouch/nanotron/run_train.py | ||
nvidia-smi | ||
source ~/.bashrc | ||
source /fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh | ||
conda activate /fsx/elie_bakouch/miniconda3/envs/smollm #Modify this line if you use something different than conda | ||
conda activate {slurm.conda_env_path} #Modify this line if you use something different than conda | ||
#Show some environment variables | ||
|
@@ -387,13 +433,10 @@ def launch_slurm_job(launch_file_contents, *args): | |
$TRAINER_PYTHON_FILE \ | ||
--config-file {config_path_yaml} | ||
" | ||
export LAUNCHER="python -u -m torch.distributed.run \ | ||
--nproc_per_node 8 \ | ||
export LAUNCHER="torchrun \ | ||
--nproc_per_node {slurm.gpu_per_node} \ | ||
--nnodes $COUNT_NODE \ | ||
--rdzv-backend etcd-v2 \ | ||
--rdzv-endpoint etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379 \ | ||
--rdzv-id $SLURM_JOB_ID \ | ||
{torchrun_args} \ | ||
--node_rank $SLURM_PROCID \ | ||
--role $SLURMD_NODENAME: \ | ||
--max_restarts 0 \ | ||
|
@@ -412,5 +455,6 @@ def launch_slurm_job(launch_file_contents, *args): | |
echo "END TIME: $(date)" | ||
""" | ||
""" | ||
|
||
print(f"Slurm job launched with id={launch_slurm_job(sbatch_script)}") |