Skip to content

Commit

Permalink
Kill the TPU process if any before running the model
Browse files Browse the repository at this point in the history
  • Loading branch information
RissyRan committed Feb 13, 2024
1 parent 693ee3f commit 19607a0
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions xlml/utils/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import airflow
from airflow.decorators import task, task_group
from airflow.utils.task_group import TaskGroup
from airflow.operators.python import get_current_context
from xlml.apis import gcp_config, test_config
import fabric
import google.api_core.exceptions
Expand Down Expand Up @@ -272,6 +273,19 @@ def wait_for_queued_resource_deletion(op_name: Optional[str]):
wait_for_queued_resource_deletion(qr_op_name)


def kill_process_by_pid() -> str:
return f"""accelerator_type=\${{1}}
if [[ \${{accelerator_type}} =~ ^v5.* ]]
then
device_name=vfio/*
else
device_name=accel*
fi
echo \\"Terminating all processes utilizing the TPU (if any).\\"
sudo lsof -t /dev/\${{device_name}} | xargs -r kill -9
"""


@task
def ssh_tpu(
qualified_name: str,
Expand Down Expand Up @@ -315,4 +329,18 @@ def ssh_tpu(
)
},
)

context = get_current_context()
if context['task_instance'].try_number > 1:
# kill TPU process by pid (if any) to avoid `TPU in use` error in retry
tmp_file = '/tmp/kill_process.sh'
accelerator_type = nodes[0].accelerator_type
script = kill_process_by_pid()
kill_process_cmds = (
f'set -x; sudo echo "{script}" > {tmp_file}',
f'bash {tmp_file} {accelerator_type}',
)
ssh_group.run(';'.join(kill_process_cmds))

# run provided commands
ssh_group.run(cmds)

0 comments on commit 19607a0

Please sign in to comment.