From 9a1ae151b0127be9f9ce9d06bd932e139c461a3b Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 23 Apr 2024 16:09:36 -0400 Subject: [PATCH] add ProjectFlags.project_only_flags from https://github.com/dbt-labs/dbt-core/pull/9366 --- core/dbt/cli/flags.py | 9 ++++++++- core/dbt/contracts/project.py | 4 ++++ tests/unit/test_cli_flags.py | 8 ++++++++ tests/unit/test_graph.py | 3 ++- 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index f5e7ca18104..b92703beb37 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -212,8 +212,8 @@ def _assign_params( # Add entire invocation command to flags object.__setattr__(self, "INVOCATION_COMMAND", "dbt " + " ".join(sys.argv[1:])) - # Overwrite default assignments with user config if available. if project_flags: + # Overwrite default assignments with project flags if available. param_assigned_from_default_copy = params_assigned_from_default.copy() for param_assigned_from_default in params_assigned_from_default: project_flags_param_value = getattr( @@ -228,6 +228,13 @@ def _assign_params( param_assigned_from_default_copy.remove(param_assigned_from_default) params_assigned_from_default = param_assigned_from_default_copy + # Add project-level flags that are not available as CLI options / env vars + for ( + project_level_flag_name, + project_level_flag_value, + ) in project_flags.project_only_flags.items(): + object.__setattr__(self, project_level_flag_name.upper(), project_level_flag_value) + # Set hard coded flags. object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name) object.__setattr__(self, "MP_CONTEXT", get_context("spawn")) diff --git a/core/dbt/contracts/project.py b/core/dbt/contracts/project.py index 1442c5bd6ed..0b174517c91 100644 --- a/core/dbt/contracts/project.py +++ b/core/dbt/contracts/project.py @@ -305,6 +305,10 @@ class ProjectFlags(ExtensibleDbtClassMixin, Replaceable): warn_error_options: Optional[Dict[str, Union[str, List[str]]]] = None write_json: Optional[bool] = None + @property + def project_only_flags(self) -> Dict[str, Any]: + return {} + @dataclass class ProfileConfig(dbtClassMixin, Replaceable): diff --git a/tests/unit/test_cli_flags.py b/tests/unit/test_cli_flags.py index da53c203239..c7fcd92e128 100644 --- a/tests/unit/test_cli_flags.py +++ b/tests/unit/test_cli_flags.py @@ -371,6 +371,14 @@ def test_global_flag_at_child_context(self): assert flags_a.USE_COLORS == flags_b.USE_COLORS + def test_set_project_only_flags(self, project_flags, run_context): + flags = Flags(run_context, project_flags) + + for project_only_flag, project_only_flag_value in project_flags.project_only_flags.items(): + assert getattr(flags, project_only_flag) == project_only_flag_value + # sanity check: ensure project_only_flag is not part of the click context + assert project_only_flag not in run_context.params + def _create_flags_from_dict(self, cmd, d): write_file("", "profiles.yml") result = Flags.from_dict(cmd, d) diff --git a/tests/unit/test_graph.py b/tests/unit/test_graph.py index 48011cd2553..9c61b3f97c7 100644 --- a/tests/unit/test_graph.py +++ b/tests/unit/test_graph.py @@ -16,6 +16,7 @@ import dbt.parser.manifest from dbt import tracking from dbt.contracts.files import SourceFile, FileHash, FilePath +from dbt.contracts.project import ProjectFlags from dbt.contracts.graph.manifest import MacroManifest, ManifestStateCheck from dbt.graph import NodeSelector, parse_difference from dbt.events.functions import setup_event_logger @@ -130,7 +131,7 @@ def get_config(self, extra_cfg=None): cfg.update(extra_cfg) config = config_from_parts_or_dicts(project=cfg, profile=self.profile) - dbt.flags.set_from_args(Namespace(), config) + dbt.flags.set_from_args(Namespace(), ProjectFlags()) setup_event_logger(dbt.flags.get_flags()) object.__setattr__(dbt.flags.get_flags(), "PARTIAL_PARSE", False) return config