diff --git a/dbt_semantic_interfaces/implementations/metric.py b/dbt_semantic_interfaces/implementations/metric.py index 656e09c5..3e44118f 100644 --- a/dbt_semantic_interfaces/implementations/metric.py +++ b/dbt_semantic_interfaces/implementations/metric.py @@ -17,7 +17,11 @@ ) from dbt_semantic_interfaces.implementations.metadata import PydanticMetadata from dbt_semantic_interfaces.references import MeasureReference, MetricReference -from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity +from dbt_semantic_interfaces.type_enums import ( + ConversionCalculationType, + MetricType, + TimeGranularity, +) class PydanticMetricInputMeasure(PydanticCustomInputParser, HashableBaseModel): @@ -134,6 +138,26 @@ def post_aggregation_reference(self) -> MetricReference: return MetricReference(element_name=self.alias or self.name) +class PydanticConversionTypeParams(HashableBaseModel): + """Type params to provide context for conversion metrics properties.""" + + base_measure: PydanticMetricInputMeasure + conversion_measure: PydanticMetricInputMeasure + entity: str + calculation: ConversionCalculationType = ConversionCalculationType.CONVERSION_RATE + window: Optional[PydanticMetricTimeWindow] + + @property + def base_measure_reference(self) -> MeasureReference: + """Return the measure reference associated with the base measure.""" + return self.base_measure.measure_reference + + @property + def conversion_measure_reference(self) -> MeasureReference: + """Return the measure reference associated with the conversion measure.""" + return self.conversion_measure.measure_reference + + class PydanticMetricTypeParams(HashableBaseModel): """Type params add additional context to certain metric types (the context depends on the metric type).""" @@ -144,6 +168,7 @@ class PydanticMetricTypeParams(HashableBaseModel): window: Optional[PydanticMetricTimeWindow] grain_to_date: Optional[TimeGranularity] metrics: Optional[List[PydanticMetricInput]] + conversion_type_params: Optional[PydanticConversionTypeParams] input_measures: List[PydanticMetricInputMeasure] = Field(default_factory=list) @@ -172,7 +197,7 @@ def measure_references(self) -> List[MeasureReference]: @property def input_metrics(self) -> Sequence[PydanticMetricInput]: """Return the associated input metrics for this metric.""" - if self.type is MetricType.SIMPLE or self.type is MetricType.CUMULATIVE: + if self.type is MetricType.SIMPLE or self.type is MetricType.CUMULATIVE or self.type is MetricType.CONVERSION: return () elif self.type is MetricType.DERIVED: assert self.type_params.metrics is not None, f"{MetricType.DERIVED} should have type_params.metrics set" @@ -184,3 +209,11 @@ def input_metrics(self) -> Sequence[PydanticMetricInput]: return (self.type_params.numerator, self.type_params.denominator) else: assert_values_exhausted(self.type) + + @property + def conversion_params(self) -> PydanticConversionTypeParams: + """Accessor for conversion type params, enforces that it's set.""" + assert self.type == MetricType.CONVERSION, "Should only access this for a conversion metric." + if self.type_params.conversion_type_params is None: + raise ValueError("conversion_type_params is not defined.") + return self.type_params.conversion_type_params