Skip to content

Commit

Permalink
refactor runner main
Browse files Browse the repository at this point in the history
  • Loading branch information
jack-e-tabaska committed Jun 10, 2024
1 parent 62b5365 commit 5ad8ba1
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 127 deletions.
41 changes: 10 additions & 31 deletions bclaw_runner/src/runner/qc_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,26 @@

import boto3

logger = logging.getLogger(__name__)

logger = logging.getLogger(__name__)

# def run_qc_checks(checks: list) -> None:
# if checks:
# logger.info("starting QC checks")
# for item in checks:
# qc_file = item["qc_result_file"]
# logger.info(f"{qc_file=}")
#
# with open(qc_file) as fp:
# qc_data = json.load(fp)
#
# for qc_expression in item["stop_early_if"]:
# run_qc_check(qc_data, qc_expression)
#
# logger.info("QC checks finished")
# else:
# logger.info("no QC checks requested")


# def run_qc_check(qc_data: dict, qc_expression: str) -> None:
# result = eval(qc_expression, globals(), qc_data)
# if result:
# logger.warning(f"failed QC check: {qc_expression}; aborting workflow execution")
# abort_execution(qc_expression)
# else:
# logger.info(f"passed QC check: {qc_expression}")
class QCFailure(Exception):
def __init__(self, message: str, failures: list):
super().__init__(message)
self.failures = failures


def abort_execution(failed_expressions: list) -> None:
logger.warning("aborting workflow execution")

region = os.environ["AWS_DEFAULT_REGION"]
acct = os.environ["AWS_ACCOUNT_ID"]
wf_name = os.environ["BC_WORKFLOW_NAME"]
exec_id = os.environ["BC_EXECUTION_ID"]
step_name = os.environ["BC_STEP_NAME"]
execution_arn = f"arn:aws:states:{region}:{acct}:execution:{wf_name}:{exec_id}"

cause = "\n".join(["failed QC conditions:"] + failed_expressions)
cause = "failed QC conditions: " + "; ".join(failed_expressions)

sfn = boto3.client("stepfunctions")
sfn.stop_execution(
Expand Down Expand Up @@ -78,10 +59,8 @@ def run_all_qc_checks(checks: list) -> Generator[str, None, None]:
def do_checks(checks: list) -> None:
if checks:
logger.info("starting QC checks")
failures = list(run_all_qc_checks(checks))
if failures:
logger.warning(f"aborting workflow execution")
abort_execution(failures)
if (failures := list(run_all_qc_checks(checks))):
raise QCFailure("QC checks failed", failures)
logger.info("QC checks finished")
else:
logger.info("no QC checks requested")
51 changes: 49 additions & 2 deletions bclaw_runner/src/runner/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
logger = logging.getLogger(__name__)


class SkipExecution(Exception):
pass


def _file_metadata():
ret = {"execution_id": os.environ.get("BC_EXECUTION_ID", "undefined")}
return ret
Expand Down Expand Up @@ -80,7 +84,29 @@ def _s3_file_exists(self, key: str) -> bool:
else:
raise

def files_exist(self, filenames: List[str]) -> bool:

def check_files_exist(self, filenames: List[str]) -> None:
"""
Raises SkipExecution if this step has been run before
"""
# this is for backward compatibility. Note that if you have a step that produces
# no outputs (i.e. being run for side effects only), it will always be skipped
# if run with skip_if_files_exist
# if len(filenames) == 0:
# raise SkipExecution("found output files; skipping")

# there's no way to know if all the files included in a glob were uploaded in
# a previous run, so always rerun to be safe
if any(_is_glob(f) for f in filenames):
return

# note: all([]) = True
keys = (self.qualify(os.path.basename(f)) for f in filenames)
if all(self._s3_file_exists(k) for k in keys):
raise SkipExecution("found output files; skipping")


def files_exist0(self, filenames: List[str]) -> bool:
# this is for backward compatibility. Note that if you have a step that produces
# no outputs (i.e. being run for side effects only), it will always be skipped
# if run with skip_if_files_exist
Expand All @@ -96,6 +122,7 @@ def files_exist(self, filenames: List[str]) -> bool:
ret = all(self._s3_file_exists(k) for k in keys)
return ret


def _inputerator(self, input_spec: Dict[str, str]) -> Generator[str, None, None]:
for symbolic_name, filename in input_spec.items():
optional = symbolic_name.endswith("?")
Expand Down Expand Up @@ -170,7 +197,27 @@ def upload_outputs(self, output_spec: Dict[str, str]) -> None:
result = list(executor.map(self._upload_that, self._outputerator(output_spec.values())))
logger.info(f"{len(result)} files uploaded")

def check_for_previous_run(self) -> bool:
def check_for_previous_run(self) -> None:
"""
Raises SkipExecution if this step has been run before
"""
try:
result = self._s3_file_exists(self.qualify(self.run_status_obj))
except Exception:
logger.warning("unable to query previous run status, assuming none")
else:
if result:
raise SkipExecution("found previous run; skipping")
# pharque!!!
# try:
# if self._s3_file_exists(self.qualify(self.run_status_obj)):
# raise SkipExecution("found previous run; skipping")
# # pharque!!!
# except Exception:
# logger.warning("unable to query previous run status, assuming none")


def check_for_previous_run0(self) -> bool:
"""
Returns:
True if this step has been run before
Expand Down
83 changes: 78 additions & 5 deletions bclaw_runner/src/runner/runner_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@
from .cache import get_reference_inputs
from .custom_logs import LOGGING_CONFIG
from .string_subs import substitute, substitute_image_tag
from .qc_check import do_checks
from .repo import Repository
from .qc_check import do_checks, abort_execution, QCFailure
from .repo import Repository, SkipExecution
from .tagging import tag_this_instance
from .termination import spot_termination_checker
from .workspace import workspace, write_job_data_file, run_commands
from .workspace import workspace, write_job_data_file, run_commands, UserCommandsFailed


logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger(__name__)


def main(commands: List[str],
image: str,
inputs: Dict[str, str],
Expand All @@ -49,6 +48,80 @@ def main(commands: List[str],
repo_path: str,
shell: str,
skip: str) -> int:
exit_code = 0
try:
repo = Repository(repo_path)

if skip == "rerun":
repo.check_for_previous_run()
elif skip == "output":
repo.check_files_exist(list(outputs.values()))

repo.clear_run_status()

job_data_obj = repo.read_job_data()

jobby_commands = substitute(commands, job_data_obj)
jobby_inputs = substitute(inputs, job_data_obj)
jobby_outputs = substitute(outputs, job_data_obj)
jobby_references = substitute(references, job_data_obj)

jobby_image = substitute_image_tag(image, job_data_obj)

with workspace() as wrk:
# download references, link to workspace
local_references = get_reference_inputs(jobby_references)

# download inputs -> returns local filenames
local_inputs = repo.download_inputs(jobby_inputs)
local_outputs = jobby_outputs

subbed_commands = substitute(jobby_commands,
local_inputs |
local_outputs |
local_references)

local_job_data = write_job_data_file(job_data_obj, wrk)

try:
run_commands(jobby_image, subbed_commands, wrk, local_job_data, shell)
do_checks(qc)

finally:
repo.upload_outputs(jobby_outputs)

except UserCommandsFailed as uce:
exit_code = uce.exit_code
logger.error(str(uce))

except QCFailure as qcf:
logger.error(str(qcf))
abort_execution(qcf.failures)

except SkipExecution as se:
logger.info(str(se))
pass

except Exception as e:
logger.exception("bclaw_runner error: ")
exit_code = 255

else:
repo.put_run_status()
logger.info("runner finished")

return exit_code


def main0(commands: List[str],
image: str,
inputs: Dict[str, str],
outputs: Dict[str, str],
qc: List[dict],
references: Dict[str, str],
repo_path: str,
shell: str,
skip: str) -> int:

repo = Repository(repo_path)

Expand Down Expand Up @@ -126,7 +199,7 @@ def cli() -> int:
with spot_termination_checker():
args = docopt(__doc__, version=os.environ["BC_VERSION"])

logger.info(f"{args = }")
logger.info(f"{args=}")

commands = json.loads(args["--cmd"])
image = args["--image"]
Expand Down
15 changes: 11 additions & 4 deletions bclaw_runner/src/runner/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
logger = logging.getLogger(__name__)


class UserCommandsFailed(Exception):
def __init__(self, message, exit_code):
super().__init__(message)
self.exit_code = exit_code


@contextmanager
def workspace() -> str:
orig_path = os.getcwd()
Expand All @@ -34,7 +40,7 @@ def write_job_data_file(job_data: dict, dest_dir: str) -> str:
return fp.name


def run_commands(image_tag: str, commands: list, work_dir: str, job_data_file: str, shell_opt: str) -> int:
def run_commands(image_tag: str, commands: list, work_dir: str, job_data_file: str, shell_opt: str) -> None:
script_file = "_commands.sh"

with open(script_file, "w") as fp:
Expand All @@ -53,6 +59,7 @@ def run_commands(image_tag: str, commands: list, work_dir: str, job_data_file: s
os.chmod(script_file, 0o700)
command = f"{shell_cmd} {script_file}"

ret = run_child_container(image_tag, command, work_dir, job_data_file)

return ret
if (exit_code := run_child_container(image_tag, command, work_dir, job_data_file)) == 0:
logger.info("command block succeeded")
else:
raise UserCommandsFailed(f"command block failed with exit code {exit_code}", exit_code)
65 changes: 8 additions & 57 deletions bclaw_runner/tests/test_qc_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import moto
import pytest

from ..src.runner.qc_check import abort_execution, run_one_qc_check, run_all_qc_checks, do_checks
from ..src.runner.qc_check import abort_execution, run_one_qc_check, run_all_qc_checks, do_checks, QCFailure

QC_DATA_1 = {
"a": 1,
Expand Down Expand Up @@ -43,56 +43,6 @@ def mock_qc_data_files(mocker, request):
ret.side_effect = [qc_file1.return_value, qc_file2.return_value]


# def test_run_qc_checks(mock_qc_data_files, mocker):
# mock_run_qc_check = mocker.patch("bclaw_runner.src.runner.qc_check.run_qc_check")
#
# spec = [
# {
# "qc_result_file": "fake1",
# "stop_early_if": [
# "b == 1",
# ],
# },
# {
# "qc_result_file": "fake2",
# "stop_early_if": [
# "x == 99",
# "y == 98",
# ],
# },
# ]
# run_qc_checks(spec)
# result = mock_run_qc_check.call_args_list
# expect = [
# mocker.call(QC_DATA_1, "b == 1"),
# mocker.call(QC_DATA_2, "x == 99"),
# mocker.call(QC_DATA_2, "y == 98"),
# ]
# assert result == expect


# def test_run_qc_checks_empty(mock_qc_data_files, mocker):
# mock_run_qc_check = mocker.patch("bclaw_runner.src.runner.qc_check.run_qc_check")
# run_qc_checks([])
# mock_run_qc_check.assert_not_called()


# @pytest.mark.parametrize("expression, expect_abort", [
# ("x == 1", False),
# ("x > 1", True),
# ])
# def test_run_qc_check(expression, expect_abort, mocker):
# mock_abort_execution = mocker.patch("bclaw_runner.src.runner.qc_check.abort_execution")
#
# qc_data = {"x": 1}
# run_qc_check(qc_data, expression)
#
# if expect_abort:
# mock_abort_execution.assert_not_called()
# else:
# mock_abort_execution.assert_called()


def test_abort_execution(mock_state_machine, monkeypatch):
sfn = boto3.client("stepfunctions", region_name="us-east-1")
sfn_execution = sfn.start_execution(
Expand Down Expand Up @@ -145,14 +95,14 @@ def test_run_all_qc_checks(fake1_cond, fake2_cond, expect, mock_qc_data_files):
assert result == expect


@pytest.mark.parametrize("fake1_cond, fake2_cond, expect", [
@pytest.mark.parametrize("fake1_cond, fake2_cond, expect_qc_fail", [
(None, None, False), # no checks
(["a>1"], ["x<99"], False), # all pass
(["a>1", "b==2"], ["y<98"], True), # one fail
(["b==1"], ["x==99", "y==98"], True), # multi fail
(["a==1", "b==2"], ["x==99", "y==98"], True), # all fail
])
def test_do_checks(fake1_cond, fake2_cond, expect, mock_qc_data_files, mocker):
def test_do_checks(fake1_cond, fake2_cond, expect_qc_fail, mock_qc_data_files, mocker):
mock_abort_execution = mocker.patch("bclaw_runner.src.runner.qc_check.abort_execution")

if fake1_cond is None:
Expand All @@ -169,8 +119,9 @@ def test_do_checks(fake1_cond, fake2_cond, expect, mock_qc_data_files, mocker):
},
]

do_checks(spec)
if expect:
mock_abort_execution.assert_called_once()
if expect_qc_fail:
# todo: should probably check the contents of qcf.failures
with pytest.raises(QCFailure) as qcf:
do_checks(spec)
else:
mock_abort_execution.assert_not_called()
do_checks(spec)
Loading

0 comments on commit 5ad8ba1

Please sign in to comment.