From e01d4c0a6e73c320604953959a5d7e11e115c033 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Wed, 28 Jun 2023 07:55:11 -0700 Subject: [PATCH] Add restrict-access to dbt_project.yml (#7962) --- .../unreleased/Features-20230627-132749.yaml | 6 + core/dbt/config/project.py | 3 + core/dbt/config/runtime.py | 1 + core/dbt/context/providers.py | 29 ++-- core/dbt/contracts/graph/manifest.py | 46 +++++- core/dbt/contracts/project.py | 1 + core/dbt/exceptions.py | 10 +- core/dbt/parser/manifest.py | 51 ++++--- .../{groups => access}/test_access.py | 134 ++++++++++++++++++ 9 files changed, 243 insertions(+), 38 deletions(-) create mode 100644 .changes/unreleased/Features-20230627-132749.yaml rename tests/functional/{groups => access}/test_access.py (63%) diff --git a/.changes/unreleased/Features-20230627-132749.yaml b/.changes/unreleased/Features-20230627-132749.yaml new file mode 100644 index 00000000000..fbd427a43da --- /dev/null +++ b/.changes/unreleased/Features-20230627-132749.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add restrict-access to dbt_project.yml +time: 2023-06-27T13:27:49.114257-04:00 +custom: + Author: michelleark + Issue: "7713" diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index 0fd2165f660..baa99239e99 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -497,6 +497,7 @@ def create_project(self, rendered: RenderComponents) -> "Project": config_version=cfg.config_version, unrendered=unrendered, project_env_vars=project_env_vars, + restrict_access=cfg.restrict_access, ) # sanity check - this means an internal issue project.validate() @@ -607,6 +608,7 @@ class Project: config_version: int unrendered: RenderComponents project_env_vars: Dict[str, Any] + restrict_access: bool @property def all_source_paths(self) -> List[str]: @@ -675,6 +677,7 @@ def to_project_config(self, with_packages=False): "vars": self.vars.to_dict(), "require-dbt-version": [v.to_version_string() for v in self.dbt_version], "config-version": self.config_version, + "restrict-access": self.restrict_access, } ) if self.query_comment: diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index 28416f68519..8e11b2cd43a 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -172,6 +172,7 @@ def from_parts( config_version=project.config_version, unrendered=project.unrendered, project_env_vars=project.project_env_vars, + restrict_access=project.restrict_access, profile_env_vars=profile.profile_env_vars, profile_name=profile.profile_name, target_name=profile.target_name, diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 28c6681be7e..989feb9b155 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -506,19 +506,24 @@ def resolve( target_version=target_version, disabled=isinstance(target_model, Disabled), ) - elif ( - target_model.resource_type == NodeType.Model - and target_model.access == AccessType.Private - # don't raise this reference error for ad hoc 'preview' queries - and self.model.resource_type != NodeType.SqlOperation - and self.model.resource_type != NodeType.RPCCall # TODO: rm + elif self.manifest.is_invalid_private_ref( + self.model, target_model, self.config.dependencies ): - if not self.model.group or self.model.group != target_model.group: - raise DbtReferenceError( - unique_id=self.model.unique_id, - ref_unique_id=target_model.unique_id, - group=cast_to_str(target_model.group), - ) + raise DbtReferenceError( + unique_id=self.model.unique_id, + ref_unique_id=target_model.unique_id, + access=AccessType.Private, + scope=cast_to_str(target_model.group), + ) + elif self.manifest.is_invalid_protected_ref( + self.model, target_model, self.config.dependencies + ): + raise DbtReferenceError( + unique_id=self.model.unique_id, + ref_unique_id=target_model.unique_id, + access=AccessType.Protected, + scope=target_model.package_name, + ) self.validate(target_model, target_name, target_package, target_version) return self.create_relation(target_model) diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 971383c831d..e2eb1ae410d 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -55,7 +55,7 @@ from dbt.events.functions import fire_event from dbt.events.types import MergedFromState, UnpinnedRefNewVersionAvailable from dbt.events.contextvars import get_node_info -from dbt.node_types import NodeType +from dbt.node_types import NodeType, AccessType from dbt.flags import get_flags, MP_CONTEXT from dbt import tracking import dbt.utils @@ -1123,6 +1123,50 @@ def resolve_doc( return result return None + def is_invalid_private_ref( + self, node: GraphMemberNode, target_model: MaybeNonSource, dependencies: Optional[Mapping] + ) -> bool: + dependencies = dependencies or {} + if not isinstance(target_model, ModelNode): + return False + + is_private_ref = ( + target_model.access == AccessType.Private + # don't raise this reference error for ad hoc 'preview' queries + and node.resource_type != NodeType.SqlOperation + and node.resource_type != NodeType.RPCCall # TODO: rm + ) + target_dependency = dependencies.get(target_model.package_name) + restrict_package_access = target_dependency.restrict_access if target_dependency else False + + # TODO: SemanticModel and SourceDefinition do not have group, and so should not be able to make _any_ private ref. + return is_private_ref and ( + not hasattr(node, "group") + or not node.group + or node.group != target_model.group + or restrict_package_access + ) + + def is_invalid_protected_ref( + self, node: GraphMemberNode, target_model: MaybeNonSource, dependencies: Optional[Mapping] + ) -> bool: + dependencies = dependencies or {} + if not isinstance(target_model, ModelNode): + return False + + is_protected_ref = ( + target_model.access == AccessType.Protected + # don't raise this reference error for ad hoc 'preview' queries + and node.resource_type != NodeType.SqlOperation + and node.resource_type != NodeType.RPCCall # TODO: rm + ) + target_dependency = dependencies.get(target_model.package_name) + restrict_package_access = target_dependency.restrict_access if target_dependency else False + + return is_protected_ref and ( + node.package_name != target_model.package_name and restrict_package_access + ) + # Called by RunTask.defer_to_manifest def merge_from_artifact( self, diff --git a/core/dbt/contracts/project.py b/core/dbt/contracts/project.py index 581932e5888..d9bd0c6fb89 100644 --- a/core/dbt/contracts/project.py +++ b/core/dbt/contracts/project.py @@ -223,6 +223,7 @@ class Project(HyphenatedDbtClassMixin, Replaceable): ) packages: List[PackageSpec] = field(default_factory=list) query_comment: Optional[Union[QueryComment, NoValue, str]] = field(default_factory=NoValue) + restrict_access: bool = False @classmethod def validate(cls, data): diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 2f714921292..665c7f6dfd9 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -7,7 +7,7 @@ from dbt.dataclass_schema import ValidationError from dbt.events.helpers import env_secrets, scrub_secrets -from dbt.node_types import NodeType +from dbt.node_types import NodeType, AccessType from dbt.ui import line_wrap_message import dbt.dataclass_schema @@ -1219,16 +1219,18 @@ def __init__(self, exc: ValidationError, node): class DbtReferenceError(ParsingError): - def __init__(self, unique_id: str, ref_unique_id: str, group: str): + def __init__(self, unique_id: str, ref_unique_id: str, access: AccessType, scope: str): self.unique_id = unique_id self.ref_unique_id = ref_unique_id - self.group = group + self.access = access + self.scope = scope + self.scope_type = "group" if self.access == AccessType.Private else "package" super().__init__(msg=self.get_message()) def get_message(self) -> str: return ( f"Node {self.unique_id} attempted to reference node {self.ref_unique_id}, " - f"which is not allowed because the referenced node is private to the {self.group} group." + f"which is not allowed because the referenced node is {self.access} to the '{self.scope}' {self.scope_type}." ) diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 93f4d21fb7c..b98b6fe2fc9 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -508,7 +508,7 @@ def load(self): # determine whether they need processing. start_process = time.perf_counter() self.process_sources(self.root_project.project_name) - self.process_refs(self.root_project.project_name) + self.process_refs(self.root_project.project_name, self.root_project.dependencies) self.process_docs(self.root_project) self.process_metrics(self.root_project) self.check_valid_group_config() @@ -533,7 +533,10 @@ def load(self): external_nodes_modified = self.inject_external_nodes() if external_nodes_modified: self.manifest.rebuild_ref_lookup() - self.process_refs(self.root_project.project_name) + self.process_refs( + self.root_project.project_name, + self.root_project.dependencies, + ) # parent and child maps will be rebuilt by write_manifest if not skip_parsing: @@ -1038,23 +1041,23 @@ def track_project_load(self): # Takes references in 'refs' array of nodes and exposures, finds the target # node, and updates 'depends_on.nodes' with the unique id - def process_refs(self, current_project: str): + def process_refs(self, current_project: str, dependencies: Optional[Dict[str, Project]]): for node in self.manifest.nodes.values(): if node.created_at < self.started_at: continue - _process_refs(self.manifest, current_project, node) + _process_refs(self.manifest, current_project, node, dependencies) for exposure in self.manifest.exposures.values(): if exposure.created_at < self.started_at: continue - _process_refs(self.manifest, current_project, exposure) + _process_refs(self.manifest, current_project, exposure, dependencies) for metric in self.manifest.metrics.values(): if metric.created_at < self.started_at: continue - _process_refs(self.manifest, current_project, metric) + _process_refs(self.manifest, current_project, metric, dependencies) for semantic_model in self.manifest.semantic_models.values(): if semantic_model.created_at < self.started_at: continue - _process_refs(self.manifest, current_project, semantic_model) + _process_refs(self.manifest, current_project, semantic_model, dependencies) self.update_semantic_model(semantic_model) # Takes references in 'metrics' array of nodes and exposures, finds the target @@ -1372,9 +1375,13 @@ def _process_docs_for_metrics(context: Dict[str, Any], metric: Metric) -> None: metric.description = get_rendered(metric.description, context) -def _process_refs(manifest: Manifest, current_project: str, node) -> None: +def _process_refs( + manifest: Manifest, current_project: str, node, dependencies: Optional[Mapping[str, Project]] +) -> None: """Given a manifest and node in that manifest, process its refs""" + dependencies = dependencies or {} + if isinstance(node, SeedNode): return @@ -1413,18 +1420,20 @@ def _process_refs(manifest: Manifest, current_project: str, node) -> None: ) continue - elif ( - isinstance(target_model, ModelNode) - and target_model.access == AccessType.Private - and node.resource_type != NodeType.SqlOperation - and node.resource_type != NodeType.RPCCall # TODO: rm - ): - if not node.group or node.group != target_model.group: - raise dbt.exceptions.DbtReferenceError( - unique_id=node.unique_id, - ref_unique_id=target_model.unique_id, - group=dbt.utils.cast_to_str(target_model.group), - ) + elif manifest.is_invalid_private_ref(node, target_model, dependencies): + raise dbt.exceptions.DbtReferenceError( + unique_id=node.unique_id, + ref_unique_id=target_model.unique_id, + access=AccessType.Private, + scope=dbt.utils.cast_to_str(target_model.group), + ) + elif manifest.is_invalid_protected_ref(node, target_model, dependencies): + raise dbt.exceptions.DbtReferenceError( + unique_id=node.unique_id, + ref_unique_id=target_model.unique_id, + access=AccessType.Protected, + scope=target_model.package_name, + ) target_model_id = target_model.unique_id node.depends_on.add_node(target_model_id) @@ -1577,7 +1586,7 @@ def process_macro(config: RuntimeConfig, manifest: Manifest, macro: Macro) -> No def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode): _process_sources_for_node(manifest, config.project_name, node) - _process_refs(manifest, config.project_name, node) + _process_refs(manifest, config.project_name, node, config.dependencies) ctx = generate_runtime_docs_context(config, node, manifest, config.project_name) _process_docs_for_node(ctx, node) diff --git a/tests/functional/groups/test_access.py b/tests/functional/access/test_access.py similarity index 63% rename from tests/functional/groups/test_access.py rename to tests/functional/access/test_access.py index fdffe4e1abd..4e9551d08a3 100644 --- a/tests/functional/groups/test_access.py +++ b/tests/functional/access/test_access.py @@ -1,4 +1,7 @@ import pytest + +from dbt.tests.fixtures.project import write_project_files +from tests.fixtures.dbt_integration_project import dbt_integration_project # noqa: F401 from dbt.tests.util import run_dbt, get_manifest, write_file, rm_file from dbt.node_types import AccessType from dbt.exceptions import InvalidAccessTypeError, DbtReferenceError @@ -152,6 +155,55 @@ """ +dbt_integration_project__dbt_project_yml_restrited_access = """ +name: dbt_integration_project +version: '1.0' +config-version: 2 + +model-paths: ["models"] # paths to models +analysis-paths: ["analyses"] # path with analysis files which are compiled, but not run +target-path: "target" # path for compiled code +clean-targets: ["target"] # directories removed by the clean task +test-paths: ["tests"] # where to store test results +seed-paths: ["seeds"] # load CSVs from this directory with `dbt seed` +macro-paths: ["macros"] # where to find macros + +profile: user + +models: + dbt_integration_project: + +restrict-access: True +""" + + +dbt_integration_project__schema_yml_protected_model = """ +version: 2 +models: +- name: table_model + access: protected +""" + +dbt_integration_project__schema_yml_private_model = """ +version: 2 +models: +- name: table_model + access: private + group: package +""" + +ref_package_model_sql = """ + select * from {{ ref('dbt_integration_project', 'table_model') }} +""" + +schema_yml_ref_package_model = """ +version: 2 +models: +- name: ref_package_model + group: package +""" + + class TestAccess: @pytest.fixture(scope="class") def models(self): @@ -233,3 +285,85 @@ def test_access_attribute(self, project): manifest = get_manifest(project.project_root) metric_id = "metric.test.number_of_people" assert manifest.metrics[metric_id].group == "analytics" + + +class TestUnrestrictedPackageAccess: + @pytest.fixture(scope="class", autouse=True) + def setUp(self, project_root, dbt_integration_project): # noqa: F811 + write_project_files(project_root, "dbt_integration_project", dbt_integration_project) + + @pytest.fixture(scope="class") + def packages(self): + return {"packages": [{"local": "dbt_integration_project"}]} + + @pytest.fixture(scope="class") + def models(self): + return {"ref_protected_package_model.sql": ref_package_model_sql} + + def test_unrestricted_protected_ref(self, project): + write_file( + dbt_integration_project__schema_yml_protected_model, + project.project_root, + "dbt_integration_project", + "models", + "schema.yml", + ) + run_dbt(["deps"]) + + # Runs without issue because restrict-access defaults to False + manifest = run_dbt(["parse"]) + assert len(manifest.nodes) == 4 + root_project_model = manifest.nodes["model.test.ref_protected_package_model"] + assert root_project_model.depends_on_nodes == ["model.dbt_integration_project.table_model"] + + +class TestRestrictedPackageAccess: + @pytest.fixture(scope="class", autouse=True) + def setUp(self, project_root, dbt_integration_project): # noqa: F811 + write_project_files(project_root, "dbt_integration_project", dbt_integration_project) + # Set table_model.access to protected + write_file( + dbt_integration_project__schema_yml_protected_model, + project_root, + "dbt_integration_project", + "models", + "schema.yml", + ) + # Set dbt_integration_project.restrict-access to True + write_file( + dbt_integration_project__dbt_project_yml_restrited_access, + project_root, + "dbt_integration_project", + "dbt_project.yml", + ) + + @pytest.fixture(scope="class") + def packages(self): + return {"packages": [{"local": "dbt_integration_project"}]} + + @pytest.fixture(scope="class") + def models(self): + return { + "ref_package_model.sql": ref_package_model_sql, + "schema.yml": schema_yml_ref_package_model, + } + + def test_restricted_protected_ref(self, project): + run_dbt(["deps"]) + with pytest.raises(DbtReferenceError): + run_dbt(["parse"]) + + def test_restricted_private_ref(self, project): + run_dbt(["deps"]) + + # Set table_model.access to private + write_file( + dbt_integration_project__schema_yml_private_model, + project.project_root, + "dbt_integration_project", + "models", + "schema.yml", + ) + + with pytest.raises(DbtReferenceError): + run_dbt(["parse"])