Skip to content

Commit

Permalink
update launcher.py
Browse files Browse the repository at this point in the history
  • Loading branch information
eliebak committed Aug 25, 2024
1 parent 6b58c25 commit 34b50a6
Showing 1 changed file with 56 additions and 17 deletions.
73 changes: 56 additions & 17 deletions launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def launch_slurm_job(launch_file_contents, *args):
)
).replace(".", "p")

print(f"🏋️ Model has {num_params} parameters")
print(f"🏋️ Model has {num_params} parameters")

# Do we have a SLURM task ID?
# You can SLURM_ARRAY_TASK_ID to run multiple runs with predefined HP
Expand Down Expand Up @@ -206,6 +206,14 @@ def launch_slurm_job(launch_file_contents, *args):
train_steps=100,
val_check_interval=-1,
)
BS = tokens.micro_batch_size*tokens.batch_accumulation_per_replica*tokens.sequence_length
GBS = BS * parallelism.dp

total_tokens = tokens.train_steps * GBS
total_tokens_billions = total_tokens / 1e9
print(f"📙 Number of tokens: {total_tokens_billions:.2f} billion")



model = ModelArgs(
model_config=model_config,
Expand All @@ -232,7 +240,23 @@ def launch_slurm_job(launch_file_contents, *args):
lr_decay_starting_step= 80,
min_decay_lr=0,
)

# Calculate and print learning rate and global batch size information
lr_initial = learning_rate_scheduler.learning_rate
lr_min = learning_rate_scheduler.min_decay_lr
lr_warmup_steps = learning_rate_scheduler.lr_warmup_steps
lr_decay_steps = learning_rate_scheduler.lr_decay_steps
lr_decay_start = learning_rate_scheduler.lr_decay_starting_step
lr_decay_style = learning_rate_scheduler.lr_decay_style

print(f"📊 Learning Rate Schedule:")
print(f" Initial LR: {lr_initial:.2e}")
print(f" Warmup: {learning_rate_scheduler.lr_warmup_style} increase over {lr_warmup_steps} steps")
if lr_decay_start != lr_warmup_steps:
print(f" Constant LR until step {lr_decay_start}")
print(f" {lr_decay_style.capitalize()} decay from step {lr_decay_start} to {lr_decay_start + lr_decay_steps}")
print(f" Final LR: {lr_min:.2e}")

print(f"🚚 Global Batch Size: {GBS:,} tokens")
optimizer = OptimizerArgs(
zero_stage=0,
weight_decay=0.01,
Expand Down Expand Up @@ -262,13 +286,11 @@ def launch_slurm_job(launch_file_contents, *args):
data_stages=[
DatasetStageArgs(
data=DataArgs(
dataset=NanosetDatasetsArgs(
dataset_folder={
"/fsx/elie_bakouch/nanotron/datasets/cosmopedia-v2":0.7,
"/fsx/elie_bakouch/nanotron/datasets/fineweb-edu-dedup":0.3,
},
),
seed=general.seed,
dataset=NanosetDatasetsArgs(
dataset_folder="/fsx/elie_bakouch/nanotron/datasets/cosmopedia-v2",
),
num_loading_workers=0,
seed=general.seed,
),
name="training stage",
start_training_step=1,
Expand Down Expand Up @@ -299,7 +321,6 @@ def launch_slurm_job(launch_file_contents, *args):

os.makedirs(f"{config.slurm.slurm_logs_path}/", exist_ok=True)

sbatch_script = f"""#!/bin/bash
#SBATCH --job-name={slurm.job_name}
#SBATCH --nodes={slurm.nodes}
#SBATCH --ntasks-per-node={slurm.n_tasks_per_node} # crucial - only 1 task per dist per node!
Expand All @@ -313,26 +334,39 @@ def launch_slurm_job(launch_file_contents, *args):
#SBATCH --mail-type=ALL
#SBATCH --mail-user={slurm.mail}
#SBATCH --requeue
sbatch_script = f"""#!/bin/bash
#SBATCH --job-name=test
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=32
#SBATCH --gres=gpu:8
#SBATCH --partition=hopper-prod
#SBATCH --output=/fsx/elie_bakouch/nanotron/debug/main/train-{timestamp}-%x-%j.out
#SBATCH --qos=high
#SBATCH --begin=now+0minutes
#SBATCH --mail-type=ALL
set -x -e
TRAINER_PYTHON_FILE=/fsx/elie_bakouch/nanotron/run_train.py
nvidia-smi
set -x -e
source ~/.bashrc
source /fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh
conda activate {config.slurm.conda_env_path} #Modify this line if you use something different than conda
conda activate /fsx/elie_bakouch/miniconda3/envs/smollm #Modify this line if you use something different than conda
module load cuda/12.1
echo "START TIME: $(date)"
#Show some environment variables
echo python3 version = `python3 --version`
echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")"
echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")"
echo "START TIME: $(date)"
secs_to_human(){{
echo "$(( ${{1}} / 3600 )):$(( (${{1}} / 60) % 60 )):$(( ${{1}} % 60 ))"
}}
start=$(date +%s)
echo "$(date -d @${{start}} "+%Y-%m-%d %H:%M:%S"): ${{SLURM_JOB_NAME}} start id=${{SLURM_JOB_ID}}\n"
# SLURM stuff
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
Expand All @@ -342,6 +376,8 @@ def launch_slurm_job(launch_file_contents, *args):
export TMPDIR=/scratch
export CUDA_DEVICE_MAX_CONNECTIONS="1"
module load cuda/12.1
echo go $COUNT_NODE
echo $HOSTNAMES
Expand All @@ -353,8 +389,11 @@ def launch_slurm_job(launch_file_contents, *args):
"
export LAUNCHER="python -u -m torch.distributed.run \
--nproc_per_node {config.slurm.gpu_per_node} \
--nproc_per_node 8 \
--nnodes $COUNT_NODE \
--rdzv-backend etcd-v2 \
--rdzv-endpoint etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379 \
--rdzv-id $SLURM_JOB_ID \
--node_rank $SLURM_PROCID \
--role $SLURMD_NODENAME: \
--max_restarts 0 \
Expand Down

0 comments on commit 34b50a6

Please sign in to comment.