Skip to content
This repository has been archived by the owner on Jun 17, 2024. It is now read-only.

Commit

Permalink
update tests (#93)
Browse files Browse the repository at this point in the history
Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
samhita-alla authored Nov 8, 2023
1 parent 76b0a51 commit 71567e6
Showing 1 changed file with 39 additions and 19 deletions.
58 changes: 39 additions & 19 deletions boilerplate/flyte/end2end/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,17 @@
("basics.named_outputs.simple_wf_with_named_outputs", {}),
# # Getting a 403 for the wikipedia image
# # ("basics.reference_task.wf", {}),
("data_types_and_io.custom_objects.wf", {"x": 10, "y": 20}),
("data_types_and_io.dataclass.dataclass_wf", {"x": 10, "y": 20}),
# Enums are not supported in flyteremote
# ("type_system.enums.enum_wf", {"c": "red"}),
("data_types_and_io.schema.df_wf", {"a": 42}),
("data_types_and_io.typed_schema.wf", {}),
("data_types_and_io.structured_dataset.simple_sd_wf", {"a": 42}),
# ("my.imperative.workflow.example", {"in1": "hello", "in2": "foo"}),
],
"integrations-k8s-spark": [
("k8s_spark_plugin.pyspark_pi.my_spark", {"triggered_date": datetime.datetime.now()}),
(
"k8s_spark_plugin.pyspark_pi.my_spark",
{"triggered_date": datetime.datetime.now()},
),
],
"integrations-kfpytorch": [
("kfpytorch_plugin.pytorch_mnist.pytorch_training_wf", {}),
Expand Down Expand Up @@ -90,25 +92,29 @@


def execute_workflow(
remote: FlyteRemote,
version,
workflow_name,
inputs,
cluster_pool_name: Optional[str] = None,
remote: FlyteRemote,
version,
workflow_name,
inputs,
cluster_pool_name: Optional[str] = None,
):
print(f"Fetching workflow={workflow_name} and version={version}")
wf = remote.fetch_workflow(name=workflow_name, version=version)
return remote.execute(wf, inputs=inputs, wait=False, cluster_pool=cluster_pool_name)


def executions_finished(executions_by_wfgroup: Dict[str, List[FlyteWorkflowExecution]]) -> bool:
def executions_finished(
executions_by_wfgroup: Dict[str, List[FlyteWorkflowExecution]]
) -> bool:
for executions in executions_by_wfgroup.values():
if not all([execution.is_done for execution in executions]):
return False
return True


def sync_executions(remote: FlyteRemote, executions_by_wfgroup: Dict[str, List[FlyteWorkflowExecution]]):
def sync_executions(
remote: FlyteRemote, executions_by_wfgroup: Dict[str, List[FlyteWorkflowExecution]]
):
try:
for executions in executions_by_wfgroup.values():
for execution in executions:
Expand Down Expand Up @@ -148,9 +154,13 @@ def schedule_workflow_groups(

# Wait for all executions to finish
attempt = 0
while attempt == 0 or (not executions_finished(executions_by_wfgroup) and attempt < MAX_ATTEMPTS):
while attempt == 0 or (
not executions_finished(executions_by_wfgroup) and attempt < MAX_ATTEMPTS
):
attempt += 1
print(f"Not all executions finished yet. Sleeping for some time, will check again in {WAIT_TIME}s")
print(
f"Not all executions finished yet. Sleeping for some time, will check again in {WAIT_TIME}s"
)
time.sleep(WAIT_TIME)
sync_executions(remote, executions_by_wfgroup)

Expand All @@ -166,9 +176,13 @@ def schedule_workflow_groups(
if len(non_succeeded_executions) != 0:
print(f"Failed executions for {wf_group}:")
for execution in non_succeeded_executions:
print(f" workflow={execution.spec.launch_plan.name}, execution_id={execution.id.name}")
print(
f" workflow={execution.spec.launch_plan.name}, execution_id={execution.id.name}"
)
if terminate_workflow_on_failure:
remote.terminate(execution, "aborting execution scheduled in functional test")
remote.terminate(
execution, "aborting execution scheduled in functional test"
)
# A workflow group succeeds iff all of its executions succeed
results[wf_group] = len(non_succeeded_executions) == 0
return results
Expand Down Expand Up @@ -200,15 +214,20 @@ def run(
# For a given release tag and priority, this function filters the workflow groups from the flytesnacks
# manifest file. For example, for the release tag "v0.2.224" and the priority "P0" it returns [ "core" ].
manifest_url = (
"https://raw.githubusercontent.com/flyteorg/flytesnacks/" f"{flytesnacks_release_tag}/flyte_tests_manifest.json"
"https://raw.githubusercontent.com/flyteorg/flytesnacks/"
f"{flytesnacks_release_tag}/flyte_tests_manifest.json"
)
r = requests.get(manifest_url)
parsed_manifest = r.json()
workflow_groups = []
workflow_groups = (
["lite"]
if "lite" in priorities
else [group["name"] for group in parsed_manifest if group["priority"] in priorities]
else [
group["name"]
for group in parsed_manifest
if group["priority"] in priorities
]
)

results = []
Expand Down Expand Up @@ -281,7 +300,7 @@ def run(
default="flytesnacks",
type=str,
is_flag=False,
help="Name of project to run functional tests on"
help="Name of project to run functional tests on",
)
@click.option(
"--test_project_domain",
Expand Down Expand Up @@ -309,7 +328,8 @@ def cli(
print(f"return_non_zero_on_failure={return_non_zero_on_failure}")
results = run(
flytesnacks_release_tag,
priorities, config_file,
priorities,
config_file,
terminate_workflow_on_failure,
test_project_name,
test_project_domain,
Expand Down

0 comments on commit 71567e6

Please sign in to comment.