From 4e9b489017f816673ad50b57480cfb36bb380b34 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Wed, 15 Feb 2023 17:54:10 -0800 Subject: [PATCH] support 1.5 --- dbt_rpc/__main__.py | 49 ++++++++++++++++++-------------- dbt_rpc/rpc/task_handler.py | 12 ++++++-- dbt_rpc/rpc/task_manager.py | 1 + dbt_rpc/task/base.py | 2 +- dbt_rpc/task/cli.py | 40 ++++++++++++++++++-------- dbt_rpc/task/deps.py | 2 ++ dbt_rpc/task/project_commands.py | 14 ++++----- tests/conftest.py | 9 +++++- tests/util.py | 4 +++ 9 files changed, 88 insertions(+), 45 deletions(-) diff --git a/dbt_rpc/__main__.py b/dbt_rpc/__main__.py index 045a749..332c29d 100644 --- a/dbt_rpc/__main__.py +++ b/dbt_rpc/__main__.py @@ -20,14 +20,15 @@ NotImplementedError, FailedToConnectError ) -import dbt.flags as flags - +from dbt.config.utils import parse_cli_vars +from dbt.flags import get_flags, set_from_args from dbt_rpc.task.server import RPCServerTask def initialize_tracking_from_flags(): # NOTE: this is copied from dbt-core # Setting these used to be in UserConfig, but had to be moved here + flags = get_flags() if flags.SEND_ANONYMOUS_USAGE_STATS: dbt.tracking.initialize_tracking(flags.PROFILES_DIR) else: @@ -168,15 +169,12 @@ def adapter_management(): def handle_and_check(args): with log_manager.applicationbound(): + # this also set global flags parsed = parse_args(args) - - # Set flags from args, user config, and env vars - user_config = read_user_config(flags.PROFILES_DIR) # This is read again later - flags.set_from_args(parsed, user_config) initialize_tracking_from_flags() # Set log_format from flags parsed.cls.set_log_format() - + flags = get_flags() # we've parsed the args and set the flags - we can now decide if we're debug or not if flags.DEBUG: log_manager.set_debug() @@ -201,21 +199,23 @@ def handle_and_check(args): @contextmanager def track_run(task): - dbt.tracking.track_invocation_start(config=task.config, args=task.args) + invocation_context = dbt.tracking.get_base_invocation_context() + invocation_context["command"] = 'rpc' + dbt.tracking.track_invocation_start(invocation_context) try: yield dbt.tracking.track_invocation_end( - config=task.config, args=task.args, result_type="ok" + invocation_context, result_type="ok" ) except (NotImplementedError, FailedToConnectError) as e: logger.error('ERROR: {}'.format(e)) dbt.tracking.track_invocation_end( - config=task.config, args=task.args, result_type="error" + invocation_context, result_type="error" ) except Exception: dbt.tracking.track_invocation_end( - config=task.config, args=task.args, result_type="error" + invocation_context, result_type="error" ) raise finally: @@ -233,6 +233,7 @@ def run_from_args(parsed): # this will convert DbtConfigErrors into DbtRuntimeErrors # task could be any one of the task objects + parsed.vars = parse_cli_vars(parsed.vars) task = parsed.cls.from_args(args=parsed) logger.debug("running dbt with arguments {parsed}", parsed=str(parsed)) @@ -246,7 +247,6 @@ def run_from_args(parsed): logger.debug("Tracking: {}".format(dbt.tracking.active_user.state())) results = None - with track_run(task): results = task.run() @@ -492,7 +492,7 @@ def parse_args(args, cls=DBTArgumentParser): p.add_argument( '--no-anonymous-usage-stats', action='store_false', - default=None, + default=False, dest='send_anonymous_usage_stats', help=''' Do not send anonymous usage stat to dbt Labs @@ -566,19 +566,19 @@ def parse_args(args, cls=DBTArgumentParser): sys.exit(1) parsed = p.parse_args(args) + + # get the correct profiles_dir + if os.getenv('DBT_PROFILES_DIR'): + parsed.profiles_dir = os.getenv('DBT_PROFILES_DIR') + else: + from dbt.cli.resolvers import default_profiles_dir + parsed.profiles_dir = default_profiles_dir() # profiles_dir is set before subcommands and after, so normalize if hasattr(parsed, 'sub_profiles_dir'): if parsed.sub_profiles_dir is not None: parsed.profiles_dir = parsed.sub_profiles_dir delattr(parsed, 'sub_profiles_dir') - if hasattr(parsed, 'profiles_dir'): - if parsed.profiles_dir is None: - parsed.profiles_dir = flags.PROFILES_DIR - else: - parsed.profiles_dir = os.path.abspath(parsed.profiles_dir) - # needs to be set before the other flags, because it's needed to - # read the profile that contains them - flags.PROFILES_DIR = parsed.profiles_dir + parsed.profiles_dir = os.path.abspath(parsed.profiles_dir) # version_check is set before subcommands and after, so normalize if hasattr(parsed, 'sub_version_check'): @@ -596,6 +596,13 @@ def parse_args(args, cls=DBTArgumentParser): expanded_user = os.path.expanduser(parsed.project_dir) parsed.project_dir = os.path.abspath(expanded_user) + # set_args construct a flags with command run + # which doesn't have defer_mode, but we need a default value + parsed.defer_mode = 'eager' + + # create global flags object + set_from_args(parsed, None) + if not hasattr(parsed, 'which'): # the user did not provide a valid subcommand. trigger the help message # and exit with a error diff --git a/dbt_rpc/rpc/task_handler.py b/dbt_rpc/rpc/task_handler.py index f615ddd..24855a1 100644 --- a/dbt_rpc/rpc/task_handler.py +++ b/dbt_rpc/rpc/task_handler.py @@ -12,7 +12,7 @@ from dbt.dataclass_schema import dbtClassMixin, ValidationError import dbt.exceptions -from dbt.flags import env_set_truthy +from dbt.flags import env_set_truthy, get_flags, set_from_args import dbt.tracking from dbt.adapters.factory import ( cleanup_connections, load_plugin, register_adapter, @@ -74,8 +74,9 @@ def _spawn_setup(self): user_config = None if self.task.config is not None: user_config = self.task.config.user_config - dbt.flags.set_from_args(self.task.args, user_config) - dbt.tracking.initialize_from_flags() + set_from_args(self.task.args, user_config) + flags = get_flags() + dbt.tracking.initialize_from_flags(flags.SEND_ANONYMOUS_USAGE_STATS, flags.PROFILES_DIR) # reload the active plugin load_plugin(self.task.config.credentials.type) # register it @@ -96,6 +97,11 @@ def task_exec(self) -> None: # some commands, like 'debug', won't have a threads value at all. if getattr(self.task.args, 'threads', None) is not None: self.task.config.threads = self.task.args.threads + + # we previously always set a selector here + if not hasattr(self.task.args, 'selector'): + object.__setattr__(self.task.args, "selector", None) + object.__setattr__(self.task.args, "SELECTOR", None) rpc_exception = None result = None try: diff --git a/dbt_rpc/rpc/task_manager.py b/dbt_rpc/rpc/task_manager.py index b92fb56..5379794 100644 --- a/dbt_rpc/rpc/task_manager.py +++ b/dbt_rpc/rpc/task_manager.py @@ -197,6 +197,7 @@ def set_parsing(self) -> bool: return True def parse_manifest(self) -> None: + register_adapter(self.config) self.manifest = ManifestLoader.get_full_manifest(self.config, reset=True) def set_compile_exception(self, exc, logs=List[LogMessage]) -> None: diff --git a/dbt_rpc/task/base.py b/dbt_rpc/task/base.py index c679c01..0ca726a 100644 --- a/dbt_rpc/task/base.py +++ b/dbt_rpc/task/base.py @@ -25,7 +25,7 @@ class RPCTask( RemoteManifestMethod[Parameters, RemoteExecutionResult] ): def __init__(self, args, config, manifest): - super().__init__(args, config) + super().__init__(args, config, manifest) RemoteManifestMethod.__init__( self, args, config, manifest # type: ignore ) diff --git a/dbt_rpc/task/cli.py b/dbt_rpc/task/cli.py index 68c9772..f4f4329 100644 --- a/dbt_rpc/task/cli.py +++ b/dbt_rpc/task/cli.py @@ -3,8 +3,6 @@ from dbt.clients.yaml_helper import Dumper, yaml # noqa: F401 from typing import Type, Optional - -from dbt.config.utils import parse_cli_vars from dbt_rpc.contracts.rpc import RPCCliParameters from dbt_rpc.rpc.method import ( @@ -28,6 +26,12 @@ def has_cli_parameters(cls): def handle_request(self) -> Result: pass +COMMAND_MAPING = { + "freshness": "source-freshness", + "snapshot-freshness": "source-freshness", + "generate": "docs.generate" +} + class RemoteRPCCli(RPCTask[RPCCliParameters]): METHOD_NAME = 'cli_args' @@ -53,11 +57,25 @@ def set_config(self, config): ) def set_args(self, params: RPCCliParameters) -> None: - # NOTE: `parse_args` is pinned to the version of dbt-core installed! - from dbt.main import parse_args - from dbt_rpc.__main__ import RPCArgumentParser split = shlex.split(params.cli) - self.args = parse_args(split, RPCArgumentParser) + + from dbt.cli.flags import args_to_context, Flags + + ctx = args_to_context(split) + self.args = Flags(ctx) + + # previously this info is preserved in gloabl flags module + from dbt.flags import get_flags + object.__setattr__(self.args, 'PROFILES_DIR', get_flags().PROFILES_DIR) + object.__setattr__(self.args, 'profiles_dir', get_flags().PROFILES_DIR) + + # this was handled by parse_args in original main before, now move the + # logic here + if ctx.command.name in COMMAND_MAPING: + rpc_method = COMMAND_MAPING[ctx.command.name] + else: + rpc_method = ctx.command.name + object.__setattr__(self.args, 'rpc_method', rpc_method) self.task_type = self.get_rpc_task_cls() def get_flags(self): @@ -96,14 +114,12 @@ def handle_request(self) -> Result: # future calls. # read any cli vars we got and use it to update cli_vars - self.config.cli_vars.update( - parse_cli_vars(getattr(self.args, 'vars', '{}')) - ) + self.config.cli_vars.update(self.args.vars) # If this changed the vars, rewrite args.vars to reflect our merged # vars and reload the manifest. - dumped = yaml.safe_dump(self.config.cli_vars) - if dumped != self.args.vars: - self.real_task.args.vars = dumped + if self.config.cli_vars != self.args.vars: + object.__setattr__(self.real_task.args, "cli_vars", self.config.cli_vars) + object.__setattr__(self.args, "cli_vars", self.config.cli_vars) self.config.args = self.args if isinstance(self.real_task, RemoteManifestMethod): self.real_task.manifest = ManifestLoader.get_full_manifest( diff --git a/dbt_rpc/task/deps.py b/dbt_rpc/task/deps.py index bf57a66..806a24f 100644 --- a/dbt_rpc/task/deps.py +++ b/dbt_rpc/task/deps.py @@ -31,5 +31,7 @@ def set_args(self, params: RPCDepsParameters): def handle_request(self) -> RemoteDepsResult: _clean_deps(self.config) + self.project = self.config + self.cli_vars = self.config.cli_vars self.run() return RemoteDepsResult([]) diff --git a/dbt_rpc/task/project_commands.py b/dbt_rpc/task/project_commands.py index bf2162f..73bd5b7 100644 --- a/dbt_rpc/task/project_commands.py +++ b/dbt_rpc/task/project_commands.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import List, Optional, Union -from dbt import flags +from dbt.flags import get_flags from dbt.contracts.graph.manifest import WritableManifest from dbt_rpc.contracts.rpc import ( GetManifestParameters, @@ -65,10 +65,11 @@ def handle_request(self) -> RemoteExecutionResult: def state_path(state: Optional[str]) -> Optional[Path]: + flags = get_flags() if state is not None: return Path(state) - elif flags.ARTIFACT_STATE_PATH is not None: - return Path(flags.ARTIFACT_STATE_PATH) + elif flags.STATE is not None: + return Path(flags.STATE) else: return None @@ -89,7 +90,6 @@ def set_args(self, params: RPCCompileParameters) -> None: self.args.threads = params.threads self.args.state = state_path(params.state) - self.set_previous_state() @@ -107,7 +107,7 @@ def set_args(self, params: RPCRunParameters) -> None: if params.threads is not None: self.args.threads = params.threads if params.defer is None: - self.args.defer = flags.DEFER_MODE + self.args.defer = get_flags().DEFER_MODE else: self.args.defer = params.defer @@ -146,7 +146,7 @@ def set_args(self, params: RPCTestParameters) -> None: if params.threads is not None: self.args.threads = params.threads if params.defer is None: - self.args.defer = flags.DEFER_MODE + self.args.defer = get_flags().DEFER_MODE else: self.args.defer = params.defer @@ -332,7 +332,7 @@ def set_args(self, params: RPCBuildParameters) -> None: if params.threads is not None: self.args.threads = params.threads if params.defer is None: - self.args.defer = flags.DEFER_MODE + self.args.defer = get_flags().DEFER_MODE else: self.args.defer = params.defer diff --git a/tests/conftest.py b/tests/conftest.py index 01b363a..e9d725d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,7 +113,14 @@ def dbt_profile_data(unique_schema, pytestconfig): @pytest.fixture def dbt_profile(profiles_root, dbt_profile_data) -> Dict[str, Any]: flags.PROFILES_DIR = profiles_root + original_profile_dir_env = os.environ.get('DBT_PROFILES_DIR') + # we have to do this to make sure the subprocess uses the correct profiles + os.environ['DBT_PROFILES_DIR'] = str(profiles_root) path = os.path.join(profiles_root, 'profiles.yml') with open(path, 'w') as fp: fp.write(yaml.safe_dump(dbt_profile_data)) - return dbt_profile_data + yield dbt_profile_data + if original_profile_dir_env: + os.environ['DBT_PROFILES_DIR'] = original_profile_dir_env + else: + del os.environ['DBT_PROFILES_DIR'] diff --git a/tests/util.py b/tests/util.py index bb6b93c..27dba11 100644 --- a/tests/util.py +++ b/tests/util.py @@ -648,6 +648,8 @@ def __init__(self, profiles_dir, which='run-operation', kwargs={}): self.project_dir = None self.profile = None self.target = None + self.threads = None + self.selector = None self.__dict__.update(kwargs) @@ -678,6 +680,8 @@ def built_schema(project_dir, schema, profiles_dir, project_def): os.chdir(project_dir) start = os.getcwd() try: + from dbt.flags import set_from_args + set_from_args(args, None) cfg = RuntimeConfig.from_args(args) finally: os.chdir(start)