diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index 0589beaf6b..202115267f 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -146,6 +146,8 @@ def __init__( # noqa: D self._metric_time_dimension_reference = DataSet.metric_time_dimension_reference() self._time_granularity_solver = TimeGranularitySolver( semantic_manifest_lookup=self._model, + read_nodes=self._read_nodes, + node_output_resolver=self._node_output_resolver, ) @staticmethod @@ -410,8 +412,6 @@ def _parse_and_validate_query( self._time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs( metric_references=metric_references, partial_time_dimension_specs=requested_linkable_specs.partial_time_dimension_specs, - read_nodes=self._read_nodes, - node_output_resolver=self._node_output_resolver, ) ) @@ -581,10 +581,7 @@ def _adjust_time_range_constraint( ) partial_time_dimension_spec_to_time_dimension_spec = ( self._time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs( - metric_references=metric_references, - partial_time_dimension_specs=(partial_metric_time_spec,), - read_nodes=self._read_nodes, - node_output_resolver=self._node_output_resolver, + metric_references=metric_references, partial_time_dimension_specs=(partial_metric_time_spec,) ) ) adjust_to_granularity = partial_time_dimension_spec_to_time_dimension_spec[ @@ -783,10 +780,7 @@ def _verify_resolved_granularity_for_date_part( ensure that the correct value was passed in. """ resolved_granularity = self._time_granularity_solver.find_minimum_granularity_for_partial_time_dimension_spec( - partial_time_dimension_spec=partial_time_dimension_spec, - metric_references=metric_references, - read_nodes=self._read_nodes, - node_output_resolver=self._node_output_resolver, + partial_time_dimension_spec=partial_time_dimension_spec, metric_references=metric_references ) if resolved_granularity != requested_dimension_structured_name.time_granularity: raise RequestTimeGranularityException( diff --git a/metricflow/test/fixtures/model_fixtures.py b/metricflow/test/fixtures/model_fixtures.py index afcf7bf38e..63e39d1b8b 100644 --- a/metricflow/test/fixtures/model_fixtures.py +++ b/metricflow/test/fixtures/model_fixtures.py @@ -90,6 +90,9 @@ class ConsistentIdObjectRepository: cyclic_join_read_nodes: OrderedDict[str, ReadSqlSourceNode] cyclic_join_source_nodes: Sequence[BaseOutput] + extended_date_model_read_nodes: OrderedDict[str, ReadSqlSourceNode] + extended_date_model_source_nodes: Sequence[BaseOutput] + @pytest.fixture(scope="session") def consistent_id_object_repository( @@ -97,6 +100,7 @@ def consistent_id_object_repository( multi_hop_join_semantic_manifest_lookup: SemanticManifestLookup, scd_semantic_manifest_lookup: SemanticManifestLookup, cyclic_join_semantic_manifest_lookup: SemanticManifestLookup, + extended_date_semantic_manifest_lookup: SemanticManifestLookup, ) -> ConsistentIdObjectRepository: # noqa: D """Create objects that have incremental numeric IDs with a consistent value. @@ -108,6 +112,7 @@ def consistent_id_object_repository( multihop_data_sets = create_data_sets(multi_hop_join_semantic_manifest_lookup) scd_data_sets = create_data_sets(scd_semantic_manifest_lookup) cyclic_join_data_sets = create_data_sets(cyclic_join_semantic_manifest_lookup) + extended_date_data_sets = create_data_sets(extended_date_semantic_manifest_lookup) return ConsistentIdObjectRepository( simple_model_data_sets=sm_data_sets, @@ -126,6 +131,10 @@ def consistent_id_object_repository( cyclic_join_source_nodes=_data_set_to_source_nodes( semantic_manifest_lookup=cyclic_join_semantic_manifest_lookup, data_sets=cyclic_join_data_sets ), + extended_date_model_read_nodes=_data_set_to_read_nodes(extended_date_data_sets), + extended_date_model_source_nodes=_data_set_to_source_nodes( + semantic_manifest_lookup=extended_date_semantic_manifest_lookup, data_sets=extended_date_data_sets + ), ) diff --git a/metricflow/test/time/test_time_granularity_solver.py b/metricflow/test/time/test_time_granularity_solver.py index 0fcdf0ce28..4e645e0186 100644 --- a/metricflow/test/time/test_time_granularity_solver.py +++ b/metricflow/test/time/test_time_granularity_solver.py @@ -22,9 +22,13 @@ @pytest.fixture(scope="session") def time_granularity_solver( # noqa: D extended_date_semantic_manifest_lookup: SemanticManifestLookup, + consistent_id_object_repository: ConsistentIdObjectRepository, + node_output_resolver: DataflowPlanNodeOutputDataSetResolver, ) -> TimeGranularitySolver: return TimeGranularitySolver( semantic_manifest_lookup=extended_date_semantic_manifest_lookup, + read_nodes=list(consistent_id_object_repository.extended_date_model_read_nodes.values()), + node_output_resolver=node_output_resolver, ) @@ -91,31 +95,18 @@ def test_validate_day_granularity_for_day_and_month_metric( # noqa: D PARTIAL_PTD_SPEC = PartialTimeDimensionSpec(element_name=DataSet.metric_time_dimension_name(), entity_links=()) -def test_granularity_solution_for_day_metric( # noqa: D - time_granularity_solver: TimeGranularitySolver, - node_output_resolver: DataflowPlanNodeOutputDataSetResolver, - consistent_id_object_repository: ConsistentIdObjectRepository, -) -> None: +def test_granularity_solution_for_day_metric(time_granularity_solver: TimeGranularitySolver) -> None: # noqa: D assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs( - metric_references=[MetricReference(element_name="bookings")], - partial_time_dimension_specs=[PARTIAL_PTD_SPEC], - node_output_resolver=node_output_resolver, - read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()), + metric_references=[MetricReference(element_name="bookings")], partial_time_dimension_specs=[PARTIAL_PTD_SPEC] ) == { PARTIAL_PTD_SPEC: MTD_SPEC_DAY, } -def test_granularity_solution_for_month_metric( # noqa: D - time_granularity_solver: TimeGranularitySolver, - node_output_resolver: DataflowPlanNodeOutputDataSetResolver, - consistent_id_object_repository: ConsistentIdObjectRepository, -) -> None: +def test_granularity_solution_for_month_metric(time_granularity_solver: TimeGranularitySolver) -> None: # noqa: D assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs( metric_references=[MetricReference(element_name="bookings_monthly")], partial_time_dimension_specs=[PARTIAL_PTD_SPEC], - node_output_resolver=node_output_resolver, - read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()), ) == { PARTIAL_PTD_SPEC: MTD_SPEC_MONTH, } @@ -123,14 +114,10 @@ def test_granularity_solution_for_month_metric( # noqa: D def test_granularity_solution_for_day_and_month_metrics( # noqa: D time_granularity_solver: TimeGranularitySolver, - node_output_resolver: DataflowPlanNodeOutputDataSetResolver, - consistent_id_object_repository: ConsistentIdObjectRepository, ) -> None: assert time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs( metric_references=[MetricReference(element_name="bookings"), MetricReference(element_name="bookings_monthly")], partial_time_dimension_specs=[PARTIAL_PTD_SPEC], - node_output_resolver=node_output_resolver, - read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()), ) == {PARTIAL_PTD_SPEC: MTD_SPEC_MONTH} diff --git a/metricflow/time/time_granularity_solver.py b/metricflow/time/time_granularity_solver.py index 994f1dcd68..802fc7fd45 100644 --- a/metricflow/time/time_granularity_solver.py +++ b/metricflow/time/time_granularity_solver.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from collections import defaultdict from dataclasses import dataclass from typing import Dict, Optional, Sequence, Set, Tuple @@ -67,8 +68,22 @@ class TimeGranularitySolver: def __init__( # noqa: D self, semantic_manifest_lookup: SemanticManifestLookup, + node_output_resolver: DataflowPlanNodeOutputDataSetResolver, + read_nodes: Sequence[ReadSqlSourceNode], ) -> None: self._semantic_manifest_lookup = semantic_manifest_lookup + self._time_dimension_names_to_supported_granularities: Dict[str, Set[TimeGranularity]] = defaultdict(set) + for read_node in read_nodes: + output_data_set = node_output_resolver.get_output_data_set(read_node) + for time_dimension_instance in output_data_set.instance_set.time_dimension_instances: + if time_dimension_instance.spec.date_part: + continue + granularity_free_qualified_name = StructuredLinkableSpecName.from_name( + time_dimension_instance.spec.qualified_name + ).granularity_free_qualified_name + self._time_dimension_names_to_supported_granularities[granularity_free_qualified_name].add( + time_dimension_instance.spec.time_granularity + ) def validate_time_granularity( self, metric_references: Sequence[MetricReference], time_dimension_specs: Sequence[TimeDimensionSpec] @@ -103,8 +118,6 @@ def resolve_granularity_for_partial_time_dimension_specs( self, metric_references: Sequence[MetricReference], partial_time_dimension_specs: Sequence[PartialTimeDimensionSpec], - read_nodes: Sequence[ReadSqlSourceNode], - node_output_resolver: DataflowPlanNodeOutputDataSetResolver, ) -> Dict[PartialTimeDimensionSpec, TimeDimensionSpec]: """Figure out the lowest granularity possible for the partially specified time dimension specs. @@ -114,10 +127,7 @@ def resolve_granularity_for_partial_time_dimension_specs( for partial_time_dimension_spec in partial_time_dimension_specs: minimum_time_granularity = self.find_minimum_granularity_for_partial_time_dimension_spec( - partial_time_dimension_spec=partial_time_dimension_spec, - metric_references=metric_references, - read_nodes=read_nodes, - node_output_resolver=node_output_resolver, + partial_time_dimension_spec=partial_time_dimension_spec, metric_references=metric_references ) result[partial_time_dimension_spec] = TimeDimensionSpec( element_name=partial_time_dimension_spec.element_name, @@ -128,11 +138,7 @@ def resolve_granularity_for_partial_time_dimension_specs( return result def find_minimum_granularity_for_partial_time_dimension_spec( - self, - partial_time_dimension_spec: PartialTimeDimensionSpec, - metric_references: Sequence[MetricReference], - read_nodes: Sequence[ReadSqlSourceNode], - node_output_resolver: DataflowPlanNodeOutputDataSetResolver, + self, partial_time_dimension_spec: PartialTimeDimensionSpec, metric_references: Sequence[MetricReference] ) -> TimeGranularity: """Find minimum granularity allowed for time dimension when queried with given metrics.""" minimum_time_granularity: Optional[TimeGranularity] = None @@ -159,46 +165,26 @@ def find_minimum_granularity_for_partial_time_dimension_spec( f"{pformat_big_objects([spec.qualified_name for spec in valid_group_by_elements.as_spec_set.as_tuple])}" ) else: - minimum_time_granularity = self.get_min_granularity_for_partial_time_dimension_without_metrics( - read_nodes=read_nodes, - node_output_resolver=node_output_resolver, - partial_time_dimension_spec=partial_time_dimension_spec, + granularity_free_qualified_name = StructuredLinkableSpecName( + entity_link_names=tuple( + [entity_link.element_name for entity_link in partial_time_dimension_spec.entity_links] + ), + element_name=partial_time_dimension_spec.element_name, + ).granularity_free_qualified_name + + supported_granularities = self._time_dimension_names_to_supported_granularities.get( + granularity_free_qualified_name ) - if not minimum_time_granularity: + if not supported_granularities: raise RequestTimeGranularityException( f"Unable to resolve the time dimension spec for {partial_time_dimension_spec}. " ) + minimum_time_granularity = min( + self._time_dimension_names_to_supported_granularities[granularity_free_qualified_name] + ) return minimum_time_granularity - def get_min_granularity_for_partial_time_dimension_without_metrics( - self, - read_nodes: Sequence[ReadSqlSourceNode], - node_output_resolver: DataflowPlanNodeOutputDataSetResolver, - partial_time_dimension_spec: PartialTimeDimensionSpec, - ) -> Optional[TimeGranularity]: - """Find the minimum.""" - granularity_free_qualified_name = StructuredLinkableSpecName( - entity_link_names=tuple( - [entity_link.element_name for entity_link in partial_time_dimension_spec.entity_links] - ), - element_name=partial_time_dimension_spec.element_name, - ).granularity_free_qualified_name - - supported_granularities: Set[TimeGranularity] = set() - for read_node in read_nodes: - output_data_set = node_output_resolver.get_output_data_set(read_node) - for time_dimension_instance in output_data_set.instance_set.time_dimension_instances: - if time_dimension_instance.spec.date_part: - continue - time_dim_name_without_granularity = StructuredLinkableSpecName.from_name( - time_dimension_instance.spec.qualified_name - ).granularity_free_qualified_name - if time_dim_name_without_granularity == granularity_free_qualified_name: - supported_granularities.add(time_dimension_instance.spec.time_granularity) - - return min(supported_granularities) if supported_granularities else None - def adjust_time_range_to_granularity( self, time_range_constraint: TimeRangeConstraint, time_granularity: TimeGranularity ) -> TimeRangeConstraint: