diff --git a/dbt_semantic_interfaces/transformations/cumulative_type_params.py b/dbt_semantic_interfaces/transformations/cumulative_type_params.py index ae47f5e2..32253969 100644 --- a/dbt_semantic_interfaces/transformations/cumulative_type_params.py +++ b/dbt_semantic_interfaces/transformations/cumulative_type_params.py @@ -1,3 +1,5 @@ +from typing import Optional + from typing_extensions import override from dbt_semantic_interfaces.implementations.metric import PydanticCumulativeTypeParams @@ -27,7 +29,7 @@ def _implements_protocol(self) -> SemanticManifestTransformRule[PydanticSemantic return self @staticmethod - def transform_model(semantic_manifest: PydanticSemanticManifest) -> PydanticSemanticManifest: + def transform_model(semantic_manifest: PydanticSemanticManifest) -> PydanticSemanticManifest: # noqa: D for metric in semantic_manifest.metrics: if metric.type == MetricType.CUMULATIVE: if not metric.type_params.cumulative_type_params: @@ -36,11 +38,20 @@ def transform_model(semantic_manifest: PydanticSemanticManifest) -> PydanticSema if metric.type_params.window and not metric.type_params.cumulative_type_params.window: metric.type_params.cumulative_type_params.window = metric.type_params.window - # Since grain_to_date is just a string, we can't add custom parsing to them. Instead, lowercase them here. - if metric.type_params.grain_to_date or metric.type_params.cumulative_type_params.grain_to_date: - metric.type_params.cumulative_type_params.grain_to_date = ( - metric.type_params.cumulative_type_params.grain_to_date - or metric.type_params.grain_to_date.value - ).lower() + # Since grain_to_date is a string, we can't add custom parsing to it. Instead, lowercase it here. + if metric.type_params.grain_to_date or ( + metric.type_params.cumulative_type_params + and metric.type_params.cumulative_type_params.grain_to_date + ): + grain_to_date: Optional[str] = None + if ( + metric.type_params.cumulative_type_params + and metric.type_params.cumulative_type_params.grain_to_date + ): + grain_to_date = metric.type_params.cumulative_type_params.grain_to_date + elif metric.type_params.grain_to_date: + grain_to_date = metric.type_params.grain_to_date.value + if grain_to_date: + metric.type_params.cumulative_type_params.grain_to_date = grain_to_date.lower() return semantic_manifest diff --git a/tests/parsing/test_metric_parsing.py b/tests/parsing/test_metric_parsing.py index b29329a5..af44915f 100644 --- a/tests/parsing/test_metric_parsing.py +++ b/tests/parsing/test_metric_parsing.py @@ -12,9 +12,10 @@ from dbt_semantic_interfaces.parsing.dir_to_model import ( parse_yaml_files_to_semantic_manifest, ) -from dbt_semantic_interfaces.transformations.semantic_manifest_transformer import PydanticSemanticManifestTransformer - from dbt_semantic_interfaces.parsing.objects import YamlConfigFile +from dbt_semantic_interfaces.transformations.semantic_manifest_transformer import ( + PydanticSemanticManifestTransformer, +) from dbt_semantic_interfaces.type_enums import ( ConversionCalculationType, MetricType, @@ -244,8 +245,10 @@ def test_cumulative_window_metric_parsing() -> None: assert metric.name == "cumulative_test" assert metric.type is MetricType.CUMULATIVE assert metric.type_params.measure == PydanticMetricInputMeasure(name="cumulative_measure") - assert metric.type_params.cumulative_type_params.window == PydanticMetricTimeWindow( - count=7, granularity=TimeGranularity.DAY.value + assert ( + metric.type_params.cumulative_type_params + and metric.type_params.cumulative_type_params.window + == PydanticMetricTimeWindow(count=7, granularity=TimeGranularity.DAY.value) ) @@ -291,7 +294,10 @@ def test_grain_to_date_metric_parsing() -> None: assert metric.type is MetricType.CUMULATIVE assert metric.type_params.measure == PydanticMetricInputMeasure(name="cumulative_measure") assert metric.type_params.window is None - assert metric.type_params.cumulative_type_params.grain_to_date == TimeGranularity.WEEK.value + assert ( + metric.type_params.cumulative_type_params + and metric.type_params.cumulative_type_params.grain_to_date == TimeGranularity.WEEK.value + ) def test_derived_metric_offset_window_parsing() -> None: