Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workflow.instance() API for obtaining current workflow instance #739

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,9 @@ def workflow_get_update_validator(self, name: Optional[str]) -> Optional[Callabl
def workflow_info(self) -> temporalio.workflow.Info:
return self._outbound.info()

def workflow_instance(self) -> Any:
return self._object

def workflow_is_continue_as_new_suggested(self) -> bool:
return self._continue_as_new_suggested

Expand Down
12 changes: 12 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@ def workflow_get_update_validator(
@abstractmethod
def workflow_info(self) -> Info: ...

@abstractmethod
def workflow_instance(self) -> Any: ...

@abstractmethod
def workflow_is_continue_as_new_suggested(self) -> bool: ...

Expand Down Expand Up @@ -818,6 +821,15 @@ def info() -> Info:
return _Runtime.current().workflow_info()


def instance() -> Any:
"""Current workflow's instance.

Returns:
The currently running workflow instance.
"""
return _Runtime.current().workflow_instance()


def memo() -> Mapping[str, Any]:
"""Current workflow's memo values, converted without type hints.

Expand Down
41 changes: 41 additions & 0 deletions tests/worker/test_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,44 @@ def pop_trace(name: str, filter: Optional[Callable[[Any], bool]] = None) -> Any:

# Confirm no unexpected traces
assert not interceptor_traces


class WorkflowInstanceAccessInterceptor(Interceptor):
def workflow_interceptor_class(
self, input: WorkflowInterceptorClassInput
) -> Optional[Type[WorkflowInboundInterceptor]]:
return WorkflowInstanceAccessInboundInterceptor


class WorkflowInstanceAccessInboundInterceptor(WorkflowInboundInterceptor):
async def execute_workflow(self, input: ExecuteWorkflowInput) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you confirm that the workflow.instance() is non-None when this interceptor starts? I read through the code and I suspect it is, just wanted to confirm.

Copy link
Contributor Author

@dandavison dandavison Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this do what you're suggesting? 0b8a0bd

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks!

# Return integer difference between ids of workflow instance obtained from workflow run method and
# from workflow.instance(). They should be the same, so the difference should be 0.
from_workflow_instance_api = workflow.instance()
assert from_workflow_instance_api is not None
id_from_workflow_instance_api = id(from_workflow_instance_api)
id_from_workflow_run_method = await super().execute_workflow(input)
return id_from_workflow_run_method - id_from_workflow_instance_api


@workflow.defn
class WorkflowInstanceAccessWorkflow:
@workflow.run
async def run(self) -> int:
return id(self)


async def test_workflow_instance_access_from_interceptor(client: Client):
task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client,
task_queue=task_queue,
workflows=[WorkflowInstanceAccessWorkflow],
interceptors=[WorkflowInstanceAccessInterceptor()],
):
difference = await client.execute_workflow(
WorkflowInstanceAccessWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)
assert difference == 0
Loading