diff --git a/tests/pytorch/nightly/llama2-model.libsonnet b/tests/pytorch/nightly/llama2-model.libsonnet index 51935bafc..8b37a251a 100644 --- a/tests/pytorch/nightly/llama2-model.libsonnet +++ b/tests/pytorch/nightly/llama2-model.libsonnet @@ -115,7 +115,7 @@ local utils = import 'templates/utils.libsonnet'; wget https://storage.googleapis.com/tpu-pytorch/lsiyuan-experiment/llama/spiece.model # git clone and build transformers ### llama/transformers/ - git clone https://github.com/pytorch-tpu/transformers.git + git clone -b lsiyuan/fsdp-data-aug https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 uninstall transformers sudo pip3 install -e . @@ -204,7 +204,7 @@ local utils = import 'templates/utils.libsonnet'; configs: [ llama2_inference + v4_8 + common.Functional + timeouts.Hours(3) + infer, - llama2_training + v4_8 + common.Functional + timeouts.Hours(3) + fsdp, + # llama2_training + v4_8 + common.Functional + timeouts.Hours(3) + fsdp, llama2_training + v4_8 + common.Functional + timeouts.Hours(3) + spmd, ], }