Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
j316chuck committed Oct 10, 2023
1 parent 0540bb7 commit 96c5be1
Showing 1 changed file with 42 additions and 64 deletions.
106 changes: 42 additions & 64 deletions .github/mcp/mcp_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
"""Run pytest using MCP."""

import argparse
import time

from mcli.sdk import (RunConfig, RunStatus, create_run, follow_run_logs,
wait_for_run_status)

if __name__ == '__main__':

def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--name',
type=str,
Expand Down Expand Up @@ -54,26 +53,13 @@
type=int,
default=1800,
help='Timeout for run (in seconds)')
args = parser.parse_args()

name = args.name
git_integration = {
'integration_type': 'git_repo',
'git_repo': 'mosaicml/llm-foundry',
'ssh_clone': 'False',
}
if args.git_branch is not None and args.git_commit is None:
name += f'-branch-{args.git_branch}'
git_integration['git_branch'] = args.git_branch
if args.git_commit is not None:
name += f'-commit-{args.git_commit}'
git_integration['git_commit'] = args.git_commit
return parser.parse_args()

def construct_base_command(args):
command = 'cd llm-foundry'

# Checkout a specific PR if specified
if args.pr_number is not None:
name += f'-pr-{args.pr_number}'
command += f'''
git fetch origin pull/{args.pr_number}/head:pr_branch
Expand All @@ -82,32 +68,58 @@
'''

# Shorten name if too long
if len(name) > 56:
name = name[:56]
return command

def construct_run_command(args, distributed: bool):
clear_tmp_path_flag = '-o tmp_path_retention_policy=none'
if distributed:
make_command = f'''make test-dist PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS" WORLD_SIZE=2'''
else:
make_command = f'''make test PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS --codeblocks"'''


run_command = f'''
pip install --upgrade --user .[all]
export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}' {clear_tmp_path_flag}"
make test-dist PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS" WORLD_SIZE=2
{make_command}
python -m coverage combine
python -m coverage report
'''
return run_command


def create_and_follow_run(args, command):
name = args.name
git_integration = {
'integration_type': 'git_repo',
'git_repo': 'mosaicml/llm-foundry',
'ssh_clone': 'False',
}
if args.git_branch is not None and args.git_commit is None:
name += f'-branch-{args.git_branch}'
git_integration['git_branch'] = args.git_branch
if args.git_commit is not None:
name += f'-commit-{args.git_commit}'
git_integration['git_commit'] = args.git_commit
if args.pr_number is not None:
name += f'-pr-{args.pr_number}'
# Shorten name if too long
if len(name) > 56:
name = name[:56]

config = RunConfig(
name=name,
cluster=args.cluster,
gpu_type=args.gpu_type,
gpu_num=args.gpu_num,
image=args.image,
integrations=[git_integration],
command=command + run_command,
command=command,
scheduling={'max_duration': args.timeout / 60 / 60},
)

Expand All @@ -117,7 +129,6 @@

# Wait until run starts before fetching logs
run = wait_for_run_status(run, status='running')
start_time = time.time()
print('[GHA] Run started. Following logs...')

# Print logs
Expand All @@ -130,44 +141,11 @@
# Fail if command exited with non-zero exit code or timed out
assert run.status == RunStatus.COMPLETED

run_command = f'''
pip install --upgrade --user .[all]
export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}' {clear_tmp_path_flag}"
make test PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS --codeblocks"
python -m coverage combine

python -m coverage report
'''
config = RunConfig(
name=name,
cluster=args.cluster,
gpu_type=args.gpu_type,
gpu_num=args.gpu_num,
image=args.image,
integrations=[git_integration],
command=command + run_command,
scheduling={'max_duration': args.timeout / 60 / 60},
)

# Create run
run = create_run(config)
print(f'[GHA] Run created: {run.name}')

# Wait until run starts before fetching logs
run = wait_for_run_status(run, status='running')
start_time = time.time()
print('[GHA] Run started. Following logs...')

# Print logs
for line in follow_run_logs(run):
print(line, end='')

print('[GHA] Run completed. Waiting for run to finish...')
run = wait_for_run_status(run, status='completed')

# Fail if command exited with non-zero exit code or timed out
assert run.status == RunStatus.COMPLETED
if __name__ == '__main__':
args = parse_arguments()
command_base = construct_base_command(args)
distributed_gpu_test_run_command = construct_run_command(args, distributed=True)
create_and_follow_run(args, command_base + distributed_gpu_test_run_command)
single_node_gpu_test_run_command = construct_run_command(args, distributed=False)
create_and_follow_run(args, command_base + single_node_gpu_test_run_command)

0 comments on commit 96c5be1

Please sign in to comment.