diff --git a/3.test_cases/17.SM-modelparallelv2/setup_conda_env.sh b/3.test_cases/17.SM-modelparallelv2/setup_conda_env.sh index 0d636b2a..4acb6750 100644 --- a/3.test_cases/17.SM-modelparallelv2/setup_conda_env.sh +++ b/3.test_cases/17.SM-modelparallelv2/setup_conda_env.sh @@ -20,13 +20,14 @@ conda create -p ${ENV_PATH} python=3.10 conda activate ${ENV_PATH} - +# Install OFI nccl conda install "aws-ofi-nccl >=1.7.1,<2.0" packaging --override-channels \ -c https://aws-ml-conda.s3.us-west-2.amazonaws.com \ -c pytorch -c numba/label/dev \ -c nvidia \ -c conda-forge \ +# Install SMP V2 pytorch. We will install SMP with pytorch 2.2 conda install pytorch="2.2.0=sm_py3.10_cuda12.1_cudnn8.9.5_nccl_pt_2.2_tsm_2.2_cuda12.1_0" packaging --override-channels \ -c https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/smp-v2/ \ -c pytorch -c numba/label/dev \ @@ -52,7 +53,7 @@ python -m pip install --no-cache-dir -U \ MAX_JOBS=128 pip install flash-attn==2.3.3 --no-build-isolation -# Install SMDDP wheel +# Install SMDDP RUN SMDDP_WHL="smdistributed_dataparallel-2.2.0-cp310-cp310-linux_x86_64.whl" \ && wget -q https://smdataparallel.s3.amazonaws.com/binary/pytorch/2.2.0/cu121/2024-03-04/${SMDDP_WHL} \