Skip to content

Commit

Permalink
integrated constant properties into builder
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamDee committed Jan 19, 2023
1 parent f2581d0 commit ac82e5c
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 5 deletions.
5 changes: 4 additions & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from metricflow.dataflow.sql_table import SqlTable
from metricflow.dataset.dataset import DataSet
from metricflow.errors.errors import UnableToSatisfyQueryError
from metricflow.model.objects.metric import MetricType, MetricTimeWindow
from metricflow.model.objects.metric import ConstantPropertyInput, MetricType, MetricTimeWindow
from metricflow.model.semantic_model import SemanticModel
from metricflow.instances import DataSourceReference
from metricflow.model.validations.unique_valid_name import MetricFlowReservedKeywords
Expand Down Expand Up @@ -178,6 +178,7 @@ def _build_aggregated_conversion_node(
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[SpecWhereClauseConstraint] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
constant_properties: Optional[List[ConstantPropertyInput]] = None,
) -> BaseOutput[SqlDataSetT]:
"""Builds a node that contains aggregated values of conversions and opportunities."""

Expand Down Expand Up @@ -250,6 +251,7 @@ def _build_aggregated_conversion_node(
conversion_primary_key_specs=primary_key_specs,
entity_spec=entity_spec,
window=window,
constant_properties=constant_properties,
)
conversion_measure_recipe = MeasureRecipe(
measure_node=join_conversion_node,
Expand Down Expand Up @@ -346,6 +348,7 @@ def _get_matching_measure(
time_range_constraint=time_range_constraint,
entity_spec=entity_spec,
window=conversion_metric_params.window,
constant_properties=conversion_metric_params.constant_properties,
)
output_nodes.append(
self.build_computed_metrics_node(
Expand Down
14 changes: 14 additions & 0 deletions metricflow/model/objects/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,20 @@ def default_expr_value(cls, value: Any, values: Any) -> str: # type: ignore[mis
raise ValueError(f"expr value should be a string (str) type, but got {type(value)} with value: {value}")
return value

@property
def base_expression(self) -> str:
"""Returns a non-empty string value of base_expr."""
if self.base_expr is None:
raise ValueError("base_expr is None")
return self.base_expr

@property
def conversion_expression(self) -> str:
"""Returns a non-empty string value of conversion_expr."""
if self.conversion_expr is None:
raise ValueError("conversion_expr is None")
return self.conversion_expr


class ConversionTypeParams(HashableBaseModel):
"""Type params to provide context for conversion metrics."""
Expand Down
8 changes: 5 additions & 3 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from collections import OrderedDict
from typing import Generic, List, Optional, Sequence, TypeVar, Union
from typing import Generic, List, Optional, Sequence, TypeVar, Tuple, Union

from metricflow.aggregation_properties import AggregationState, AggregationType
from metricflow.column_assoc import ColumnAssociation, SingleColumnCorrelationKey
Expand Down Expand Up @@ -1516,15 +1516,17 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S
ColumnEqualityDescription(
left_column_alias=entity_column_name,
right_column_alias=entity_column_name,
), # add constant property here
),
),
)

# Builds the first_value window function columns
base_sql_column_references = base_data_set.instance_set.transform(
CreateSqlColumnReferencesForInstances(base_data_set_alias, self._column_association_resolver)
)
partition_by_columns = (entity_column_name, conversion_time_dimension_column_name) # add constant property here
partition_by_columns: Tuple[str, ...] = (entity_column_name, conversion_time_dimension_column_name)
if node.constant_properties:
partition_by_columns += tuple(x.conversion_expression for x in node.constant_properties)
base_sql_select_columns = tuple(
SqlSelectColumn(
expr=SqlWindowFunctionExpression(
Expand Down
12 changes: 11 additions & 1 deletion metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,21 @@ def make_join_conversion_join_description(
time_comparison_dataset=conversion_data_set,
window=node.window,
)

column_equality_descriptions = list(column_equality_descriptions)

for constant_property in node.constant_properties or []:
column_equality_descriptions.append(
ColumnEqualityDescription(
left_column_alias=constant_property.base_expression,
right_column_alias=constant_property.conversion_expression,
)
)
return SqlQueryPlanJoinBuilder.make_column_equality_sql_join_description(
right_source_node=conversion_data_set.data_set.sql_select_node,
left_source_alias=base_data_set.alias,
right_source_alias=conversion_data_set.alias,
column_equality_descriptions=column_equality_descriptions,
column_equality_descriptions=tuple(column_equality_descriptions),
join_type=SqlJoinType.INNER,
additional_on_conditions=(window_condition,),
)
Expand Down

0 comments on commit ac82e5c

Please sign in to comment.