From 391d2f8e4089a543c28eb36d39b43c6116a8a4ff Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Thu, 4 Apr 2024 15:00:17 -0400 Subject: [PATCH] Add "Flags" to args parameter in tasks --- core/dbt/cli/flags.py | 3 +++ core/dbt/task/base.py | 13 +++++++------ core/dbt/task/build.py | 3 ++- core/dbt/task/debug.py | 3 ++- core/dbt/task/list.py | 3 ++- core/dbt/task/retry.py | 2 +- core/dbt/task/run.py | 3 ++- core/dbt/task/runnable.py | 3 ++- 8 files changed, 21 insertions(+), 12 deletions(-) diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index 92a6cbc5a28..d40e0b3d28f 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -352,6 +352,9 @@ def set_common_global_flags(self): if getattr(self, "MACRO_DEBUGGING", None) is not None: jinja.MACRO_DEBUGGING = getattr(self, "MACRO_DEBUGGING") + def __getattr__(self, name: str) -> Any: + return super().__get_attribute__(name) # type: ignore + CommandParams = List[str] diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index 2ca6cb2e978..b000c70755e 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -12,6 +12,7 @@ import dbt_common.exceptions.base import dbt.exceptions from dbt import tracking +from dbt.cli.flags import Flags from dbt.config import RuntimeConfig, Project from dbt.config.profile import read_profile from dbt.constants import DBT_PROJECT_FILE_NAME @@ -51,7 +52,7 @@ class NoneConfig: @classmethod - def from_args(cls, args): + def from_args(cls, args: Flags): return None @@ -73,13 +74,13 @@ def read_profiles(profiles_dir=None): class BaseTask(metaclass=ABCMeta): ConfigType: Union[Type[NoneConfig], Type[Project]] = NoneConfig - def __init__(self, args, config, project=None) -> None: + def __init__(self, args: Flags, config, project=None) -> None: self.args = args self.config = config self.project = config if isinstance(config, Project) else project @classmethod - def pre_init_hook(cls, args): + def pre_init_hook(cls, args: Flags): """A hook called before the task is initialized.""" if args.log_format == "json": log_manager.format_json() @@ -155,7 +156,7 @@ def move_to_nearest_project_dir(project_dir: Optional[str]) -> Path: class ConfiguredTask(BaseTask): ConfigType = RuntimeConfig - def __init__(self, args, config, manifest: Optional[Manifest] = None) -> None: + def __init__(self, args: Flags, config, manifest: Optional[Manifest] = None) -> None: super().__init__(args, config) self.graph: Optional[Graph] = None self.manifest = manifest @@ -174,7 +175,7 @@ def compile_manifest(self): dbt.tracking.track_runnable_timing({"graph_compilation_elapsed": compile_time}) @classmethod - def from_args(cls, args, *pargs, **kwargs): + def from_args(cls, args: Flags, *pargs, **kwargs): move_to_nearest_project_dir(args.project_dir) return super().from_args(args, *pargs, **kwargs) @@ -487,7 +488,7 @@ def do_skip(self, cause=None): def resource_types_from_args( - args, all_resource_values: Set[NodeType], default_resource_values: Set[NodeType] + args: Flags, all_resource_values: Set[NodeType], default_resource_values: Set[NodeType] ) -> Set[NodeType]: if not args.resource_types: diff --git a/core/dbt/task/build.py b/core/dbt/task/build.py index 5d3a42b3b9f..6c6426ceefd 100644 --- a/core/dbt/task/build.py +++ b/core/dbt/task/build.py @@ -8,6 +8,7 @@ from dbt.artifacts.schemas.results import NodeStatus, RunStatus from dbt.artifacts.schemas.run import RunResult +from dbt.cli.flags import Flags from dbt.graph import ResourceTypeSelector, GraphQueue, Graph from dbt.node_types import NodeType from dbt.task.test import TestSelector @@ -74,7 +75,7 @@ class BuildTask(RunTask): } ALL_RESOURCE_VALUES = frozenset({x for x in RUNNER_MAP.keys()}) - def __init__(self, args, config, manifest) -> None: + def __init__(self, args: Flags, config, manifest) -> None: super().__init__(args, config, manifest) self.selected_unit_tests: Set = set() self.model_to_unit_test_map: Dict[str, List] = {} diff --git a/core/dbt/task/debug.py b/core/dbt/task/debug.py index ea0f636bd6c..51eabaea13e 100644 --- a/core/dbt/task/debug.py +++ b/core/dbt/task/debug.py @@ -19,6 +19,7 @@ import dbt.exceptions import dbt_common.exceptions from dbt.adapters.factory import get_adapter, register_adapter +from dbt.cli.flags import Flags from dbt.config import PartialProject, Project, Profile from dbt.config.renderer import DbtProjectYamlRenderer, ProfileRenderer from dbt.artifacts.schemas.results import RunStatus @@ -77,7 +78,7 @@ class DebugRunStatus(Flag): class DebugTask(BaseTask): - def __init__(self, args, config) -> None: + def __init__(self, args: Flags, config) -> None: super().__init__(args, config) self.profiles_dir = args.PROFILES_DIR self.profile_path = os.path.join(self.profiles_dir, "profiles.yml") diff --git a/core/dbt/task/list.py b/core/dbt/task/list.py index ff6fea3447e..3ea4b1e2d27 100644 --- a/core/dbt/task/list.py +++ b/core/dbt/task/list.py @@ -8,6 +8,7 @@ SemanticModel, UnitTestDefinition, ) +from dbt.cli.flags import Flags from dbt.flags import get_flags from dbt.graph import ResourceTypeSelector from dbt.task.base import resource_types_from_args @@ -54,7 +55,7 @@ class ListTask(GraphRunnableTask): ) ) - def __init__(self, args, config, manifest) -> None: + def __init__(self, args: Flags, config, manifest) -> None: super().__init__(args, config, manifest) if self.args.models: if self.args.select: diff --git a/core/dbt/task/retry.py b/core/dbt/task/retry.py index 9aadf9ead97..70dea9756f2 100644 --- a/core/dbt/task/retry.py +++ b/core/dbt/task/retry.py @@ -63,7 +63,7 @@ class RetryTask(ConfiguredTask): - def __init__(self, args, config) -> None: + def __init__(self, args: Flags, config) -> None: # load previous run results state_path = args.state or config.target_path self.previous_results = load_result_state( diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 2a6031ad45c..4fc6ebc64cf 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -15,6 +15,7 @@ from dbt import tracking from dbt import utils from dbt.adapters.base import BaseRelation +from dbt.cli.flags import Flags from dbt.clients.jinja import MacroGenerator from dbt.context.providers import generate_runtime_model_context from dbt.contracts.graph.nodes import HookNode, ResultNode @@ -305,7 +306,7 @@ def execute(self, model, manifest): class RunTask(CompileTask): - def __init__(self, args, config, manifest) -> None: + def __init__(self, args: Flags, config, manifest) -> None: super().__init__(args, config, manifest) self.ran_hooks: List[HookNode] = [] self._total_executed = 0 diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index cff69e23e80..746c08bf656 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -15,6 +15,7 @@ import dbt.utils from dbt.adapters.base import BaseRelation from dbt.adapters.factory import get_adapter +from dbt.cli.flags import Flags from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import ResultNode from dbt.artifacts.schemas.results import NodeStatus, RunningStatus, RunStatus, BaseResult @@ -65,7 +66,7 @@ class GraphRunnableTask(ConfiguredTask): MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error] - def __init__(self, args, config, manifest) -> None: + def __init__(self, args: Flags, config, manifest) -> None: super().__init__(args, config, manifest) self._flattened_nodes: Optional[List[ResultNode]] = None self._raise_next_tick: Optional[DbtRuntimeError] = None