diff --git a/.github/actions/pytest-gpu/action.yaml b/.github/actions/pytest-gpu/action.yaml index 9822e8d..fd276c2 100644 --- a/.github/actions/pytest-gpu/action.yaml +++ b/.github/actions/pytest-gpu/action.yaml @@ -125,3 +125,22 @@ runs: --gpu_num ${{ inputs.gpu_num }} \ --git_ssh_clone ${{ inputs.git_ssh_clone }} \ ${REF_ARGS} + - name: Follow Run Logs + shell: bash + env: + MOSAICML_API_KEY: ${{ inputs.mcloud_api_key }} + run: | + set -ex + + python .github/mcli/follow_mcli_logs.py \ + --name '${{ steps.tests.outputs.RUN_NAME }}' + - name: Stop Run if Cancelled + if: ${{ cancelled() }} + shell: bash + env: + MOSAICML_API_KEY: ${{ inputs.mcloud_api_key }} + run: | + set -ex + + python .github/mcli/cancel_mcli_run.py \ + --name '${{ steps.tests.outputs.RUN_NAME }}' diff --git a/.github/mcli/cancel_mcli_run.py b/.github/mcli/cancel_mcli_run.py new file mode 100644 index 0000000..9632dc1 --- /dev/null +++ b/.github/mcli/cancel_mcli_run.py @@ -0,0 +1,23 @@ +# Copyright 2024 MosaicML CI-Testing authors +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +from mcli import RunStatus, get_run, stop_run, wait_for_run_status + +"""Cancel an MCLI run.""" + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--name', type=str, required=True, help='Name of run') + args = parser.parse_args() + + run = get_run(args.name) + + print('[GHA] Stopping run.') + stop_run(run) + + # Wait until run stops + run = wait_for_run_status(run, status=RunStatus.STOPPED) + print('[GHA] Run stopped.') diff --git a/.github/mcli/follow_mcli_logs.py b/.github/mcli/follow_mcli_logs.py new file mode 100644 index 0000000..485ddfd --- /dev/null +++ b/.github/mcli/follow_mcli_logs.py @@ -0,0 +1,30 @@ +# Copyright 2024 MosaicML CI-Testing authors +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +from mcli import RunStatus, follow_run_logs, get_run, wait_for_run_status + +"""Follow MCLI run logs.""" + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--name', type=str, required=True, help='Name of run') + args = parser.parse_args() + + run = get_run(args.name) + + # Wait until run starts before fetching logs + run = wait_for_run_status(run, status='running') + print('[GHA] Run started. Following logs...') + + # Print logs + for line in follow_run_logs(run): + print(line, end='') + + print('[GHA] Run completed. Waiting for run to finish...') + run = wait_for_run_status(run, status=RunStatus.COMPLETED) + + # Fail if command exited with non-zero exit code or timed out (didn't reach COMPLETED) + assert run.status == RunStatus.COMPLETED, f'Run {run.name} did not complete: {run.status} ({run.reason})' diff --git a/.github/mcli/mcli_pytest.py b/.github/mcli/mcli_pytest.py index b4f063c..d8b11cd 100644 --- a/.github/mcli/mcli_pytest.py +++ b/.github/mcli/mcli_pytest.py @@ -4,9 +4,9 @@ """Run pytest using MCLI.""" import argparse -import time +import os -from mcli import RunConfig, RunStatus, create_run, follow_run_logs, wait_for_run_status +from mcli import RunConfig, create_run if __name__ == '__main__': @@ -111,17 +111,5 @@ run = create_run(config) print(f'[GHA] Run created: {run.name}') - # Wait until run starts before fetching logs - run = wait_for_run_status(run, status='running') - start_time = time.time() - print('[GHA] Run started. Following logs...') - - # Print logs - for line in follow_run_logs(run): - print(line, end='') - - print('[GHA] Run completed. Waiting for run to finish...') - run = wait_for_run_status(run, status=RunStatus.COMPLETED) - - # Fail if command exited with non-zero exit code or timed out (didn't reach COMPLETED) - assert run.status == RunStatus.COMPLETED, f'Run {run.name} did not complete: {run.status} ({run.reason})' + with open(os.environ['GITHUB_OUTPUT'], 'a') as fh: + print(f'RUN_NAME={run.name}', file=fh)