From 6a7dad5f6067713220eefbac39d023d606bba239 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 10 Oct 2023 11:17:41 -0400 Subject: [PATCH] [Fix] respect project root when loading seeds (#8762) (cherry picked from commit 964e0e4e8ad50a917074aabc8493b210a84e0258) --- .../unreleased/Fixes-20231006-134551.yaml | 6 + core/dbt/context/providers.py | 13 +- core/dbt/tests/util.py | 545 ++++++++++++++++++ tests/functional/partial_parsing/fixtures.py | 66 +++ .../partial_parsing/test_partial_parsing.py | 86 +++ tests/functional/utils.py | 14 + 6 files changed, 728 insertions(+), 2 deletions(-) create mode 100644 .changes/unreleased/Fixes-20231006-134551.yaml create mode 100644 core/dbt/tests/util.py create mode 100644 tests/functional/partial_parsing/fixtures.py create mode 100644 tests/functional/partial_parsing/test_partial_parsing.py create mode 100644 tests/functional/utils.py diff --git a/.changes/unreleased/Fixes-20231006-134551.yaml b/.changes/unreleased/Fixes-20231006-134551.yaml new file mode 100644 index 00000000000..9dea4f6e194 --- /dev/null +++ b/.changes/unreleased/Fixes-20231006-134551.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Enable seeds to be handled from stored manifest data +time: 2023-10-06T13:45:51.925546-04:00 +custom: + Author: michelleark + Issue: "6875" diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 3b0ff96d004..e74722936c0 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -769,9 +769,18 @@ def load_agate_table(self) -> agate.Table: 'can only load_agate_table for seeds (got a {})' .format(self.model.resource_type) ) - path = os.path.join( - self.model.root_path, self.model.original_file_path + + # include package_path for seeds defined in packages + package_path = ( + os.path.join(self.config.packages_install_path, self.model.package_name) + if self.model.package_name != self.config.project_name + else "." ) + path = os.path.join(self.config.project_root, package_path, self.model.original_file_path) + if not os.path.exists(path): + assert self.model.root_path + path = os.path.join(self.model.root_path, self.model.original_file_path) + column_types = self.model.config.column_types try: table = agate_helper.from_csv(path, text_columns=column_types) diff --git a/core/dbt/tests/util.py b/core/dbt/tests/util.py new file mode 100644 index 00000000000..98b5e1f4eb1 --- /dev/null +++ b/core/dbt/tests/util.py @@ -0,0 +1,545 @@ +from io import StringIO +import os +import shutil +import yaml +import json +import warnings +from datetime import datetime +from typing import Dict, List +from contextlib import contextmanager +from dbt.adapters.factory import Adapter + +from dbt.main import handle_and_check +from dbt.logger import log_manager +from dbt.contracts.graph.manifest import Manifest +from dbt.events.functions import fire_event, capture_stdout_logs, stop_capture_stdout_logs, reset_metadata_vars +from dbt.events.test_types import IntegrationTestDebug + +# ============================================================================= +# Test utilities +# run_dbt +# run_dbt_and_capture +# get_manifest +# copy_file +# rm_file +# write_file +# read_file +# mkdir +# rm_dir +# get_artifact +# update_config_file +# write_config_file +# get_unique_ids_in_results +# check_result_nodes_by_name +# check_result_nodes_by_unique_id + +# SQL related utilities that use the adapter +# run_sql_with_adapter +# relation_from_name +# check_relation_types (table/view) +# check_relations_equal +# check_relation_has_expected_schema +# check_relations_equal_with_relations +# check_table_does_exist +# check_table_does_not_exist +# get_relation_columns +# update_rows +# generate_update_clause +# +# Classes for comparing fields in dictionaries +# AnyFloat +# AnyInteger +# AnyString +# AnyStringWith +# ============================================================================= + + +# 'run_dbt' is used in pytest tests to run dbt commands. It will return +# different objects depending on the command that is executed. +# For a run command (and most other commands) it will return a list +# of results. For the 'docs generate' command it returns a CatalogArtifact. +# The first parameter is a list of dbt command line arguments, such as +# run_dbt(["run", "--vars", "seed_name: base"]) +# If the command is expected to fail, pass in "expect_pass=False"): +# run_dbt("test"], expect_pass=False) +def run_dbt(args: List[str] = None, expect_pass=True): + # Ignore logbook warnings + warnings.filterwarnings("ignore", category=DeprecationWarning, module="logbook") + + # reset global vars + reset_metadata_vars() + + # The logger will complain about already being initialized if + # we don't do this. + log_manager.reset_handlers() + if args is None: + args = ["run"] + + print("\n\nInvoking dbt with {}".format(args)) + res, success = handle_and_check(args) + + if expect_pass is not None: + assert success == expect_pass, "dbt exit state did not match expected" + + return res + + +# Use this if you need to capture the command logs in a test. +# If you want the logs that are normally written to a file, you must +# start with the "--debug" flag. The structured schema log CI test +# will turn the logs into json, so you have to be prepared for that. +def run_dbt_and_capture(args: List[str] = None, expect_pass=True): + try: + stringbuf = StringIO() + capture_stdout_logs(stringbuf) + res = run_dbt(args, expect_pass=expect_pass) + stdout = stringbuf.getvalue() + + finally: + stop_capture_stdout_logs() + + # Json logs will have lots of escape characters which will + # make checks for strings in the logs fail, so remove those. + if '{"code":' in stdout: + stdout = stdout.replace("\\", "") + + return res, stdout + + +# Used in test cases to get the manifest from the partial parsing file +# Note: this uses an internal version of the manifest, and in the future +# parts of it will not be supported for external use. +def get_manifest(project_root): + path = os.path.join(project_root, "target", "partial_parse.msgpack") + if os.path.exists(path): + with open(path, "rb") as fp: + manifest_mp = fp.read() + manifest: Manifest = Manifest.from_msgpack(manifest_mp) + return manifest + else: + return None + + +# Used in tests to copy a file, usually from a data directory to the project directory +def copy_file(src_path, src, dest_path, dest) -> None: + # dest is a list, so that we can provide nested directories, like 'models' etc. + # copy files from the data_dir to appropriate project directory + shutil.copyfile( + os.path.join(src_path, src), + os.path.join(dest_path, *dest), + ) + + +# Used in tests when you want to remove a file from the project directory +def rm_file(*paths) -> None: + # remove files from proj_path + os.remove(os.path.join(*paths)) + + +# Used in tests to write out the string contents of a file to a +# file in the project directory. +# We need to explicitly use encoding="utf-8" because otherwise on +# Windows we'll get codepage 1252 and things might break +def write_file(contents, *paths): + with open(os.path.join(*paths), "w", encoding="utf-8") as fp: + fp.write(contents) + + +# Used in test utilities +def read_file(*paths): + contents = "" + with open(os.path.join(*paths), "r") as fp: + contents = fp.read() + return contents + + +# To create a directory +def mkdir(directory_path): + try: + os.makedirs(directory_path) + except FileExistsError: + raise FileExistsError(f"{directory_path} already exists.") + + +# To remove a directory +def rm_dir(directory_path): + try: + shutil.rmtree(directory_path) + except FileNotFoundError: + raise FileNotFoundError(f"{directory_path} does not exist.") + + +def rename_dir(src_directory_path, dest_directory_path): + os.rename(src_directory_path, dest_directory_path) + + +# Get an artifact (usually from the target directory) such as +# manifest.json or catalog.json to use in a test +def get_artifact(*paths): + contents = read_file(*paths) + dct = json.loads(contents) + return dct + + +# For updating yaml config files +def update_config_file(updates, *paths): + current_yaml = read_file(*paths) + config = yaml.safe_load(current_yaml) + config.update(updates) + new_yaml = yaml.safe_dump(config) + write_file(new_yaml, *paths) + + +# Write new config file +def write_config_file(data, *paths): + if type(data) is dict: + data = yaml.safe_dump(data) + write_file(data, *paths) + + +# Get the unique_ids in dbt command results +def get_unique_ids_in_results(results): + unique_ids = [] + for result in results: + unique_ids.append(result.node.unique_id) + return unique_ids + + +# Check the nodes in the results returned by a dbt run command +def check_result_nodes_by_name(results, names): + result_names = [] + for result in results: + result_names.append(result.node.name) + assert set(names) == set(result_names) + + +# Check the nodes in the results returned by a dbt run command +def check_result_nodes_by_unique_id(results, unique_ids): + result_unique_ids = [] + for result in results: + result_unique_ids.append(result.node.unique_id) + assert set(unique_ids) == set(result_unique_ids) + + +# Check datetime is between start and end/now +def check_datetime_between(timestr, start, end=None): + datefmt = "%Y-%m-%dT%H:%M:%S.%fZ" + if end is None: + end = datetime.utcnow() + parsed = datetime.strptime(timestr, datefmt) + assert start <= parsed + assert end >= parsed + + +class TestProcessingException(Exception): + pass + + +# Testing utilities that use adapter code + +# Uses: +# adapter.config.credentials +# adapter.quote +# adapter.run_sql_for_tests +def run_sql_with_adapter(adapter, sql, fetch=None): + if sql.strip() == "": + return + + # substitute schema and database in sql + kwargs = { + "schema": adapter.config.credentials.schema, + "database": adapter.quote(adapter.config.credentials.database), + } + sql = sql.format(**kwargs) + + msg = f'test connection "__test" executing: {sql}' + fire_event(IntegrationTestDebug(msg=msg)) + with get_connection(adapter) as conn: + return adapter.run_sql_for_tests(sql, fetch, conn) + + +# Get a Relation object from the identifier (name of table/view). +# Uses the default database and schema. If you need a relation +# with a different schema, it should be constructed in the test. +# Uses: +# adapter.Relation +# adapter.config.credentials +# Relation.get_default_quote_policy +# Relation.get_default_include_policy +def relation_from_name(adapter, name: str): + """reverse-engineer a relation from a given name and + the adapter. The relation name is split by the '.' character. + """ + + # Different adapters have different Relation classes + cls = adapter.Relation + credentials = adapter.config.credentials + quote_policy = cls.get_default_quote_policy().to_dict() + include_policy = cls.get_default_include_policy().to_dict() + + # Make sure we have database/schema/identifier parts, even if + # only identifier was supplied. + relation_parts = name.split(".") + if len(relation_parts) == 1: + relation_parts.insert(0, credentials.schema) + if len(relation_parts) == 2: + relation_parts.insert(0, credentials.database) + kwargs = { + "database": relation_parts[0], + "schema": relation_parts[1], + "identifier": relation_parts[2], + } + + relation = cls.create( + include_policy=include_policy, + quote_policy=quote_policy, + **kwargs, + ) + return relation + + +# Ensure that models with different materialiations have the +# corrent table/view. +# Uses: +# adapter.list_relations_without_caching +def check_relation_types(adapter, relation_to_type): + """ + Relation name to table/view + { + "base": "table", + "other": "view", + } + """ + + expected_relation_values = {} + found_relations = [] + schemas = set() + + for key, value in relation_to_type.items(): + relation = relation_from_name(adapter, key) + expected_relation_values[relation] = value + schemas.add(relation.without_identifier()) + + with get_connection(adapter): + for schema in schemas: + found_relations.extend(adapter.list_relations_without_caching(schema)) + + for key, value in relation_to_type.items(): + for relation in found_relations: + # this might be too broad + if relation.identifier == key: + assert relation.type == value, ( + f"Got an unexpected relation type of {relation.type} " + f"for relation {key}, expected {value}" + ) + + +# Replaces assertTablesEqual. assertManyTablesEqual can be replaced +# by doing a separate call for each set of tables/relations. +# Wraps check_relations_equal_with_relations by creating relations +# from the list of names passed in. +def check_relations_equal(adapter, relation_names: List, compare_snapshot_cols=False): + if len(relation_names) < 2: + raise TestProcessingException( + "Not enough relations to compare", + ) + relations = [relation_from_name(adapter, name) for name in relation_names] + return check_relations_equal_with_relations( + adapter, relations, compare_snapshot_cols=compare_snapshot_cols + ) + + +# Used to check that a particular relation has an expected schema +# expected_schema should look like {"column_name": "expected datatype"} +def check_relation_has_expected_schema(adapter, relation_name, expected_schema: Dict): + relation = relation_from_name(adapter, relation_name) + with get_connection(adapter): + actual_columns = {c.name: c.data_type for c in adapter.get_columns_in_relation(relation)} + assert ( + actual_columns == expected_schema + ), f"Actual schema did not match expected, actual: {json.dumps(actual_columns)}" + + +# This can be used when checking relations in different schemas, by supplying +# a list of relations. Called by 'check_relations_equal'. +# Uses: +# adapter.get_columns_in_relation +# adapter.get_rows_different_sql +# adapter.execute +def check_relations_equal_with_relations( + adapter: Adapter, relations: List, compare_snapshot_cols=False +): + + with get_connection(adapter): + basis, compares = relations[0], relations[1:] + # Skip columns starting with "dbt_" because we don't want to + # compare those, since they are time sensitive + # (unless comparing "dbt_" snapshot columns is explicitly enabled) + column_names = [ + c.name + for c in adapter.get_columns_in_relation(basis) # type: ignore + if not c.name.lower().startswith("dbt_") or compare_snapshot_cols + ] + + for relation in compares: + sql = adapter.get_rows_different_sql(basis, relation, column_names=column_names) # type: ignore + _, tbl = adapter.execute(sql, fetch=True) + num_rows = len(tbl) + assert ( + num_rows == 1 + ), f"Invalid sql query from get_rows_different_sql: incorrect number of rows ({num_rows})" + num_cols = len(tbl[0]) + assert ( + num_cols == 2 + ), f"Invalid sql query from get_rows_different_sql: incorrect number of cols ({num_cols})" + row_count_difference = tbl[0][0] + assert ( + row_count_difference == 0 + ), f"Got {row_count_difference} difference in row count betwen {basis} and {relation}" + rows_mismatched = tbl[0][1] + assert ( + rows_mismatched == 0 + ), f"Got {rows_mismatched} different rows between {basis} and {relation}" + + +# Uses: +# adapter.update_column_sql +# adapter.execute +# adapter.commit_if_has_connection +def update_rows(adapter, update_rows_config): + """ + { + "name": "base", + "dst_col": "some_date" + "clause": { + "type": "add_timestamp", + "src_col": "some_date", + "where" "id > 10" + } + """ + for key in ["name", "dst_col", "clause"]: + if key not in update_rows_config: + raise TestProcessingException(f"Invalid update_rows: no {key}") + + clause = update_rows_config["clause"] + clause = generate_update_clause(adapter, clause) + + where = None + if "where" in update_rows_config: + where = update_rows_config["where"] + + name = update_rows_config["name"] + dst_col = update_rows_config["dst_col"] + relation = relation_from_name(adapter, name) + + with get_connection(adapter): + sql = adapter.update_column_sql( + dst_name=str(relation), + dst_column=dst_col, + clause=clause, + where_clause=where, + ) + adapter.execute(sql, auto_begin=True) + adapter.commit_if_has_connection() + + +# This is called by the 'update_rows' function. +# Uses: +# adapter.timestamp_add_sql +# adapter.string_add_sql +def generate_update_clause(adapter, clause) -> str: + """ + Called by update_rows function. Expects the "clause" dictionary + documented in 'update_rows. + """ + + if "type" not in clause or clause["type"] not in ["add_timestamp", "add_string"]: + raise TestProcessingException("invalid update_rows clause: type missing or incorrect") + clause_type = clause["type"] + + if clause_type == "add_timestamp": + if "src_col" not in clause: + raise TestProcessingException("Invalid update_rows clause: no src_col") + add_to = clause["src_col"] + kwargs = {k: v for k, v in clause.items() if k in ("interval", "number")} + with get_connection(adapter): + return adapter.timestamp_add_sql(add_to=add_to, **kwargs) + elif clause_type == "add_string": + for key in ["src_col", "value"]: + if key not in clause: + raise TestProcessingException(f"Invalid update_rows clause: no {key}") + src_col = clause["src_col"] + value = clause["value"] + location = clause.get("location", "append") + with get_connection(adapter): + return adapter.string_add_sql(src_col, value, location) + return "" + + +@contextmanager +def get_connection(adapter, name="_test"): + with adapter.connection_named(name): + conn = adapter.connections.get_thread_connection() + yield conn + + +# Uses: +# adapter.get_columns_in_relation +def get_relation_columns(adapter, name): + relation = relation_from_name(adapter, name) + with get_connection(adapter): + columns = adapter.get_columns_in_relation(relation) + return sorted(((c.name, c.dtype, c.char_size) for c in columns), key=lambda x: x[0]) + + +def check_table_does_not_exist(adapter, name): + columns = get_relation_columns(adapter, name) + assert len(columns) == 0 + + +def check_table_does_exist(adapter, name): + columns = get_relation_columns(adapter, name) + assert len(columns) > 0 + + +# Utility classes for enabling comparison of dictionaries + + +class AnyFloat: + """Any float. Use this in assert calls""" + + def __eq__(self, other): + return isinstance(other, float) + + +class AnyInteger: + """Any Integer. Use this in assert calls""" + + def __eq__(self, other): + return isinstance(other, int) + + +class AnyString: + """Any string. Use this in assert calls""" + + def __eq__(self, other): + return isinstance(other, str) + + +class AnyStringWith: + """AnyStringWith("AUTO")""" + + def __init__(self, contains=None): + self.contains = contains + + def __eq__(self, other): + if not isinstance(other, str): + return False + + if self.contains is None: + return True + + return self.contains in other + + def __repr__(self): + return "AnyStringWith<{!r}>".format(self.contains) diff --git a/tests/functional/partial_parsing/fixtures.py b/tests/functional/partial_parsing/fixtures.py new file mode 100644 index 00000000000..6938b0bfcd7 --- /dev/null +++ b/tests/functional/partial_parsing/fixtures.py @@ -0,0 +1,66 @@ +local_dependency__dbt_project_yml = """ + +name: 'local_dep' +version: '1.0' +config-version: 2 + +profile: 'default' + +model-paths: ["models"] +analysis-paths: ["analyses"] +test-paths: ["tests"] +seed-paths: ["seeds"] +macro-paths: ["macros"] + +require-dbt-version: '>=0.1.0' + +target-path: "target" # directory which will store compiled SQL files +clean-targets: # directories to be removed by `dbt clean` + - "target" + - "dbt_packages" + + +seeds: + quote_columns: False + +""" + +local_dependency__models__schema_yml = """ +version: 2 +sources: + - name: seed_source + schema: "{{ var('schema_override', target.schema) }}" + tables: + - name: "seed" + columns: + - name: id + tests: + - unique + +""" + +local_dependency__models__model_to_import_sql = """ +select * from {{ ref('seed') }} + +""" + +local_dependency__macros__dep_macro_sql = """ +{% macro some_overridden_macro() -%} +100 +{%- endmacro %} + +""" + +local_dependency__seeds__seed_csv = """id +1 +""" + +model_one_sql = """ +select 1 as fun + +""" + +model_two_sql = """ +select 1 as notfun + +""" diff --git a/tests/functional/partial_parsing/test_partial_parsing.py b/tests/functional/partial_parsing/test_partial_parsing.py new file mode 100644 index 00000000000..62a6553f5cf --- /dev/null +++ b/tests/functional/partial_parsing/test_partial_parsing.py @@ -0,0 +1,86 @@ +from argparse import Namespace +import pytest + +import dbt.flags as flags +from dbt.tests.util import ( + run_dbt, + write_file, + rename_dir, +) +from tests.functional.utils import up_one +from dbt.tests.fixtures.project import write_project_files +from tests.functional.partial_parsing.fixtures import ( + model_one_sql, + model_two_sql, + local_dependency__dbt_project_yml, + local_dependency__models__schema_yml, + local_dependency__models__model_to_import_sql, + local_dependency__macros__dep_macro_sql, + local_dependency__seeds__seed_csv, +) + +import os + +os.environ["DBT_PP_TEST"] = "true" + + +class TestPortablePartialParsing: + @pytest.fixture(scope="class") + def models(self): + return { + "model_one.sql": model_one_sql, + } + + @pytest.fixture(scope="class") + def packages(self): + return {"packages": [{"local": "local_dependency"}]} + + @pytest.fixture(scope="class") + def local_dependency_files(self): + return { + "dbt_project.yml": local_dependency__dbt_project_yml, + "models": { + "schema.yml": local_dependency__models__schema_yml, + "model_to_import.sql": local_dependency__models__model_to_import_sql, + }, + "macros": {"dep_macro.sql": local_dependency__macros__dep_macro_sql}, + "seeds": {"seed.csv": local_dependency__seeds__seed_csv}, + } + + def rename_project_root(self, project, new_project_root): + with up_one(new_project_root): + rename_dir(project.project_root, new_project_root) + project.project_root = new_project_root + # flags.project_dir is set during the project test fixture, and is persisted across run_dbt calls, + # so it needs to be reset between invocations + flags.set_from_args(Namespace(PROJECT_DIR=new_project_root), None) + + @pytest.fixture(scope="class", autouse=True) + def initial_run_and_rename_project_dir(self, project, local_dependency_files): + initial_project_root = project.project_root + renamed_project_root = os.path.join(project.project_root.dirname, "renamed_project_dir") + + write_project_files(project.project_root, "local_dependency", local_dependency_files) + + # initial run + run_dbt(["deps"]) + assert len(run_dbt(["seed"])) == 1 + assert len(run_dbt(["run"])) == 2 + + self.rename_project_root(project, renamed_project_root) + yield + self.rename_project_root(project, initial_project_root) + + def test_pp_renamed_project_dir_unchanged_project_contents(self, project): + # partial parse same project in new absolute dir location, using partial_parse.msgpack created in previous dir + run_dbt(["deps"]) + assert len(run_dbt(["--partial-parse", "seed"])) == 1 + assert len(run_dbt(["--partial-parse", "run"])) == 2 + + def test_pp_renamed_project_dir_changed_project_contents(self, project): + write_file(model_two_sql, project.project_root, "models", "model_two.sql") + + # partial parse changed project in new absolute dir location, using partial_parse.msgpack created in previous dir + run_dbt(["deps"]) + len(run_dbt(["--partial-parse", "seed"])) == 1 + len(run_dbt(["--partial-parse", "run"])) == 3 diff --git a/tests/functional/utils.py b/tests/functional/utils.py new file mode 100644 index 00000000000..ddfe367856b --- /dev/null +++ b/tests/functional/utils.py @@ -0,0 +1,14 @@ +import os +from contextlib import contextmanager +from typing import Optional +from pathlib import Path + + +@contextmanager +def up_one(return_path: Optional[Path] = None): + current_path = Path.cwd() + os.chdir("../") + try: + yield + finally: + os.chdir(return_path or current_path)