diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index f5fcc683bc..b53493145f 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -318,6 +318,7 @@ def __init__( semantic_manifest_lookup: SemanticManifestLookup, sql_client: SqlClient, time_source: TimeSource = ServerTimeSource(), + query_parser: Optional[MetricFlowQueryParser] = None, column_association_resolver: Optional[ColumnAssociationResolver] = None, ) -> None: """Initializer for MetricFlowEngine. @@ -365,7 +366,7 @@ def __init__( ) self._executor = SequentialPlanExecutor() - self._query_parser = MetricFlowQueryParser( + self._query_parser = query_parser or MetricFlowQueryParser( semantic_manifest_lookup=self._semantic_manifest_lookup, ) diff --git a/metricflow/query/group_by_item/filter_spec_resolution/filter_pattern_factory.py b/metricflow/query/group_by_item/filter_spec_resolution/filter_pattern_factory.py new file mode 100644 index 0000000000..92fe217816 --- /dev/null +++ b/metricflow/query/group_by_item/filter_spec_resolution/filter_pattern_factory.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from dbt_semantic_interfaces.call_parameter_sets import ( + DimensionCallParameterSet, + EntityCallParameterSet, + TimeDimensionCallParameterSet, +) +from typing_extensions import override + +from metricflow.specs.patterns.spec_pattern import SpecPattern +from metricflow.specs.patterns.typed_patterns import DimensionPattern, EntityPattern, TimeDimensionPattern + + +class WhereFilterPatternFactory(ABC): + """Interface that defines how spec patterns should be generated for the group-by-items specified in filters.""" + + @abstractmethod + def create_for_dimension_call_parameter_set( # noqa: D + self, dimension_call_parameter_set: DimensionCallParameterSet + ) -> SpecPattern: + raise NotImplementedError + + @abstractmethod + def create_for_time_dimension_call_parameter_set( # noqa: D + self, time_dimension_call_parameter_set: TimeDimensionCallParameterSet + ) -> SpecPattern: + raise NotImplementedError + + @abstractmethod + def create_for_entity_call_parameter_set( # noqa: D + self, entity_call_parameter_set: EntityCallParameterSet + ) -> SpecPattern: + raise NotImplementedError + + +class DefaultWhereFilterPatternFactory(WhereFilterPatternFactory): + """Default implementation using patterns derived from EntityLinkPattern.""" + + @override + def create_for_dimension_call_parameter_set( + self, dimension_call_parameter_set: DimensionCallParameterSet + ) -> SpecPattern: + return DimensionPattern.from_call_parameter_set(dimension_call_parameter_set) + + @override + def create_for_time_dimension_call_parameter_set( + self, time_dimension_call_parameter_set: TimeDimensionCallParameterSet + ) -> SpecPattern: + return TimeDimensionPattern.from_call_parameter_set(time_dimension_call_parameter_set) + + @override + def create_for_entity_call_parameter_set(self, entity_call_parameter_set: EntityCallParameterSet) -> SpecPattern: + return EntityPattern.from_call_parameter_set(entity_call_parameter_set) diff --git a/metricflow/query/group_by_item/filter_spec_resolution/filter_spec_resolver.py b/metricflow/query/group_by_item/filter_spec_resolution/filter_spec_resolver.py index 00039137b9..557eeae766 100644 --- a/metricflow/query/group_by_item/filter_spec_resolution/filter_spec_resolver.py +++ b/metricflow/query/group_by_item/filter_spec_resolution/filter_spec_resolver.py @@ -15,6 +15,7 @@ from metricflow.naming.object_builder_str import ObjectBuilderNameConverter from metricflow.query.group_by_item.candidate_push_down.push_down_visitor import DagTraversalPathTracker from metricflow.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation +from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import WhereFilterPatternFactory from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import ( FilterSpecResolution, FilterSpecResolutionLookUp, @@ -44,7 +45,6 @@ from metricflow.query.issues.issues_base import ( MetricFlowQueryResolutionIssueSet, ) -from metricflow.specs.patterns.typed_patterns import DimensionPattern, EntityPattern, TimeDimensionPattern logger = logging.getLogger(__name__) @@ -66,13 +66,18 @@ def __init__( # noqa: D self, manifest_lookup: SemanticManifestLookup, resolution_dag: GroupByItemResolutionDag, + spec_pattern_factory: WhereFilterPatternFactory, ) -> None: self._manifest_lookup = manifest_lookup self._resolution_dag = resolution_dag + self.spec_pattern_factory = spec_pattern_factory def resolve_lookup(self) -> FilterSpecResolutionLookUp: """Find all where filters and return a lookup that provides the specs for the included group-by-items.""" - visitor = _ResolveWhereFilterSpecVisitor(manifest_lookup=self._manifest_lookup) + visitor = _ResolveWhereFilterSpecVisitor( + manifest_lookup=self._manifest_lookup, + spec_pattern_factory=self.spec_pattern_factory, + ) return self._resolution_dag.sink_node.accept(visitor) @@ -85,9 +90,12 @@ class _ResolveWhereFilterSpecVisitor(GroupByItemResolutionNodeVisitor[FilterSpec collected and returned in a lookup object. """ - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D + def __init__( # noqa: D + self, manifest_lookup: SemanticManifestLookup, spec_pattern_factory: WhereFilterPatternFactory + ) -> None: self._manifest_lookup = manifest_lookup self._path_from_start_node_tracker = DagTraversalPathTracker() + self._spec_pattern_factory = spec_pattern_factory @staticmethod def _dedupe_filter_call_parameter_sets( @@ -121,8 +129,8 @@ def _dedupe_filter_call_parameter_sets( ), ) - @staticmethod def _map_filter_parameter_sets_to_pattern( + self, filter_call_parameter_sets: FilterCallParameterSets, ) -> Sequence[PatternAssociationForWhereFilterGroupByItem]: """Given the call parameter sets in a filter, map them to spec patterns. @@ -140,7 +148,9 @@ def _map_filter_parameter_sets_to_pattern( object_builder_str=ObjectBuilderNameConverter.input_str_from_dimension_call_parameter_set( dimension_call_parameter_set ), - spec_pattern=DimensionPattern.from_call_parameter_set(dimension_call_parameter_set), + spec_pattern=self._spec_pattern_factory.create_for_dimension_call_parameter_set( + dimension_call_parameter_set + ), ) ) for time_dimension_call_parameter_set in filter_call_parameter_sets.time_dimension_call_parameter_sets: @@ -150,7 +160,9 @@ def _map_filter_parameter_sets_to_pattern( object_builder_str=ObjectBuilderNameConverter.input_str_from_time_dimension_call_parameter_set( time_dimension_call_parameter_set ), - spec_pattern=TimeDimensionPattern.from_call_parameter_set(time_dimension_call_parameter_set), + spec_pattern=self._spec_pattern_factory.create_for_time_dimension_call_parameter_set( + time_dimension_call_parameter_set + ), ) ) for entity_call_parameter_set in filter_call_parameter_sets.entity_call_parameter_sets: @@ -160,7 +172,9 @@ def _map_filter_parameter_sets_to_pattern( object_builder_str=ObjectBuilderNameConverter.input_str_from_entity_call_parameter_set( entity_call_parameter_set ), - spec_pattern=EntityPattern.from_call_parameter_set(entity_call_parameter_set), + spec_pattern=self._spec_pattern_factory.create_for_entity_call_parameter_set( + entity_call_parameter_set + ), ) ) diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index b02d016a3d..12c154eeac 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -29,6 +29,10 @@ OrderByQueryParameter, SavedQueryParameter, ) +from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import ( + DefaultWhereFilterPatternFactory, + WhereFilterPatternFactory, +) from metricflow.query.group_by_item.group_by_item_resolver import GroupByItemResolver from metricflow.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag from metricflow.query.issues.issues_base import MetricFlowQueryResolutionIssueSet @@ -78,13 +82,15 @@ class MetricFlowQueryParser: def __init__( # noqa: D self, semantic_manifest_lookup: SemanticManifestLookup, + where_filter_pattern_factory: WhereFilterPatternFactory = DefaultWhereFilterPatternFactory(), ) -> None: self._manifest_lookup = semantic_manifest_lookup + self._metric_naming_schemes = (MetricNamingScheme(),) self._group_by_item_naming_schemes = ( ObjectBuilderNamingScheme(), DunderNamingScheme(), ) - self._metric_naming_schemes = (MetricNamingScheme(),) + self._where_filter_pattern_factory = where_filter_pattern_factory def parse_and_validate_saved_query( self, @@ -412,7 +418,7 @@ def parse_and_validate_query( ) query_resolver = MetricFlowQueryResolver( - manifest_lookup=self._manifest_lookup, + manifest_lookup=self._manifest_lookup, where_filter_pattern_factory=self._where_filter_pattern_factory ) resolver_inputs_for_order_by: List[ResolverInputForOrderByItem] = [] diff --git a/metricflow/query/query_resolver.py b/metricflow/query/query_resolver.py index beb347df0c..36fc370442 100644 --- a/metricflow/query/query_resolver.py +++ b/metricflow/query/query_resolver.py @@ -10,6 +10,7 @@ from metricflow.dag.dag_to_text import dag_as_text from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.naming.metric_scheme import MetricNamingScheme +from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import WhereFilterPatternFactory from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import FilterSpecResolutionLookUp from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_resolver import ( WhereFilterSpecResolver, @@ -102,11 +103,16 @@ class ResolveGroupByItemsResult: class MetricFlowQueryResolver: """Resolves inputs to a query (e.g. metrics, group by items into concrete specs.""" - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D + def __init__( # noqa: D + self, + manifest_lookup: SemanticManifestLookup, + where_filter_pattern_factory: WhereFilterPatternFactory, + ) -> None: self._manifest_lookup = manifest_lookup self._post_resolution_query_validator = PostResolutionQueryValidator( manifest_lookup=self._manifest_lookup, ) + self._where_filter_pattern_factory = where_filter_pattern_factory @staticmethod def _resolve_group_by_item_input( @@ -315,6 +321,7 @@ def _build_filter_spec_lookup( where_filter_spec_resolver = WhereFilterSpecResolver( manifest_lookup=self._manifest_lookup, resolution_dag=resolution_dag, + spec_pattern_factory=self._where_filter_pattern_factory, ) return where_filter_spec_resolver.resolve_lookup() diff --git a/metricflow/test/query/group_by_item/filter_spec_resolution/test_spec_lookup.py b/metricflow/test/query/group_by_item/filter_spec_resolution/test_spec_lookup.py index 9f297b5d4c..b9db5e5148 100644 --- a/metricflow/test/query/group_by_item/filter_spec_resolution/test_spec_lookup.py +++ b/metricflow/test/query/group_by_item/filter_spec_resolution/test_spec_lookup.py @@ -22,6 +22,9 @@ from metricflow.collection_helpers.pretty_print import mf_pformat from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.naming.naming_scheme import QueryItemNamingScheme +from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import ( + DefaultWhereFilterPatternFactory, +) from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import ( FilterSpecResolutionLookUp, ) @@ -66,6 +69,7 @@ def test_filter_spec_resolution( # noqa: D spec_pattern_resolver = WhereFilterSpecResolver( manifest_lookup=ambiguous_resolution_manifest_lookup, resolution_dag=resolution_dag, + spec_pattern_factory=DefaultWhereFilterPatternFactory(), ) resolution_result = spec_pattern_resolver.resolve_lookup() @@ -105,6 +109,7 @@ def check_resolution_with_filter( # noqa: D spec_pattern_resolver = WhereFilterSpecResolver( manifest_lookup=manifest_lookup, resolution_dag=resolution_dag, + spec_pattern_factory=DefaultWhereFilterPatternFactory(), ) resolution_result = spec_pattern_resolver.resolve_lookup()