diff --git a/.changes/unreleased/Fixes-20230913-153924.yaml b/.changes/unreleased/Fixes-20230913-153924.yaml new file mode 100644 index 00000000000..c61452174de --- /dev/null +++ b/.changes/unreleased/Fixes-20230913-153924.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: 'update dbt show to include limit in DWH query ' +time: 2023-09-13T15:39:24.591805+01:00 +custom: + Author: michelleark + Issue: 8496, 8417 diff --git a/core/dbt/include/global_project/macros/adapters/show.sql b/core/dbt/include/global_project/macros/adapters/show.sql new file mode 100644 index 00000000000..2e23a919d29 --- /dev/null +++ b/core/dbt/include/global_project/macros/adapters/show.sql @@ -0,0 +1,24 @@ +{% macro get_show_sql(compiled_code, sql_header, limit) -%} + {%- set sql_header = sql_header -%} + {{ sql_header if sql_header is not none }} + {%- if sql_header -%} + {{ sql_header }} + {%- endif -%} + {%- if limit is not none -%} + {{ get_limit_subquery_sql(compiled_code, limit) }} + {%- else -%} + {{ compiled_code }} + {%- endif -%} +{% endmacro %} + +{% macro get_limit_subquery_sql(sql, limit) %} + {{ adapter.dispatch('get_limit_subquery_sql', 'dbt')(sql, limit) }} +{% endmacro %} + +{% macro default__get_limit_subquery_sql(sql, limit) %} + select * + from ( + {{ sql }} + ) as model_limit_subq + limit {{ limit }} +{% endmacro %} diff --git a/core/dbt/task/show.py b/core/dbt/task/show.py index 19681b3a0c3..83fd508600f 100644 --- a/core/dbt/task/show.py +++ b/core/dbt/task/show.py @@ -2,6 +2,7 @@ import threading import time +from dbt.context.providers import generate_runtime_model_context from dbt.contracts.graph.nodes import SeedNode from dbt.contracts.results import RunResult, RunStatus from dbt.events.base_types import EventLevel @@ -20,14 +21,24 @@ def __init__(self, config, adapter, node, node_index, num_nodes): def execute(self, compiled_node, manifest): start_time = time.time() - if "sql_header" in compiled_node.unrendered_config: - compiled_node.compiled_code = ( - compiled_node.unrendered_config["sql_header"] + compiled_node.compiled_code - ) - + # Allow passing in -1 (or any negative number) to get all rows + limit = None if self.config.args.limit < 0 else self.config.args.limit + + model_context = generate_runtime_model_context(compiled_node, self.config, manifest) + compiled_node.compiled_code = self.adapter.execute_macro( + macro_name="get_show_sql", + manifest=manifest, + context_override=model_context, + kwargs={ + "compiled_code": model_context["compiled_code"], + "sql_header": model_context["config"].get("sql_header"), + "limit": limit, + }, + ) adapter_response, execute_result = self.adapter.execute( compiled_node.compiled_code, fetch=True ) + end_time = time.time() return RunResult( diff --git a/tests/adapter/dbt/tests/adapter/dbt_show/fixtures.py b/tests/adapter/dbt/tests/adapter/dbt_show/fixtures.py new file mode 100644 index 00000000000..6eda5a695f3 --- /dev/null +++ b/tests/adapter/dbt/tests/adapter/dbt_show/fixtures.py @@ -0,0 +1,34 @@ +models__sql_header = """ +{% call set_sql_header(config) %} +set session time zone '{{ var("timezone", "Europe/Paris") }}'; +{%- endcall %} +select current_setting('timezone') as timezone +""" + +models__ephemeral_model = """ +{{ config(materialized = 'ephemeral') }} +select + coalesce(sample_num, 0) + 10 as col_deci +from {{ ref('sample_model') }} +""" + +models__second_ephemeral_model = """ +{{ config(materialized = 'ephemeral') }} +select + col_deci + 100 as col_hundo +from {{ ref('ephemeral_model') }} +""" + +models__sample_model = """ +select * from {{ ref('sample_seed') }} +""" + +seeds__sample_seed = """sample_num,sample_bool +1,true +2,false +3,true +4,false +5,true +6,false +7,true +""" diff --git a/tests/adapter/dbt/tests/adapter/dbt_show/test_dbt_show.py b/tests/adapter/dbt/tests/adapter/dbt_show/test_dbt_show.py new file mode 100644 index 00000000000..a93bb9dd2ab --- /dev/null +++ b/tests/adapter/dbt/tests/adapter/dbt_show/test_dbt_show.py @@ -0,0 +1,62 @@ +import pytest +from dbt.tests.util import run_dbt + +from dbt.tests.adapter.dbt_show.fixtures import ( + models__sql_header, + models__ephemeral_model, + models__second_ephemeral_model, + models__sample_model, + seeds__sample_seed, +) + + +# -- Below we define base classes for tests you import based on if your adapter supports dbt show or not -- +class BaseShowLimit: + @pytest.fixture(scope="class") + def models(self): + return { + "sample_model.sql": models__sample_model, + "ephemeral_model.sql": models__ephemeral_model, + } + + @pytest.fixture(scope="class") + def seeds(self): + return {"sample_seed.csv": seeds__sample_seed} + + @pytest.mark.parametrize( + "args,expected", + [ + ([], 5), # default limit + (["--limit", 3], 3), # fetch 3 rows + (["--limit", -1], 7), # fetch all rows + ], + ) + def test_limit(self, project, args, expected): + run_dbt(["build"]) + dbt_args = ["show", "--inline", models__second_ephemeral_model, *args] + results = run_dbt(dbt_args) + assert len(results.results[0].agate_table) == expected + # ensure limit was injected in compiled_code when limit specified in command args + limit = results.args.get("limit") + if limit > 0: + assert f"limit {limit}" in results.results[0].node.compiled_code + + +class BaseShowSqlHeader: + @pytest.fixture(scope="class") + def models(self): + return { + "sql_header.sql": models__sql_header, + } + + def test_sql_header(self, project): + run_dbt(["build", "--vars", "timezone: Asia/Kolkata"]) + run_dbt(["show", "--select", "sql_header", "--vars", "timezone: Asia/Kolkata"]) + + +class TestPostgresShowSqlHeader(BaseShowSqlHeader): + pass + + +class TestPostgresShowLimit(BaseShowLimit): + pass diff --git a/tests/functional/show/fixtures.py b/tests/functional/show/fixtures.py index 85bfcd26c29..44e8df63393 100644 --- a/tests/functional/show/fixtures.py +++ b/tests/functional/show/fixtures.py @@ -12,7 +12,7 @@ models__sql_header = """ {% call set_sql_header(config) %} -set session time zone 'Asia/Kolkata'; +set session time zone '{{ var("timezone", "Europe/Paris") }}'; {%- endcall %} select current_setting('timezone') as timezone """ diff --git a/tests/functional/show/test_show.py b/tests/functional/show/test_show.py index a076e6dcf0a..aa271e9cd2f 100644 --- a/tests/functional/show/test_show.py +++ b/tests/functional/show/test_show.py @@ -9,7 +9,6 @@ models__second_model, models__ephemeral_model, schema_yml, - models__sql_header, private_model_yml, ) @@ -21,7 +20,6 @@ def models(self): "sample_model.sql": models__sample_model, "second_model.sql": models__second_model, "ephemeral_model.sql": models__ephemeral_model, - "sql_header.sql": models__sql_header, } @pytest.fixture(scope="class") @@ -88,15 +86,6 @@ def test_second_ephemeral_model(self, project): ) assert "col_hundo" in log_output - def test_seed(self, project): - (results, log_output) = run_dbt_and_capture(["show", "--select", "sample_seed"]) - assert "Previewing node 'sample_seed'" in log_output - - def test_sql_header(self, project): - run_dbt(["build"]) - (results, log_output) = run_dbt_and_capture(["show", "--select", "sql_header"]) - assert "Asia/Kolkata" in log_output - class TestShowModelVersions: @pytest.fixture(scope="class")