diff --git a/metricflow/time/time_granularity_solver.py b/metricflow/time/time_granularity_solver.py index 1cbfcff8c4..528edce977 100644 --- a/metricflow/time/time_granularity_solver.py +++ b/metricflow/time/time_granularity_solver.py @@ -71,7 +71,9 @@ def __init__( # noqa: D read_nodes: Sequence[ReadSqlSourceNode], ) -> None: self._semantic_manifest_lookup = semantic_manifest_lookup - self._time_dimension_names_to_supported_granularities: Dict[str, Set[TimeGranularity]] = defaultdict(set) + self._time_dimensions_to_supported_granularities: Dict[ + TimeDimensionReference, Set[TimeGranularity] + ] = defaultdict(set) for read_node in read_nodes: output_data_set = node_output_resolver.get_output_data_set(read_node) for time_dimension_instance in output_data_set.instance_set.time_dimension_instances: @@ -80,10 +82,10 @@ def __init__( # noqa: D granularity_free_qualified_name = StructuredLinkableSpecName.from_name( time_dimension_instance.spec.qualified_name ).granularity_free_qualified_name - self._time_dimension_names_to_supported_granularities[granularity_free_qualified_name].add( - time_dimension_instance.spec.time_granularity - ) - self._time_dimension_names_to_supported_granularities[DataSet.metric_time_dimension_name()] = { + self._time_dimensions_to_supported_granularities[ + TimeDimensionReference(granularity_free_qualified_name) + ].add(time_dimension_instance.spec.time_granularity) + self._time_dimensions_to_supported_granularities[DataSet.metric_time_dimension_reference()] = { granularity for granularity in TimeGranularity if granularity.to_int() >= semantic_manifest_lookup.time_spine_source.time_column_granularity.to_int() @@ -169,23 +171,20 @@ def find_minimum_granularity_for_partial_time_dimension_spec( f"{pformat_big_objects([spec.qualified_name for spec in valid_group_by_elements.as_spec_set.as_tuple])}" ) else: - granularity_free_qualified_name = StructuredLinkableSpecName( - entity_link_names=tuple( - [entity_link.element_name for entity_link in partial_time_dimension_spec.entity_links] - ), - element_name=partial_time_dimension_spec.element_name, - ).granularity_free_qualified_name - - supported_granularities = self._time_dimension_names_to_supported_granularities.get( - granularity_free_qualified_name + time_dim_reference = TimeDimensionReference( + StructuredLinkableSpecName( + entity_link_names=tuple( + [entity_link.element_name for entity_link in partial_time_dimension_spec.entity_links] + ), + element_name=partial_time_dimension_spec.element_name, + ).granularity_free_qualified_name ) + supported_granularities = self._time_dimensions_to_supported_granularities.get(time_dim_reference) if not supported_granularities: raise RequestTimeGranularityException( f"Unable to resolve the time dimension spec for {partial_time_dimension_spec}. " ) - minimum_time_granularity = min( - self._time_dimension_names_to_supported_granularities[granularity_free_qualified_name] - ) + minimum_time_granularity = min(self._time_dimensions_to_supported_granularities[time_dim_reference]) return minimum_time_granularity