Skip to content

Commit

Permalink
support 1.5
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenyuLInx committed Feb 16, 2023
1 parent c305bfc commit 4e9b489
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 45 deletions.
49 changes: 28 additions & 21 deletions dbt_rpc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
NotImplementedError,
FailedToConnectError
)
import dbt.flags as flags

from dbt.config.utils import parse_cli_vars
from dbt.flags import get_flags, set_from_args
from dbt_rpc.task.server import RPCServerTask


def initialize_tracking_from_flags():
# NOTE: this is copied from dbt-core
# Setting these used to be in UserConfig, but had to be moved here
flags = get_flags()
if flags.SEND_ANONYMOUS_USAGE_STATS:
dbt.tracking.initialize_tracking(flags.PROFILES_DIR)
else:
Expand Down Expand Up @@ -168,15 +169,12 @@ def adapter_management():

def handle_and_check(args):
with log_manager.applicationbound():
# this also set global flags
parsed = parse_args(args)

# Set flags from args, user config, and env vars
user_config = read_user_config(flags.PROFILES_DIR) # This is read again later
flags.set_from_args(parsed, user_config)
initialize_tracking_from_flags()
# Set log_format from flags
parsed.cls.set_log_format()

flags = get_flags()
# we've parsed the args and set the flags - we can now decide if we're debug or not
if flags.DEBUG:
log_manager.set_debug()
Expand All @@ -201,21 +199,23 @@ def handle_and_check(args):

@contextmanager
def track_run(task):
dbt.tracking.track_invocation_start(config=task.config, args=task.args)
invocation_context = dbt.tracking.get_base_invocation_context()
invocation_context["command"] = 'rpc'
dbt.tracking.track_invocation_start(invocation_context)
try:
yield
dbt.tracking.track_invocation_end(
config=task.config, args=task.args, result_type="ok"
invocation_context, result_type="ok"
)
except (NotImplementedError,
FailedToConnectError) as e:
logger.error('ERROR: {}'.format(e))
dbt.tracking.track_invocation_end(
config=task.config, args=task.args, result_type="error"
invocation_context, result_type="error"
)
except Exception:
dbt.tracking.track_invocation_end(
config=task.config, args=task.args, result_type="error"
invocation_context, result_type="error"
)
raise
finally:
Expand All @@ -233,6 +233,7 @@ def run_from_args(parsed):

# this will convert DbtConfigErrors into DbtRuntimeErrors
# task could be any one of the task objects
parsed.vars = parse_cli_vars(parsed.vars)
task = parsed.cls.from_args(args=parsed)

logger.debug("running dbt with arguments {parsed}", parsed=str(parsed))
Expand All @@ -246,7 +247,6 @@ def run_from_args(parsed):
logger.debug("Tracking: {}".format(dbt.tracking.active_user.state()))

results = None

with track_run(task):
results = task.run()

Expand Down Expand Up @@ -492,7 +492,7 @@ def parse_args(args, cls=DBTArgumentParser):
p.add_argument(
'--no-anonymous-usage-stats',
action='store_false',
default=None,
default=False,
dest='send_anonymous_usage_stats',
help='''
Do not send anonymous usage stat to dbt Labs
Expand Down Expand Up @@ -566,19 +566,19 @@ def parse_args(args, cls=DBTArgumentParser):
sys.exit(1)

parsed = p.parse_args(args)

# get the correct profiles_dir
if os.getenv('DBT_PROFILES_DIR'):
parsed.profiles_dir = os.getenv('DBT_PROFILES_DIR')
else:
from dbt.cli.resolvers import default_profiles_dir
parsed.profiles_dir = default_profiles_dir()
# profiles_dir is set before subcommands and after, so normalize
if hasattr(parsed, 'sub_profiles_dir'):
if parsed.sub_profiles_dir is not None:
parsed.profiles_dir = parsed.sub_profiles_dir
delattr(parsed, 'sub_profiles_dir')
if hasattr(parsed, 'profiles_dir'):
if parsed.profiles_dir is None:
parsed.profiles_dir = flags.PROFILES_DIR
else:
parsed.profiles_dir = os.path.abspath(parsed.profiles_dir)
# needs to be set before the other flags, because it's needed to
# read the profile that contains them
flags.PROFILES_DIR = parsed.profiles_dir
parsed.profiles_dir = os.path.abspath(parsed.profiles_dir)

# version_check is set before subcommands and after, so normalize
if hasattr(parsed, 'sub_version_check'):
Expand All @@ -596,6 +596,13 @@ def parse_args(args, cls=DBTArgumentParser):
expanded_user = os.path.expanduser(parsed.project_dir)
parsed.project_dir = os.path.abspath(expanded_user)

# set_args construct a flags with command run
# which doesn't have defer_mode, but we need a default value
parsed.defer_mode = 'eager'

# create global flags object
set_from_args(parsed, None)

if not hasattr(parsed, 'which'):
# the user did not provide a valid subcommand. trigger the help message
# and exit with a error
Expand Down
12 changes: 9 additions & 3 deletions dbt_rpc/rpc/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dbt.dataclass_schema import dbtClassMixin, ValidationError

import dbt.exceptions
from dbt.flags import env_set_truthy
from dbt.flags import env_set_truthy, get_flags, set_from_args
import dbt.tracking
from dbt.adapters.factory import (
cleanup_connections, load_plugin, register_adapter,
Expand Down Expand Up @@ -74,8 +74,9 @@ def _spawn_setup(self):
user_config = None
if self.task.config is not None:
user_config = self.task.config.user_config
dbt.flags.set_from_args(self.task.args, user_config)
dbt.tracking.initialize_from_flags()
set_from_args(self.task.args, user_config)
flags = get_flags()
dbt.tracking.initialize_from_flags(flags.SEND_ANONYMOUS_USAGE_STATS, flags.PROFILES_DIR)
# reload the active plugin
load_plugin(self.task.config.credentials.type)
# register it
Expand All @@ -96,6 +97,11 @@ def task_exec(self) -> None:
# some commands, like 'debug', won't have a threads value at all.
if getattr(self.task.args, 'threads', None) is not None:
self.task.config.threads = self.task.args.threads

# we previously always set a selector here
if not hasattr(self.task.args, 'selector'):
object.__setattr__(self.task.args, "selector", None)
object.__setattr__(self.task.args, "SELECTOR", None)
rpc_exception = None
result = None
try:
Expand Down
1 change: 1 addition & 0 deletions dbt_rpc/rpc/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def set_parsing(self) -> bool:
return True

def parse_manifest(self) -> None:
register_adapter(self.config)
self.manifest = ManifestLoader.get_full_manifest(self.config, reset=True)

def set_compile_exception(self, exc, logs=List[LogMessage]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion dbt_rpc/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class RPCTask(
RemoteManifestMethod[Parameters, RemoteExecutionResult]
):
def __init__(self, args, config, manifest):
super().__init__(args, config)
super().__init__(args, config, manifest)
RemoteManifestMethod.__init__(
self, args, config, manifest # type: ignore
)
Expand Down
40 changes: 28 additions & 12 deletions dbt_rpc/task/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from dbt.clients.yaml_helper import Dumper, yaml # noqa: F401
from typing import Type, Optional


from dbt.config.utils import parse_cli_vars
from dbt_rpc.contracts.rpc import RPCCliParameters

from dbt_rpc.rpc.method import (
Expand All @@ -28,6 +26,12 @@ def has_cli_parameters(cls):
def handle_request(self) -> Result:
pass

COMMAND_MAPING = {
"freshness": "source-freshness",
"snapshot-freshness": "source-freshness",
"generate": "docs.generate"
}


class RemoteRPCCli(RPCTask[RPCCliParameters]):
METHOD_NAME = 'cli_args'
Expand All @@ -53,11 +57,25 @@ def set_config(self, config):
)

def set_args(self, params: RPCCliParameters) -> None:
# NOTE: `parse_args` is pinned to the version of dbt-core installed!
from dbt.main import parse_args
from dbt_rpc.__main__ import RPCArgumentParser
split = shlex.split(params.cli)
self.args = parse_args(split, RPCArgumentParser)

from dbt.cli.flags import args_to_context, Flags

ctx = args_to_context(split)
self.args = Flags(ctx)

# previously this info is preserved in gloabl flags module
from dbt.flags import get_flags
object.__setattr__(self.args, 'PROFILES_DIR', get_flags().PROFILES_DIR)
object.__setattr__(self.args, 'profiles_dir', get_flags().PROFILES_DIR)

# this was handled by parse_args in original main before, now move the
# logic here
if ctx.command.name in COMMAND_MAPING:
rpc_method = COMMAND_MAPING[ctx.command.name]
else:
rpc_method = ctx.command.name
object.__setattr__(self.args, 'rpc_method', rpc_method)
self.task_type = self.get_rpc_task_cls()

def get_flags(self):
Expand Down Expand Up @@ -96,14 +114,12 @@ def handle_request(self) -> Result:
# future calls.

# read any cli vars we got and use it to update cli_vars
self.config.cli_vars.update(
parse_cli_vars(getattr(self.args, 'vars', '{}'))
)
self.config.cli_vars.update(self.args.vars)
# If this changed the vars, rewrite args.vars to reflect our merged
# vars and reload the manifest.
dumped = yaml.safe_dump(self.config.cli_vars)
if dumped != self.args.vars:
self.real_task.args.vars = dumped
if self.config.cli_vars != self.args.vars:
object.__setattr__(self.real_task.args, "cli_vars", self.config.cli_vars)
object.__setattr__(self.args, "cli_vars", self.config.cli_vars)
self.config.args = self.args
if isinstance(self.real_task, RemoteManifestMethod):
self.real_task.manifest = ManifestLoader.get_full_manifest(
Expand Down
2 changes: 2 additions & 0 deletions dbt_rpc/task/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@ def set_args(self, params: RPCDepsParameters):

def handle_request(self) -> RemoteDepsResult:
_clean_deps(self.config)
self.project = self.config
self.cli_vars = self.config.cli_vars
self.run()
return RemoteDepsResult([])
14 changes: 7 additions & 7 deletions dbt_rpc/task/project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from typing import List, Optional, Union

from dbt import flags
from dbt.flags import get_flags
from dbt.contracts.graph.manifest import WritableManifest
from dbt_rpc.contracts.rpc import (
GetManifestParameters,
Expand Down Expand Up @@ -65,10 +65,11 @@ def handle_request(self) -> RemoteExecutionResult:


def state_path(state: Optional[str]) -> Optional[Path]:
flags = get_flags()
if state is not None:
return Path(state)
elif flags.ARTIFACT_STATE_PATH is not None:
return Path(flags.ARTIFACT_STATE_PATH)
elif flags.STATE is not None:
return Path(flags.STATE)
else:
return None

Expand All @@ -89,7 +90,6 @@ def set_args(self, params: RPCCompileParameters) -> None:
self.args.threads = params.threads

self.args.state = state_path(params.state)

self.set_previous_state()


Expand All @@ -107,7 +107,7 @@ def set_args(self, params: RPCRunParameters) -> None:
if params.threads is not None:
self.args.threads = params.threads
if params.defer is None:
self.args.defer = flags.DEFER_MODE
self.args.defer = get_flags().DEFER_MODE
else:
self.args.defer = params.defer

Expand Down Expand Up @@ -146,7 +146,7 @@ def set_args(self, params: RPCTestParameters) -> None:
if params.threads is not None:
self.args.threads = params.threads
if params.defer is None:
self.args.defer = flags.DEFER_MODE
self.args.defer = get_flags().DEFER_MODE
else:
self.args.defer = params.defer

Expand Down Expand Up @@ -332,7 +332,7 @@ def set_args(self, params: RPCBuildParameters) -> None:
if params.threads is not None:
self.args.threads = params.threads
if params.defer is None:
self.args.defer = flags.DEFER_MODE
self.args.defer = get_flags().DEFER_MODE
else:
self.args.defer = params.defer

Expand Down
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ def dbt_profile_data(unique_schema, pytestconfig):
@pytest.fixture
def dbt_profile(profiles_root, dbt_profile_data) -> Dict[str, Any]:
flags.PROFILES_DIR = profiles_root
original_profile_dir_env = os.environ.get('DBT_PROFILES_DIR')
# we have to do this to make sure the subprocess uses the correct profiles
os.environ['DBT_PROFILES_DIR'] = str(profiles_root)
path = os.path.join(profiles_root, 'profiles.yml')
with open(path, 'w') as fp:
fp.write(yaml.safe_dump(dbt_profile_data))
return dbt_profile_data
yield dbt_profile_data
if original_profile_dir_env:
os.environ['DBT_PROFILES_DIR'] = original_profile_dir_env
else:
del os.environ['DBT_PROFILES_DIR']
4 changes: 4 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@ def __init__(self, profiles_dir, which='run-operation', kwargs={}):
self.project_dir = None
self.profile = None
self.target = None
self.threads = None
self.selector = None
self.__dict__.update(kwargs)


Expand Down Expand Up @@ -678,6 +680,8 @@ def built_schema(project_dir, schema, profiles_dir, project_def):
os.chdir(project_dir)
start = os.getcwd()
try:
from dbt.flags import set_from_args
set_from_args(args, None)
cfg = RuntimeConfig.from_args(args)
finally:
os.chdir(start)
Expand Down

0 comments on commit 4e9b489

Please sign in to comment.