Skip to content

Commit

Permalink
Add Invocation Context Support to MultiThreadedExecutor (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb authored Feb 8, 2024
1 parent e4ecbf5 commit a37ea5b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240208-100709.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Add Invocation Context Support to MultiThreadedExecutor
time: 2024-02-08T10:07:09.584747-05:00
custom:
Author: peterallenwebb
Issue: "75"
6 changes: 3 additions & 3 deletions dbt_common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def env_secrets(self) -> List[str]:
_INVOCATION_CONTEXT_VAR: ContextVar[InvocationContext] = ContextVar("DBT_INVOCATION_CONTEXT_VAR")


def _reliably_get_invocation_var() -> ContextVar:
def reliably_get_invocation_var() -> ContextVar:
invocation_var: Optional[ContextVar] = next(
(cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None
)
Expand All @@ -38,11 +38,11 @@ def _reliably_get_invocation_var() -> ContextVar:


def set_invocation_context(env: Mapping[str, str]) -> None:
invocation_var = _reliably_get_invocation_var()
invocation_var = reliably_get_invocation_var()
invocation_var.set(InvocationContext(env))


def get_invocation_context() -> InvocationContext:
invocation_var = _reliably_get_invocation_var()
invocation_var = reliably_get_invocation_var()
ctx = invocation_var.get()
return ctx
14 changes: 13 additions & 1 deletion dbt_common/utils/executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import concurrent.futures
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Protocol, Optional

from dbt_common.context import get_invocation_context, reliably_get_invocation_var


class ConnectingExecutor(concurrent.futures.Executor):
def submit_connected(self, adapter, conn_name, func, *args, **kwargs):
Expand Down Expand Up @@ -60,8 +63,17 @@ class HasThreadingConfig(Protocol):
threads: Optional[int]


def _thread_initializer(invocation_context: ContextVar) -> None:
invocation_var = reliably_get_invocation_var()
invocation_var.set(invocation_context)


def executor(config: HasThreadingConfig) -> ConnectingExecutor:
if config.args.single_threaded:
return SingleThreadedExecutor()
else:
return MultiThreadedExecutor(max_workers=config.threads)
return MultiThreadedExecutor(
max_workers=config.threads,
initializer=_thread_initializer,
initargs=(get_invocation_context(),),
)

0 comments on commit a37ea5b

Please sign in to comment.