Skip to content

Commit

Permalink
Update llama2-model.libsonnet with training using SPMD (GoogleCloudPl…
Browse files Browse the repository at this point in the history
…atform#979)

* Update llama2-model.libsonnet with training using SPMD
  • Loading branch information
ManfeiBai authored Sep 4, 2023
1 parent fad390b commit 3fbf2a0
Showing 1 changed file with 72 additions and 18 deletions.
90 changes: 72 additions & 18 deletions tests/pytorch/nightly/llama2-model.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ local utils = import 'templates/utils.libsonnet';
local llama2_inference = self.llama2_inference,
llama2_inference:: common.PyTorchTest {
local config = self,
modelName: 'l2-i',
modelName: 'llama2-i',
paramsOverride:: {
scriptPath: 'llama/7B/llama2inference.sh',
trainCommand: [
Expand All @@ -35,19 +35,19 @@ local utils = import 'templates/utils.libsonnet';
local llama2_training = self.llama2_training,
llama2_training:: common.PyTorchTest {
local config = self,
modelName: 'l2-t',
modelName: 'llama2-t',
paramsOverride:: {
scriptPath: 'llama/transformers/7B/llama2training.sh',
scriptPath: 'transformers/7B/llama2training.sh',
trainCommand: [
'bash',
self.scriptPath,
],
},
command: self.paramsOverride.trainCommand,
},
local pjrt = self.pjrt,
pjrt:: common.PyTorchTpuVmMixin {
modelName+: '-n-i',
local infer = self.infer,
infer:: common.PyTorchTpuVmMixin {
modelName+: '-infer',
tpuSettings+: {
tpuVmExtraSetup: |||
pip3 uninstall torch torch_xla torchvision libtpu-nightly -y
Expand Down Expand Up @@ -92,9 +92,9 @@ local utils = import 'templates/utils.libsonnet';
|||,
},
},
local hf = self.hf,
hf:: common.PyTorchTpuVmMixin {
modelName+: '-h-f',
local fsdp = self.fsdp,
fsdp:: common.PyTorchTpuVmMixin {
modelName+: '-train-fsdp',
tpuSettings+: {
tpuVmExtraSetup: |||
pip3 uninstall torch torch_xla torchvision libtpu-nightly -y
Expand All @@ -114,12 +114,6 @@ local utils = import 'templates/utils.libsonnet';
# install tokenizer model
wget https://storage.googleapis.com/tpu-pytorch/lsiyuan-experiment/llama/spiece.model
# git clone and build llama
git clone --branch llama2-google-next-inference https://github.com/pytorch-tpu/llama.git
cd llama
pip3 install -r requirements.txt
pip3 install -e .
# git clone and build transformers ### llama/transformers/
git clone -b lsiyuan/fsdp-data-aug https://github.com/pytorch-tpu/transformers.git
cd transformers
Expand All @@ -138,7 +132,66 @@ local utils = import 'templates/utils.libsonnet';
wget https://storage.googleapis.com/tpu-pytorch/lsiyuan-experiment/configs/hf_llama/7B.json
# save llama2 training
echo -e 'python3 -u llama/transformers/examples/pytorch/xla_spawn.py --num_cores 64 llama/transformers/examples/pytorch/language-modeling/run_clm.py --num_train_epochs 2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --per_device_train_batch_size 8 --do_train --output_dir . --overwrite_output_dir --config_name llama/transformers/7B/7B.json --cache_dir /tmp --tokenizer_name gpt2 --block_size 1024 --optim adafactor --adafactor true --save_strategy no --logging_strategy no' >> llama2training.sh
echo -e 'python3 -u transformers/examples/pytorch/xla_spawn.py --num_cores 64 transformers/examples/pytorch/language-modeling/run_clm.py --num_train_epochs 2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --per_device_train_batch_size 8 --do_train --output_dir . --overwrite_output_dir --config_name transformers/7B/7B.json --cache_dir /tmp --tokenizer_name gpt2 --block_size 1024 --optim adafactor --adafactor true --save_strategy no --logging_strategy no' >> llama2training.sh
cat llama2training.sh
pwd
ls
|||,
},
},
local spmd = self.spmd,
spmd:: common.PyTorchTpuVmMixin {
modelName+: '-train-spmd',
tpuSettings+: {
tpuVmExports+: |||
export XLA_USE_BF16=1
export XLA_IR_DEBUG=1
export XLA_HLO_DEBUG=1
export BATCH_SIZE=32
export NUM_EPOCH=5
export PROFILE_EPOCH=2
export PROFILE_STEP=0
export PROFILE_DURATION_MS=20000
export XLA_USE_SPMD=1
export PJRT_DEVICE=TPU
export TPU_MEGACORE=megacore_dense
|||,
tpuVmExtraSetup: |||
pip3 uninstall torch torch_xla torchvision libtpu-nightly -y
sudo apt update -y
sudo apt-get update -y
pip install accelerate -U
sudo apt-get install libomp5 -y
pip3 install mkl mkl-include
pip3 install numpy
sudo apt-get install numactl -y
sudo apt-get install libopenblas-dev -y
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp310-cp310-linux_x86_64.whl
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
pip3 install torch_xla[tpuvm]
# install tokenizer model
wget https://storage.googleapis.com/tpu-pytorch/lsiyuan-experiment/llama/spiece.model
# git clone and build transformers ### llama/transformers/
git clone -b llama2-google-next-training https://github.com/pytorch-tpu/transformers.git
cd transformers
sudo pip3 uninstall transformers
sudo pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate
pwd
ls
# 7B config
mkdir 7B
cd 7B/
wget https://storage.mtls.cloud.google.com/hf-train-config/llama/2B.json
# save llama2 training
echo -e 'python transformers/examples/pytorch/language-modeling/run_clm.py --tokenizer_name gpt2 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --per_device_train_batch_size 32 --per_device_eval_batch_size 8 --num_train_epochs 1 --do_train --output_dir /tmp/output --overwrite_output_dir --config_name transformers/7B/2B.json --save_strategy no --logging_strategy no --remove_unused_columns no --spmd_fsdp_sharding --torch_dtype bfloat16 --dataloader_drop_last yes --spmd_grad_chkpt --report_to none' >> llama2training.sh
cat llama2training.sh
pwd
ls
Expand All @@ -152,7 +205,8 @@ local utils = import 'templates/utils.libsonnet';
},

configs: [
llama2_inference + v4_8 + common.Functional + timeouts.Hours(3) + pjrt,
llama2_training + v4_8 + common.Functional + timeouts.Hours(3) + hf,
llama2_inference + v4_8 + common.Functional + timeouts.Hours(3) + infer,
// llama2_training + v4_8 + common.Functional + timeouts.Hours(3) + fsdp,
llama2_training + v4_8 + common.Functional + timeouts.Hours(3) + spmd,
],
}

0 comments on commit 3fbf2a0

Please sign in to comment.