diff --git a/datajob/stepfunctions/stepfunctions_workflow.py b/datajob/stepfunctions/stepfunctions_workflow.py index 487cc96..a38b015 100644 --- a/datajob/stepfunctions/stepfunctions_workflow.py +++ b/datajob/stepfunctions/stepfunctions_workflow.py @@ -218,6 +218,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback) -> None: """steps we have to do when exiting the context manager.""" + self.build_workflow() _set_workflow(None) logger.info(f"step functions workflow {self.unique_name} created") @@ -270,4 +271,3 @@ def _get_workflow(): def connect(self, other: DataJobBase) -> None: work_flow = _get_workflow() work_flow.directed_graph[other].add(self) - work_flow.build_workflow() diff --git a/datajob_tests/stepfunctions/test_stepfunctions_workflow.py b/datajob_tests/stepfunctions/test_stepfunctions_workflow.py index 5dd9e9e..7038bec 100644 --- a/datajob_tests/stepfunctions/test_stepfunctions_workflow.py +++ b/datajob_tests/stepfunctions/test_stepfunctions_workflow.py @@ -166,11 +166,16 @@ def test_update_stepfunctions_continuously(self): test written based on ticket https://github.com/vincentclaes/datajob/issues/116 + + Update: + this continous update causes duplicate states. removing it for now. + https://github.com/vincentclaes/datajob/pull/126 """ task1 = stepfunctions_workflow.task(SomeMockedClass("task1")) task2 = stepfunctions_workflow.task(SomeMockedClass("task2")) task3 = stepfunctions_workflow.task(SomeMockedClass("task3")) + task4 = stepfunctions_workflow.task(SomeMockedClass("task4")) djs = DataJobStack( scope=self.app, @@ -181,13 +186,12 @@ def test_update_stepfunctions_continuously(self): account="3098726354", ) with StepfunctionsWorkflow(djs, "some-name") as a_step_functions_workflow: - self.assertIsNone(a_step_functions_workflow.workflow) - self.assertIsNone(a_step_functions_workflow.chain_of_tasks) task1 >> task2 - self.assertIsNotNone(a_step_functions_workflow.workflow) - self.assertEqual(len(a_step_functions_workflow.chain_of_tasks.steps), 2) task2 >> task3 - self.assertEqual(len(a_step_functions_workflow.chain_of_tasks.steps), 3) + self.assertIsNone(a_step_functions_workflow.workflow) + self.assertIsNone(a_step_functions_workflow.chain_of_tasks) + self.assertIsNotNone(a_step_functions_workflow.workflow) + self.assertEqual(len(a_step_functions_workflow.chain_of_tasks.steps), 3) expected_workflow_definition = { "StartAt": "task1",