From ddaf7b2165f5e0f986072a902184d619869ef661 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 17 Jan 2025 22:34:51 -0500 Subject: [PATCH 1/3] Failing test --- tests/worker/test_interceptor.py | 39 ++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/worker/test_interceptor.py b/tests/worker/test_interceptor.py index a9e726d5..95b2c327 100644 --- a/tests/worker/test_interceptor.py +++ b/tests/worker/test_interceptor.py @@ -283,3 +283,42 @@ def pop_trace(name: str, filter: Optional[Callable[[Any], bool]] = None) -> Any: # Confirm no unexpected traces assert not interceptor_traces + + +class WorkflowInstanceAccessInterceptor(Interceptor): + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> Optional[Type[WorkflowInboundInterceptor]]: + return WorkflowInstanceAccessInboundInterceptor + + +class WorkflowInstanceAccessInboundInterceptor(WorkflowInboundInterceptor): + async def execute_workflow(self, input: ExecuteWorkflowInput) -> int: + # Return integer difference between ids of workflow instance obtained from workflow run method and + # from workflow.instance(). They should be the same, so the difference should be 0. + id_from_workflow_run_method = await super().execute_workflow(input) + id_from_workflow_instance_api = id(workflow.instance()) + return id_from_workflow_run_method - id_from_workflow_instance_api + + +@workflow.defn +class WorkflowInstanceAccessWorkflow: + @workflow.run + async def run(self) -> int: + return id(self) + + +async def test_workflow_instance_access_from_interceptor(client: Client): + task_queue = f"task_queue_{uuid.uuid4()}" + async with Worker( + client, + task_queue=task_queue, + workflows=[WorkflowInstanceAccessWorkflow], + interceptors=[WorkflowInstanceAccessInterceptor()], + ): + difference = await client.execute_workflow( + WorkflowInstanceAccessWorkflow.run, + id=f"workflow_{uuid.uuid4()}", + task_queue=task_queue, + ) + assert difference == 0 From def451b7d2ffe2b7225037955772e17a2dd63572 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Fri, 17 Jan 2025 22:44:12 -0500 Subject: [PATCH 2/3] Implement --- temporalio/worker/_workflow_instance.py | 3 +++ temporalio/workflow.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index c33796ef..5f24d684 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1023,6 +1023,9 @@ def workflow_get_update_validator(self, name: Optional[str]) -> Optional[Callabl def workflow_info(self) -> temporalio.workflow.Info: return self._outbound.info() + def workflow_instance(self) -> Any: + return self._object + def workflow_is_continue_as_new_suggested(self) -> bool: return self._continue_as_new_suggested diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 6351bace..bf57e928 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -625,6 +625,9 @@ def workflow_get_update_validator( @abstractmethod def workflow_info(self) -> Info: ... + @abstractmethod + def workflow_instance(self) -> Any: ... + @abstractmethod def workflow_is_continue_as_new_suggested(self) -> bool: ... @@ -818,6 +821,15 @@ def info() -> Info: return _Runtime.current().workflow_info() +def instance() -> Any: + """Current workflow's instance. + + Returns: + The currently running workflow instance. + """ + return _Runtime.current().workflow_instance() + + def memo() -> Mapping[str, Any]: """Current workflow's memo values, converted without type hints. From 87e80fe8e7cbac5db64982b3e46a89da9942ff39 Mon Sep 17 00:00:00 2001 From: Dan Davison Date: Thu, 23 Jan 2025 08:56:35 -0500 Subject: [PATCH 3/3] Add assertion that instance is non-None --- tests/worker/test_interceptor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/worker/test_interceptor.py b/tests/worker/test_interceptor.py index 95b2c327..1392cd35 100644 --- a/tests/worker/test_interceptor.py +++ b/tests/worker/test_interceptor.py @@ -296,8 +296,10 @@ class WorkflowInstanceAccessInboundInterceptor(WorkflowInboundInterceptor): async def execute_workflow(self, input: ExecuteWorkflowInput) -> int: # Return integer difference between ids of workflow instance obtained from workflow run method and # from workflow.instance(). They should be the same, so the difference should be 0. + from_workflow_instance_api = workflow.instance() + assert from_workflow_instance_api is not None + id_from_workflow_instance_api = id(from_workflow_instance_api) id_from_workflow_run_method = await super().execute_workflow(input) - id_from_workflow_instance_api = id(workflow.instance()) return id_from_workflow_run_method - id_from_workflow_instance_api