diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 3911211bc1..00a3f08aaa 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -43,7 +43,7 @@ from metricflow.dataflow.sql_table import SqlTable from metricflow.dataset.dataset import DataSet from metricflow.errors.errors import UnableToSatisfyQueryError -from metricflow.model.objects.metric import MetricType, MetricTimeWindow +from metricflow.model.objects.metric import ConstantPropertyInput, MetricType, MetricTimeWindow from metricflow.model.semantic_model import SemanticModel from metricflow.instances import DataSourceReference from metricflow.model.validations.unique_valid_name import MetricFlowReservedKeywords @@ -178,6 +178,7 @@ def _build_aggregated_conversion_node( queried_linkable_specs: LinkableSpecSet, where_constraint: Optional[SpecWhereClauseConstraint] = None, time_range_constraint: Optional[TimeRangeConstraint] = None, + constant_properties: Optional[List[ConstantPropertyInput]] = None, ) -> BaseOutput[SqlDataSetT]: """Builds a node that contains aggregated values of conversions and opportunities.""" @@ -250,6 +251,7 @@ def _build_aggregated_conversion_node( conversion_primary_key_specs=primary_key_specs, entity_spec=entity_spec, window=window, + constant_properties=constant_properties, ) conversion_measure_recipe = MeasureRecipe( measure_node=join_conversion_node, @@ -346,6 +348,7 @@ def _get_matching_measure( time_range_constraint=time_range_constraint, entity_spec=entity_spec, window=conversion_metric_params.window, + constant_properties=conversion_metric_params.constant_properties, ) output_nodes.append( self.build_computed_metrics_node( diff --git a/metricflow/model/objects/metric.py b/metricflow/model/objects/metric.py index 75968323a9..b3bd901303 100644 --- a/metricflow/model/objects/metric.py +++ b/metricflow/model/objects/metric.py @@ -166,6 +166,20 @@ def default_expr_value(cls, value: Any, values: Any) -> str: # type: ignore[mis raise ValueError(f"expr value should be a string (str) type, but got {type(value)} with value: {value}") return value + @property + def base_expression(self) -> str: + """Returns a non-empty string value of base_expr.""" + if self.base_expr is None: + raise ValueError("base_expr is None") + return self.base_expr + + @property + def conversion_expression(self) -> str: + """Returns a non-empty string value of conversion_expr.""" + if self.conversion_expr is None: + raise ValueError("conversion_expr is None") + return self.conversion_expr + class ConversionTypeParams(HashableBaseModel): """Type params to provide context for conversion metrics.""" diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index df25ae59e7..2218c60aad 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -2,7 +2,7 @@ import logging from collections import OrderedDict -from typing import Generic, List, Optional, Sequence, TypeVar, Union +from typing import Generic, List, Optional, Sequence, TypeVar, Tuple, Union from metricflow.aggregation_properties import AggregationState, AggregationType from metricflow.column_assoc import ColumnAssociation, SingleColumnCorrelationKey @@ -1516,7 +1516,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S ColumnEqualityDescription( left_column_alias=entity_column_name, right_column_alias=entity_column_name, - ), # add constant property here + ), ), ) @@ -1524,7 +1524,9 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S base_sql_column_references = base_data_set.instance_set.transform( CreateSqlColumnReferencesForInstances(base_data_set_alias, self._column_association_resolver) ) - partition_by_columns = (entity_column_name, conversion_time_dimension_column_name) # add constant property here + partition_by_columns: Tuple[str, ...] = (entity_column_name, conversion_time_dimension_column_name) + if node.constant_properties: + partition_by_columns += tuple(x.conversion_expression for x in node.constant_properties) base_sql_select_columns = tuple( SqlSelectColumn( expr=SqlWindowFunctionExpression( diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 1a3725cdde..c249ed4305 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -493,11 +493,21 @@ def make_join_conversion_join_description( time_comparison_dataset=conversion_data_set, window=node.window, ) + + column_equality_descriptions = list(column_equality_descriptions) + + for constant_property in node.constant_properties or []: + column_equality_descriptions.append( + ColumnEqualityDescription( + left_column_alias=constant_property.base_expression, + right_column_alias=constant_property.conversion_expression, + ) + ) return SqlQueryPlanJoinBuilder.make_column_equality_sql_join_description( right_source_node=conversion_data_set.data_set.sql_select_node, left_source_alias=base_data_set.alias, right_source_alias=conversion_data_set.alias, - column_equality_descriptions=column_equality_descriptions, + column_equality_descriptions=tuple(column_equality_descriptions), join_type=SqlJoinType.INNER, additional_on_conditions=(window_condition,), )