Skip to content

Commit

Permalink
Add "Flags" to args parameter in tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Apr 4, 2024
1 parent a994ace commit 391d2f8
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 12 deletions.
3 changes: 3 additions & 0 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def set_common_global_flags(self):
if getattr(self, "MACRO_DEBUGGING", None) is not None:
jinja.MACRO_DEBUGGING = getattr(self, "MACRO_DEBUGGING")

def __getattr__(self, name: str) -> Any:
return super().__get_attribute__(name) # type: ignore


CommandParams = List[str]

Expand Down
13 changes: 7 additions & 6 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dbt_common.exceptions.base
import dbt.exceptions
from dbt import tracking
from dbt.cli.flags import Flags
from dbt.config import RuntimeConfig, Project
from dbt.config.profile import read_profile
from dbt.constants import DBT_PROJECT_FILE_NAME
Expand Down Expand Up @@ -51,7 +52,7 @@

class NoneConfig:
@classmethod
def from_args(cls, args):
def from_args(cls, args: Flags):
return None


Expand All @@ -73,13 +74,13 @@ def read_profiles(profiles_dir=None):
class BaseTask(metaclass=ABCMeta):
ConfigType: Union[Type[NoneConfig], Type[Project]] = NoneConfig

def __init__(self, args, config, project=None) -> None:
def __init__(self, args: Flags, config, project=None) -> None:
self.args = args
self.config = config
self.project = config if isinstance(config, Project) else project

@classmethod
def pre_init_hook(cls, args):
def pre_init_hook(cls, args: Flags):
"""A hook called before the task is initialized."""
if args.log_format == "json":
log_manager.format_json()
Expand Down Expand Up @@ -155,7 +156,7 @@ def move_to_nearest_project_dir(project_dir: Optional[str]) -> Path:
class ConfiguredTask(BaseTask):
ConfigType = RuntimeConfig

def __init__(self, args, config, manifest: Optional[Manifest] = None) -> None:
def __init__(self, args: Flags, config, manifest: Optional[Manifest] = None) -> None:
super().__init__(args, config)
self.graph: Optional[Graph] = None
self.manifest = manifest
Expand All @@ -174,7 +175,7 @@ def compile_manifest(self):
dbt.tracking.track_runnable_timing({"graph_compilation_elapsed": compile_time})

@classmethod
def from_args(cls, args, *pargs, **kwargs):
def from_args(cls, args: Flags, *pargs, **kwargs):
move_to_nearest_project_dir(args.project_dir)
return super().from_args(args, *pargs, **kwargs)

Expand Down Expand Up @@ -487,7 +488,7 @@ def do_skip(self, cause=None):


def resource_types_from_args(
args, all_resource_values: Set[NodeType], default_resource_values: Set[NodeType]
args: Flags, all_resource_values: Set[NodeType], default_resource_values: Set[NodeType]
) -> Set[NodeType]:

if not args.resource_types:
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/task/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from dbt.artifacts.schemas.results import NodeStatus, RunStatus
from dbt.artifacts.schemas.run import RunResult
from dbt.cli.flags import Flags
from dbt.graph import ResourceTypeSelector, GraphQueue, Graph
from dbt.node_types import NodeType
from dbt.task.test import TestSelector
Expand Down Expand Up @@ -74,7 +75,7 @@ class BuildTask(RunTask):
}
ALL_RESOURCE_VALUES = frozenset({x for x in RUNNER_MAP.keys()})

def __init__(self, args, config, manifest) -> None:
def __init__(self, args: Flags, config, manifest) -> None:
super().__init__(args, config, manifest)
self.selected_unit_tests: Set = set()
self.model_to_unit_test_map: Dict[str, List] = {}
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/task/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import dbt.exceptions
import dbt_common.exceptions
from dbt.adapters.factory import get_adapter, register_adapter
from dbt.cli.flags import Flags
from dbt.config import PartialProject, Project, Profile
from dbt.config.renderer import DbtProjectYamlRenderer, ProfileRenderer
from dbt.artifacts.schemas.results import RunStatus
Expand Down Expand Up @@ -77,7 +78,7 @@ class DebugRunStatus(Flag):


class DebugTask(BaseTask):
def __init__(self, args, config) -> None:
def __init__(self, args: Flags, config) -> None:
super().__init__(args, config)
self.profiles_dir = args.PROFILES_DIR
self.profile_path = os.path.join(self.profiles_dir, "profiles.yml")
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/task/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
SemanticModel,
UnitTestDefinition,
)
from dbt.cli.flags import Flags

Check warning on line 11 in core/dbt/task/list.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/list.py#L11

Added line #L11 was not covered by tests
from dbt.flags import get_flags
from dbt.graph import ResourceTypeSelector
from dbt.task.base import resource_types_from_args
Expand Down Expand Up @@ -54,7 +55,7 @@ class ListTask(GraphRunnableTask):
)
)

def __init__(self, args, config, manifest) -> None:
def __init__(self, args: Flags, config, manifest) -> None:

Check warning on line 58 in core/dbt/task/list.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/list.py#L58

Added line #L58 was not covered by tests
super().__init__(args, config, manifest)
if self.args.models:
if self.args.select:
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@


class RetryTask(ConfiguredTask):
def __init__(self, args, config) -> None:
def __init__(self, args: Flags, config) -> None:
# load previous run results
state_path = args.state or config.target_path
self.previous_results = load_result_state(
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dbt import tracking
from dbt import utils
from dbt.adapters.base import BaseRelation
from dbt.cli.flags import Flags
from dbt.clients.jinja import MacroGenerator
from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.graph.nodes import HookNode, ResultNode
Expand Down Expand Up @@ -305,7 +306,7 @@ def execute(self, model, manifest):


class RunTask(CompileTask):
def __init__(self, args, config, manifest) -> None:
def __init__(self, args: Flags, config, manifest) -> None:
super().__init__(args, config, manifest)
self.ran_hooks: List[HookNode] = []
self._total_executed = 0
Expand Down
3 changes: 2 additions & 1 deletion core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dbt.utils
from dbt.adapters.base import BaseRelation
from dbt.adapters.factory import get_adapter
from dbt.cli.flags import Flags
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.artifacts.schemas.results import NodeStatus, RunningStatus, RunStatus, BaseResult
Expand Down Expand Up @@ -65,7 +66,7 @@
class GraphRunnableTask(ConfiguredTask):
MARK_DEPENDENT_ERRORS_STATUSES = [NodeStatus.Error]

def __init__(self, args, config, manifest) -> None:
def __init__(self, args: Flags, config, manifest) -> None:
super().__init__(args, config, manifest)
self._flattened_nodes: Optional[List[ResultNode]] = None
self._raise_next_tick: Optional[DbtRuntimeError] = None
Expand Down

0 comments on commit 391d2f8

Please sign in to comment.