Skip to content

Commit

Permalink
Merge pull request #263 from aws-samples/bugfix/jax_sbatch
Browse files Browse the repository at this point in the history
Bugfix/jax sbatch
  • Loading branch information
KeitaW authored Apr 12, 2024
2 parents 65b3d59 + 93c5c9b commit 2736418
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions 3.test_cases/jax/jax.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
GPU_PER_NODE=8
TOTAL_NB_GPUS=$(($SLURM_JOB_NUM_NODES * $GPU_PER_NODE))

CHECKPOINT_DIR=/data/700/$SLURM_JOBID
if [ ! -d ${CHECKPOINT_DIR} ]; then
mkdir -p ${CHECKPOINT_DIR}
fi
# Shared file system and container directories
export SHARED_FS_DIR=/fsx/data
export CONTAINER_DIR=/data

# EFA Flags
export FI_PROVIDER=efa
Expand All @@ -34,12 +33,18 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enabl
export TPU_TYPE=gpu
export TF_FORCE_GPU_ALLOW_GROWTH=true

# Setup and checkpoint directory
# Setup and results directory
export LEAD_NODE=${SLURMD_NODENAME}
export BASE_DIR=${CHECKPOINT_DIR}
export BASE_DIR=${CONTAINER_DIR}/700/$SLURM_JOBID

# Create results directory on shared file system
CHECKPOINT_DIR=/fsx/${BASE_DIR}
mkdir -p ${CHECKPOINT_DIR}/checkpoints
mkdir -p ${CHECKPOINT_DIR}/LOG_DIR


# JAX Configuration
export TRAINING_CONFIG=paxml.tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitsteps
export TRAINING_CONFIG=paxml.tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps
export JAX_FLAGS="--fdl.ICI_MESH_SHAPE=[1,${TOTAL_NB_GPUS},1] --fdl.PERCORE_BATCH_SIZE=32"

srun --container-image /fsx/paxml_jax-0.4.18-1.2.0.sqsh --container-mounts /fsx/data:/data -n ${TOTAL_NB_GPUS} -N ${SLURM_JOB_NUM_NODES} /bin/bash run_paxml.sh
srun --container-image /fsx/paxml_jax-0.4.18-1.2.0.sqsh --container-mounts ${SHARED_FS_DIR}:${CONTAINER_DIR} -n ${TOTAL_NB_GPUS} -N ${SLURM_JOB_NUM_NODES} /bin/bash run_paxml.sh

0 comments on commit 2736418

Please sign in to comment.