Skip to content

Commit

Permalink
use mutex to enforce mutually exclusive options
Browse files Browse the repository at this point in the history
  • Loading branch information
aranke committed Sep 19, 2023
1 parent efdedc3 commit 3901d22
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 26 deletions.
25 changes: 2 additions & 23 deletions core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

from click import Context, get_current_context, Parameter
from click.core import Command as ClickCommand, Group, ParameterSource

from dbt.cli.exceptions import DbtUsageException
from dbt.cli.resolvers import default_log_path, default_project_dir
from dbt.cli.types import Command as CliCommand
from dbt.config.profile import read_user_config
from dbt.contracts.project import UserConfig
from dbt.exceptions import DbtInternalError
from dbt.deprecations import renamed_env_var
from dbt.exceptions import DbtInternalError
from dbt.helper_types import WarnErrorOptions

if os.name != "nt":
Expand Down Expand Up @@ -243,11 +244,6 @@ def _assign_params(
if os.getenv("DO_NOT_TRACK", "").lower() in ("1", "t", "true", "y", "yes"):
object.__setattr__(self, "SEND_ANONYMOUS_USAGE_STATS", False)

# Check mutual exclusivity once all flags are set.
self._assert_mutually_exclusive(
params_assigned_from_default, ["WARN_ERROR", "WARN_ERROR_OPTIONS"]
)

# Support lower cased access for legacy code.
params = set(
x for x in dir(self) if not callable(getattr(self, x)) and not x.startswith("__")
Expand All @@ -263,23 +259,6 @@ def _override_if_set(self, lead: str, follow: str, defaulted: Set[str]) -> None:
if lead.lower() not in defaulted and follow.lower() in defaulted:
object.__setattr__(self, follow.upper(), getattr(self, lead.upper(), None))

def _assert_mutually_exclusive(
self, params_assigned_from_default: Set[str], group: List[str]
) -> None:
"""
Ensure no elements from group are simultaneously provided by a user, as inferred from params_assigned_from_default.
Raises click.UsageError if any two elements from group are simultaneously provided by a user.
"""
set_flag = None
for flag in group:
flag_set_by_user = flag.lower() not in params_assigned_from_default
if flag_set_by_user and set_flag:
raise DbtUsageException(
f"{flag.lower()}: not allowed with argument {set_flag.lower()}"
)
elif flag_set_by_user:
set_flag = flag

def fire_deprecations(self):
"""Fires events for deprecated env_var usage."""
[dep_fn() for dep_fn in self.deprecated_env_var_warnings]
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def global_flags(func):
@p.use_experimental_parser
@p.version
@p.version_check
@p.warn_error
@p.warn_error_options
@p.write_json
@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand All @@ -166,8 +168,6 @@ def wrapper(*args, **kwargs):
)
@click.pass_context
@global_flags
@p.warn_error
@p.warn_error_options
@p.log_format
def cli(ctx, **kwargs):
"""An ELT tool for managing your SQL transformations and data models.
Expand Down
34 changes: 33 additions & 1 deletion core/dbt/cli/params.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
from pathlib import Path

import click
from dbt.cli.options import MultiOption

from dbt.cli.exceptions import DbtUsageException
from dbt.cli.option_types import YAML, ChoiceTuple, WarnErrorOptionsType
from dbt.cli.options import MultiOption
from dbt.cli.resolvers import default_project_dir, default_profiles_dir
from dbt.version import get_version_information


# Copied from https://github.com/pallets/click/issues/257#issuecomment-403312784
class Mutex(click.Option):
def __init__(self, *args, **kwargs):
self.not_required_if: list = kwargs.pop("not_required_if")

assert self.not_required_if, "'not_required_if' parameter required"
kwargs["help"] = (
kwargs.get("help", "")
+ "Option is mutually exclusive with "
+ ", ".join(self.not_required_if)
+ "."
).strip()
super(Mutex, self).__init__(*args, **kwargs)

def handle_parse_result(self, ctx, opts, args):
current_opt: bool = self.name in opts
for mutex_opt in self.not_required_if:
if mutex_opt in opts:
if current_opt:
raise DbtUsageException(f"{self.name}: not allowed with argument {mutex_opt}")
else:
self.prompt = None
return super(Mutex, self).handle_parse_result(ctx, opts, args)


args = click.option(
"--args",
envvar=None,
Expand Down Expand Up @@ -589,6 +617,8 @@ def _version_callback(ctx, _param, value):
help="If dbt would normally warn, instead raise an exception. Examples include --select that selects nothing, deprecations, configurations with no associated models, invalid test configurations, and missing sources/refs in tests.",
default=None,
is_flag=True,
cls=Mutex,
not_required_if=["warn_error_options"],
)

warn_error_options = click.option(
Expand All @@ -598,6 +628,8 @@ def _version_callback(ctx, _param, value):
help="""If dbt would normally warn, instead raise an exception based on include/exclude configuration. Examples include --select that selects nothing, deprecations, configurations with no associated models, invalid test configurations,
and missing sources/refs in tests. This argument should be a YAML string, with keys 'include' or 'exclude'. eg. '{"include": "all", "exclude": ["NoNodesForSelectionCriteria"]}'""",
type=WarnErrorOptionsType(),
cls=Mutex,
not_required_if=["warn_error"],
)

write_json = click.option(
Expand Down

0 comments on commit 3901d22

Please sign in to comment.