Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to reuse TPUs #769

Open
wants to merge 1 commit into
base: multipod-tests
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 102 additions & 23 deletions tests/multipods/experimental_multipod.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -76,38 +76,64 @@ local volumes = import 'templates/volumes.libsonnet';
'create-tpu-slices': {
image: 'google/cloud-sdk',
local tpuCreateSettings = {
acceleratorName: std.escapeStringBash(config.accelerator.name),
acceleratorName: config.accelerator.name,
sliceCount: config.tpuSettings.slices,
softwareVersion: std.escapeStringBash(config.tpuSettings.softwareVersion),
startupScript: std.escapeStringBash(config.tpuSettings.tpuVmStartupScript),
sleepTime: config.tpuSettings.tpuVmCreateSleepSeconds,
testName: std.strReplace(config.testName, '.', '-'),
tpuExists: config.tpuExists,
tpuPrefix: config.tpuPrefix,
userName: config.userName,
},
command: utils.scriptCommand(|||
set +x
project=$(curl -sS "http://metadata.google.internal/computeMetadata/v1/project/project-id" -H "Metadata-Flavor: Google")
zone=$(curl -sS "http://metadata.google.internal/computeMetadata/v1/instance/zone" -H "Metadata-Flavor: Google" | awk -F'/' '{print $4}')
tpu_name_prefix=tpu-${POD_UID}
if [ %(tpuExists)s = true ]; then
tpu_name_prefix=%(tpuPrefix)s
fi
ssh-keygen -t rsa -f /scripts/id_rsa -q -N ""

echo "${project}:$(cat /scripts/id_rsa.pub)" > ssh-keys.txt
echo %(startupScript)s > startup-script.txt

echo %(sliceCount)s >> /scripts/slice_count
for (( i=0; i < %(sliceCount)s; i++ )); do
tpu_name=${tpu_name_prefix}-${i}
echo "
gcloud alpha compute tpus tpu-vm delete -q ${tpu_name} --zone=${zone}
" > /scripts/cleanup_${i}.sh

if [ %(tpuExists)s = false ]; then
for (( i=0; i < %(sliceCount)s; i++ )); do
tpu_name_delete=${tpu_name_prefix}-${i}
echo "
gcloud alpha compute tpus tpu-vm delete -q ${tpu_name_delete} --zone=${zone} --project=${project}
" > /scripts/cleanup_${i}.sh
echo "
bash /scripts/cleanup_${i}.sh
" >> /scripts/cleanup.sh
done
else
echo "
bash /scripts/cleanup_${i}.sh
true
" >> /scripts/cleanup.sh

fi
delete_tpus() {
echo -e "\n\nDeleting TPUs..."
for tpu_id in "${TPU_LIST[@]}"; do
echo -e "\n${tpu_id} is being deleted."
gcloud alpha compute tpus tpu-vm delete -q "${tpu_id}" --zone=${zone} --project=${project}
if [[ $? -ne 0 ]]; then
echo "Failed to delete the TPU ${TPU_NAME}. Delete it manually."
exit 1
fi
done
}
create_tpu() {
echo "Create TPU called"
TPU_NAME=$1
SLICE_ID=$2
# Retry every 30 seconds for 10 minutes
for j in {1..20}; do
set +e
gcloud alpha compute tpus tpu-vm create ${tpu_name} \
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \
--accelerator-type=%(acceleratorName)s \
--version=%(softwareVersion)s \
--metadata-from-file='ssh-keys=ssh-keys.txt,startup-script=startup-script.txt' \
Expand All @@ -120,19 +146,67 @@ local volumes = import 'templates/volumes.libsonnet';
done

if [ $exit_code -ne 0 ]; then
echo "TPU VM with name ${TPU_NAME} failed to create. So exiting the setup."
delete_tpus
exit $exit_code
fi

echo ${tpu_name} >> /scripts/tpu_name_${i}

if [ ${i} -eq 0 ]; then
gcloud compute tpus describe ${tpu_name} --project=${project} --zone=${zone} --format="value(networkEndpoints[0].ipAddress)" > /scripts/coordinator_ip
echo -e "Slice_${SLICE_ID}: TPU VM ${TPU_NAME} successfully created."
TPU_CREATED=true
}
create_tpu_slice_environment() {
echo -e "\n\nSetting %(sliceCount)s TPU Slices with %(acceleratorName)s in each slice..."
for (( i=0; i < %(sliceCount)s; i++ )); do
TPU_NAME=${tpu_name_prefix}-${i}
tpu_exist_with_same_type=false
tpu_exist_with_diff_type=false
echo "$TPU_NAME, $zone, $project, $(gcloud compute tpus list --zone=${zone} --project=${project} | grep "^$TPU_NAME ")"
if [[ -z "$(gcloud compute tpus list --zone=${zone} --project=${project} | grep "^$TPU_NAME ")" ]]; then
list_of_tpu_with_same_name=''
else
list_of_tpu_with_same_name=$(gcloud compute tpus list --zone=${zone} --project=${project} | grep "^$TPU_NAME ")
fi
if [[ ! -z "$(gcloud compute tpus list --zone=${zone} --project=${project} | grep "^$TPU_NAME ")" ]]; then
list_of_tpu_with_same_type=$(echo "$list_of_tpu_with_same_name" | grep "%(acceleratorName)s")
echo "$list_of_tpu_with_same_type"
if [[ ! -z "$list_of_tpu_with_same_type" ]]; then
tpu_exist_with_same_type=true
else
tpu_exist_with_diff_type=true
fi
fi
echo "$TPU_NAME, $tpu_exist_with_same_type, $tpu_exist_with_diff_type"
if [[ %(tpuExists)s = true ]]; then
if [[ "$tpu_exist_with_same_type" = false ]]; then
if [[ "$tpu_exist_with_diff_type" = false ]]; then
echo -e "\nYou chooses to use existing TPU. But TPU with name $TPU_NAME doesn't exist!"
else
echo -e "\nTPU with name $TPU_NAME already exists but with different configuration. So exiting."
fi
exit 1
fi
else
if [[ "$tpu_exist_with_same_type" = true ]] || [[ "$tpu_exist_with_diff_type" = true ]]; then
echo -e "\nTPU with name $TPU_NAME already exists and you choose USE_EXISTING_TPUS=%(tpuExists)s. So exiting."
exit 1
fi
create_tpu "$TPU_NAME" $i
fi
TPU_LIST+=(${TPU_NAME})
echo ${TPU_NAME} >> /scripts/tpu_name_${i}
if [ ${i} -eq 0 ]; then
gcloud compute tpus describe ${TPU_NAME} --project=${project} --zone=${zone} --format="value(networkEndpoints[0].ipAddress)" > /scripts/coordinator_ip
fi
gcloud compute tpus describe ${TPU_NAME} --project=${project} --zone=${zone} --format="value(networkEndpoints[0].ipAddress)" >> /scripts/tpu_ip_slice_${i}
gcloud compute tpus describe ${TPU_NAME} --project=${project} --zone=${zone} --flatten="networkEndpoints[]" --format="csv[no-heading](networkEndpoints.ipAddress)" >> /scripts/all_tpu_ips_slice_${i}
wc -l < /scripts/all_tpu_ips_slice_${i} >> /scripts/worker_count_slice_${i}
done
if [[ "$TPU_CREATED" = false ]]; then
echo -e "\nUsing already created %(sliceCount)s TPU Slices with %(acceleratorName)s in each slice..."
fi
gcloud compute tpus describe ${tpu_name} --project=${project} --zone=${zone} --format="value(networkEndpoints[0].ipAddress)" >> /scripts/tpu_ip_slice_${i}
gcloud compute tpus describe ${tpu_name} --project=${project} --zone=${zone} --flatten="networkEndpoints[]" --format="csv[no-heading](networkEndpoints.ipAddress)" >> /scripts/all_tpu_ips_slice_${i}
wc -l < /scripts/all_tpu_ips_slice_${i} >> /scripts/worker_count_slice_${i}
done

}
TPU_CREATED=false
create_tpu_slice_environment
echo "$TPU_LIST"
sleep %(sleepTime)d

COORDINATOR_IP=$(cat /scripts/coordinator_ip)
Expand All @@ -143,9 +217,13 @@ local volumes = import 'templates/volumes.libsonnet';
echo "export MEGASCALE_COORDINATOR_ADDRESS=${COORDINATOR_IP}:8080" >> ~/.profile
echo "export MEGASCALE_NUM_SLICES=${SLICE_COUNT}" >> ~/.profile
echo "export MEGASCALE_SLICE_ID=${i}" >> ~/.profile
echo "export MEGASCALE_TRANSPORT_TYPE=\"grpc\"" >> ~/.profile
echo "export MEGASCALE_PORT=8080" >> ~/.profile
echo "export MEGASCALE_AUTHENTICATION=\"insecure\"" >> ~/.profile
SCRIPT_EOF

gcloud alpha compute tpus tpu-vm ssh cloud-tpu-multipod-dev@$(cat /scripts/tpu_name_${i}) \
echo $(cat /scripts/tpu_name_${i})
echo "$(cat set_mxla_flags.sh)"
gcloud alpha compute tpus tpu-vm ssh %(userName)s@$(cat /scripts/tpu_name_${i}) \
--zone=${zone} \
--ssh-key-file=/scripts/id_rsa \
--strict-host-key-checking=no \
Expand All @@ -156,7 +234,7 @@ local volumes = import 'templates/volumes.libsonnet';

echo ${zone} > /scripts/zone

echo "LOGGER: TPU VMs created successfully."
echo "LOGGER: TPU VMs setup successful."
||| % tpuCreateSettings),
env: [
{
Expand Down Expand Up @@ -213,3 +291,4 @@ local volumes = import 'templates/volumes.libsonnet';
},
},
}

130 changes: 83 additions & 47 deletions tests/multipods/jax/common.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ local tpus = import 'templates/tpus.libsonnet';
frameworkPrefix: 'mp-jax',
image: 'google/cloud-sdk',
accelerator: tpus.v4_16,

tpuExists: false,
tpuPrefix: 'test',
userName: 'cloud-tpu-multipod-dev',
metricConfig+: {
sourceMap+:: {
tensorboard+: {
Expand Down Expand Up @@ -71,52 +73,54 @@ local tpus = import 'templates/tpus.libsonnet';
|||
set +x
set -u
SLICE_COUNT=$(cat /scripts/slice_count)
ZONE=$(cat /scripts/zone)

cat > testsetup.sh << SCRIPT_EOF
set +x
set -u
set -e

# .bash_logout sometimes causes a spurious bad exit code, remove it.
rm .bash_logout

%(installPipPackages)s
%(installJax)s
%(installJaxlib)s
%(installLibtpu)s
if [ %(tpuExists)s = false ]; then
cat > testsetup.sh << SCRIPT_EOF
set +x
set -u
set -e

# .bash_logout sometimes causes a spurious bad exit code, remove it.
rm .bash_logout

%(installPipPackages)s
%(installJax)s
%(installJaxlib)s
%(installLibtpu)s
SCRIPT_EOF

setup_process_ids=()
setup_process_ids=()

SLICE_COUNT=$(cat /scripts/slice_count)
ZONE=$(cat /scripts/zone)

for (( i=0; i < ${SLICE_COUNT}; i++ )); do
gcloud alpha compute tpus tpu-vm ssh cloud-tpu-multipod-dev@$(cat /scripts/tpu_name_${i}) \
--zone=${ZONE} \
--ssh-key-file=/scripts/id_rsa \
--strict-host-key-checking=no \
--internal-ip \
--worker=all \
--command "$(cat testsetup.sh)" >> output_testsetup_${i}.txt 2>&1 &

setup_process_ids+=($!)
done
for (( i=0; i < ${SLICE_COUNT}; i++ )); do
gcloud alpha compute tpus tpu-vm ssh %(userName)s@$(cat /scripts/tpu_name_${i}) \
--zone=${ZONE} \
--ssh-key-file=/scripts/id_rsa \
--strict-host-key-checking=no \
--internal-ip \
--worker=all \
--command "$(cat testsetup.sh)" >> output_testsetup_${i}.txt 2>&1 &

echo "LOGGER: Waiting for test setup to be installed on all TPU VM hosts in ${SLICE_COUNT} slices."
setup_process_ids+=($!)
done

for i in "${!setup_process_ids[@]}"; do
wait ${setup_process_ids[$i]}
if [[ $? -ne 0 ]]; then
echo "LOGGER: Set up failed on slice_${i}. Here is the output:"
cat output_testsetup_${i}.txt
bash /scripts/cleanup.sh
exit 1
fi
done
echo "LOGGER: Waiting for test setup to be installed on all TPU VM hosts in ${SLICE_COUNT} slices."

echo "LOGGER: Test set up completed successfully on ${SLICE_COUNT} slices."
for i in "${!setup_process_ids[@]}"; do
wait ${setup_process_ids[$i]}
if [[ $? -ne 0 ]]; then
echo "LOGGER: Set up failed on slice_${i}. Here is the output:"
cat output_testsetup_${i}.txt
bash /scripts/cleanup.sh
exit 1
fi
done

echo "LOGGER: Test set up completed successfully on ${SLICE_COUNT} slices."
else
echo "LOGGER: Not installing anything"
fi
test_script_process_ids=()

cat > test_script.sh << TEST_SCRIPT_EOF
Expand All @@ -125,7 +129,7 @@ local tpus = import 'templates/tpus.libsonnet';

for (( i=0; i < ${SLICE_COUNT}; i++ )); do
for (( j=0; j < $(cat /scripts/worker_count_slice_${i}); j++ )); do
gcloud alpha compute tpus tpu-vm ssh cloud-tpu-multipod-dev@$(cat /scripts/tpu_name_${i}) \
gcloud alpha compute tpus tpu-vm ssh %(userName)s@$(cat /scripts/tpu_name_${i}) \
--zone=${ZONE} \
--ssh-key-file=/scripts/id_rsa \
--strict-host-key-checking=no \
Expand Down Expand Up @@ -153,17 +157,15 @@ local tpus = import 'templates/tpus.libsonnet';

echo "LOGGER: Test script completed successfully on all the TPU VM hosts of ${SLICE_COUNT} slices. Here is the output from Slice 0:"
cat output_slice_0_worker_0.txt

echo "LOGGER: Cleaning up the TPU VM resources:"

sleep 60


sleep 30
echo $(cat /scripts/cleanup.sh)
bash /scripts/cleanup.sh

exit_code=$?

exit $exit_code
||| % { testScript: config.testScript, installPipPackages: config.scriptConfig.installPipPackages, installJax: config.scriptConfig.installJax, installJaxlib: config.scriptConfig.installJaxlib, installLibtpu: config.scriptConfig.installLibtpu },
||| % { testScript: config.testScript, installPipPackages: config.scriptConfig.installPipPackages, installJax: config.scriptConfig.installJax, installJaxlib: config.scriptConfig.installJaxlib, installLibtpu: config.scriptConfig.installLibtpu, userName: config.userName, tpuExists: config.tpuExists },
],
},

Expand Down Expand Up @@ -241,7 +243,41 @@ local tpus = import 'templates/tpus.libsonnet';
|||,
},
},

jaxlibOldStable:: {
jaxlibVersion:: 'old',
scriptConfig+: {
installJax: |||
pip3 install jax==0.3.25
|||,
installJaxlib: |||
pip3 install jaxlib==0.3.25
|||,
installLibtpu: |||
/usr/bin/docker-credential-gcr configure-docker
sudo bash /var/scripts/docker-login.sh

sudo docker create --name libtpu_next gcr.io/cloud-tpu-v2-images-dev/libtpu_unsanitized:libtpu_unsanitized_2022111705_RC00 "/bin/bash"
sudo docker cp libtpu_next:_libtpu_next.so /lib/libtpu.so

sudo docker rm libtpu_next
echo "export TPU_LIBRARY_PATH=/lib/libtpu.so" >> ~/.profile
|||,
},
},
noInstall:: {
jaxlibVersion:: 'not-installed',
scriptConfig+: {
installJax: |||
true
|||,
installJaxlib: |||
true
|||,
installLibtpu: |||
true
|||,
},
},
tpuVmV4Base:: {
local config = self,
accelerator: tpus.v4_16,
Expand Down