Skip to content

Commit

Permalink
Finalize validations for custom granularities (#370)
Browse files Browse the repository at this point in the history
### Description
This PR does 2 things:
- In [this
PR](https://github.com/dbt-labs/dbt-semantic-interfaces/pull/365/files),
we unintentionally made `grain_to_date` and `offset_to_grain` fields
case sensitive. This needs to be fixed before we can release to core.
This PR fixes that.
- Adds validation errors to block using custom grain on any fields
except `offset_window`. This is due to a product decision we made to
prioritize that feature and defer the rest.

### Checklist

- [ ] I have read [the contributing
guide](https://github.com/dbt-labs/dbt-semantic-interfaces/blob/main/CONTRIBUTING.md)
and understand what's expected of me
- [ ] I have signed the
[CLA](https://docs.getdbt.com/docs/contributor-license-agreements)
- [ ] This PR includes tests, or tests are not required/relevant for
this PR
- [ ] I have run `changie new` to [create a changelog
entry](https://github.com/dbt-labs/dbt-semantic-interfaces/blob/main/CONTRIBUTING.md#adding-a-changelog-entry)
  • Loading branch information
courtneyholcomb authored Dec 2, 2024
1 parent f3a50f0 commit a860e25
Show file tree
Hide file tree
Showing 18 changed files with 414 additions and 178 deletions.
50 changes: 26 additions & 24 deletions dbt_semantic_interfaces/implementations/metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Dict, List, Optional, Sequence, Set
from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Set

from typing_extensions import override

Expand Down Expand Up @@ -83,7 +84,7 @@ def _from_yaml_value(cls, input: PydanticParseableValueType) -> PydanticMetricTi
The MetricTimeWindow is always expected to be provided as a string in user-defined YAML configs.
"""
if isinstance(input, str):
return PydanticMetricTimeWindow.parse(window=input.lower(), custom_granularity_names=(), strict=False)
return PydanticMetricTimeWindow.parse(window=input.lower())
else:
raise ValueError(
f"MetricTimeWindow inputs from model configs are expected to always be of type string, but got "
Expand All @@ -101,12 +102,8 @@ def window_string(self) -> str:
return f"{self.count} {self.granularity}"

@staticmethod
def parse(window: str, custom_granularity_names: Sequence[str], strict: bool = True) -> PydanticMetricTimeWindow:
"""Returns window values if parsing succeeds, None otherwise.
If strict=True, then the granularity in the window must exist as a valid granularity.
Use strict=True for when you have all valid granularities, otherwise use strict=False.
"""
def parse(window: str) -> PydanticMetricTimeWindow:
"""Returns window values if parsing succeeds, None otherwise."""
parts = window.lower().split(" ")
if len(parts) != 2:
raise ParsingException(
Expand All @@ -115,22 +112,6 @@ def parse(window: str, custom_granularity_names: Sequence[str], strict: bool = T
)

granularity = parts[1]

valid_time_granularities = {item.value.lower() for item in TimeGranularity} | set(
c.lower() for c in custom_granularity_names
)

# if we switched to python 3.9 this could just be `granularity = parts[0].removesuffix('s')
if granularity.endswith("s") and granularity[:-1] in valid_time_granularities:
# months -> month
granularity = granularity[:-1]

if strict and granularity not in valid_time_granularities:
raise ParsingException(
f"Invalid time granularity {granularity} in metric window string: ({window})",
)
# If not strict and not standard granularity, it may be a custom grain, so validations happens later

count = parts[0]
if not count.isdigit():
raise ParsingException(f"Invalid count ({count}) in cumulative metric window string: ({window})")
Expand Down Expand Up @@ -222,6 +203,27 @@ def _implements_protocol(self) -> Metric: # noqa: D
config: Optional[PydanticSemanticLayerElementConfig]
time_granularity: Optional[str] = None

@classmethod
def parse_obj(cls, input: Any) -> PydanticMetric:
"""Adds custom parsing to the default method."""
data = deepcopy(input)

# Ensure grain_to_date is lowercased
type_params = data.get("type_params", {})
grain_to_date = type_params.get("cumulative_type_params", {}).get("grain_to_date")
if isinstance(grain_to_date, str):
data["type_params"]["cumulative_type_params"]["grain_to_date"] = grain_to_date.lower()

# Ensure offset_to_grain is lowercased
input_metrics = type_params.get("metrics", [])
if input_metrics:
for input_metric in input_metrics:
offset_to_grain = input_metric.get("offset_to_grain")
if offset_to_grain and isinstance(offset_to_grain, str):
input_metric["offset_to_grain"] = offset_to_grain.lower()

return super(HashableBaseModel, cls).parse_obj(data)

@property
def input_measures(self) -> Sequence[PydanticMetricInputMeasure]:
"""Return the complete list of input measure configurations for this metric."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from typing import Set

from typing_extensions import override

Expand All @@ -12,7 +12,7 @@
from dbt_semantic_interfaces.transformations.transform_rule import (
SemanticManifestTransformRule,
)
from dbt_semantic_interfaces.type_enums import MetricType
from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity


class RemovePluralFromWindowGranularityRule(ProtocolHint[SemanticManifestTransformRule[PydanticSemanticManifest]]):
Expand All @@ -30,15 +30,21 @@ def _implements_protocol(self) -> SemanticManifestTransformRule[PydanticSemantic

@staticmethod
def _update_metric(
semantic_manifest: PydanticSemanticManifest, metric_name: str, custom_granularity_names: Sequence[str]
semantic_manifest: PydanticSemanticManifest, metric_name: str, custom_granularity_names: Set[str]
) -> None:
"""Mutates all the MetricTimeWindow by reparsing to remove the trailing 's'."""
valid_time_granularities = {item.value.lower() for item in TimeGranularity} | set(
c.lower() for c in custom_granularity_names
)

def reparse_window(window: PydanticMetricTimeWindow) -> PydanticMetricTimeWindow:
def trim_trailing_s(window: PydanticMetricTimeWindow) -> PydanticMetricTimeWindow:
"""Reparse the window to remove the trailing 's'."""
return PydanticMetricTimeWindow.parse(
window=window.window_string, custom_granularity_names=custom_granularity_names
)
granularity = window.granularity
if granularity.endswith("s") and granularity[:-1] in valid_time_granularities:
# months -> month
granularity = granularity[:-1]
window.granularity = granularity
return window

matched_metric = next(
iter((metric for metric in semantic_manifest.metrics if metric.name == metric_name)), None
Expand All @@ -49,22 +55,23 @@ def reparse_window(window: PydanticMetricTimeWindow) -> PydanticMetricTimeWindow
matched_metric.type_params.cumulative_type_params
and matched_metric.type_params.cumulative_type_params.window
):
matched_metric.type_params.cumulative_type_params.window = reparse_window(
matched_metric.type_params.cumulative_type_params.window = trim_trailing_s(
matched_metric.type_params.cumulative_type_params.window
)

elif matched_metric.type is MetricType.CONVERSION:
if (
matched_metric.type_params.conversion_type_params
and matched_metric.type_params.conversion_type_params.window
):
matched_metric.type_params.conversion_type_params.window = reparse_window(
matched_metric.type_params.conversion_type_params.window = trim_trailing_s(
matched_metric.type_params.conversion_type_params.window
)

elif matched_metric.type is MetricType.DERIVED or matched_metric.type is MetricType.RATIO:
for input_metric in matched_metric.input_metrics:
if input_metric.offset_window:
input_metric.offset_window = reparse_window(input_metric.offset_window)
input_metric.offset_window = trim_trailing_s(input_metric.offset_window)
elif matched_metric.type is MetricType.SIMPLE:
pass
else:
Expand All @@ -74,11 +81,11 @@ def reparse_window(window: PydanticMetricTimeWindow) -> PydanticMetricTimeWindow

@staticmethod
def transform_model(semantic_manifest: PydanticSemanticManifest) -> PydanticSemanticManifest: # noqa: D
custom_granularity_names = [
custom_granularity_names = {
granularity.name
for time_spine in semantic_manifest.project_configuration.time_spines
for granularity in time_spine.custom_granularities
]
}

for metric in semantic_manifest.metrics:
RemovePluralFromWindowGranularityRule._update_metric(
Expand Down
2 changes: 1 addition & 1 deletion dbt_semantic_interfaces/validations/agg_time_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati

@staticmethod
@validate_safely(whats_being_done="checking aggregation time dimension for a semantic model")
def _validate_semantic_model(semantic_model: SemanticModel) -> List[ValidationIssue]:
def _validate_semantic_model(semantic_model: SemanticModel) -> Sequence[ValidationIssue]:
issues: List[ValidationIssue] = []

for measure in semantic_model.measures:
Expand Down
14 changes: 8 additions & 6 deletions dbt_semantic_interfaces/validations/common_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _check_entity(
entity: Entity,
semantic_model: SemanticModel,
entities_to_semantic_models: Dict[EntityReference, Set[str]],
) -> List[ValidationIssue]:
) -> Sequence[ValidationIssue]:
issues: List[ValidationIssue] = []
# If the entity is the dict and if the set of semantic models minus this semantic model is empty,
# then we warn the user that their entity will be unused in joins
Expand Down Expand Up @@ -65,15 +65,17 @@ def _check_entity(
@validate_safely(whats_being_done="running model validation warning if entities are only one one semantic model")
def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]:
"""Issues a warning for any entity that is associated with only one semantic_model."""
issues = []
issues: List[ValidationIssue] = []

entities_to_semantic_models = CommonEntitysRule._map_semantic_model_entities(semantic_manifest.semantic_models)
for semantic_model in semantic_manifest.semantic_models or []:
for entity in semantic_model.entities or []:
issues += CommonEntitysRule._check_entity(
entity=entity,
semantic_model=semantic_model,
entities_to_semantic_models=entities_to_semantic_models,
issues.extend(
CommonEntitysRule._check_entity(
entity=entity,
semantic_model=semantic_model,
entities_to_semantic_models=entities_to_semantic_models,
)
)

return issues
4 changes: 2 additions & 2 deletions dbt_semantic_interfaces/validations/dimension_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _validate_dimension(
dimension: Dimension,
time_dims_to_granularity: Dict[DimensionReference, TimeGranularity],
semantic_model: SemanticModel,
) -> List[ValidationIssue]:
) -> Sequence[ValidationIssue]:
"""Check that time dimensions of the same name and aren't primary have the same time granularity.
Args:
Expand Down Expand Up @@ -104,7 +104,7 @@ def _validate_semantic_model(
semantic_model: SemanticModel,
dimension_to_invariant: Dict[DimensionReference, DimensionInvariants],
update_invariant_dict: bool,
) -> List[ValidationIssue]:
) -> Sequence[ValidationIssue]:
"""Checks that the given semantic model has dimensions consistent with the given invariants.
Args:
Expand Down
7 changes: 2 additions & 5 deletions dbt_semantic_interfaces/validations/element_const.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from collections import defaultdict
from typing import DefaultDict, Generic, List, Sequence

from dbt_semantic_interfaces.implementations.semantic_manifest import (
PydanticSemanticManifest,
)
from dbt_semantic_interfaces.protocols import SemanticManifestT
from dbt_semantic_interfaces.references import SemanticModelReference
from dbt_semantic_interfaces.validations.validator_helpers import (
Expand All @@ -28,7 +25,7 @@ class ElementConsistencyRule(SemanticManifestValidationRule[SemanticManifestT],

@staticmethod
@validate_safely(whats_being_done="running model validation ensuring model wide element consistency")
def validate_manifest(semantic_manifest: PydanticSemanticManifest) -> Sequence[ValidationIssue]: # noqa: D
def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]: # noqa: D
issues = []
element_name_to_types = ElementConsistencyRule._get_element_name_to_types(semantic_manifest=semantic_manifest)
invalid_elements = {
Expand All @@ -54,7 +51,7 @@ def validate_manifest(semantic_manifest: PydanticSemanticManifest) -> Sequence[V

@staticmethod
def _get_element_name_to_types(
semantic_manifest: PydanticSemanticManifest,
semantic_manifest: SemanticManifestT,
) -> DefaultDict[str, DefaultDict[SemanticModelElementType, List[SemanticModelContext]]]:
"""Create a mapping of element names in the semantic manifest to types with a list of associated contexts."""
element_types: DefaultDict[
Expand Down
2 changes: 1 addition & 1 deletion dbt_semantic_interfaces/validations/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class NaturalEntityConfigurationRule(SemanticManifestValidationRule[SemanticMani
"natural entities are used in the appropriate contexts"
)
)
def _validate_semantic_model_natural_entities(semantic_model: SemanticModel) -> List[ValidationIssue]:
def _validate_semantic_model_natural_entities(semantic_model: SemanticModel) -> Sequence[ValidationIssue]:
issues: List[ValidationIssue] = []
context = SemanticModelContext(
file_context=FileContext.from_metadata(metadata=semantic_model.metadata),
Expand Down
4 changes: 2 additions & 2 deletions dbt_semantic_interfaces/validations/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class MeasureConstraintAliasesRule(SemanticManifestValidationRule[SemanticManife

@staticmethod
@validate_safely(whats_being_done="ensuring measures aliases are set when required")
def _validate_required_aliases_are_set(metric: Metric, metric_context: MetricContext) -> List[ValidationIssue]:
def _validate_required_aliases_are_set(metric: Metric, metric_context: MetricContext) -> Sequence[ValidationIssue]:
"""Checks if valid aliases are set on the input measure references where they are required.
Aliases are required whenever there are 2 or more input measures with the same measure
Expand Down Expand Up @@ -188,7 +188,7 @@ class MetricMeasuresRule(SemanticManifestValidationRule[SemanticManifestT], Gene

@staticmethod
@validate_safely(whats_being_done="checking all measures referenced by the metric exist")
def _validate_metric_measure_references(metric: Metric, valid_measure_names: Set[str]) -> List[ValidationIssue]:
def _validate_metric_measure_references(metric: Metric, valid_measure_names: Set[str]) -> Sequence[ValidationIssue]:
issues: List[ValidationIssue] = []

for measure_reference in metric.measure_references:
Expand Down
Loading

0 comments on commit a860e25

Please sign in to comment.