diff --git a/3.test_cases/10.FSDP/1.distributed-training.sbatch b/3.test_cases/10.FSDP/1.distributed-training.sbatch index 0e20b147..e76d5129 100755 --- a/3.test_cases/10.FSDP/1.distributed-training.sbatch +++ b/3.test_cases/10.FSDP/1.distributed-training.sbatch @@ -39,11 +39,11 @@ export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 ########################### declare -a TORCHRUN_ARGS=( - --nproc_per_node=$GPUS_PER_NODE \ - --nnodes=$SLURM_JOB_NUM_NODES \ - --rdzv_id=$SLURM_JOB_ID \ - --rdzv_backend=c10d \ - --rdzv_endpoint=$(hostname) \ + --nproc_per_node=$GPUS_PER_NODE + --nnodes=$SLURM_JOB_NUM_NODES + --rdzv_id=$SLURM_JOB_ID + --rdzv_backend=c10d + --rdzv_endpoint=$(hostname) ) export TORCHRUN=./pt_fsdp/bin/torchrun @@ -54,25 +54,31 @@ export TRAIN_SCRIPT=./train.py ############################ declare -a TRAINING_ARGS=( - --max_context_width=4096 \ - --num_key_value_heads=32 \ # 7b: 32 13b: 40 70b: 8 - --llama_intermediate_size=11008 \ # 7b: 11008 13b: 13824 70b: 28672 - --hidden_width=4096 \ # 7b: 4096 13b: 5120 70b: 8192 - --num_layers=32 \ # 7b: 32 13b: 40 70b: 80 - --num_heads=32 \ # 7b: 32 13b: 40 70b: 64 - --model_type=llama_v2 \ - --tokenizer="hf-internal-testing/llama-tokenizer" \ - --checkpoint_freq=5000 \ - --validation_freq=500 \ - --max_steps=5000 \ - --checkpoint_dir=./checkpoints \ - --dataset='c4' \ - --dataset_config_name='en' \ - --resume_from_checkpoint=./checkpoints \ - --train_batch_size=1 \ - --val_batch_size=1 \ - --sharding_strategy="full" \ # https://pytorch.org/docs/stable/fsdp.html + --max_context_width=4096 + --num_key_value_heads=32 # 7b: 32 13b: 40 70b: 8 + --llama_intermediate_size=11008 # 7b: 11008 13b: 13824 70b: 28672 + --hidden_width=4096 # 7b: 4096 13b: 5120 70b: 8192 + --num_layers=32 # 7b: 32 13b: 40 70b: 80 + --num_heads=32 # 7b: 32 13b: 40 70b: 64 + --model_type=llama_v2 + --tokenizer="hf-internal-testing/llama-tokenizer" + --checkpoint_freq=5000 + --validation_freq=500 + --max_steps=5000 + --checkpoint_dir=./checkpoints + --dataset='c4' + --dataset_config_name='en' + --resume_from_checkpoint=./checkpoints + --train_batch_size=1 + --val_batch_size=1 + --sharding_strategy="full" # https://pytorch.org/docs/stable/fsdp.html --offload_activations=1 ) -srun -l ${TORCHRUN} "${TORCHRUN_ARGS[@]}" $TRAIN_SCRIPT "${TRAINING_ARGS[@]}" +AUTO_RESUME="" +if [ -d "/opt/sagemaker_cluster" ]; then + echo "Detected Hyperpod cluster.. enabling --auto-resume=1" + AUTO_RESUME="--auto-resume=1" +fi + +srun ${AUTO_RESUME} -l ${TORCHRUN} "${TORCHRUN_ARGS[@]}" $TRAIN_SCRIPT "${TRAINING_ARGS[@]}"