Skip to content

Commit

Permalink
Support --empty flag for schema-only dry runs (#8971)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored Nov 29, 2023
1 parent 0935570 commit 5488dfb
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 8 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231116-234049.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support --empty flag for schema-only dry runs
time: 2023-11-16T23:40:49.96651-05:00
custom:
Author: michelleark
Issue: "8971"
18 changes: 13 additions & 5 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class BaseRelation(FakeAPIObject, Hashable):
include_policy: Policy = field(default_factory=lambda: Policy())
quote_policy: Policy = field(default_factory=lambda: Policy())
dbt_created: bool = False
limit: Optional[int] = None

# register relation types that can be renamed for the purpose of replacing relations using stages and backups
# adding a relation type here also requires defining the associated rename macro
Expand Down Expand Up @@ -194,6 +195,15 @@ def render(self) -> str:
# if there is nothing set, this will return the empty string.
return ".".join(part for _, part in self._render_iterator() if part is not None)

def render_limited(self) -> str:
rendered = self.render()
if self.limit is None:
return rendered
elif self.limit == 0:
return f"(select * from {rendered} where false limit 0) _dbt_limit_subq"
else:
return f"(select * from {rendered} limit {self.limit}) _dbt_limit_subq"

def quoted(self, identifier):
return "{quote_char}{identifier}{quote_char}".format(
quote_char=self.quote_character,
Expand Down Expand Up @@ -227,13 +237,11 @@ def create_ephemeral_from_node(
cls: Type[Self],
config: HasQuoting,
node: ManifestNode,
limit: Optional[int],
) -> Self:
# Note that ephemeral models are based on the name.
identifier = cls.add_ephemeral_prefix(node.name)
return cls.create(
type=cls.CTE,
identifier=identifier,
).quote(identifier=False)
return cls.create(type=cls.CTE, identifier=identifier, limit=limit).quote(identifier=False)

@classmethod
def create_from_node(
Expand Down Expand Up @@ -313,7 +321,7 @@ def __hash__(self) -> int:
return hash(self.render())

def __str__(self) -> str:
return self.render()
return self.render() if self.limit is None else self.render_limited()

@property
def database(self) -> Optional[str]:
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def docs_serve(ctx, **kwargs):
@p.profile
@p.profiles_dir
@p.project_dir
@p.empty
@p.select
@p.selector
@p.inline
Expand Down Expand Up @@ -599,6 +600,7 @@ def parse(ctx, **kwargs):
@p.profile
@p.profiles_dir
@p.project_dir
@p.empty
@p.select
@p.selector
@p.state
Expand Down
6 changes: 6 additions & 0 deletions core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@
is_flag=True,
)

empty = click.option(
"--empty/--no-empty",
envvar="DBT_EMPTY",
help="If specified, limit input refs and sources to zero rows.",
is_flag=True,
)

enable_legacy_logger = click.option(
"--enable-legacy-logger/--no-enable-legacy-logger",
Expand Down
12 changes: 9 additions & 3 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def current_project(self):
def Relation(self):
return self.db_wrapper.Relation

@property
def resolve_limit(self) -> Optional[int]:
return 0 if getattr(self.config.args, "EMPTY", False) else None

@abc.abstractmethod
def __call__(self, *args: str) -> Union[str, RelationProxy, MetricReference]:
pass
Expand Down Expand Up @@ -531,9 +535,11 @@ def resolve(
def create_relation(self, target_model: ManifestNode) -> RelationProxy:
if target_model.is_ephemeral_model:
self.model.set_cte(target_model.unique_id, None)
return self.Relation.create_ephemeral_from_node(self.config, target_model)
return self.Relation.create_ephemeral_from_node(
self.config, target_model, limit=self.resolve_limit
)
else:
return self.Relation.create_from(self.config, target_model)
return self.Relation.create_from(self.config, target_model, limit=self.resolve_limit)

def validate(
self,
Expand Down Expand Up @@ -590,7 +596,7 @@ def resolve(self, source_name: str, table_name: str):
target_kind="source",
disabled=(isinstance(target_source, Disabled)),
)
return self.Relation.create_from_source(target_source)
return self.Relation.create_from_source(target_source, limit=self.resolve_limit)


# metric` implementations
Expand Down
75 changes: 75 additions & 0 deletions tests/adapter/dbt/tests/adapter/empty/test_empty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
from dbt.tests.util import run_dbt, relation_from_name


model_input_sql = """
select 1 as id
"""

ephemeral_model_input_sql = """
{{ config(materialized='ephemeral') }}
select 2 as id
"""

raw_source_csv = """id
3
"""


model_sql = """
select *
from {{ ref('model_input') }}
union all
select *
from {{ ref('ephemeral_model_input') }}
union all
select *
from {{ source('seed_sources', 'raw_source') }}
"""


schema_sources_yml = """
sources:
- name: seed_sources
schema: "{{ target.schema }}"
tables:
- name: raw_source
"""


class BaseTestEmpty:
@pytest.fixture(scope="class")
def seeds(self):
return {
"raw_source.csv": raw_source_csv,
}

@pytest.fixture(scope="class")
def models(self):
return {
"model_input.sql": model_input_sql,
"ephemeral_model_input.sql": ephemeral_model_input_sql,
"model.sql": model_sql,
"sources.yml": schema_sources_yml,
}

def assert_row_count(self, project, relation_name: str, expected_row_count: int):
relation = relation_from_name(project.adapter, relation_name)
result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one")
assert result[0] == expected_row_count

def test_run_with_empty(self, project):
# create source from seed
run_dbt(["seed"])

# run without empty - 3 expected rows in output - 1 from each input
run_dbt(["run"])
self.assert_row_count(project, "model", 3)

# run with empty - 0 expected rows in output
run_dbt(["run", "--empty"])
self.assert_row_count(project, "model", 0)


class TestEmpty(BaseTestEmpty):
pass
104 changes: 104 additions & 0 deletions tests/unit/test_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest
from unittest import mock

from dbt.adapters.base import BaseRelation
from dbt.context.providers import BaseResolver, RuntimeRefResolver, RuntimeSourceResolver
from dbt.contracts.graph.unparsed import Quoting


class TestBaseResolver:
class ResolverSubclass(BaseResolver):
def __call__(self, *args: str):
pass

@pytest.fixture
def resolver(self):
return self.ResolverSubclass(
db_wrapper=mock.Mock(),
model=mock.Mock(),
config=mock.Mock(),
manifest=mock.Mock(),
)

@pytest.mark.parametrize(
"empty,expected_resolve_limit",
[(False, None), (True, 0)],
)
def test_resolve_limit(self, resolver, empty, expected_resolve_limit):
resolver.config.args.EMPTY = empty

assert resolver.resolve_limit == expected_resolve_limit


class TestRuntimeRefResolver:
@pytest.fixture
def resolver(self):
mock_db_wrapper = mock.Mock()
mock_db_wrapper.Relation = BaseRelation

return RuntimeRefResolver(
db_wrapper=mock_db_wrapper,
model=mock.Mock(),
config=mock.Mock(),
manifest=mock.Mock(),
)

@pytest.mark.parametrize(
"empty,is_ephemeral_model,expected_limit",
[
(False, False, None),
(True, False, 0),
(False, True, None),
(True, True, 0),
],
)
def test_create_relation_with_empty(self, resolver, empty, is_ephemeral_model, expected_limit):
# setup resolver and input node
resolver.config.args.EMPTY = empty
mock_node = mock.Mock()
mock_node.database = "test"
mock_node.schema = "test"
mock_node.identifier = "test"
mock_node.alias = "test"
mock_node.is_ephemeral_model = is_ephemeral_model

# create limited relation
with mock.patch("dbt.adapters.base.relation.ParsedNode", new=mock.Mock):
relation = resolver.create_relation(mock_node)
assert relation.limit == expected_limit


class TestRuntimeSourceResolver:
@pytest.fixture
def resolver(self):
mock_db_wrapper = mock.Mock()
mock_db_wrapper.Relation = BaseRelation

return RuntimeSourceResolver(
db_wrapper=mock_db_wrapper,
model=mock.Mock(),
config=mock.Mock(),
manifest=mock.Mock(),
)

@pytest.mark.parametrize(
"empty,expected_limit",
[
(False, None),
(True, 0),
],
)
def test_create_relation_with_empty(self, resolver, empty, expected_limit):
# setup resolver and input source
resolver.config.args.EMPTY = empty

mock_source = mock.Mock()
mock_source.database = "test"
mock_source.schema = "test"
mock_source.identifier = "test"
mock_source.quoting = Quoting()
resolver.manifest.resolve_source.return_value = mock_source

# create limited relation
relation = resolver.resolve("test", "test")
assert relation.limit == expected_limit
26 changes: 26 additions & 0 deletions tests/unit/test_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,29 @@ def test_can_be_replaced(relation_type, result):
def test_can_be_replaced_default():
my_relation = BaseRelation.create(type=RelationType.View)
assert my_relation.can_be_replaced is False


@pytest.mark.parametrize(
"limit,expected_result",
[
(None, '"test_database"."test_schema"."test_identifier"'),
(
0,
'(select * from "test_database"."test_schema"."test_identifier" where false limit 0) _dbt_limit_subq',
),
(
1,
'(select * from "test_database"."test_schema"."test_identifier" limit 1) _dbt_limit_subq',
),
],
)
def test_render_limited(limit, expected_result):
my_relation = BaseRelation.create(
database="test_database",
schema="test_schema",
identifier="test_identifier",
limit=limit,
)
actual_result = my_relation.render_limited()
assert actual_result == expected_result
assert str(my_relation) == expected_result

0 comments on commit 5488dfb

Please sign in to comment.