Skip to content

Commit

Permalink
Refactor to simplify time granularity solver
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Oct 12, 2023
1 parent 4cea5b6 commit 722d7ea
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 74 deletions.
14 changes: 4 additions & 10 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def __init__( # noqa: D
self._metric_time_dimension_reference = DataSet.metric_time_dimension_reference()
self._time_granularity_solver = TimeGranularitySolver(
semantic_manifest_lookup=self._model,
read_nodes=self._read_nodes,
node_output_resolver=self._node_output_resolver,
)

@staticmethod
Expand Down Expand Up @@ -410,8 +412,6 @@ def _parse_and_validate_query(
self._time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs(
metric_references=metric_references,
partial_time_dimension_specs=requested_linkable_specs.partial_time_dimension_specs,
read_nodes=self._read_nodes,
node_output_resolver=self._node_output_resolver,
)
)

Expand Down Expand Up @@ -581,10 +581,7 @@ def _adjust_time_range_constraint(
)
partial_time_dimension_spec_to_time_dimension_spec = (
self._time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs(
metric_references=metric_references,
partial_time_dimension_specs=(partial_metric_time_spec,),
read_nodes=self._read_nodes,
node_output_resolver=self._node_output_resolver,
metric_references=metric_references, partial_time_dimension_specs=(partial_metric_time_spec,)
)
)
adjust_to_granularity = partial_time_dimension_spec_to_time_dimension_spec[
Expand Down Expand Up @@ -783,10 +780,7 @@ def _verify_resolved_granularity_for_date_part(
ensure that the correct value was passed in.
"""
resolved_granularity = self._time_granularity_solver.find_minimum_granularity_for_partial_time_dimension_spec(
partial_time_dimension_spec=partial_time_dimension_spec,
metric_references=metric_references,
read_nodes=self._read_nodes,
node_output_resolver=self._node_output_resolver,
partial_time_dimension_spec=partial_time_dimension_spec, metric_references=metric_references
)
if resolved_granularity != requested_dimension_structured_name.time_granularity:
raise RequestTimeGranularityException(
Expand Down
9 changes: 9 additions & 0 deletions metricflow/test/fixtures/model_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,17 @@ class ConsistentIdObjectRepository:
cyclic_join_read_nodes: OrderedDict[str, ReadSqlSourceNode]
cyclic_join_source_nodes: Sequence[BaseOutput]

extended_date_model_read_nodes: OrderedDict[str, ReadSqlSourceNode]
extended_date_model_source_nodes: Sequence[BaseOutput]


@pytest.fixture(scope="session")
def consistent_id_object_repository(
simple_semantic_manifest_lookup: SemanticManifestLookup,
multi_hop_join_semantic_manifest_lookup: SemanticManifestLookup,
scd_semantic_manifest_lookup: SemanticManifestLookup,
cyclic_join_semantic_manifest_lookup: SemanticManifestLookup,
extended_date_semantic_manifest_lookup: SemanticManifestLookup,
) -> ConsistentIdObjectRepository: # noqa: D
"""Create objects that have incremental numeric IDs with a consistent value.
Expand All @@ -108,6 +112,7 @@ def consistent_id_object_repository(
multihop_data_sets = create_data_sets(multi_hop_join_semantic_manifest_lookup)
scd_data_sets = create_data_sets(scd_semantic_manifest_lookup)
cyclic_join_data_sets = create_data_sets(cyclic_join_semantic_manifest_lookup)
extended_date_data_sets = create_data_sets(extended_date_semantic_manifest_lookup)

return ConsistentIdObjectRepository(
simple_model_data_sets=sm_data_sets,
Expand All @@ -126,6 +131,10 @@ def consistent_id_object_repository(
cyclic_join_source_nodes=_data_set_to_source_nodes(
semantic_manifest_lookup=cyclic_join_semantic_manifest_lookup, data_sets=cyclic_join_data_sets
),
extended_date_model_read_nodes=_data_set_to_read_nodes(extended_date_data_sets),
extended_date_model_source_nodes=_data_set_to_source_nodes(
semantic_manifest_lookup=extended_date_semantic_manifest_lookup, data_sets=extended_date_data_sets
),
)


Expand Down
27 changes: 7 additions & 20 deletions metricflow/test/time/test_time_granularity_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
@pytest.fixture(scope="session")
def time_granularity_solver( # noqa: D
extended_date_semantic_manifest_lookup: SemanticManifestLookup,
consistent_id_object_repository: ConsistentIdObjectRepository,
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
) -> TimeGranularitySolver:
return TimeGranularitySolver(
semantic_manifest_lookup=extended_date_semantic_manifest_lookup,
read_nodes=list(consistent_id_object_repository.extended_date_model_read_nodes.values()),
node_output_resolver=node_output_resolver,
)


Expand Down Expand Up @@ -91,46 +95,29 @@ def test_validate_day_granularity_for_day_and_month_metric( # noqa: D
PARTIAL_PTD_SPEC = PartialTimeDimensionSpec(element_name=DataSet.metric_time_dimension_name(), entity_links=())


def test_granularity_solution_for_day_metric( # noqa: D
time_granularity_solver: TimeGranularitySolver,
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
consistent_id_object_repository: ConsistentIdObjectRepository,
) -> None:
def test_granularity_solution_for_day_metric(time_granularity_solver: TimeGranularitySolver) -> None: # noqa: D
assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs(
metric_references=[MetricReference(element_name="bookings")],
partial_time_dimension_specs=[PARTIAL_PTD_SPEC],
node_output_resolver=node_output_resolver,
read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()),
metric_references=[MetricReference(element_name="bookings")], partial_time_dimension_specs=[PARTIAL_PTD_SPEC]
) == {
PARTIAL_PTD_SPEC: MTD_SPEC_DAY,
}


def test_granularity_solution_for_month_metric( # noqa: D
time_granularity_solver: TimeGranularitySolver,
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
consistent_id_object_repository: ConsistentIdObjectRepository,
) -> None:
def test_granularity_solution_for_month_metric(time_granularity_solver: TimeGranularitySolver) -> None: # noqa: D
assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs(
metric_references=[MetricReference(element_name="bookings_monthly")],
partial_time_dimension_specs=[PARTIAL_PTD_SPEC],
node_output_resolver=node_output_resolver,
read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()),
) == {
PARTIAL_PTD_SPEC: MTD_SPEC_MONTH,
}


def test_granularity_solution_for_day_and_month_metrics( # noqa: D
time_granularity_solver: TimeGranularitySolver,
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
consistent_id_object_repository: ConsistentIdObjectRepository,
) -> None:
assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs(
metric_references=[MetricReference(element_name="bookings"), MetricReference(element_name="bookings_monthly")],
partial_time_dimension_specs=[PARTIAL_PTD_SPEC],
node_output_resolver=node_output_resolver,
read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()),
) == {PARTIAL_PTD_SPEC: MTD_SPEC_MONTH}


Expand Down
74 changes: 30 additions & 44 deletions metricflow/time/time_granularity_solver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Optional, Sequence, Set, Tuple

Expand Down Expand Up @@ -67,8 +68,22 @@ class TimeGranularitySolver:
def __init__( # noqa: D
self,
semantic_manifest_lookup: SemanticManifestLookup,
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
read_nodes: Sequence[ReadSqlSourceNode],
) -> None:
self._semantic_manifest_lookup = semantic_manifest_lookup
self._time_dimension_names_to_supported_granularities: Dict[str, Set[TimeGranularity]] = defaultdict(set)
for read_node in read_nodes:
output_data_set = node_output_resolver.get_output_data_set(read_node)
for time_dimension_instance in output_data_set.instance_set.time_dimension_instances:
if time_dimension_instance.spec.date_part:
continue
granularity_free_qualified_name = StructuredLinkableSpecName.from_name(
time_dimension_instance.spec.qualified_name
).granularity_free_qualified_name
self._time_dimension_names_to_supported_granularities[granularity_free_qualified_name].add(
time_dimension_instance.spec.time_granularity
)

def validate_time_granularity(
self, metric_references: Sequence[MetricReference], time_dimension_specs: Sequence[TimeDimensionSpec]
Expand Down Expand Up @@ -103,8 +118,6 @@ def resolve_granularity_for_partial_time_dimension_specs(
self,
metric_references: Sequence[MetricReference],
partial_time_dimension_specs: Sequence[PartialTimeDimensionSpec],
read_nodes: Sequence[ReadSqlSourceNode],
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
) -> Dict[PartialTimeDimensionSpec, TimeDimensionSpec]:
"""Figure out the lowest granularity possible for the partially specified time dimension specs.
Expand All @@ -114,10 +127,7 @@ def resolve_granularity_for_partial_time_dimension_specs(

for partial_time_dimension_spec in partial_time_dimension_specs:
minimum_time_granularity = self.find_minimum_granularity_for_partial_time_dimension_spec(
partial_time_dimension_spec=partial_time_dimension_spec,
metric_references=metric_references,
read_nodes=read_nodes,
node_output_resolver=node_output_resolver,
partial_time_dimension_spec=partial_time_dimension_spec, metric_references=metric_references
)
result[partial_time_dimension_spec] = TimeDimensionSpec(
element_name=partial_time_dimension_spec.element_name,
Expand All @@ -128,11 +138,7 @@ def resolve_granularity_for_partial_time_dimension_specs(
return result

def find_minimum_granularity_for_partial_time_dimension_spec(
self,
partial_time_dimension_spec: PartialTimeDimensionSpec,
metric_references: Sequence[MetricReference],
read_nodes: Sequence[ReadSqlSourceNode],
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
self, partial_time_dimension_spec: PartialTimeDimensionSpec, metric_references: Sequence[MetricReference]
) -> TimeGranularity:
"""Find minimum granularity allowed for time dimension when queried with given metrics."""
minimum_time_granularity: Optional[TimeGranularity] = None
Expand All @@ -159,46 +165,26 @@ def find_minimum_granularity_for_partial_time_dimension_spec(
f"{pformat_big_objects([spec.qualified_name for spec in valid_group_by_elements.as_spec_set.as_tuple])}"
)
else:
minimum_time_granularity = self.get_min_granularity_for_partial_time_dimension_without_metrics(
read_nodes=read_nodes,
node_output_resolver=node_output_resolver,
partial_time_dimension_spec=partial_time_dimension_spec,
granularity_free_qualified_name = StructuredLinkableSpecName(
entity_link_names=tuple(
[entity_link.element_name for entity_link in partial_time_dimension_spec.entity_links]
),
element_name=partial_time_dimension_spec.element_name,
).granularity_free_qualified_name

supported_granularities = self._time_dimension_names_to_supported_granularities.get(
granularity_free_qualified_name
)
if not minimum_time_granularity:
if not supported_granularities:
raise RequestTimeGranularityException(
f"Unable to resolve the time dimension spec for {partial_time_dimension_spec}. "
)
minimum_time_granularity = min(
self._time_dimension_names_to_supported_granularities[granularity_free_qualified_name]
)

return minimum_time_granularity

def get_min_granularity_for_partial_time_dimension_without_metrics(
self,
read_nodes: Sequence[ReadSqlSourceNode],
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
partial_time_dimension_spec: PartialTimeDimensionSpec,
) -> Optional[TimeGranularity]:
"""Find the minimum."""
granularity_free_qualified_name = StructuredLinkableSpecName(
entity_link_names=tuple(
[entity_link.element_name for entity_link in partial_time_dimension_spec.entity_links]
),
element_name=partial_time_dimension_spec.element_name,
).granularity_free_qualified_name

supported_granularities: Set[TimeGranularity] = set()
for read_node in read_nodes:
output_data_set = node_output_resolver.get_output_data_set(read_node)
for time_dimension_instance in output_data_set.instance_set.time_dimension_instances:
if time_dimension_instance.spec.date_part:
continue
time_dim_name_without_granularity = StructuredLinkableSpecName.from_name(
time_dimension_instance.spec.qualified_name
).granularity_free_qualified_name
if time_dim_name_without_granularity == granularity_free_qualified_name:
supported_granularities.add(time_dimension_instance.spec.time_granularity)

return min(supported_granularities) if supported_granularities else None

def adjust_time_range_to_granularity(
self, time_range_constraint: TimeRangeConstraint, time_granularity: TimeGranularity
) -> TimeRangeConstraint:
Expand Down

0 comments on commit 722d7ea

Please sign in to comment.