Skip to content

Commit

Permalink
update dbt-rpc to use the latest flags module
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenyuLInx committed Feb 16, 2023
1 parent ea716ec commit 8f6cbc2
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 32 deletions.
27 changes: 13 additions & 14 deletions dbt_rpc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,24 +566,19 @@ def parse_args(args, cls=DBTArgumentParser):
sys.exit(1)

parsed = p.parse_args(args)
from dbt.cli.resolvers import default_profiles_dir
parsed.profiles_dir = default_profiles_dir()
parsed.defer_mode = 'eager'
set_from_args(parsed, None)
flags = get_flags()

# get the correct profiles_dir
if os.getenv('DBT_PROFILES_DIR'):
parsed.profiles_dir = os.getenv('DBT_PROFILES_DIR')
else:
from dbt.cli.resolvers import default_profiles_dir
parsed.profiles_dir = default_profiles_dir()
# profiles_dir is set before subcommands and after, so normalize
if hasattr(parsed, 'sub_profiles_dir'):
if parsed.sub_profiles_dir is not None:
parsed.profiles_dir = parsed.sub_profiles_dir
delattr(parsed, 'sub_profiles_dir')
# if hasattr(parsed, 'profiles_dir'):
# if parsed.profiles_dir is None:
# parsed.profiles_dir = flags.PROFILES_DIR
# else:
# parsed.profiles_dir = os.path.abspath(parsed.profiles_dir)
# # needs to be set before the other flags, because it's needed to
# # read the profile that contains them
# flags.PROFILES_DIR = parsed.profiles_dir
parsed.profiles_dir = os.path.abspath(parsed.profiles_dir)

# version_check is set before subcommands and after, so normalize
if hasattr(parsed, 'sub_version_check'):
Expand All @@ -599,7 +594,11 @@ def parse_args(args, cls=DBTArgumentParser):

if getattr(parsed, 'project_dir', None) is not None:
expanded_user = os.path.expanduser(parsed.project_dir)
parsed.project_dir = os.path.abspath(expanded_user)
parsed.project_dir = os.path.abspath(expanded_user)
parsed.defer_mode = 'eager'

# create global flags object
set_from_args(parsed, None)

if not hasattr(parsed, 'which'):
# the user did not provide a valid subcommand. trigger the help message
Expand Down
8 changes: 6 additions & 2 deletions dbt_rpc/rpc/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _spawn_setup(self):
user_config = self.task.config.user_config
set_from_args(self.task.args, user_config)
flags = get_flags()
dbt.tracking.initialize_from_flags(flags.ANONYMOUS_USAGE_STATS, flags.PROFILES_DIR)
dbt.tracking.initialize_from_flags(flags.SEND_ANONYMOUS_USAGE_STATS, flags.PROFILES_DIR)
# reload the active plugin
load_plugin(self.task.config.credentials.type)
# register it
Expand All @@ -97,7 +97,11 @@ def task_exec(self) -> None:
# some commands, like 'debug', won't have a threads value at all.
if getattr(self.task.args, 'threads', None) is not None:
self.task.config.threads = self.task.args.threads
self.task.args.selector = None

# we previously always set a selector here
if not hasattr(self.task.args, 'selector'):
object.__setattr__(self.task.args, "selector", None)
object.__setattr__(self.task.args, "SELECTOR", None)
rpc_exception = None
result = None
try:
Expand Down
50 changes: 35 additions & 15 deletions dbt_rpc/task/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from dbt.clients.yaml_helper import Dumper, yaml # noqa: F401
from typing import Type, Optional


from dbt.config.utils import parse_cli_vars
from dbt_rpc.contracts.rpc import RPCCliParameters

from dbt_rpc.rpc.method import (
Expand All @@ -28,6 +26,12 @@ def has_cli_parameters(cls):
def handle_request(self) -> Result:
pass

COMMAND_MAPING = {
"freshness": "source-freshness",
"snapshot-freshness": "source-freshness",
"generate": "docs.generate"
}


class RemoteRPCCli(RPCTask[RPCCliParameters]):
METHOD_NAME = 'cli_args'
Expand All @@ -53,11 +57,25 @@ def set_config(self, config):
)

def set_args(self, params: RPCCliParameters) -> None:
# NOTE: `parse_args` is pinned to the version of dbt-core installed!
from dbt.main import parse_args
from dbt_rpc.__main__ import RPCArgumentParser
split = shlex.split(params.cli)
self.args = parse_args(split, RPCArgumentParser)

from dbt.cli.flags import args_to_context, Flags

ctx = args_to_context(split)
self.args = Flags(ctx)

# previously this info is preserved in gloabl flags module
from dbt.flags import get_flags
object.__setattr__(self.args, 'PROFILES_DIR', get_flags().PROFILES_DIR)
object.__setattr__(self.args, 'profiles_dir', get_flags().PROFILES_DIR)

# this was handled by parse_args in original main before, now move the
# logic here
if ctx.command.name in COMMAND_MAPING:
rpc_method = COMMAND_MAPING[ctx.command.name]
else:
rpc_method = ctx.command.name
object.__setattr__(self.args, 'rpc_method', rpc_method)
self.task_type = self.get_rpc_task_cls()

def get_flags(self):
Expand Down Expand Up @@ -95,21 +113,23 @@ def handle_request(self) -> Result:
# `self.config` is before the fork(), so it would alter the behavior of
# future calls.

# read any cli vars we got and use it to update cli_vars
self.config.cli_vars.update(
parse_cli_vars(getattr(self.args, 'vars', '{}'))
)

# # read any cli vars we got and use it to update cli_vars
self.config.cli_vars.update(self.args.vars)
# If this changed the vars, rewrite args.vars to reflect our merged
# vars and reload the manifest.

#TODO Here's why this logic nolonger works:
# we updated to have cli vars and also vars all being the parsed version, which makes them object
# https://github.com/dbt-labs/dbt-core/pull/6396/files#diff-7685e44a07e8211f0e710116a07186168af0feb2e466fb46e40504d6b2282ec1L237
dumped = yaml.safe_dump(self.config.cli_vars)
if dumped != self.args.vars:
self.real_task.args.cli_vars = self.config.cli_vars
self.args.cli_vars = self.config.cli_vars
self.args.vars = self.config.cli_vars
# dumped = yaml.safe_dump(self.config.cli_vars)
if self.config.cli_vars != self.args.vars:
# object.__setattr__(self.real_task.args, "VARS", self.config.cli_vars)
# object.__setattr__(self.real_task.args, "vars", self.config.cli_vars)
# object.__setattr__(self.args, "VARS", self.config.cli_vars)
# object.__setattr__(self.args, "vars", self.config.cli_vars)
object.__setattr__(self.real_task.args, "cli_vars", self.config.cli_vars)
object.__setattr__(self.args, "cli_vars", self.config.cli_vars)
self.config.args = self.args
if isinstance(self.real_task, RemoteManifestMethod):
self.real_task.manifest = ManifestLoader.get_full_manifest(
Expand Down
2 changes: 2 additions & 0 deletions dbt_rpc/task/project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(self, args, config, manifest):
RemoteManifestMethod.__init__(
self, args, config, manifest # type: ignore
)
self.set_from_params = False

def load_manifest(self):
# we started out with a manifest!
Expand All @@ -202,6 +203,7 @@ def load_manifest(self):
def set_args(self, params: RPCRunOperationParameters) -> None:
self.args.macro = params.macro
self.args.args = params.args
self.set_from_params = True

def _get_kwargs(self):
if isinstance(self.args.args, dict):
Expand Down
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ def dbt_profile_data(unique_schema, pytestconfig):
@pytest.fixture
def dbt_profile(profiles_root, dbt_profile_data) -> Dict[str, Any]:
flags.PROFILES_DIR = profiles_root
original_profile_dir_env = os.environ.get('DBT_PROFILES_DIR')
# we have to do this to make sure the subprocess uses the correct profiles
os.environ['DBT_PROFILES_DIR'] = str(profiles_root)
path = os.path.join(profiles_root, 'profiles.yml')
with open(path, 'w') as fp:
fp.write(yaml.safe_dump(dbt_profile_data))
return dbt_profile_data
yield dbt_profile_data
if original_profile_dir_env:
os.environ['DBT_PROFILES_DIR'] = original_profile_dir_env
else:
del os.environ['DBT_PROFILES_DIR']

0 comments on commit 8f6cbc2

Please sign in to comment.