Skip to content

Commit

Permalink
Update interfaces for custom spec patterns in query parsing.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Jan 5, 2024
1 parent ccaf493 commit 384e04e
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 11 deletions.
3 changes: 2 additions & 1 deletion metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def __init__(
semantic_manifest_lookup: SemanticManifestLookup,
sql_client: SqlClient,
time_source: TimeSource = ServerTimeSource(),
query_parser: Optional[MetricFlowQueryParser] = None,
column_association_resolver: Optional[ColumnAssociationResolver] = None,
) -> None:
"""Initializer for MetricFlowEngine.
Expand Down Expand Up @@ -365,7 +366,7 @@ def __init__(
)
self._executor = SequentialPlanExecutor()

self._query_parser = MetricFlowQueryParser(
self._query_parser = query_parser or MetricFlowQueryParser(
semantic_manifest_lookup=self._semantic_manifest_lookup,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

from abc import ABC, abstractmethod

from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
EntityCallParameterSet,
TimeDimensionCallParameterSet,
)
from typing_extensions import override

from metricflow.specs.patterns.spec_pattern import SpecPattern
from metricflow.specs.patterns.typed_patterns import DimensionPattern, EntityPattern, TimeDimensionPattern


class WhereFilterPatternFactory(ABC):
"""Interface that defines how spec patterns should be generated for the group-by-items specified in filters."""

@abstractmethod
def create_for_dimension_call_parameter_set( # noqa: D
self, dimension_call_parameter_set: DimensionCallParameterSet
) -> SpecPattern:
raise NotImplementedError

@abstractmethod
def create_for_time_dimension_call_parameter_set( # noqa: D
self, time_dimension_call_parameter_set: TimeDimensionCallParameterSet
) -> SpecPattern:
raise NotImplementedError

@abstractmethod
def create_for_entity_call_parameter_set( # noqa: D
self, entity_call_parameter_set: EntityCallParameterSet
) -> SpecPattern:
raise NotImplementedError


class DefaultWhereFilterPatternFactory(WhereFilterPatternFactory):
"""Default implementation using patterns derived from EntityLinkPattern."""

@override
def create_for_dimension_call_parameter_set(
self, dimension_call_parameter_set: DimensionCallParameterSet
) -> SpecPattern:
return DimensionPattern.from_call_parameter_set(dimension_call_parameter_set)

@override
def create_for_time_dimension_call_parameter_set(
self, time_dimension_call_parameter_set: TimeDimensionCallParameterSet
) -> SpecPattern:
return TimeDimensionPattern.from_call_parameter_set(time_dimension_call_parameter_set)

@override
def create_for_entity_call_parameter_set(self, entity_call_parameter_set: EntityCallParameterSet) -> SpecPattern:
return EntityPattern.from_call_parameter_set(entity_call_parameter_set)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow.naming.object_builder_str import ObjectBuilderNameConverter
from metricflow.query.group_by_item.candidate_push_down.push_down_visitor import DagTraversalPathTracker
from metricflow.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import WhereFilterPatternFactory
from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolution,
FilterSpecResolutionLookUp,
Expand Down Expand Up @@ -44,7 +45,6 @@
from metricflow.query.issues.issues_base import (
MetricFlowQueryResolutionIssueSet,
)
from metricflow.specs.patterns.typed_patterns import DimensionPattern, EntityPattern, TimeDimensionPattern

logger = logging.getLogger(__name__)

Expand All @@ -66,13 +66,18 @@ def __init__( # noqa: D
self,
manifest_lookup: SemanticManifestLookup,
resolution_dag: GroupByItemResolutionDag,
spec_pattern_factory: WhereFilterPatternFactory,
) -> None:
self._manifest_lookup = manifest_lookup
self._resolution_dag = resolution_dag
self.spec_pattern_factory = spec_pattern_factory

def resolve_lookup(self) -> FilterSpecResolutionLookUp:
"""Find all where filters and return a lookup that provides the specs for the included group-by-items."""
visitor = _ResolveWhereFilterSpecVisitor(manifest_lookup=self._manifest_lookup)
visitor = _ResolveWhereFilterSpecVisitor(
manifest_lookup=self._manifest_lookup,
spec_pattern_factory=self.spec_pattern_factory,
)

return self._resolution_dag.sink_node.accept(visitor)

Expand All @@ -85,9 +90,12 @@ class _ResolveWhereFilterSpecVisitor(GroupByItemResolutionNodeVisitor[FilterSpec
collected and returned in a lookup object.
"""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D
def __init__( # noqa: D
self, manifest_lookup: SemanticManifestLookup, spec_pattern_factory: WhereFilterPatternFactory
) -> None:
self._manifest_lookup = manifest_lookup
self._path_from_start_node_tracker = DagTraversalPathTracker()
self._spec_pattern_factory = spec_pattern_factory

@staticmethod
def _dedupe_filter_call_parameter_sets(
Expand Down Expand Up @@ -121,8 +129,8 @@ def _dedupe_filter_call_parameter_sets(
),
)

@staticmethod
def _map_filter_parameter_sets_to_pattern(
self,
filter_call_parameter_sets: FilterCallParameterSets,
) -> Sequence[PatternAssociationForWhereFilterGroupByItem]:
"""Given the call parameter sets in a filter, map them to spec patterns.
Expand All @@ -140,7 +148,9 @@ def _map_filter_parameter_sets_to_pattern(
object_builder_str=ObjectBuilderNameConverter.input_str_from_dimension_call_parameter_set(
dimension_call_parameter_set
),
spec_pattern=DimensionPattern.from_call_parameter_set(dimension_call_parameter_set),
spec_pattern=self._spec_pattern_factory.create_for_dimension_call_parameter_set(
dimension_call_parameter_set
),
)
)
for time_dimension_call_parameter_set in filter_call_parameter_sets.time_dimension_call_parameter_sets:
Expand All @@ -150,7 +160,9 @@ def _map_filter_parameter_sets_to_pattern(
object_builder_str=ObjectBuilderNameConverter.input_str_from_time_dimension_call_parameter_set(
time_dimension_call_parameter_set
),
spec_pattern=TimeDimensionPattern.from_call_parameter_set(time_dimension_call_parameter_set),
spec_pattern=self._spec_pattern_factory.create_for_time_dimension_call_parameter_set(
time_dimension_call_parameter_set
),
)
)
for entity_call_parameter_set in filter_call_parameter_sets.entity_call_parameter_sets:
Expand All @@ -160,7 +172,9 @@ def _map_filter_parameter_sets_to_pattern(
object_builder_str=ObjectBuilderNameConverter.input_str_from_entity_call_parameter_set(
entity_call_parameter_set
),
spec_pattern=EntityPattern.from_call_parameter_set(entity_call_parameter_set),
spec_pattern=self._spec_pattern_factory.create_for_entity_call_parameter_set(
entity_call_parameter_set
),
)
)

Expand Down
10 changes: 8 additions & 2 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
OrderByQueryParameter,
SavedQueryParameter,
)
from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import (
DefaultWhereFilterPatternFactory,
WhereFilterPatternFactory,
)
from metricflow.query.group_by_item.group_by_item_resolver import GroupByItemResolver
from metricflow.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag
from metricflow.query.issues.issues_base import MetricFlowQueryResolutionIssueSet
Expand Down Expand Up @@ -78,13 +82,15 @@ class MetricFlowQueryParser:
def __init__( # noqa: D
self,
semantic_manifest_lookup: SemanticManifestLookup,
where_filter_pattern_factory: WhereFilterPatternFactory = DefaultWhereFilterPatternFactory(),
) -> None:
self._manifest_lookup = semantic_manifest_lookup
self._metric_naming_schemes = (MetricNamingScheme(),)
self._group_by_item_naming_schemes = (
ObjectBuilderNamingScheme(),
DunderNamingScheme(),
)
self._metric_naming_schemes = (MetricNamingScheme(),)
self._where_filter_pattern_factory = where_filter_pattern_factory

def parse_and_validate_saved_query(
self,
Expand Down Expand Up @@ -412,7 +418,7 @@ def parse_and_validate_query(
)

query_resolver = MetricFlowQueryResolver(
manifest_lookup=self._manifest_lookup,
manifest_lookup=self._manifest_lookup, where_filter_pattern_factory=self._where_filter_pattern_factory
)

resolver_inputs_for_order_by: List[ResolverInputForOrderByItem] = []
Expand Down
9 changes: 8 additions & 1 deletion metricflow/query/query_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from metricflow.dag.dag_to_text import dag_as_text
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.naming.metric_scheme import MetricNamingScheme
from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import WhereFilterPatternFactory
from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import FilterSpecResolutionLookUp
from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_resolver import (
WhereFilterSpecResolver,
Expand Down Expand Up @@ -102,11 +103,16 @@ class ResolveGroupByItemsResult:
class MetricFlowQueryResolver:
"""Resolves inputs to a query (e.g. metrics, group by items into concrete specs."""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D
def __init__( # noqa: D
self,
manifest_lookup: SemanticManifestLookup,
where_filter_pattern_factory: WhereFilterPatternFactory,
) -> None:
self._manifest_lookup = manifest_lookup
self._post_resolution_query_validator = PostResolutionQueryValidator(
manifest_lookup=self._manifest_lookup,
)
self._where_filter_pattern_factory = where_filter_pattern_factory

@staticmethod
def _resolve_group_by_item_input(
Expand Down Expand Up @@ -315,6 +321,7 @@ def _build_filter_spec_lookup(
where_filter_spec_resolver = WhereFilterSpecResolver(
manifest_lookup=self._manifest_lookup,
resolution_dag=resolution_dag,
spec_pattern_factory=self._where_filter_pattern_factory,
)

return where_filter_spec_resolver.resolve_lookup()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from metricflow.collection_helpers.pretty_print import mf_pformat
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.naming.naming_scheme import QueryItemNamingScheme
from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import (
DefaultWhereFilterPatternFactory,
)
from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolutionLookUp,
)
Expand Down Expand Up @@ -66,6 +69,7 @@ def test_filter_spec_resolution( # noqa: D
spec_pattern_resolver = WhereFilterSpecResolver(
manifest_lookup=ambiguous_resolution_manifest_lookup,
resolution_dag=resolution_dag,
spec_pattern_factory=DefaultWhereFilterPatternFactory(),
)

resolution_result = spec_pattern_resolver.resolve_lookup()
Expand Down Expand Up @@ -105,6 +109,7 @@ def check_resolution_with_filter( # noqa: D
spec_pattern_resolver = WhereFilterSpecResolver(
manifest_lookup=manifest_lookup,
resolution_dag=resolution_dag,
spec_pattern_factory=DefaultWhereFilterPatternFactory(),
)

resolution_result = spec_pattern_resolver.resolve_lookup()
Expand Down

0 comments on commit 384e04e

Please sign in to comment.