Skip to content

Commit

Permalink
Validate default_grain
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Jun 15, 2024
1 parent 86dbbe1 commit ad6949a
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ test:
export FORMAT_JSON_LOGS="1" && hatch -v run dev-env:pytest -n auto tests

lint:
hatch run dev-env:pre-commit run --show-diff-on-failure --color=always --all-files
hatch run dev-env:pre-commit run --color=always --all-files

json_schema:
hatch run dev-env:python dbt_semantic_interfaces/parsing/generate_json_schema_file.py
43 changes: 40 additions & 3 deletions dbt_semantic_interfaces/implementations/metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence, Set

from typing_extensions import override

Expand All @@ -16,7 +16,12 @@
PydanticWhereFilterIntersection,
)
from dbt_semantic_interfaces.implementations.metadata import PydanticMetadata
from dbt_semantic_interfaces.protocols import MetricConfig, ProtocolHint
from dbt_semantic_interfaces.protocols import (
Metric,
MetricConfig,
MetricInputMeasure,
ProtocolHint,
)
from dbt_semantic_interfaces.references import MeasureReference, MetricReference
from dbt_semantic_interfaces.type_enums import (
ConversionCalculationType,
Expand Down Expand Up @@ -191,9 +196,13 @@ def _implements_protocol(self) -> MetricConfig: # noqa: D
meta: Dict[str, Any] = Field(default_factory=dict)


class PydanticMetric(HashableBaseModel, ModelWithMetadataParsing):
class PydanticMetric(HashableBaseModel, ModelWithMetadataParsing, ProtocolHint[Metric]):
"""Describes a metric."""

@override
def _implements_protocol(self) -> Metric: # noqa: D
return self

name: str
description: Optional[str]
type: MetricType
Expand Down Expand Up @@ -229,3 +238,31 @@ def input_metrics(self) -> Sequence[PydanticMetricInput]:
return (self.type_params.numerator, self.type_params.denominator)
else:
assert_values_exhausted(self.type)

@staticmethod
def all_input_measures_for_metric(
metric: Metric, metric_index: Dict[MetricReference, Metric]
) -> Set[MetricInputMeasure]:
"""Gets all input measures for the metric, including those defined on input metrics (recursively)."""
measures = set()
if metric.type is MetricType.SIMPLE or metric.type is MetricType.CUMULATIVE:
assert (
metric.type_params.measure is not None
), f"Metric {metric.name} should have a measure defined, but it does not."
measures.add(metric.type_params.measure)
elif metric.type is MetricType.DERIVED or metric.type is MetricType.RATIO:
for input_metric in metric.input_metrics:
nested_metric = metric_index.get(MetricReference(input_metric.name))
assert nested_metric, f"Could not find metric {input_metric.name} in semantic manifest."
measures.update(
PydanticMetric.all_input_measures_for_metric(metric=nested_metric, metric_index=metric_index)
)
elif metric.type is MetricType.CONVERSION:
conversion_type_params = metric.type_params.conversion_type_params
assert conversion_type_params, "Conversion metric should have conversion_type_params."
measures.add(conversion_type_params.base_measure)
measures.add(conversion_type_params.conversion_measure)
else:
assert_values_exhausted(metric.type)

return measures
3 changes: 2 additions & 1 deletion dbt_semantic_interfaces/implementations/semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
SemanticModelDefaults,
)
from dbt_semantic_interfaces.references import (
DimensionReference,
EntityReference,
LinkableElementReference,
MeasureReference,
Expand Down Expand Up @@ -168,7 +169,7 @@ def get_measure(self, measure_reference: MeasureReference) -> PydanticMeasure:
f"No dimension with name ({measure_reference.element_name}) in semantic_model with name ({self.name})"
)

def get_dimension(self, dimension_reference: LinkableElementReference) -> PydanticDimension: # noqa: D
def get_dimension(self, dimension_reference: DimensionReference) -> PydanticDimension: # noqa: D
for dim in self.dimensions:
if dim.reference == dimension_reference:
return dim
Expand Down
1 change: 1 addition & 0 deletions dbt_semantic_interfaces/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class EntityReference(LinkableElementReference): # noqa: D
class TimeDimensionReference(DimensionReference): # noqa: D
pass

@property
def dimension_reference(self) -> DimensionReference: # noqa: D
return DimensionReference(element_name=self.element_name)

Expand Down
4 changes: 3 additions & 1 deletion dbt_semantic_interfaces/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
PydanticSemanticModel,
)
from dbt_semantic_interfaces.parsing.objects import YamlConfigFile
from dbt_semantic_interfaces.type_enums import MetricType
from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,6 +123,7 @@ def metric_with_guaranteed_meta(
type_params: PydanticMetricTypeParams,
metadata: PydanticMetadata = default_meta(),
description: str = "adhoc metric",
default_grain: Optional[TimeGranularity] = None,
) -> PydanticMetric:
"""Creates a metric with the given input.
Expand All @@ -135,6 +136,7 @@ def metric_with_guaranteed_meta(
type_params=type_params,
filter=None,
metadata=metadata,
default_grain=default_grain,
)


Expand Down
118 changes: 114 additions & 4 deletions dbt_semantic_interfaces/validations/metrics.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
import traceback
from typing import Generic, List, Optional, Sequence
from typing import Dict, Generic, List, Optional, Sequence

from dbt_semantic_interfaces.errors import ParsingException
from dbt_semantic_interfaces.implementations.metric import PydanticMetricTimeWindow
from dbt_semantic_interfaces.implementations.metric import (
PydanticMetric,
PydanticMetricTimeWindow,
)
from dbt_semantic_interfaces.protocols import (
ConversionTypeParams,
Dimension,
Metric,
SemanticManifest,
SemanticManifestT,
SemanticModel,
)
from dbt_semantic_interfaces.references import MeasureReference, MetricModelReference
from dbt_semantic_interfaces.type_enums import AggregationType, MetricType
from dbt_semantic_interfaces.references import (
DimensionReference,
MeasureReference,
MetricModelReference,
MetricReference,
)
from dbt_semantic_interfaces.type_enums import (
AggregationType,
MetricType,
TimeGranularity,
)
from dbt_semantic_interfaces.validations.unique_valid_name import UniqueAndValidNameRule
from dbt_semantic_interfaces.validations.validator_helpers import (
FileContext,
Expand Down Expand Up @@ -514,3 +527,100 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati
conversion_semantic_model=conversion_semantic_model,
)
return issues


class DefaultGrainRule(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]):
"""Checks that default_grain set for metric is queryable for that metric."""

@staticmethod
def _min_queryable_granularity_for_metric(
metric: Metric,
metric_index: Dict[MetricReference, Metric],
measure_to_agg_time_dimension: Dict[MeasureReference, Dimension],
) -> TimeGranularity:
"""Get the minimum time granularity this metric is allowed to be queried with.
This should be the largest granularity that any of the metric's agg_time_dimensions is defined at.
Defaults to DAY in the
"""
min_queryable_granularity: Optional[TimeGranularity] = None
for input_measure in PydanticMetric.all_input_measures_for_metric(metric=metric, metric_index=metric_index):
agg_time_dimension = measure_to_agg_time_dimension.get(input_measure.measure_reference)
assert agg_time_dimension, f"Measure '{input_measure.name}' not found in semantic manifest."
if not agg_time_dimension.type_params:
continue
defined_time_granularity = agg_time_dimension.type_params.time_granularity
if not min_queryable_granularity or defined_time_granularity.to_int() > min_queryable_granularity.to_int():
min_queryable_granularity = defined_time_granularity

return min_queryable_granularity or TimeGranularity.DAY

@staticmethod
@validate_safely(
whats_being_done="running model validation ensuring a metric's default_grain is valid for the metric"
)
def _validate_metric(
metric: Metric,
metric_index: Dict[MetricReference, Metric],
measure_to_agg_time_dimension: Dict[MeasureReference, Dimension],
) -> Sequence[ValidationIssue]: # noqa: D
issues: List[ValidationIssue] = []
context = MetricContext(
file_context=FileContext.from_metadata(metadata=metric.metadata),
metric=MetricModelReference(metric_name=metric.name),
)

if metric.default_grain:
min_queryable_granularity = DefaultGrainRule._min_queryable_granularity_for_metric(
metric=metric, metric_index=metric_index, measure_to_agg_time_dimension=measure_to_agg_time_dimension
)
valid_granularities = [
granularity.name
for granularity in TimeGranularity
if granularity.to_int() >= min_queryable_granularity.to_int()
]
if metric.default_grain.name not in valid_granularities:
issues.append(
ValidationError(
context=context,
message=(
f"`default_grain` for metric '{metric.name}' must be >= {min_queryable_granularity.name}. "
"Valid options are those that are >= the largest granularity defined for the metric's "
f"measures' agg_time_dimensions. Got: {metric.default_grain.name}. "
f"Valid options: {valid_granularities}"
),
)
)

return issues

@staticmethod
@validate_safely(whats_being_done="running manifest validation ensuring metric default_grains are valid")
def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]:
"""Validate that the default_grain for each metric is queryable for that metric.
TODO: figure out a more efficient way to reference other aspects of the model. This validation essentially
requires parsing the entire model, which could be slow and likely is repeated work. The blocker is that the
inputs to validations are protocols, which don't easily store parsed metadata.
"""
issues: List[ValidationIssue] = []

measure_to_agg_time_dimension: Dict[MeasureReference, Dimension] = {}
for semantic_model in semantic_manifest.semantic_models:
dimension_index = {DimensionReference(dimension.name): dimension for dimension in semantic_model.dimensions}
for measure in semantic_model.measures:
agg_time_dimension_ref = semantic_model.checked_agg_time_dimension_for_measure(measure.reference)
agg_time_dimension = dimension_index.get(agg_time_dimension_ref.dimension_reference)
assert (
agg_time_dimension
), f"Dimension '{agg_time_dimension_ref.element_name}' not found in semantic manifest."
measure_to_agg_time_dimension[measure.reference] = agg_time_dimension

metric_index = {MetricReference(metric.name): metric for metric in semantic_manifest.metrics}
for metric in semantic_manifest.metrics or []:
issues += DefaultGrainRule._validate_metric(
metric=metric,
metric_index=metric_index,
measure_to_agg_time_dimension=measure_to_agg_time_dimension,
)
return issues
114 changes: 112 additions & 2 deletions tests/validations/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from dbt_semantic_interfaces.validations.metrics import (
ConversionMetricRule,
CumulativeMetricRule,
DefaultGrainRule,
DerivedMetricRule,
WhereFiltersAreParseable,
)
Expand Down Expand Up @@ -670,8 +671,6 @@ def test_cumulative_metrics() -> None: # noqa: D
)

build_issues = validation_results.all_issues
for issue in build_issues:
print(issue.message)
assert len(build_issues) == 8
expected_substr1 = "Both window and grain_to_date set for cumulative metric. Please set one or the other."
expected_substr2 = "Got differing values for `window`"
Expand All @@ -684,3 +683,114 @@ def test_cumulative_metrics() -> None: # noqa: D
missing_error_strings.add(expected_str)
assert len(missing_error_strings) == 0, "Failed to match one or more expected issues: "
f"{missing_error_strings} in {set([x.as_readable_str() for x in build_issues])}"


def test_default_grain() -> None:
"""Test that default grain is validated appropriately."""
week_measure_name = "foo"
month_measure_name = "boo"
week_time_dim_name = "ds__week"
month_time_dim_name = "ds__month"
model_validator = SemanticManifestValidator[PydanticSemanticManifest]([DefaultGrainRule()])
validation_results = model_validator.validate_semantic_manifest(
PydanticSemanticManifest(
semantic_models=[
semantic_model_with_guaranteed_meta(
name="semantic_model",
measures=[
PydanticMeasure(
name=month_measure_name, agg=AggregationType.SUM, agg_time_dimension=month_time_dim_name
),
PydanticMeasure(
name=week_measure_name, agg=AggregationType.SUM, agg_time_dimension=week_time_dim_name
),
],
dimensions=[
PydanticDimension(
name=month_time_dim_name,
type=DimensionType.TIME,
type_params=PydanticDimensionTypeParams(time_granularity=TimeGranularity.MONTH),
),
PydanticDimension(
name=week_time_dim_name,
type=DimensionType.TIME,
type_params=PydanticDimensionTypeParams(time_granularity=TimeGranularity.WEEK),
),
],
),
],
metrics=[
# Simple metrics
metric_with_guaranteed_meta(
name="month_metric_with_no_default_grain_set",
type=MetricType.SIMPLE,
type_params=PydanticMetricTypeParams(
measure=PydanticMetricInputMeasure(name=month_measure_name),
),
),
metric_with_guaranteed_meta(
name="week_metric_with_valid_default_grain",
type=MetricType.SIMPLE,
type_params=PydanticMetricTypeParams(
measure=PydanticMetricInputMeasure(name=week_measure_name),
),
default_grain=TimeGranularity.MONTH,
),
metric_with_guaranteed_meta(
name="month_metric_with_invalid_default_grain",
type=MetricType.SIMPLE,
type_params=PydanticMetricTypeParams(
measure=PydanticMetricInputMeasure(name=month_measure_name),
),
default_grain=TimeGranularity.WEEK,
),
# Derived metrics
metric_with_guaranteed_meta(
name="derived_metric_with_no_default_grain_set",
type=MetricType.DERIVED,
type_params=PydanticMetricTypeParams(
metrics=[
PydanticMetricInput(name="week_metric_with_valid_default_grain"),
],
expr="week_metric_with_valid_default_grain + 1",
),
),
metric_with_guaranteed_meta(
name="derived_metric_with_valid_default_grain",
type=MetricType.DERIVED,
type_params=PydanticMetricTypeParams(
metrics=[
PydanticMetricInput(name="week_metric_with_valid_default_grain"),
PydanticMetricInput(name="month_metric_with_no_default_grain_set"),
],
expr="week_metric_with_valid_default_grain + month_metric_with_no_default_grain_set",
),
default_grain=TimeGranularity.YEAR,
),
metric_with_guaranteed_meta(
name="derived_metric_with_invalid_default_grain",
type=MetricType.DERIVED,
type_params=PydanticMetricTypeParams(
metrics=[
PydanticMetricInput(name="week_metric_with_valid_default_grain"),
PydanticMetricInput(name="month_metric_with_no_default_grain_set"),
],
expr="week_metric_with_valid_default_grain + month_metric_with_no_default_grain_set",
),
default_grain=TimeGranularity.DAY,
),
],
project_configuration=EXAMPLE_PROJECT_CONFIGURATION,
)
)

build_issues = validation_results.all_issues
assert len(build_issues) == 2
expected_substr1 = "`default_grain` for metric 'month_metric_with_invalid_default_grain' must be >= MONTH."
expected_substr2 = "`default_grain` for metric 'derived_metric_with_invalid_default_grain' must be >= MONTH."
missing_error_strings = set()
for expected_str in [expected_substr1, expected_substr2]:
if not any(actual_str.as_readable_str().find(expected_str) != -1 for actual_str in build_issues):
missing_error_strings.add(expected_str)
assert len(missing_error_strings) == 0, "Failed to match one or more expected issues: "
f"{missing_error_strings} in {set([x.as_readable_str() for x in build_issues])}"

0 comments on commit ad6949a

Please sign in to comment.