Skip to content

Commit

Permalink
Add restrict-access to dbt_project.yml (#7962)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored Jun 28, 2023
1 parent 7a6beda commit e01d4c0
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 38 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230627-132749.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add restrict-access to dbt_project.yml
time: 2023-06-27T13:27:49.114257-04:00
custom:
Author: michelleark
Issue: "7713"
3 changes: 3 additions & 0 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
config_version=cfg.config_version,
unrendered=unrendered,
project_env_vars=project_env_vars,
restrict_access=cfg.restrict_access,
)
# sanity check - this means an internal issue
project.validate()
Expand Down Expand Up @@ -607,6 +608,7 @@ class Project:
config_version: int
unrendered: RenderComponents
project_env_vars: Dict[str, Any]
restrict_access: bool

@property
def all_source_paths(self) -> List[str]:
Expand Down Expand Up @@ -675,6 +677,7 @@ def to_project_config(self, with_packages=False):
"vars": self.vars.to_dict(),
"require-dbt-version": [v.to_version_string() for v in self.dbt_version],
"config-version": self.config_version,
"restrict-access": self.restrict_access,
}
)
if self.query_comment:
Expand Down
1 change: 1 addition & 0 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def from_parts(
config_version=project.config_version,
unrendered=project.unrendered,
project_env_vars=project.project_env_vars,
restrict_access=project.restrict_access,
profile_env_vars=profile.profile_env_vars,
profile_name=profile.profile_name,
target_name=profile.target_name,
Expand Down
29 changes: 17 additions & 12 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,19 +506,24 @@ def resolve(
target_version=target_version,
disabled=isinstance(target_model, Disabled),
)
elif (
target_model.resource_type == NodeType.Model
and target_model.access == AccessType.Private
# don't raise this reference error for ad hoc 'preview' queries
and self.model.resource_type != NodeType.SqlOperation
and self.model.resource_type != NodeType.RPCCall # TODO: rm
elif self.manifest.is_invalid_private_ref(
self.model, target_model, self.config.dependencies
):
if not self.model.group or self.model.group != target_model.group:
raise DbtReferenceError(
unique_id=self.model.unique_id,
ref_unique_id=target_model.unique_id,
group=cast_to_str(target_model.group),
)
raise DbtReferenceError(
unique_id=self.model.unique_id,
ref_unique_id=target_model.unique_id,
access=AccessType.Private,
scope=cast_to_str(target_model.group),
)
elif self.manifest.is_invalid_protected_ref(
self.model, target_model, self.config.dependencies
):
raise DbtReferenceError(
unique_id=self.model.unique_id,
ref_unique_id=target_model.unique_id,
access=AccessType.Protected,
scope=target_model.package_name,
)

self.validate(target_model, target_name, target_package, target_version)
return self.create_relation(target_model)
Expand Down
46 changes: 45 additions & 1 deletion core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from dbt.events.functions import fire_event
from dbt.events.types import MergedFromState, UnpinnedRefNewVersionAvailable
from dbt.events.contextvars import get_node_info
from dbt.node_types import NodeType
from dbt.node_types import NodeType, AccessType
from dbt.flags import get_flags, MP_CONTEXT
from dbt import tracking
import dbt.utils
Expand Down Expand Up @@ -1123,6 +1123,50 @@ def resolve_doc(
return result
return None

def is_invalid_private_ref(
self, node: GraphMemberNode, target_model: MaybeNonSource, dependencies: Optional[Mapping]
) -> bool:
dependencies = dependencies or {}
if not isinstance(target_model, ModelNode):
return False

is_private_ref = (
target_model.access == AccessType.Private
# don't raise this reference error for ad hoc 'preview' queries
and node.resource_type != NodeType.SqlOperation
and node.resource_type != NodeType.RPCCall # TODO: rm
)
target_dependency = dependencies.get(target_model.package_name)
restrict_package_access = target_dependency.restrict_access if target_dependency else False

# TODO: SemanticModel and SourceDefinition do not have group, and so should not be able to make _any_ private ref.
return is_private_ref and (
not hasattr(node, "group")
or not node.group
or node.group != target_model.group
or restrict_package_access
)

def is_invalid_protected_ref(
self, node: GraphMemberNode, target_model: MaybeNonSource, dependencies: Optional[Mapping]
) -> bool:
dependencies = dependencies or {}
if not isinstance(target_model, ModelNode):
return False

is_protected_ref = (
target_model.access == AccessType.Protected
# don't raise this reference error for ad hoc 'preview' queries
and node.resource_type != NodeType.SqlOperation
and node.resource_type != NodeType.RPCCall # TODO: rm
)
target_dependency = dependencies.get(target_model.package_name)
restrict_package_access = target_dependency.restrict_access if target_dependency else False

return is_protected_ref and (
node.package_name != target_model.package_name and restrict_package_access
)

# Called by RunTask.defer_to_manifest
def merge_from_artifact(
self,
Expand Down
1 change: 1 addition & 0 deletions core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class Project(HyphenatedDbtClassMixin, Replaceable):
)
packages: List[PackageSpec] = field(default_factory=list)
query_comment: Optional[Union[QueryComment, NoValue, str]] = field(default_factory=NoValue)
restrict_access: bool = False

@classmethod
def validate(cls, data):
Expand Down
10 changes: 6 additions & 4 deletions core/dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dbt.dataclass_schema import ValidationError
from dbt.events.helpers import env_secrets, scrub_secrets
from dbt.node_types import NodeType
from dbt.node_types import NodeType, AccessType
from dbt.ui import line_wrap_message

import dbt.dataclass_schema
Expand Down Expand Up @@ -1219,16 +1219,18 @@ def __init__(self, exc: ValidationError, node):


class DbtReferenceError(ParsingError):
def __init__(self, unique_id: str, ref_unique_id: str, group: str):
def __init__(self, unique_id: str, ref_unique_id: str, access: AccessType, scope: str):
self.unique_id = unique_id
self.ref_unique_id = ref_unique_id
self.group = group
self.access = access
self.scope = scope
self.scope_type = "group" if self.access == AccessType.Private else "package"
super().__init__(msg=self.get_message())

def get_message(self) -> str:
return (
f"Node {self.unique_id} attempted to reference node {self.ref_unique_id}, "
f"which is not allowed because the referenced node is private to the {self.group} group."
f"which is not allowed because the referenced node is {self.access} to the '{self.scope}' {self.scope_type}."
)


Expand Down
51 changes: 30 additions & 21 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def load(self):
# determine whether they need processing.
start_process = time.perf_counter()
self.process_sources(self.root_project.project_name)
self.process_refs(self.root_project.project_name)
self.process_refs(self.root_project.project_name, self.root_project.dependencies)
self.process_docs(self.root_project)
self.process_metrics(self.root_project)
self.check_valid_group_config()
Expand All @@ -533,7 +533,10 @@ def load(self):
external_nodes_modified = self.inject_external_nodes()
if external_nodes_modified:
self.manifest.rebuild_ref_lookup()
self.process_refs(self.root_project.project_name)
self.process_refs(
self.root_project.project_name,
self.root_project.dependencies,
)
# parent and child maps will be rebuilt by write_manifest

if not skip_parsing:
Expand Down Expand Up @@ -1038,23 +1041,23 @@ def track_project_load(self):

# Takes references in 'refs' array of nodes and exposures, finds the target
# node, and updates 'depends_on.nodes' with the unique id
def process_refs(self, current_project: str):
def process_refs(self, current_project: str, dependencies: Optional[Dict[str, Project]]):
for node in self.manifest.nodes.values():
if node.created_at < self.started_at:
continue
_process_refs(self.manifest, current_project, node)
_process_refs(self.manifest, current_project, node, dependencies)
for exposure in self.manifest.exposures.values():
if exposure.created_at < self.started_at:
continue
_process_refs(self.manifest, current_project, exposure)
_process_refs(self.manifest, current_project, exposure, dependencies)
for metric in self.manifest.metrics.values():
if metric.created_at < self.started_at:
continue
_process_refs(self.manifest, current_project, metric)
_process_refs(self.manifest, current_project, metric, dependencies)
for semantic_model in self.manifest.semantic_models.values():
if semantic_model.created_at < self.started_at:
continue
_process_refs(self.manifest, current_project, semantic_model)
_process_refs(self.manifest, current_project, semantic_model, dependencies)
self.update_semantic_model(semantic_model)

# Takes references in 'metrics' array of nodes and exposures, finds the target
Expand Down Expand Up @@ -1372,9 +1375,13 @@ def _process_docs_for_metrics(context: Dict[str, Any], metric: Metric) -> None:
metric.description = get_rendered(metric.description, context)


def _process_refs(manifest: Manifest, current_project: str, node) -> None:
def _process_refs(
manifest: Manifest, current_project: str, node, dependencies: Optional[Mapping[str, Project]]
) -> None:
"""Given a manifest and node in that manifest, process its refs"""

dependencies = dependencies or {}

if isinstance(node, SeedNode):
return

Expand Down Expand Up @@ -1413,18 +1420,20 @@ def _process_refs(manifest: Manifest, current_project: str, node) -> None:
)

continue
elif (
isinstance(target_model, ModelNode)
and target_model.access == AccessType.Private
and node.resource_type != NodeType.SqlOperation
and node.resource_type != NodeType.RPCCall # TODO: rm
):
if not node.group or node.group != target_model.group:
raise dbt.exceptions.DbtReferenceError(
unique_id=node.unique_id,
ref_unique_id=target_model.unique_id,
group=dbt.utils.cast_to_str(target_model.group),
)
elif manifest.is_invalid_private_ref(node, target_model, dependencies):
raise dbt.exceptions.DbtReferenceError(
unique_id=node.unique_id,
ref_unique_id=target_model.unique_id,
access=AccessType.Private,
scope=dbt.utils.cast_to_str(target_model.group),
)
elif manifest.is_invalid_protected_ref(node, target_model, dependencies):
raise dbt.exceptions.DbtReferenceError(
unique_id=node.unique_id,
ref_unique_id=target_model.unique_id,
access=AccessType.Protected,
scope=target_model.package_name,
)

target_model_id = target_model.unique_id
node.depends_on.add_node(target_model_id)
Expand Down Expand Up @@ -1577,7 +1586,7 @@ def process_macro(config: RuntimeConfig, manifest: Manifest, macro: Macro) -> No
def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode):

_process_sources_for_node(manifest, config.project_name, node)
_process_refs(manifest, config.project_name, node)
_process_refs(manifest, config.project_name, node, config.dependencies)
ctx = generate_runtime_docs_context(config, node, manifest, config.project_name)
_process_docs_for_node(ctx, node)

Expand Down
Loading

0 comments on commit e01d4c0

Please sign in to comment.