From a1f005789dd477ad165fcb0c9bdf0db1cec431c0 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Wed, 10 Apr 2024 09:30:53 -0400 Subject: [PATCH] [Tidy first] Task config type (#9874) --- core/dbt/cli/main.py | 3 +-- core/dbt/deps/resolver.py | 15 +++++------ core/dbt/task/base.py | 57 +++++++++++++++------------------------ core/dbt/task/build.py | 4 ++- core/dbt/task/clean.py | 7 +++++ core/dbt/task/debug.py | 11 ++------ core/dbt/task/deps.py | 7 ++--- core/dbt/task/list.py | 4 ++- core/dbt/task/retry.py | 2 +- core/dbt/task/run.py | 4 ++- core/dbt/task/runnable.py | 12 ++++++--- 11 files changed, 61 insertions(+), 65 deletions(-) diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index deff7d8d341..07a9de861a7 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -399,7 +399,6 @@ def debug(ctx, **kwargs): task = DebugTask( ctx.obj["flags"], - None, ) results = task.run() @@ -464,7 +463,7 @@ def init(ctx, **kwargs): """Initialize a new dbt project.""" from dbt.task.init import InitTask - task = InitTask(ctx.obj["flags"], None) + task = InitTask(ctx.obj["flags"]) results = task.run() success = task.interpret_results(results) diff --git a/core/dbt/deps/resolver.py b/core/dbt/deps/resolver.py index 3d74bac4980..5f890109b0e 100644 --- a/core/dbt/deps/resolver.py +++ b/core/dbt/deps/resolver.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Dict, List, NoReturn, Union, Type, Iterator, Set, Any +from typing import Dict, List, NoReturn, Type, Iterator, Set, Any from dbt.exceptions import ( DuplicateDependencyToRootError, @@ -17,14 +17,13 @@ from dbt.deps.registry import RegistryUnpinnedPackage from dbt.contracts.project import ( + PackageSpec, LocalPackage, TarballPackage, GitPackage, RegistryPackage, ) -PackageContract = Union[LocalPackage, TarballPackage, GitPackage, RegistryPackage] - @dataclass class PackageListing: @@ -68,7 +67,7 @@ def incorporate(self, package: UnpinnedPackage): else: self.packages[key] = package - def update_from(self, src: List[PackageContract]) -> None: + def update_from(self, src: List[PackageSpec]) -> None: pkg: UnpinnedPackage for contract in src: if isinstance(contract, LocalPackage): @@ -84,9 +83,7 @@ def update_from(self, src: List[PackageContract]) -> None: self.incorporate(pkg) @classmethod - def from_contracts( - cls: Type["PackageListing"], src: List[PackageContract] - ) -> "PackageListing": + def from_contracts(cls: Type["PackageListing"], src: List[PackageSpec]) -> "PackageListing": self = cls({}) self.update_from(src) return self @@ -114,7 +111,7 @@ def _check_for_duplicate_project_names( def resolve_packages( - packages: List[PackageContract], + packages: List[PackageSpec], project: Project, cli_vars: Dict[str, Any], ) -> List[PinnedPackage]: @@ -137,7 +134,7 @@ def resolve_packages( return resolved -def resolve_lock_packages(packages: List[PackageContract]) -> List[PinnedPackage]: +def resolve_lock_packages(packages: List[PackageSpec]) -> List[PinnedPackage]: lock_packages = PackageListing.from_contracts(packages) final = PackageListing() diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index b000c70755e..d4c206b023c 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -6,14 +6,14 @@ from contextlib import nullcontext from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Type, Union, Set +from typing import Any, Dict, List, Optional, Set from dbt.compilation import Compiler import dbt_common.exceptions.base import dbt.exceptions from dbt import tracking from dbt.cli.flags import Flags -from dbt.config import RuntimeConfig, Project +from dbt.config import RuntimeConfig from dbt.config.profile import read_profile from dbt.constants import DBT_PROJECT_FILE_NAME from dbt.contracts.graph.manifest import Manifest @@ -50,12 +50,6 @@ from dbt.task.printer import print_run_result_error -class NoneConfig: - @classmethod - def from_args(cls, args: Flags): - return None - - def read_profiles(profiles_dir=None): """This is only used for some error handling""" if profiles_dir is None: @@ -72,12 +66,8 @@ def read_profiles(profiles_dir=None): class BaseTask(metaclass=ABCMeta): - ConfigType: Union[Type[NoneConfig], Type[Project]] = NoneConfig - - def __init__(self, args: Flags, config, project=None) -> None: + def __init__(self, args: Flags) -> None: self.args = args - self.config = config - self.project = config if isinstance(config, Project) else project @classmethod def pre_init_hook(cls, args: Flags): @@ -94,23 +84,6 @@ def set_log_format(cls): else: log_manager.format_text() - @classmethod - def from_args(cls, args, *pargs, **kwargs): - try: - # This is usually RuntimeConfig - config = cls.ConfigType.from_args(args) - except dbt.exceptions.DbtProjectError as exc: - fire_event(LogDbtProjectError(exc=str(exc))) - - tracking.track_invalid_invocation(args=args, result_type=exc.result_type) - raise dbt_common.exceptions.DbtRuntimeError("Could not run dbt") from exc - except dbt.exceptions.DbtProfileError as exc: - all_profile_names = list(read_profiles(get_flags().PROFILES_DIR).keys()) - fire_event(LogDbtProfileError(exc=str(exc), profiles=all_profile_names)) - tracking.track_invalid_invocation(args=args, result_type=exc.result_type) - raise dbt_common.exceptions.DbtRuntimeError("Could not run dbt") from exc - return cls(args, config, *pargs, **kwargs) - @abstractmethod def run(self): raise dbt_common.exceptions.base.NotImplementedError("Not Implemented") @@ -154,10 +127,11 @@ def move_to_nearest_project_dir(project_dir: Optional[str]) -> Path: # produce the same behavior. currently this class only contains manifest compilation, # holding a manifest, and moving direcories. class ConfiguredTask(BaseTask): - ConfigType = RuntimeConfig - - def __init__(self, args: Flags, config, manifest: Optional[Manifest] = None) -> None: - super().__init__(args, config) + def __init__( + self, args: Flags, config: RuntimeConfig, manifest: Optional[Manifest] = None + ) -> None: + super().__init__(args) + self.config = config self.graph: Optional[Graph] = None self.manifest = manifest self.compiler = Compiler(self.config) @@ -177,7 +151,20 @@ def compile_manifest(self): @classmethod def from_args(cls, args: Flags, *pargs, **kwargs): move_to_nearest_project_dir(args.project_dir) - return super().from_args(args, *pargs, **kwargs) + try: + # This is usually RuntimeConfig + config = RuntimeConfig.from_args(args) + except dbt.exceptions.DbtProjectError as exc: + fire_event(LogDbtProjectError(exc=str(exc))) + + tracking.track_invalid_invocation(args=args, result_type=exc.result_type) + raise dbt_common.exceptions.DbtRuntimeError("Could not run dbt") from exc + except dbt.exceptions.DbtProfileError as exc: + all_profile_names = list(read_profiles(get_flags().PROFILES_DIR).keys()) + fire_event(LogDbtProfileError(exc=str(exc), profiles=all_profile_names)) + tracking.track_invalid_invocation(args=args, result_type=exc.result_type) + raise dbt_common.exceptions.DbtRuntimeError("Could not run dbt") from exc + return cls(args, config, *pargs, **kwargs) class ExecutionContext: diff --git a/core/dbt/task/build.py b/core/dbt/task/build.py index 6c6426ceefd..57f11c71bd5 100644 --- a/core/dbt/task/build.py +++ b/core/dbt/task/build.py @@ -9,6 +9,8 @@ from dbt.artifacts.schemas.results import NodeStatus, RunStatus from dbt.artifacts.schemas.run import RunResult from dbt.cli.flags import Flags +from dbt.config.runtime import RuntimeConfig +from dbt.contracts.graph.manifest import Manifest from dbt.graph import ResourceTypeSelector, GraphQueue, Graph from dbt.node_types import NodeType from dbt.task.test import TestSelector @@ -75,7 +77,7 @@ class BuildTask(RunTask): } ALL_RESOURCE_VALUES = frozenset({x for x in RUNNER_MAP.keys()}) - def __init__(self, args: Flags, config, manifest) -> None: + def __init__(self, args: Flags, config: RuntimeConfig, manifest: Manifest) -> None: super().__init__(args, config, manifest) self.selected_unit_tests: Set = set() self.model_to_unit_test_map: Dict[str, List] = {} diff --git a/core/dbt/task/clean.py b/core/dbt/task/clean.py index efae26bf6e1..c4e98f5db2b 100644 --- a/core/dbt/task/clean.py +++ b/core/dbt/task/clean.py @@ -9,6 +9,8 @@ FinishedCleanPaths, ) from dbt_common.exceptions import DbtRuntimeError +from dbt.cli.flags import Flags +from dbt.config.project import Project from dbt.task.base import ( BaseTask, move_to_nearest_project_dir, @@ -16,6 +18,11 @@ class CleanTask(BaseTask): + def __init__(self, args: Flags, config: Project): + super().__init__(args) + self.config = config + self.project = config + def run(self): """ This function takes all the paths in the target file diff --git a/core/dbt/task/debug.py b/core/dbt/task/debug.py index 51eabaea13e..b388e4336ba 100644 --- a/core/dbt/task/debug.py +++ b/core/dbt/task/debug.py @@ -78,8 +78,8 @@ class DebugRunStatus(Flag): class DebugTask(BaseTask): - def __init__(self, args: Flags, config) -> None: - super().__init__(args, config) + def __init__(self, args: Flags) -> None: + super().__init__(args) self.profiles_dir = args.PROFILES_DIR self.profile_path = os.path.join(self.profiles_dir, "profiles.yml") try: @@ -98,13 +98,6 @@ def __init__(self, args: Flags, config) -> None: self.profile: Optional[Profile] = None self.raw_profile_data: Optional[Dict[str, Any]] = None self.profile_name: Optional[str] = None - self.project: Optional[Project] = None - - @property - def project_profile(self): - if self.project is None: - return None - return self.project.profile_name def run(self) -> bool: # WARN: this is a legacy workflow that is not compatible with other runtime flags diff --git a/core/dbt/task/deps.py b/core/dbt/task/deps.py index 85788fe440f..0f8e45f073f 100644 --- a/core/dbt/task/deps.py +++ b/core/dbt/task/deps.py @@ -13,7 +13,7 @@ from dbt.deps.base import downloads_directory from dbt.deps.resolver import resolve_lock_packages, resolve_packages from dbt.deps.registry import RegistryPinnedPackage -from dbt.contracts.project import Package +from dbt.contracts.project import PackageSpec from dbt_common.events.functions import fire_event @@ -44,7 +44,7 @@ def increase_indent(self, flow=False, indentless=False): return super(dbtPackageDumper, self).increase_indent(flow, False) -def _create_sha1_hash(packages: List[Package]) -> str: +def _create_sha1_hash(packages: List[PackageSpec]) -> str: """Create a SHA1 hash of the packages list, this is used to determine if the packages for current execution matches the previous lock. @@ -94,14 +94,15 @@ def _create_packages_yml_entry(package: str, version: Optional[str], source: str class DepsTask(BaseTask): def __init__(self, args: Any, project: Project) -> None: + super().__init__(args=args) # N.B. This is a temporary fix for a bug when using relative paths via # --project-dir with deps. A larger overhaul of our path handling methods # is needed to fix this the "right" way. # See GH-7615 project.project_root = str(Path(project.project_root).resolve()) + self.project = project move_to_nearest_project_dir(project.project_root) - super().__init__(args=args, config=None, project=project) self.cli_vars = args.vars def track_package_install( diff --git a/core/dbt/task/list.py b/core/dbt/task/list.py index 3ea4b1e2d27..e345bc78d94 100644 --- a/core/dbt/task/list.py +++ b/core/dbt/task/list.py @@ -9,6 +9,8 @@ UnitTestDefinition, ) from dbt.cli.flags import Flags +from dbt.config.runtime import RuntimeConfig +from dbt.contracts.graph.manifest import Manifest from dbt.flags import get_flags from dbt.graph import ResourceTypeSelector from dbt.task.base import resource_types_from_args @@ -55,7 +57,7 @@ class ListTask(GraphRunnableTask): ) ) - def __init__(self, args: Flags, config, manifest) -> None: + def __init__(self, args: Flags, config: RuntimeConfig, manifest: Manifest) -> None: super().__init__(args, config, manifest) if self.args.models: if self.args.select: diff --git a/core/dbt/task/retry.py b/core/dbt/task/retry.py index 70dea9756f2..57724f455e0 100644 --- a/core/dbt/task/retry.py +++ b/core/dbt/task/retry.py @@ -63,7 +63,7 @@ class RetryTask(ConfiguredTask): - def __init__(self, args: Flags, config) -> None: + def __init__(self, args: Flags, config: RuntimeConfig) -> None: # load previous run results state_path = args.state or config.target_path self.previous_results = load_result_state( diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 4fc6ebc64cf..b57d39c785b 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -17,8 +17,10 @@ from dbt.adapters.base import BaseRelation from dbt.cli.flags import Flags from dbt.clients.jinja import MacroGenerator +from dbt.config.runtime import RuntimeConfig from dbt.context.providers import generate_runtime_model_context from dbt.contracts.graph.nodes import HookNode, ResultNode +from dbt.contracts.graph.manifest import Manifest from dbt.artifacts.schemas.results import NodeStatus, RunStatus, RunningStatus, BaseResult from dbt.artifacts.schemas.run import RunResult from dbt.artifacts.resources import Hook @@ -306,7 +308,7 @@ def execute(self, model, manifest): class RunTask(CompileTask): - def __init__(self, args: Flags, config, manifest) -> None: + def __init__(self, args: Flags, config: RuntimeConfig, manifest: Manifest) -> None: super().__init__(args, config, manifest) self.ran_hooks: List[HookNode] = [] self._total_executed = 0 diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 746c08bf656..6593053c285 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -16,6 +16,7 @@ from dbt.adapters.base import BaseRelation from dbt.adapters.factory import get_adapter from dbt.cli.flags import Flags +from dbt.config.runtime import RuntimeConfig from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import ResultNode from dbt.artifacts.schemas.results import NodeStatus, RunningStatus, RunStatus, BaseResult @@ -66,8 +67,9 @@ class GraphRunnableTask(ConfiguredTask): MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error] - def __init__(self, args: Flags, config, manifest) -> None: + def __init__(self, args: Flags, config: RuntimeConfig, manifest: Manifest) -> None: super().__init__(args, config, manifest) + self.config = config self._flattened_nodes: Optional[List[ResultNode]] = None self._raise_next_tick: Optional[DbtRuntimeError] = None self._skipped_children: Dict[str, Optional[RunResult]] = {} @@ -124,7 +126,9 @@ def get_selection_spec(self) -> SelectionSpec: # This is what's used with no default selector and no selection # use --select and --exclude args spec = parse_difference(self.selection_arg, self.exclusion_arg, indirect_selection) - return spec + # mypy complains because the return values of get_selector and parse_difference + # are different + return spec # type: ignore @abstractmethod def get_node_selector(self) -> NodeSelector: @@ -624,7 +628,9 @@ def create_schema(relation: BaseRelation) -> None: list_futures = [] create_futures = [] - with dbt_common.utils.executor(self.config) as tpe: + # TODO: following has a mypy issue because profile and project config + # defines threads as int and HasThreadingConfig defines it as Optional[int] + with dbt_common.utils.executor(self.config) as tpe: # type: ignore for req in required_databases: if req.database is None: name = "list_schemas"