From 5ad8ba14bd1fa8babb5785e9a0a1abb70fdf365a Mon Sep 17 00:00:00 2001 From: jetaba Date: Mon, 10 Jun 2024 08:05:16 -0500 Subject: [PATCH] refactor runner main --- bclaw_runner/src/runner/qc_check.py | 41 ++++--------- bclaw_runner/src/runner/repo.py | 51 +++++++++++++++- bclaw_runner/src/runner/runner_main.py | 83 ++++++++++++++++++++++++-- bclaw_runner/src/runner/workspace.py | 15 +++-- bclaw_runner/tests/test_qc_check.py | 65 +++----------------- bclaw_runner/tests/test_repo.py | 43 +++++++------ bclaw_runner/tests/test_runner_main.py | 1 + bclaw_runner/tests/test_workspace.py | 16 ++--- cloudformation/bc_core.yaml | 3 +- 9 files changed, 191 insertions(+), 127 deletions(-) diff --git a/bclaw_runner/src/runner/qc_check.py b/bclaw_runner/src/runner/qc_check.py index 66cb34b..6283cfa 100644 --- a/bclaw_runner/src/runner/qc_check.py +++ b/bclaw_runner/src/runner/qc_check.py @@ -5,37 +5,18 @@ import boto3 -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) -# def run_qc_checks(checks: list) -> None: -# if checks: -# logger.info("starting QC checks") -# for item in checks: -# qc_file = item["qc_result_file"] -# logger.info(f"{qc_file=}") -# -# with open(qc_file) as fp: -# qc_data = json.load(fp) -# -# for qc_expression in item["stop_early_if"]: -# run_qc_check(qc_data, qc_expression) -# -# logger.info("QC checks finished") -# else: -# logger.info("no QC checks requested") - - -# def run_qc_check(qc_data: dict, qc_expression: str) -> None: -# result = eval(qc_expression, globals(), qc_data) -# if result: -# logger.warning(f"failed QC check: {qc_expression}; aborting workflow execution") -# abort_execution(qc_expression) -# else: -# logger.info(f"passed QC check: {qc_expression}") +class QCFailure(Exception): + def __init__(self, message: str, failures: list): + super().__init__(message) + self.failures = failures def abort_execution(failed_expressions: list) -> None: + logger.warning("aborting workflow execution") + region = os.environ["AWS_DEFAULT_REGION"] acct = os.environ["AWS_ACCOUNT_ID"] wf_name = os.environ["BC_WORKFLOW_NAME"] @@ -43,7 +24,7 @@ def abort_execution(failed_expressions: list) -> None: step_name = os.environ["BC_STEP_NAME"] execution_arn = f"arn:aws:states:{region}:{acct}:execution:{wf_name}:{exec_id}" - cause = "\n".join(["failed QC conditions:"] + failed_expressions) + cause = "failed QC conditions: " + "; ".join(failed_expressions) sfn = boto3.client("stepfunctions") sfn.stop_execution( @@ -78,10 +59,8 @@ def run_all_qc_checks(checks: list) -> Generator[str, None, None]: def do_checks(checks: list) -> None: if checks: logger.info("starting QC checks") - failures = list(run_all_qc_checks(checks)) - if failures: - logger.warning(f"aborting workflow execution") - abort_execution(failures) + if (failures := list(run_all_qc_checks(checks))): + raise QCFailure("QC checks failed", failures) logger.info("QC checks finished") else: logger.info("no QC checks requested") diff --git a/bclaw_runner/src/runner/repo.py b/bclaw_runner/src/runner/repo.py index 837aed9..200c683 100644 --- a/bclaw_runner/src/runner/repo.py +++ b/bclaw_runner/src/runner/repo.py @@ -15,6 +15,10 @@ logger = logging.getLogger(__name__) +class SkipExecution(Exception): + pass + + def _file_metadata(): ret = {"execution_id": os.environ.get("BC_EXECUTION_ID", "undefined")} return ret @@ -80,7 +84,29 @@ def _s3_file_exists(self, key: str) -> bool: else: raise - def files_exist(self, filenames: List[str]) -> bool: + + def check_files_exist(self, filenames: List[str]) -> None: + """ + Raises SkipExecution if this step has been run before + """ + # this is for backward compatibility. Note that if you have a step that produces + # no outputs (i.e. being run for side effects only), it will always be skipped + # if run with skip_if_files_exist + # if len(filenames) == 0: + # raise SkipExecution("found output files; skipping") + + # there's no way to know if all the files included in a glob were uploaded in + # a previous run, so always rerun to be safe + if any(_is_glob(f) for f in filenames): + return + + # note: all([]) = True + keys = (self.qualify(os.path.basename(f)) for f in filenames) + if all(self._s3_file_exists(k) for k in keys): + raise SkipExecution("found output files; skipping") + + + def files_exist0(self, filenames: List[str]) -> bool: # this is for backward compatibility. Note that if you have a step that produces # no outputs (i.e. being run for side effects only), it will always be skipped # if run with skip_if_files_exist @@ -96,6 +122,7 @@ def files_exist(self, filenames: List[str]) -> bool: ret = all(self._s3_file_exists(k) for k in keys) return ret + def _inputerator(self, input_spec: Dict[str, str]) -> Generator[str, None, None]: for symbolic_name, filename in input_spec.items(): optional = symbolic_name.endswith("?") @@ -170,7 +197,27 @@ def upload_outputs(self, output_spec: Dict[str, str]) -> None: result = list(executor.map(self._upload_that, self._outputerator(output_spec.values()))) logger.info(f"{len(result)} files uploaded") - def check_for_previous_run(self) -> bool: + def check_for_previous_run(self) -> None: + """ + Raises SkipExecution if this step has been run before + """ + try: + result = self._s3_file_exists(self.qualify(self.run_status_obj)) + except Exception: + logger.warning("unable to query previous run status, assuming none") + else: + if result: + raise SkipExecution("found previous run; skipping") + # pharque!!! + # try: + # if self._s3_file_exists(self.qualify(self.run_status_obj)): + # raise SkipExecution("found previous run; skipping") + # # pharque!!! + # except Exception: + # logger.warning("unable to query previous run status, assuming none") + + + def check_for_previous_run0(self) -> bool: """ Returns: True if this step has been run before diff --git a/bclaw_runner/src/runner/runner_main.py b/bclaw_runner/src/runner/runner_main.py index 623c4b2..489ca7f 100644 --- a/bclaw_runner/src/runner/runner_main.py +++ b/bclaw_runner/src/runner/runner_main.py @@ -29,17 +29,16 @@ from .cache import get_reference_inputs from .custom_logs import LOGGING_CONFIG from .string_subs import substitute, substitute_image_tag -from .qc_check import do_checks -from .repo import Repository +from .qc_check import do_checks, abort_execution, QCFailure +from .repo import Repository, SkipExecution from .tagging import tag_this_instance from .termination import spot_termination_checker -from .workspace import workspace, write_job_data_file, run_commands +from .workspace import workspace, write_job_data_file, run_commands, UserCommandsFailed logging.config.dictConfig(LOGGING_CONFIG) logger = logging.getLogger(__name__) - def main(commands: List[str], image: str, inputs: Dict[str, str], @@ -49,6 +48,80 @@ def main(commands: List[str], repo_path: str, shell: str, skip: str) -> int: + exit_code = 0 + try: + repo = Repository(repo_path) + + if skip == "rerun": + repo.check_for_previous_run() + elif skip == "output": + repo.check_files_exist(list(outputs.values())) + + repo.clear_run_status() + + job_data_obj = repo.read_job_data() + + jobby_commands = substitute(commands, job_data_obj) + jobby_inputs = substitute(inputs, job_data_obj) + jobby_outputs = substitute(outputs, job_data_obj) + jobby_references = substitute(references, job_data_obj) + + jobby_image = substitute_image_tag(image, job_data_obj) + + with workspace() as wrk: + # download references, link to workspace + local_references = get_reference_inputs(jobby_references) + + # download inputs -> returns local filenames + local_inputs = repo.download_inputs(jobby_inputs) + local_outputs = jobby_outputs + + subbed_commands = substitute(jobby_commands, + local_inputs | + local_outputs | + local_references) + + local_job_data = write_job_data_file(job_data_obj, wrk) + + try: + run_commands(jobby_image, subbed_commands, wrk, local_job_data, shell) + do_checks(qc) + + finally: + repo.upload_outputs(jobby_outputs) + + except UserCommandsFailed as uce: + exit_code = uce.exit_code + logger.error(str(uce)) + + except QCFailure as qcf: + logger.error(str(qcf)) + abort_execution(qcf.failures) + + except SkipExecution as se: + logger.info(str(se)) + pass + + except Exception as e: + logger.exception("bclaw_runner error: ") + exit_code = 255 + + else: + repo.put_run_status() + logger.info("runner finished") + + return exit_code + + +def main0(commands: List[str], + image: str, + inputs: Dict[str, str], + outputs: Dict[str, str], + qc: List[dict], + references: Dict[str, str], + repo_path: str, + shell: str, + skip: str) -> int: repo = Repository(repo_path) @@ -126,7 +199,7 @@ def cli() -> int: with spot_termination_checker(): args = docopt(__doc__, version=os.environ["BC_VERSION"]) - logger.info(f"{args = }") + logger.info(f"{args=}") commands = json.loads(args["--cmd"]) image = args["--image"] diff --git a/bclaw_runner/src/runner/workspace.py b/bclaw_runner/src/runner/workspace.py index dc0d3e8..8740bc4 100644 --- a/bclaw_runner/src/runner/workspace.py +++ b/bclaw_runner/src/runner/workspace.py @@ -10,6 +10,12 @@ logger = logging.getLogger(__name__) +class UserCommandsFailed(Exception): + def __init__(self, message, exit_code): + super().__init__(message) + self.exit_code = exit_code + + @contextmanager def workspace() -> str: orig_path = os.getcwd() @@ -34,7 +40,7 @@ def write_job_data_file(job_data: dict, dest_dir: str) -> str: return fp.name -def run_commands(image_tag: str, commands: list, work_dir: str, job_data_file: str, shell_opt: str) -> int: +def run_commands(image_tag: str, commands: list, work_dir: str, job_data_file: str, shell_opt: str) -> None: script_file = "_commands.sh" with open(script_file, "w") as fp: @@ -53,6 +59,7 @@ def run_commands(image_tag: str, commands: list, work_dir: str, job_data_file: s os.chmod(script_file, 0o700) command = f"{shell_cmd} {script_file}" - ret = run_child_container(image_tag, command, work_dir, job_data_file) - - return ret + if (exit_code := run_child_container(image_tag, command, work_dir, job_data_file)) == 0: + logger.info("command block succeeded") + else: + raise UserCommandsFailed(f"command block failed with exit code {exit_code}", exit_code) diff --git a/bclaw_runner/tests/test_qc_check.py b/bclaw_runner/tests/test_qc_check.py index a27154d..b4d6674 100644 --- a/bclaw_runner/tests/test_qc_check.py +++ b/bclaw_runner/tests/test_qc_check.py @@ -4,7 +4,7 @@ import moto import pytest -from ..src.runner.qc_check import abort_execution, run_one_qc_check, run_all_qc_checks, do_checks +from ..src.runner.qc_check import abort_execution, run_one_qc_check, run_all_qc_checks, do_checks, QCFailure QC_DATA_1 = { "a": 1, @@ -43,56 +43,6 @@ def mock_qc_data_files(mocker, request): ret.side_effect = [qc_file1.return_value, qc_file2.return_value] -# def test_run_qc_checks(mock_qc_data_files, mocker): -# mock_run_qc_check = mocker.patch("bclaw_runner.src.runner.qc_check.run_qc_check") -# -# spec = [ -# { -# "qc_result_file": "fake1", -# "stop_early_if": [ -# "b == 1", -# ], -# }, -# { -# "qc_result_file": "fake2", -# "stop_early_if": [ -# "x == 99", -# "y == 98", -# ], -# }, -# ] -# run_qc_checks(spec) -# result = mock_run_qc_check.call_args_list -# expect = [ -# mocker.call(QC_DATA_1, "b == 1"), -# mocker.call(QC_DATA_2, "x == 99"), -# mocker.call(QC_DATA_2, "y == 98"), -# ] -# assert result == expect - - -# def test_run_qc_checks_empty(mock_qc_data_files, mocker): -# mock_run_qc_check = mocker.patch("bclaw_runner.src.runner.qc_check.run_qc_check") -# run_qc_checks([]) -# mock_run_qc_check.assert_not_called() - - -# @pytest.mark.parametrize("expression, expect_abort", [ -# ("x == 1", False), -# ("x > 1", True), -# ]) -# def test_run_qc_check(expression, expect_abort, mocker): -# mock_abort_execution = mocker.patch("bclaw_runner.src.runner.qc_check.abort_execution") -# -# qc_data = {"x": 1} -# run_qc_check(qc_data, expression) -# -# if expect_abort: -# mock_abort_execution.assert_not_called() -# else: -# mock_abort_execution.assert_called() - - def test_abort_execution(mock_state_machine, monkeypatch): sfn = boto3.client("stepfunctions", region_name="us-east-1") sfn_execution = sfn.start_execution( @@ -145,14 +95,14 @@ def test_run_all_qc_checks(fake1_cond, fake2_cond, expect, mock_qc_data_files): assert result == expect -@pytest.mark.parametrize("fake1_cond, fake2_cond, expect", [ +@pytest.mark.parametrize("fake1_cond, fake2_cond, expect_qc_fail", [ (None, None, False), # no checks (["a>1"], ["x<99"], False), # all pass (["a>1", "b==2"], ["y<98"], True), # one fail (["b==1"], ["x==99", "y==98"], True), # multi fail (["a==1", "b==2"], ["x==99", "y==98"], True), # all fail ]) -def test_do_checks(fake1_cond, fake2_cond, expect, mock_qc_data_files, mocker): +def test_do_checks(fake1_cond, fake2_cond, expect_qc_fail, mock_qc_data_files, mocker): mock_abort_execution = mocker.patch("bclaw_runner.src.runner.qc_check.abort_execution") if fake1_cond is None: @@ -169,8 +119,9 @@ def test_do_checks(fake1_cond, fake2_cond, expect, mock_qc_data_files, mocker): }, ] - do_checks(spec) - if expect: - mock_abort_execution.assert_called_once() + if expect_qc_fail: + # todo: should probably check the contents of qcf.failures + with pytest.raises(QCFailure) as qcf: + do_checks(spec) else: - mock_abort_execution.assert_not_called() + do_checks(spec) diff --git a/bclaw_runner/tests/test_repo.py b/bclaw_runner/tests/test_repo.py index 4f2b85f..2ca89f5 100644 --- a/bclaw_runner/tests/test_repo.py +++ b/bclaw_runner/tests/test_repo.py @@ -7,7 +7,7 @@ import moto import pytest -from ..src.runner.repo import _is_glob, _expand_s3_glob, Repository +from ..src.runner.repo import _is_glob, _expand_s3_glob, Repository, SkipExecution TEST_BUCKET = "test-bucket" JOB_DATA = {"job": "data"} @@ -132,17 +132,20 @@ def test_s3_file_exists(monkeypatch, key, expect, mock_buckets): assert result == expect -@pytest.mark.parametrize("files, expect", [ +@pytest.mark.parametrize("files, expect_skip", [ (["file1", "file2", "subdir/file3"], True), (["file1", "file99", "subdir/file3"], False), (["file1", "file*", "subdir/file3"], False), ([], True), ]) -def test_files_exist(monkeypatch, mock_buckets, files, expect): +def test_check_files_exist(monkeypatch, mock_buckets, files, expect_skip): monkeypatch.setenv("BC_STEP_NAME", "test_step") repo = Repository(f"s3://{TEST_BUCKET}/repo/path") - result = repo.files_exist(files) - assert result == expect + if expect_skip: + with pytest.raises(SkipExecution): + repo.check_files_exist(files) + else: + repo.check_files_exist(files) def test_inputerator(monkeypatch, mock_buckets): @@ -440,15 +443,18 @@ def test_upload_outputs_empty_outputs(monkeypatch, mock_buckets): repo.upload_outputs(file_spec) -@pytest.mark.parametrize("step_name, expect", [ +@pytest.mark.parametrize("step_name, expect_skip", [ ("test_step", True), ("non_step", False), ]) -def test_check_for_previous_run(monkeypatch, mock_buckets, step_name, expect): +def test_check_for_previous_run(monkeypatch, mock_buckets, step_name, expect_skip): monkeypatch.setenv("BC_STEP_NAME", step_name) repo = Repository(f"s3://{TEST_BUCKET}/repo/path") - result = repo.check_for_previous_run() - assert result == expect + if expect_skip: + with pytest.raises(SkipExecution): + repo.check_for_previous_run() + else: + repo.check_for_previous_run() def failing_thing(*args, **kwargs): @@ -459,21 +465,19 @@ def test_check_for_previous_run_fail(monkeypatch, mock_buckets, caplog): monkeypatch.setenv("BC_STEP_NAME", "nothing") repo = Repository("s3://non_bucket/repo/path") monkeypatch.setattr(repo, "_s3_file_exists", failing_thing) - result = repo.check_for_previous_run() + repo.check_for_previous_run() assert "unable to query previous run status" in caplog.text - assert result == False -@pytest.mark.parametrize("step_name, expect", [ - ("test_step", True), - ("un_step", False), +@pytest.mark.parametrize("step_name", [ + "test_step", + "un_step", ]) -def test_clear_run_status(monkeypatch, mock_buckets, step_name, expect): +def test_clear_run_status(monkeypatch, mock_buckets, step_name): monkeypatch.setenv("BC_STEP_NAME", step_name) repo = Repository(f"s3://{TEST_BUCKET}/repo/path") - assert repo.check_for_previous_run() == expect repo.clear_run_status() - assert repo.check_for_previous_run() is False + repo.check_for_previous_run() def test_clear_run_status_fail(monkeypatch, mock_buckets, caplog): @@ -487,9 +491,10 @@ def test_clear_run_status_fail(monkeypatch, mock_buckets, caplog): def test_put_run_status(monkeypatch, mock_buckets): monkeypatch.setenv("BC_STEP_NAME", "test_step_two") repo = Repository(f"s3://{TEST_BUCKET}/repo/path") - assert repo.check_for_previous_run() is False + repo.check_for_previous_run() repo.put_run_status() - assert repo.check_for_previous_run() is True + with pytest.raises(SkipExecution): + repo.check_for_previous_run() def test_put_run_status_fail(monkeypatch, mock_buckets, caplog): diff --git a/bclaw_runner/tests/test_runner_main.py b/bclaw_runner/tests/test_runner_main.py index 9bb6c40..8297c4c 100644 --- a/bclaw_runner/tests/test_runner_main.py +++ b/bclaw_runner/tests/test_runner_main.py @@ -286,6 +286,7 @@ def test_main_skip(monkeypatch, tmp_path, mock_bucket, skip, expect): monkeypatch.setenv("BC_SCRATCH_PATH", str(tmp_path)) monkeypatch.setattr(runner.workspace, "run_child_container", fake_container) + # note: this only tests skip = output with empty outputs references = {} inputs = {} outputs = {} diff --git a/bclaw_runner/tests/test_workspace.py b/bclaw_runner/tests/test_workspace.py index b24606e..83e0543 100644 --- a/bclaw_runner/tests/test_workspace.py +++ b/bclaw_runner/tests/test_workspace.py @@ -5,7 +5,7 @@ import pytest from ..src import runner -from ..src.runner.workspace import workspace, write_job_data_file, run_commands, run_commands +from ..src.runner.workspace import workspace, write_job_data_file, run_commands, run_commands, UserCommandsFailed def test_workspace(monkeypatch, tmp_path): @@ -43,7 +43,7 @@ def fake_container(image_tag: str, command: str, work_dir: str, job_data_file) - return response.returncode -def test_run_commands(tmp_path, monkeypatch): +def test_run_commands(tmp_path, monkeypatch, caplog): monkeypatch.setattr(runner.workspace, "run_child_container", fake_container) f = tmp_path / "test_success.out" @@ -56,7 +56,7 @@ def test_run_commands(tmp_path, monkeypatch): os.chdir(tmp_path) response = run_commands("fake/image:tag", commands, tmp_path, "fake/job/data/file.json", "sh") - assert response == 0 + assert "command block succeeded" in caplog.text assert f.exists() with f.open() as fp: lines = fp.readlines() @@ -74,8 +74,9 @@ def test_exit_on_command_fail1(tmp_path, monkeypatch): ] os.chdir(tmp_path) - response = run_commands("fake/image:tag", commands, tmp_path, "fake/job/data/file.json", "sh") - assert response != 0 + with pytest.raises(UserCommandsFailed) as ucf: + run_commands("fake/image:tag", commands, tmp_path, "fake/job/data/file.json", "sh") + assert ucf.value.exit_code != 0 assert f.exists() with f.open() as fp: @@ -94,8 +95,9 @@ def test_exit_on_undef_var1(tmp_path, monkeypatch): ] os.chdir(tmp_path) - response = run_commands("fake/image:tag", commands, tmp_path, "fake/job/data/file.json", "sh") - assert response != 0 + with pytest.raises(UserCommandsFailed) as ucf: + run_commands("fake/image:tag", commands, tmp_path, "fake/job/data/file.json", "sh") + assert ucf.value.exit_code != 0 assert f.exists() with f.open() as fp: diff --git a/cloudformation/bc_core.yaml b/cloudformation/bc_core.yaml index c13f43b..80bc426 100644 --- a/cloudformation/bc_core.yaml +++ b/cloudformation/bc_core.yaml @@ -302,7 +302,6 @@ Resources: LOGGING_DESTINATION_ARN: !Ref LoggingDestinationArn ON_DEMAND_GPU_QUEUE_ARN: !GetAtt OnDemandGpuQueueStack.Outputs.BatchQueueArn ON_DEMAND_QUEUE_ARN: !GetAtt OnDemandQueueStack.Outputs.BatchQueueArn -# QC_CHECKER_LAMBDA_ARN: !Ref QCCheckerLambda.Version RUNNER_REPO_URI: !GetAtt RunnerRepo.RepositoryUri SCATTER_INIT_LAMBDA_ARN: !Ref ScatterInitLambda.Version SCATTER_LAMBDA_ARN: !Ref ScatterLambda.Version @@ -385,7 +384,7 @@ Resources: Handler: register.lambda_handler Runtime: python3.10 CodeUri: lambda/src/job_def - # do not enable AutoPublishAlias + # do not enable AutoPublishAlias for custom resource lambdas # https://advancedweb.hu/custom-resources-in-cloudformation-templates-lessons-learned/#cant-change-the-servicetoken Environment: Variables: