Skip to content

Commit

Permalink
Update protocol specs to use WhereFilterIntersection for metrics
Browse files Browse the repository at this point in the history
WhereFilterIntersection is now ready for integration into the
metric parser and validation logic.
  • Loading branch information
tlento committed Oct 10, 2023
1 parent 86d1568 commit 242235d
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 40 deletions.
8 changes: 4 additions & 4 deletions dbt_semantic_interfaces/implementations/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
PydanticParseableValueType,
)
from dbt_semantic_interfaces.implementations.filters.where_filter import (
PydanticWhereFilter,
PydanticWhereFilterIntersection,
)
from dbt_semantic_interfaces.implementations.metadata import PydanticMetadata
from dbt_semantic_interfaces.references import MeasureReference, MetricReference
Expand All @@ -28,7 +28,7 @@ class PydanticMetricInputMeasure(PydanticCustomInputParser, HashableBaseModel):
"""

name: str
filter: Optional[PydanticWhereFilter]
filter: Optional[PydanticWhereFilterIntersection]
alias: Optional[str]
join_to_timespine: bool = False
fill_nulls_with: Optional[int] = None
Expand Down Expand Up @@ -118,7 +118,7 @@ class PydanticMetricInput(HashableBaseModel):
"""Provides a pointer to a metric along with the additional properties used on that metric."""

name: str
filter: Optional[PydanticWhereFilter]
filter: Optional[PydanticWhereFilterIntersection]
alias: Optional[str]
offset_window: Optional[PydanticMetricTimeWindow]
offset_to_grain: Optional[TimeGranularity]
Expand Down Expand Up @@ -155,7 +155,7 @@ class PydanticMetric(HashableBaseModel, ModelWithMetadataParsing):
description: Optional[str]
type: MetricType
type_params: PydanticMetricTypeParams
filter: Optional[PydanticWhereFilter]
filter: Optional[PydanticWhereFilterIntersection]
metadata: Optional[PydanticMetadata]
label: Optional[str] = None

Expand Down
5 changes: 4 additions & 1 deletion dbt_semantic_interfaces/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@
SemanticModelDefaults,
SemanticModelT,
)
from dbt_semantic_interfaces.protocols.where_filter import WhereFilter # noqa:F401
from dbt_semantic_interfaces.protocols.where_filter import ( # noqa:F401
WhereFilter,
WhereFilterIntersection,
)
11 changes: 7 additions & 4 deletions dbt_semantic_interfaces/protocols/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, Protocol, Sequence

from dbt_semantic_interfaces.protocols.metadata import Metadata
from dbt_semantic_interfaces.protocols.where_filter import WhereFilter
from dbt_semantic_interfaces.protocols.where_filter import WhereFilterIntersection
from dbt_semantic_interfaces.references import MeasureReference, MetricReference
from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity

Expand All @@ -23,7 +23,8 @@ def name(self) -> str: # noqa: D

@property
@abstractmethod
def filter(self) -> Optional[WhereFilter]: # noqa: D
def filter(self) -> Optional[WhereFilterIntersection]:
"""Return the set of filters to apply prior to aggregating this input measure."""
pass

@property
Expand Down Expand Up @@ -80,7 +81,8 @@ def name(self) -> str: # noqa: D

@property
@abstractmethod
def filter(self) -> Optional[WhereFilter]: # noqa: D
def filter(self) -> Optional[WhereFilterIntersection]:
"""Return the set of filters to apply prior to calculating this input metric."""
pass

@property
Expand Down Expand Up @@ -181,7 +183,8 @@ def type_params(self) -> MetricTypeParams: # noqa: D

@property
@abstractmethod
def filter(self) -> Optional[WhereFilter]: # noqa: D
def filter(self) -> Optional[WhereFilterIntersection]:
"""Return the set of filters to apply prior to calculating this metric."""
pass

@property
Expand Down
6 changes: 1 addition & 5 deletions dbt_semantic_interfaces/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from dbt_semantic_interfaces.implementations.elements.dimension import PydanticDimension
from dbt_semantic_interfaces.implementations.elements.entity import PydanticEntity
from dbt_semantic_interfaces.implementations.elements.measure import PydanticMeasure
from dbt_semantic_interfaces.implementations.filters.where_filter import (
PydanticWhereFilter,
)
from dbt_semantic_interfaces.implementations.metadata import (
PydanticFileSlice,
PydanticMetadata,
Expand Down Expand Up @@ -124,7 +121,6 @@ def metric_with_guaranteed_meta(
name: str,
type: MetricType,
type_params: PydanticMetricTypeParams,
where_filter: Optional[PydanticWhereFilter] = None,
metadata: PydanticMetadata = default_meta(),
description: str = "adhoc metric",
) -> PydanticMetric:
Expand All @@ -137,7 +133,7 @@ def metric_with_guaranteed_meta(
description=description,
type=type,
type_params=type_params,
filter=where_filter,
filter=None,
metadata=metadata,
)

Expand Down
15 changes: 5 additions & 10 deletions dbt_semantic_interfaces/validations/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D

if metric.filter is not None:
try:
metric.filter.call_parameter_sets
metric.filter.filter_expression_parameter_sets
except Exception as e:
issues.append(
generate_exception_issue(
Expand All @@ -181,7 +181,6 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D
context=context,
extras={
"traceback": "".join(traceback.format_tb(e.__traceback__)),
"filter": metric.filter.where_sql_template,
},
)
)
Expand All @@ -190,7 +189,7 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D
measure = metric.type_params.measure
if measure is not None and measure.filter is not None:
try:
measure.filter.call_parameter_sets
measure.filter.filter_expression_parameter_sets
except Exception as e:
issues.append(
generate_exception_issue(
Expand All @@ -200,15 +199,14 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D
context=context,
extras={
"traceback": "".join(traceback.format_tb(e.__traceback__)),
"filter": measure.filter.where_sql_template,
},
)
)

numerator = metric.type_params.numerator
if numerator is not None and numerator.filter is not None:
try:
numerator.filter.call_parameter_sets
numerator.filter.filter_expression_parameter_sets
except Exception as e:
issues.append(
generate_exception_issue(
Expand All @@ -217,15 +215,14 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D
context=context,
extras={
"traceback": "".join(traceback.format_tb(e.__traceback__)),
"filter": numerator.filter.where_sql_template,
},
)
)

denominator = metric.type_params.denominator
if denominator is not None and denominator.filter is not None:
try:
denominator.filter.call_parameter_sets
denominator.filter.filter_expression_parameter_sets
except Exception as e:
issues.append(
generate_exception_issue(
Expand All @@ -234,15 +231,14 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D
context=context,
extras={
"traceback": "".join(traceback.format_tb(e.__traceback__)),
"filter": denominator.filter.where_sql_template,
},
)
)

for input_metric in metric.type_params.metrics or []:
if input_metric.filter is not None:
try:
input_metric.filter.call_parameter_sets
input_metric.filter.filter_expression_parameter_sets
except Exception as e:
issues.append(
generate_exception_issue(
Expand All @@ -252,7 +248,6 @@ def _validate_metric(metric: Metric) -> Sequence[ValidationIssue]: # noqa: D
context=context,
extras={
"traceback": "".join(traceback.format_tb(e.__traceback__)),
"filter": input_metric.filter.where_sql_template,
},
)
)
Expand Down
23 changes: 17 additions & 6 deletions tests/parsing/test_metric_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dbt_semantic_interfaces.implementations.filters.where_filter import (
PydanticWhereFilter,
PydanticWhereFilterIntersection,
)
from dbt_semantic_interfaces.implementations.metric import (
PydanticMetricInput,
Expand Down Expand Up @@ -66,7 +67,9 @@ def test_legacy_metric_input_measure_object_parsing() -> None:
metric = build_result.semantic_manifest.metrics[0]
assert metric.type_params.measure == PydanticMetricInputMeasure(
name="legacy_measure_from_object",
filter=PydanticWhereFilter(where_sql_template="""{{ dimension('some_bool') }}"""),
filter=PydanticWhereFilterIntersection(
where_filters=[PydanticWhereFilter(where_sql_template="""{{ dimension('some_bool') }}""")]
),
join_to_timespine=True,
fill_nulls_with=1,
)
Expand Down Expand Up @@ -181,8 +184,12 @@ def test_ratio_metric_input_measure_object_parsing() -> None:
metric = build_result.semantic_manifest.metrics[0]
assert metric.type_params.numerator == PydanticMetricInput(
name="numerator_metric_from_object",
filter=PydanticWhereFilter(
where_sql_template="some_number > 5",
filter=PydanticWhereFilterIntersection(
where_filters=[
PydanticWhereFilter(
where_sql_template="some_number > 5",
)
],
),
)
assert metric.type_params.denominator == PydanticMetricInput(name="denominator_metric_from_object")
Expand Down Expand Up @@ -328,8 +335,10 @@ def test_constraint_metric_parsing() -> None:
metric = build_result.semantic_manifest.metrics[0]
assert metric.name == "constraint_test"
assert metric.type is MetricType.SIMPLE
assert metric.filter == PydanticWhereFilter(
where_sql_template="{{ dimension('some_dimension') }} IN ('value1', 'value2')"
assert metric.filter == PydanticWhereFilterIntersection(
where_filters=[
PydanticWhereFilter(where_sql_template="{{ dimension('some_dimension') }} IN ('value1', 'value2')")
]
)


Expand Down Expand Up @@ -364,7 +373,9 @@ def test_derived_metric_input_parsing() -> None:
assert metric.type_params.metrics[1] == PydanticMetricInput(
name="input_metric",
alias="constrained_input_metric",
filter=PydanticWhereFilter(where_sql_template="input_metric < 10"),
filter=PydanticWhereFilterIntersection(
where_filters=[PydanticWhereFilter(where_sql_template="input_metric < 10")]
),
)


Expand Down
32 changes: 22 additions & 10 deletions tests/validations/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
)
from dbt_semantic_interfaces.implementations.elements.entity import PydanticEntity
from dbt_semantic_interfaces.implementations.elements.measure import PydanticMeasure
from dbt_semantic_interfaces.implementations.filters.where_filter import (
PydanticWhereFilter,
PydanticWhereFilterIntersection,
)
from dbt_semantic_interfaces.implementations.metric import (
PydanticMetricInput,
PydanticMetricInputMeasure,
PydanticMetricTimeWindow,
PydanticMetricTypeParams,
PydanticWhereFilter,
)
from dbt_semantic_interfaces.implementations.semantic_manifest import (
PydanticSemanticManifest,
Expand Down Expand Up @@ -323,7 +326,8 @@ def test_where_filter_validations_bad_base_filter( # noqa: D

metric, _ = find_metric_with(manifest, lambda metric: metric.filter is not None)
assert metric.filter is not None
metric.filter.where_sql_template = "{{ dimension('too', 'many', 'variables', 'to', 'handle') }}"
assert len(metric.filter.where_filters) > 0
metric.filter.where_filters[0].where_sql_template = "{{ dimension('too', 'many', 'variables', 'to', 'handle') }}"
validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()])
with pytest.raises(SemanticManifestValidationException, match=f"trying to parse filter of metric `{metric.name}`"):
validator.checked_validations(manifest)
Expand All @@ -338,8 +342,10 @@ def test_where_filter_validations_bad_measure_filter( # noqa: D
manifest, lambda metric: metric.type_params is not None and metric.type_params.measure is not None
)
assert metric.type_params.measure is not None
metric.type_params.measure.filter = PydanticWhereFilter(
where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}"
metric.type_params.measure.filter = PydanticWhereFilterIntersection(
where_filters=[
PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}")
]
)
validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()])
with pytest.raises(
Expand All @@ -358,8 +364,10 @@ def test_where_filter_validations_bad_numerator_filter( # noqa: D
manifest, lambda metric: metric.type_params is not None and metric.type_params.numerator is not None
)
assert metric.type_params.numerator is not None
metric.type_params.numerator.filter = PydanticWhereFilter(
where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}"
metric.type_params.numerator.filter = PydanticWhereFilterIntersection(
where_filters=[
PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}")
]
)
validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()])
with pytest.raises(
Expand All @@ -377,8 +385,10 @@ def test_where_filter_validations_bad_denominator_filter( # noqa: D
manifest, lambda metric: metric.type_params is not None and metric.type_params.denominator is not None
)
assert metric.type_params.denominator is not None
metric.type_params.denominator.filter = PydanticWhereFilter(
where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}"
metric.type_params.denominator.filter = PydanticWhereFilterIntersection(
where_filters=[
PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}")
]
)
validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()])
with pytest.raises(
Expand All @@ -400,8 +410,10 @@ def test_where_filter_validations_bad_input_metric_filter( # noqa: D
)
assert metric.type_params.metrics is not None
input_metric = metric.type_params.metrics[0]
input_metric.filter = PydanticWhereFilter(
where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}"
input_metric.filter = PydanticWhereFilterIntersection(
where_filters=[
PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}")
]
)
validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()])
with pytest.raises(
Expand Down

0 comments on commit 242235d

Please sign in to comment.