Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support conversion metrics with custom grain #1475

Merged
merged 9 commits into from
Oct 30, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.group_by_metric_spec import GroupByMetricSpec
from metricflow_semantics.specs.instance_spec import InstanceSpecVisitor, LinkableInstanceSpec
from metricflow_semantics.specs.spec_set import InstanceSpecSet
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -154,6 +155,15 @@ def difference(self, other: LinkableSpecSet) -> LinkableSpecSet: # noqa: D102
def create_from_specs(specs: Sequence[LinkableInstanceSpec]) -> LinkableSpecSet: # noqa: D102
return _group_specs_by_type(specs)

@property
def as_instance_spec_set(self) -> InstanceSpecSet: # noqa: D102
return InstanceSpecSet(
dimension_specs=self.dimension_specs,
entity_specs=self.entity_specs,
time_dimension_specs=self.time_dimension_specs,
group_by_metric_specs=self.group_by_metric_specs,
)


@dataclass
class _GroupSpecByTypeVisitor(InstanceSpecVisitor[None]):
Expand Down
18 changes: 10 additions & 8 deletions metricflow-semantics/metricflow_semantics/specs/spec_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
import itertools
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, List, Sequence, Tuple, TypeVar
from typing import TYPE_CHECKING, Generic, List, Sequence, Tuple, TypeVar

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from typing_extensions import override

from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.group_by_metric_spec import GroupByMetricSpec
from metricflow_semantics.specs.instance_spec import InstanceSpec, InstanceSpecVisitor, LinkableInstanceSpec
from metricflow_semantics.specs.measure_spec import MeasureSpec
from metricflow_semantics.specs.metadata_spec import MetadataSpec
from metricflow_semantics.specs.metric_spec import MetricSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec

if TYPE_CHECKING:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to resolve circular imports.

from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.group_by_metric_spec import GroupByMetricSpec
from metricflow_semantics.specs.measure_spec import MeasureSpec
from metricflow_semantics.specs.metadata_spec import MetadataSpec
from metricflow_semantics.specs.metric_spec import MetricSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec


@dataclass(frozen=True)
Expand Down
43 changes: 22 additions & 21 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _build_aggregated_conversion_node(
# Filter the source nodes with only the required specs needed for the calculation
constant_property_specs = []
required_local_specs = [base_measure_spec.measure_spec, entity_spec, base_time_dimension_spec] + list(
base_measure_recipe.required_local_linkable_specs
base_measure_recipe.required_local_linkable_specs.as_tuple
)
for constant_property in constant_properties or []:
base_property_spec = self._semantic_model_lookup.get_element_spec_for_name(constant_property.base_property)
Expand All @@ -333,6 +333,11 @@ def _build_aggregated_conversion_node(
unaggregated_base_measure_node = JoinOnEntitiesNode.create(
left_node=unaggregated_base_measure_node, join_targets=base_measure_recipe.join_targets
)
for time_dimension_spec in base_required_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
unaggregated_base_measure_node = JoinToCustomGranularityNode.create(
parent_node=unaggregated_base_measure_node, time_dimension_spec=time_dimension_spec
)
if len(base_measure_spec.filter_spec_set.all_filter_specs) > 0:
unaggregated_base_measure_node = WhereConstraintNode.create(
parent_node=unaggregated_base_measure_node,
Expand All @@ -341,7 +346,7 @@ def _build_aggregated_conversion_node(
filtered_unaggregated_base_node = FilterElementsNode.create(
parent_node=unaggregated_base_measure_node,
include_specs=group_specs_by_type(required_local_specs)
.merge(InstanceSpecSet.create_from_specs(base_required_linkable_specs.as_tuple))
.merge(base_required_linkable_specs.as_instance_spec_set)
.dedupe(),
)

Expand All @@ -361,11 +366,12 @@ def _build_aggregated_conversion_node(
constant_properties=constant_property_specs,
)

# Aggregate the conversion events with the JoinConversionEventsNode as the source node
# Aggregate the conversion events with the JoinConversionEventsNode as the source node.
recipe_with_join_conversion_source_node = SourceNodeRecipe(
source_node=join_conversion_node,
required_local_linkable_specs=queried_linkable_specs.as_tuple,
required_local_linkable_specs=queried_linkable_specs,
join_linkable_instances_recipes=(),
all_linkable_specs_required_for_source_nodes=queried_linkable_specs,
)
# TODO: Refine conversion metric configuration to fit into the standard dataflow plan building model
# In this case we override the measure recipe, which currently results in us bypassing predicate pushdown
Expand Down Expand Up @@ -925,13 +931,11 @@ def _select_source_nodes_with_linkable_specs(
selected_nodes: Dict[DataflowPlanNode, None] = {}

# TODO: Add support for no-metrics queries for custom grains without a join (i.e., select directly from time spine).
linkable_specs_set_with_base_granularities = set(linkable_specs.as_tuple)
linkable_specs_set = set(linkable_specs.as_tuple)
for source_node in source_nodes:
output_spec_set = self._node_data_set_resolver.get_output_data_set(source_node).instance_set.spec_set
all_linkable_specs_in_node = set(output_spec_set.linkable_specs)
requested_linkable_specs_in_node = linkable_specs_set_with_base_granularities.intersection(
all_linkable_specs_in_node
)
requested_linkable_specs_in_node = linkable_specs_set.intersection(all_linkable_specs_in_node)
if requested_linkable_specs_in_node:
selected_nodes[source_node] = None

Expand Down Expand Up @@ -998,13 +1002,7 @@ def _find_source_node_recipe(self, parameter_set: FindSourceNodeRecipeParameterS
return result.source_node_recipe
source_node_recipe = self._find_source_node_recipe_non_cached(parameter_set)
self._cache.set_find_source_node_recipe_result(parameter_set, FindSourceNodeRecipeResult(source_node_recipe))
if source_node_recipe is not None:
return SourceNodeRecipe(
source_node=source_node_recipe.source_node,
required_local_linkable_specs=source_node_recipe.required_local_linkable_specs,
join_linkable_instances_recipes=source_node_recipe.join_linkable_instances_recipes,
)
return None
return source_node_recipe

def _find_source_node_recipe_non_cached(
self, parameter_set: FindSourceNodeRecipeParameterSet
Expand Down Expand Up @@ -1234,13 +1232,14 @@ def _find_source_node_recipe_non_cached(
)
return SourceNodeRecipe(
source_node=node_with_lowest_cost_plan,
required_local_linkable_specs=(
required_local_linkable_specs=LinkableSpecSet.create_from_specs(
evaluation.local_linkable_specs
+ required_local_entity_specs
+ required_local_dimension_specs
+ required_local_time_dimension_specs
),
join_linkable_instances_recipes=node_to_evaluation[node_with_lowest_cost_plan].join_recipes,
all_linkable_specs_required_for_source_nodes=linkable_specs_to_satisfy,
)

logger.error(LazyFormat(lambda: "No recipe could be constructed."))
Expand Down Expand Up @@ -1641,7 +1640,7 @@ def _build_aggregated_measure_from_measure_source_node(
filtered_measure_source_node = FilterElementsNode.create(
parent_node=join_to_time_spine_node or time_range_node or measure_recipe.source_node,
include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(
group_specs_by_type(measure_recipe.required_local_linkable_specs),
measure_recipe.required_local_linkable_specs.as_instance_spec_set,
),
)

Expand All @@ -1654,9 +1653,7 @@ def _build_aggregated_measure_from_measure_source_node(
)

specs_to_keep_after_join = InstanceSpecSet(measure_specs=(measure_spec,)).merge(
InstanceSpecSet.create_from_specs(
required_linkable_specs.replace_custom_granularity_with_base_granularity().as_tuple
),
InstanceSpecSet.create_from_specs(measure_recipe.all_linkable_specs_required_for_source_nodes.as_tuple),
)

after_join_filtered_node = FilterElementsNode.create(
Expand All @@ -1667,7 +1664,11 @@ def _build_aggregated_measure_from_measure_source_node(
unaggregated_measure_node = filtered_measure_source_node

for time_dimension_spec in required_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
if (
time_dimension_spec.time_granularity.is_custom_granularity
# If this is the second layer of aggregation for a conversion metric, we have already joined the custom granularity.
and time_dimension_spec not in measure_recipe.all_linkable_specs_required_for_source_nodes.as_tuple
):
unaggregated_measure_node = JoinToCustomGranularityNode.create(
parent_node=unaggregated_measure_node, time_dimension_spec=time_dimension_spec
)
Expand Down
5 changes: 3 additions & 2 deletions metricflow/dataflow/builder/source_node_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import List, Tuple

from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet

from metricflow.dataflow.builder.node_evaluator import JoinLinkableInstancesRecipe
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
Expand All @@ -15,8 +15,9 @@ class SourceNodeRecipe:
"""Get a recipe for how to build a dataflow plan node that outputs measures and linkable instances as needed."""

source_node: DataflowPlanNode
required_local_linkable_specs: Tuple[LinkableInstanceSpec, ...]
required_local_linkable_specs: LinkableSpecSet
join_linkable_instances_recipes: Tuple[JoinLinkableInstancesRecipe, ...]
all_linkable_specs_required_for_source_nodes: LinkableSpecSet

@property
def join_targets(self) -> List[JoinDescription]:
Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,7 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
parent_column: Optional[SqlSelectColumn] = None
assert parent_time_dimension_instance, (
"JoinToCustomGranularityNode's expected time_dimension_spec not found in parent dataset instances. "
f"This indicates internal misconfiguration. Expected: {node.time_dimension_spec.with_base_grain}; "
f"This indicates internal misconfiguration. Expected: {node.time_dimension_spec.with_base_grain()}; "
f"Got: {[instance.spec for instance in parent_data_set.instance_set.time_dimension_instances]}"
)
for select_column in parent_data_set.checked_sql_select_node.select_columns:
Expand Down
77 changes: 77 additions & 0 deletions tests_metricflow/query_rendering/test_custom_granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,80 @@ def test_offset_metric_with_custom_granularity_filter_not_in_group_by( # noqa:
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
)


@pytest.mark.sql_engine_snapshot
def test_conversion_metric_with_custom_granularity( # noqa: D103
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
query_parser: MetricFlowQueryParser,
) -> None:
query_spec = query_parser.parse_and_validate_query(
metric_names=("visit_buy_conversion_rate_7days",),
group_by_names=("metric_time__martian_day",),
).query_spec

render_and_check(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
)


@pytest.mark.sql_engine_snapshot
def test_conversion_metric_with_custom_granularity_filter( # noqa: D103
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
query_parser: MetricFlowQueryParser,
) -> None:
query_spec = query_parser.parse_and_validate_query(
metric_names=("visit_buy_conversion_rate_7days",),
group_by_names=("metric_time__martian_day",),
where_constraints=[
PydanticWhereFilter(where_sql_template=("{{ TimeDimension('metric_time', 'martian_day') }} = '2020-01-01'"))
],
).query_spec

render_and_check(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
)


@pytest.mark.sql_engine_snapshot
def test_conversion_metric_with_custom_granularity_filter_not_in_group_by( # noqa: D103
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
dataflow_plan_builder: DataflowPlanBuilder,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
sql_client: SqlClient,
query_parser: MetricFlowQueryParser,
) -> None:
query_spec = query_parser.parse_and_validate_query(
metric_names=("visit_buy_conversion_rate_7days",),
where_constraints=[
PydanticWhereFilter(where_sql_template=("{{ TimeDimension('metric_time', 'martian_day') }} = '2020-01-01'"))
],
).query_spec

render_and_check(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
)
Loading