Skip to content

Commit

Permalink
Tests for metric filters with inner-query joins (#1208)
Browse files Browse the repository at this point in the history
Add tests for metric filters where the metric subquery has one or two joins.
  • Loading branch information
courtneyholcomb authored May 17, 2024
1 parent 8ba3897 commit 279414e
Show file tree
Hide file tree
Showing 61 changed files with 10,339 additions and 1,801 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ semantic_model:
schema_name: $source_schema
alias: bridge_table

defaults:
agg_time_dimension: ds_partitioned

measures:
- name: account_customer_combos
expr: account_id || customer_id
agg: count_distinct
create_metric: true

dimensions:
- name: extra_dim
type: categorical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,21 @@ semantic_model:
schema_name: $source_schema
alias: customer_other_data

defaults:
agg_time_dimension: acquired_ds

measures:
- name: customers_with_other_data
expr: 1
agg: sum

dimensions:
- name: country
type: categorical
- name: acquired_ds
type: time
type_params:
time_granularity: day

entities:
- name: customer_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ semantic_model:
schema_name: $source_schema
alias: customer_table


defaults:
agg_time_dimension: ds_partitioned

measures:
- name: customers
expr: 1
agg: sum
create_metric: true

dimensions:
- name: customer_name
type: categorical
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
metric:
name: paraguayan_customers
type: SIMPLE
type_params:
measure: customers_with_other_data
filter: |
{{ Dimension('customer_id__country') }} = 'paraguay'
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,22 @@ semantic_model:
schema_name: $source_schema
alias: third_hop_table

defaults:
agg_time_dimension: third_hop_ds

measures:
- name: third_hop_count
expr: customer_third_hop_id
agg: count_distinct
create_metric: true

dimensions:
- name: value
type: categorical
- name: third_hop_ds
type: time
type_params:
time_granularity: day

entities:
- name: customer_third_hop_id
Expand Down
7 changes: 7 additions & 0 deletions tests_metricflow/fixtures/dataflow_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def multihop_dataflow_plan_builder( # noqa: D103
].dataflow_plan_builder


@pytest.fixture(scope="session")
def multihop_query_parser( # noqa: D103
mf_engine_test_fixture_mapping: Mapping[SemanticManifestSetup, MetricFlowEngineTestFixture]
) -> MetricFlowQueryParser:
return mf_engine_test_fixture_mapping[SemanticManifestSetup.PARTITIONED_MULTI_HOP_JOIN_MANIFEST].query_parser


@pytest.fixture(scope="session")
def scd_column_association_resolver( # noqa: D103
mf_engine_test_fixture_mapping: Mapping[SemanticManifestSetup, MetricFlowEngineTestFixture]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ table_snapshot:
type: STRING
- name: customer_third_hop_id
type: STRING
- name: acquired_ds
type: STRING
rows:
- ["0", "turkmenistan", "another_id0"]
- ["1", "paraguay", "another_id1"]
- ["2", "myanmar", "another_id2"]
- ["3", "djibouti", "another_id3"]
- ["0", "turkmenistan", "another_id0", "2020-01-01"]
- ["1", "paraguay", "another_id1", "2020-01-02"]
- ["2", "myanmar", "another_id2", "2020-01-03"]
- ["3", "djibouti", "another_id3", "2020-01-04"]
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ table_snapshot:
type: STRING
- name: value
type: STRING
- name: third_hop_ds
type: STRING
rows:
- ["another_id0", "citadel"]
- ["another_id1", "virtu"]
- ["another_id2", "two sigma"]
- ["another_id3", "jump"]
- ["another_id0", "citadel", "2020-01-01"]
- ["another_id1", "virtu", "2020-01-02"]
- ["another_id2", "two sigma", "2020-01-03"]
- ["another_id3", "jump", "2020-01-04"]
45 changes: 45 additions & 0 deletions tests_metricflow/integration/test_cases/itest_metrics.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2086,3 +2086,48 @@ integration_test:
) metric_subquery
ON l.listing_id = metric_subquery.listing_id
WHERE views > 10
---
integration_test:
name: test_metric_filter_with_inner_query_single_hop_join
description: Query with a metric filter where the inner query uses a single-hop join
model: PARTITIONED_MULTI_HOP_JOIN_MODEL
metrics: ["third_hop_count"]
where_filter: "{{ render_metric_template('paraguayan_customers', ['customer_third_hop_id']) }} > 0"
check_query: |
SELECT
COUNT(DISTINCT t.customer_third_hop_id) AS third_hop_count
FROM {{ source_schema }}.third_hop_table t
LEFT OUTER JOIN (
SELECT
customer_third_hop_id
, SUM(1) AS paraguayan_customers
FROM {{ source_schema }}.customer_other_data c
WHERE country = 'paraguay'
GROUP BY customer_third_hop_id
) metric_subquery
ON t.customer_third_hop_id = metric_subquery.customer_third_hop_id
WHERE paraguayan_customers > 0
---
integration_test:
name: test_metric_filter_with_inner_query_multi_hop_join
description: Query with a metric filter where the inner query uses a two-hop join
model: PARTITIONED_MULTI_HOP_JOIN_MODEL
metrics: ["third_hop_count"]
where_filter: "{{ render_metric_template('txn_count', ['account_id__customer_id__customer_third_hop_id']) }} > 0"
check_query: |
SELECT
COUNT(DISTINCT t.customer_third_hop_id) AS third_hop_count
FROM {{ source_schema }}.third_hop_table t
LEFT OUTER JOIN (
SELECT
c.customer_third_hop_id
, SUM(a.txn_count) AS txn_count
FROM {{ source_schema }}.account_month_txns a
LEFT OUTER JOIN {{ source_schema }}.bridge_table b
ON (a.account_id = b.account_id) AND ({{ render_date_trunc("a.ds_partitioned", TimeGranularity.DAY) }} = {{ render_date_trunc("b.ds_partitioned", TimeGranularity.DAY) }})
LEFT OUTER JOIN {{ source_schema }}.customer_other_data c
ON b.customer_id = c.customer_id
GROUP BY c.customer_third_hop_id
) metric_subquery
ON t.customer_third_hop_id = metric_subquery.customer_third_hop_id
WHERE txn_count > 0
2 changes: 1 addition & 1 deletion tests_metricflow/populate_persistent_source_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def populate_schemas(test_configuration: MetricFlowTestConfiguration) -> None:
hatch_env = f"{engine_name}-env"
run_command(
f"hatch -v run {hatch_env}:pytest -vv --log-cli-level info --use-persistent-source-schema "
"tests/source_schema_tools.py::populate_source_schema"
"tests_metricflow/source_schema_tools.py::populate_source_schema"
)
else:
assert_values_exhausted(test_configuration.engine)
Expand Down
54 changes: 54 additions & 0 deletions tests_metricflow/query_rendering/test_metric_filter_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,57 @@ def test_filter_with_conversion_metric( # noqa: D103
sql_client=sql_client,
node=dataflow_plan.sink_node,
)


@pytest.mark.sql_engine_snapshot
def test_inner_query_single_hop(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
multihop_dataflow_plan_builder: DataflowPlanBuilder,
sql_client: SqlClient,
multihop_dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
multihop_query_parser: MetricFlowQueryParser,
) -> None:
"""Tests rendering for a metric filter using a one-hop join in the inner query."""
query_spec = multihop_query_parser.parse_and_validate_query(
metric_names=("third_hop_count",),
where_constraint=PydanticWhereFilter(
where_sql_template="{{ Metric('paraguayan_customers', ['customer_id__customer_third_hop_id']) }} > 0",
),
).query_spec
dataflow_plan = multihop_dataflow_plan_builder.build_plan(query_spec)

convert_and_check(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_to_sql_converter=multihop_dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_node,
)


@pytest.mark.sql_engine_snapshot
def test_inner_query_multi_hop(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
multihop_dataflow_plan_builder: DataflowPlanBuilder,
sql_client: SqlClient,
multihop_dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
multihop_query_parser: MetricFlowQueryParser,
) -> None:
"""Tests rendering for a metric filter using a two-hop join in the inner query."""
query_spec = multihop_query_parser.parse_and_validate_query(
metric_names=("third_hop_count",),
where_constraint=PydanticWhereFilter(
where_sql_template="{{ Metric('txn_count', ['account_id__customer_id__customer_third_hop_id']) }} > 2",
),
).query_spec
dataflow_plan = multihop_dataflow_plan_builder.build_plan(query_spec)

convert_and_check(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_to_sql_converter=multihop_dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_node,
)
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,16 @@
<!-- ), -->
<!-- ), -->
<!-- ) -->
<ReadSqlSourceNode>
<!-- description = "Read From SemanticModelDataSet('bridge_table')" -->
<!-- node_id = NodeId(id_str='rss_22007') -->
<!-- data_set = SemanticModelDataSet('bridge_table') -->
</ReadSqlSourceNode>
<MetricTimeDimensionTransformNode>
<!-- description = "Metric Time Dimension 'ds_partitioned'" -->
<!-- node_id = NodeId(id_str='sma_22002') -->
<!-- aggregation_time_dimension = 'ds_partitioned' -->
<ReadSqlSourceNode>
<!-- description = "Read From SemanticModelDataSet('bridge_table')" -->
<!-- node_id = NodeId(id_str='rss_22007') -->
<!-- data_set = SemanticModelDataSet('bridge_table') -->
</ReadSqlSourceNode>
</MetricTimeDimensionTransformNode>
<FilterElementsNode>
<!-- description = -->
<!-- ('Pass Only Elements: [\n' -->
Expand Down Expand Up @@ -129,6 +134,17 @@
<!-- " 'customer_id__ds_partitioned__extract_day',\n" -->
<!-- " 'customer_id__ds_partitioned__extract_dow',\n" -->
<!-- " 'customer_id__ds_partitioned__extract_doy',\n" -->
<!-- " 'metric_time__day',\n" -->
<!-- " 'metric_time__week',\n" -->
<!-- " 'metric_time__month',\n" -->
<!-- " 'metric_time__quarter',\n" -->
<!-- " 'metric_time__year',\n" -->
<!-- " 'metric_time__extract_year',\n" -->
<!-- " 'metric_time__extract_quarter',\n" -->
<!-- " 'metric_time__extract_month',\n" -->
<!-- " 'metric_time__extract_day',\n" -->
<!-- " 'metric_time__extract_dow',\n" -->
<!-- " 'metric_time__extract_doy',\n" -->
<!-- " 'customer_id',\n" -->
<!-- ']') -->
<!-- node_id = NodeId(id_str='pfe_0') -->
Expand Down Expand Up @@ -262,13 +278,64 @@
<!-- time_granularity=DAY, -->
<!-- date_part=DOY, -->
<!-- ) -->
<!-- include_spec = -->
<!-- TimeDimensionSpec(element_name='metric_time', time_granularity=DAY) -->
<!-- include_spec = -->
<!-- TimeDimensionSpec(element_name='metric_time', time_granularity=WEEK) -->
<!-- include_spec = -->
<!-- TimeDimensionSpec(element_name='metric_time', time_granularity=MONTH) -->
<!-- include_spec = -->
<!-- TimeDimensionSpec(element_name='metric_time', time_granularity=QUARTER) -->
<!-- include_spec = -->
<!-- TimeDimensionSpec(element_name='metric_time', time_granularity=YEAR) -->
<!-- include_spec = -->
<!-- TimeDimensionSpec( -->
<!-- element_name='metric_time', -->
<!-- time_granularity=DAY, -->
<!-- date_part=YEAR, -->
<!-- ) -->
<!-- include_spec = -->
<!-- TimeDimensionSpec( -->
<!-- element_name='metric_time', -->
<!-- time_granularity=DAY, -->
<!-- date_part=QUARTER, -->
<!-- ) -->
<!-- include_spec = -->
<!-- TimeDimensionSpec( -->
<!-- element_name='metric_time', -->
<!-- time_granularity=DAY, -->
<!-- date_part=MONTH, -->
<!-- ) -->
<!-- include_spec = -->
<!-- TimeDimensionSpec( -->
<!-- element_name='metric_time', -->
<!-- time_granularity=DAY, -->
<!-- date_part=DAY, -->
<!-- ) -->
<!-- include_spec = -->
<!-- TimeDimensionSpec( -->
<!-- element_name='metric_time', -->
<!-- time_granularity=DAY, -->
<!-- date_part=DOW, -->
<!-- ) -->
<!-- include_spec = -->
<!-- TimeDimensionSpec( -->
<!-- element_name='metric_time', -->
<!-- time_granularity=DAY, -->
<!-- date_part=DOY, -->
<!-- ) -->
<!-- include_spec = EntitySpec(element_name='customer_id') -->
<!-- distinct = False -->
<ReadSqlSourceNode>
<!-- description = "Read From SemanticModelDataSet('customer_table')" -->
<!-- node_id = NodeId(id_str='rss_22009') -->
<!-- data_set = SemanticModelDataSet('customer_table') -->
</ReadSqlSourceNode>
<MetricTimeDimensionTransformNode>
<!-- description = "Metric Time Dimension 'ds_partitioned'" -->
<!-- node_id = NodeId(id_str='sma_22004') -->
<!-- aggregation_time_dimension = 'ds_partitioned' -->
<ReadSqlSourceNode>
<!-- description = "Read From SemanticModelDataSet('customer_table')" -->
<!-- node_id = NodeId(id_str='rss_22009') -->
<!-- data_set = SemanticModelDataSet('customer_table') -->
</ReadSqlSourceNode>
</MetricTimeDimensionTransformNode>
</FilterElementsNode>
</JoinOnEntitiesNode>
</FilterElementsNode>
Expand Down
Loading

0 comments on commit 279414e

Please sign in to comment.