Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand MetricFlowEngine Initializer Signature #961

Merged
merged 2 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def __init__(
semantic_manifest_lookup: SemanticManifestLookup,
sql_client: SqlClient,
time_source: TimeSource = ServerTimeSource(),
query_parser: Optional[MetricFlowQueryParser] = None,
column_association_resolver: Optional[ColumnAssociationResolver] = None,
) -> None:
"""Initializer for MetricFlowEngine.
Expand Down Expand Up @@ -365,9 +366,8 @@ def __init__(
)
self._executor = SequentialPlanExecutor()

self._query_parser = MetricFlowQueryParser(
column_association_resolver=self._column_association_resolver,
model=self._semantic_manifest_lookup,
self._query_parser = query_parser or MetricFlowQueryParser(
semantic_manifest_lookup=self._semantic_manifest_lookup,
)

@log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

from abc import ABC, abstractmethod

from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
EntityCallParameterSet,
TimeDimensionCallParameterSet,
)
from typing_extensions import override

from metricflow.specs.patterns.spec_pattern import SpecPattern
from metricflow.specs.patterns.typed_patterns import DimensionPattern, EntityPattern, TimeDimensionPattern


class WhereFilterPatternFactory(ABC):
"""Interface that defines how spec patterns should be generated for the group-by-items specified in filters."""

@abstractmethod
def create_for_dimension_call_parameter_set( # noqa: D
self, dimension_call_parameter_set: DimensionCallParameterSet
) -> SpecPattern:
raise NotImplementedError

@abstractmethod
def create_for_time_dimension_call_parameter_set( # noqa: D
self, time_dimension_call_parameter_set: TimeDimensionCallParameterSet
) -> SpecPattern:
raise NotImplementedError

@abstractmethod
def create_for_entity_call_parameter_set( # noqa: D
self, entity_call_parameter_set: EntityCallParameterSet
) -> SpecPattern:
raise NotImplementedError


class DefaultWhereFilterPatternFactory(WhereFilterPatternFactory):
"""Default implementation using patterns derived from EntityLinkPattern."""

@override
def create_for_dimension_call_parameter_set(
self, dimension_call_parameter_set: DimensionCallParameterSet
) -> SpecPattern:
return DimensionPattern.from_call_parameter_set(dimension_call_parameter_set)

@override
def create_for_time_dimension_call_parameter_set(
self, time_dimension_call_parameter_set: TimeDimensionCallParameterSet
) -> SpecPattern:
return TimeDimensionPattern.from_call_parameter_set(time_dimension_call_parameter_set)

@override
def create_for_entity_call_parameter_set(self, entity_call_parameter_set: EntityCallParameterSet) -> SpecPattern:
return EntityPattern.from_call_parameter_set(entity_call_parameter_set)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow.naming.object_builder_str import ObjectBuilderNameConverter
from metricflow.query.group_by_item.candidate_push_down.push_down_visitor import DagTraversalPathTracker
from metricflow.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import WhereFilterPatternFactory
from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolution,
FilterSpecResolutionLookUp,
Expand Down Expand Up @@ -44,7 +45,6 @@
from metricflow.query.issues.issues_base import (
MetricFlowQueryResolutionIssueSet,
)
from metricflow.specs.patterns.typed_patterns import DimensionPattern, EntityPattern, TimeDimensionPattern

logger = logging.getLogger(__name__)

Expand All @@ -66,13 +66,18 @@ def __init__( # noqa: D
self,
manifest_lookup: SemanticManifestLookup,
resolution_dag: GroupByItemResolutionDag,
spec_pattern_factory: WhereFilterPatternFactory,
) -> None:
self._manifest_lookup = manifest_lookup
self._resolution_dag = resolution_dag
self.spec_pattern_factory = spec_pattern_factory

def resolve_lookup(self) -> FilterSpecResolutionLookUp:
"""Find all where filters and return a lookup that provides the specs for the included group-by-items."""
visitor = _ResolveWhereFilterSpecVisitor(manifest_lookup=self._manifest_lookup)
visitor = _ResolveWhereFilterSpecVisitor(
manifest_lookup=self._manifest_lookup,
spec_pattern_factory=self.spec_pattern_factory,
)

return self._resolution_dag.sink_node.accept(visitor)

Expand All @@ -85,9 +90,12 @@ class _ResolveWhereFilterSpecVisitor(GroupByItemResolutionNodeVisitor[FilterSpec
collected and returned in a lookup object.
"""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D
def __init__( # noqa: D
self, manifest_lookup: SemanticManifestLookup, spec_pattern_factory: WhereFilterPatternFactory
) -> None:
self._manifest_lookup = manifest_lookup
self._path_from_start_node_tracker = DagTraversalPathTracker()
self._spec_pattern_factory = spec_pattern_factory

@staticmethod
def _dedupe_filter_call_parameter_sets(
Expand Down Expand Up @@ -121,8 +129,8 @@ def _dedupe_filter_call_parameter_sets(
),
)

@staticmethod
def _map_filter_parameter_sets_to_pattern(
self,
filter_call_parameter_sets: FilterCallParameterSets,
) -> Sequence[PatternAssociationForWhereFilterGroupByItem]:
"""Given the call parameter sets in a filter, map them to spec patterns.
Expand All @@ -140,7 +148,9 @@ def _map_filter_parameter_sets_to_pattern(
object_builder_str=ObjectBuilderNameConverter.input_str_from_dimension_call_parameter_set(
dimension_call_parameter_set
),
spec_pattern=DimensionPattern.from_call_parameter_set(dimension_call_parameter_set),
spec_pattern=self._spec_pattern_factory.create_for_dimension_call_parameter_set(
dimension_call_parameter_set
),
)
)
for time_dimension_call_parameter_set in filter_call_parameter_sets.time_dimension_call_parameter_sets:
Expand All @@ -150,7 +160,9 @@ def _map_filter_parameter_sets_to_pattern(
object_builder_str=ObjectBuilderNameConverter.input_str_from_time_dimension_call_parameter_set(
time_dimension_call_parameter_set
),
spec_pattern=TimeDimensionPattern.from_call_parameter_set(time_dimension_call_parameter_set),
spec_pattern=self._spec_pattern_factory.create_for_time_dimension_call_parameter_set(
time_dimension_call_parameter_set
),
)
)
for entity_call_parameter_set in filter_call_parameter_sets.entity_call_parameter_sets:
Expand All @@ -160,7 +172,9 @@ def _map_filter_parameter_sets_to_pattern(
object_builder_str=ObjectBuilderNameConverter.input_str_from_entity_call_parameter_set(
entity_call_parameter_set
),
spec_pattern=EntityPattern.from_call_parameter_set(entity_call_parameter_set),
spec_pattern=self._spec_pattern_factory.create_for_entity_call_parameter_set(
entity_call_parameter_set
),
)
)

Expand Down
16 changes: 10 additions & 6 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
OrderByQueryParameter,
SavedQueryParameter,
)
from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import (
DefaultWhereFilterPatternFactory,
WhereFilterPatternFactory,
)
from metricflow.query.group_by_item.group_by_item_resolver import GroupByItemResolver
from metricflow.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag
from metricflow.query.issues.issues_base import MetricFlowQueryResolutionIssueSet
Expand All @@ -47,7 +51,6 @@
ResolverInputForQuery,
ResolverInputForQueryLevelWhereFilterIntersection,
)
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.patterns.base_time_grain import BaseTimeGrainPattern
from metricflow.specs.patterns.metric_time_pattern import MetricTimePattern
from metricflow.specs.patterns.none_date_part import NoneDatePartPattern
Expand Down Expand Up @@ -78,15 +81,16 @@ class MetricFlowQueryParser:

def __init__( # noqa: D
self,
column_association_resolver: ColumnAssociationResolver,
model: SemanticManifestLookup,
semantic_manifest_lookup: SemanticManifestLookup,
where_filter_pattern_factory: WhereFilterPatternFactory = DefaultWhereFilterPatternFactory(),
) -> None:
self._manifest_lookup = model
self._manifest_lookup = semantic_manifest_lookup
self._metric_naming_schemes = (MetricNamingScheme(),)
self._group_by_item_naming_schemes = (
ObjectBuilderNamingScheme(),
DunderNamingScheme(),
)
self._metric_naming_schemes = (MetricNamingScheme(),)
self._where_filter_pattern_factory = where_filter_pattern_factory

def parse_and_validate_saved_query(
self,
Expand Down Expand Up @@ -414,7 +418,7 @@ def parse_and_validate_query(
)

query_resolver = MetricFlowQueryResolver(
manifest_lookup=self._manifest_lookup,
manifest_lookup=self._manifest_lookup, where_filter_pattern_factory=self._where_filter_pattern_factory
)

resolver_inputs_for_order_by: List[ResolverInputForOrderByItem] = []
Expand Down
9 changes: 8 additions & 1 deletion metricflow/query/query_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from metricflow.dag.dag_to_text import dag_as_text
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.naming.metric_scheme import MetricNamingScheme
from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import WhereFilterPatternFactory
from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import FilterSpecResolutionLookUp
from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_resolver import (
WhereFilterSpecResolver,
Expand Down Expand Up @@ -102,11 +103,16 @@ class ResolveGroupByItemsResult:
class MetricFlowQueryResolver:
"""Resolves inputs to a query (e.g. metrics, group by items into concrete specs."""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D
def __init__( # noqa: D
self,
manifest_lookup: SemanticManifestLookup,
where_filter_pattern_factory: WhereFilterPatternFactory,
) -> None:
self._manifest_lookup = manifest_lookup
self._post_resolution_query_validator = PostResolutionQueryValidator(
manifest_lookup=self._manifest_lookup,
)
self._where_filter_pattern_factory = where_filter_pattern_factory

@staticmethod
def _resolve_group_by_item_input(
Expand Down Expand Up @@ -315,6 +321,7 @@ def _build_filter_spec_lookup(
where_filter_spec_resolver = WhereFilterSpecResolver(
manifest_lookup=self._manifest_lookup,
resolution_dag=resolution_dag,
spec_pattern_factory=self._where_filter_pattern_factory,
)

return where_filter_spec_resolver.resolve_lookup()
Expand Down
7 changes: 2 additions & 5 deletions metricflow/test/fixtures/dataflow_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,10 @@ def dataflow_plan_builder( # noqa: D
@pytest.fixture
def query_parser( # noqa: D
simple_semantic_manifest_lookup: SemanticManifestLookup,
column_association_resolver: ColumnAssociationResolver,
consistent_id_object_repository: ConsistentIdObjectRepository,
) -> MetricFlowQueryParser:
return MetricFlowQueryParser(
column_association_resolver=column_association_resolver,
model=simple_semantic_manifest_lookup,
semantic_manifest_lookup=simple_semantic_manifest_lookup,
)


Expand Down Expand Up @@ -104,8 +102,7 @@ def scd_query_parser( # noqa: D
scd_semantic_manifest_lookup: SemanticManifestLookup,
) -> MetricFlowQueryParser:
return MetricFlowQueryParser(
column_association_resolver=scd_column_association_resolver,
model=scd_semantic_manifest_lookup,
semantic_manifest_lookup=scd_semantic_manifest_lookup,
)


Expand Down
3 changes: 1 addition & 2 deletions metricflow/test/fixtures/model_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def query_parser_from_yaml(yaml_contents: List[YamlConfigFile]) -> MetricFlowQue
)
SemanticManifestValidator[SemanticManifest]().checked_validations(semantic_manifest_lookup.semantic_manifest)
return MetricFlowQueryParser(
model=semantic_manifest_lookup,
column_association_resolver=DunderColumnAssociationResolver(semantic_manifest_lookup),
semantic_manifest_lookup=semantic_manifest_lookup,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from metricflow.collection_helpers.pretty_print import mf_pformat
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.naming.naming_scheme import QueryItemNamingScheme
from metricflow.query.group_by_item.filter_spec_resolution.filter_pattern_factory import (
DefaultWhereFilterPatternFactory,
)
from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolutionLookUp,
)
Expand Down Expand Up @@ -65,6 +68,7 @@ def test_filter_spec_resolution( # noqa: D
spec_pattern_resolver = WhereFilterSpecResolver(
manifest_lookup=ambiguous_resolution_manifest_lookup,
resolution_dag=resolution_dag,
spec_pattern_factory=DefaultWhereFilterPatternFactory(),
)

resolution_result = spec_pattern_resolver.resolve_lookup()
Expand Down Expand Up @@ -104,6 +108,7 @@ def check_resolution_with_filter( # noqa: D
spec_pattern_resolver = WhereFilterSpecResolver(
manifest_lookup=manifest_lookup,
resolution_dag=resolution_dag,
spec_pattern_factory=DefaultWhereFilterPatternFactory(),
)

resolution_result = spec_pattern_resolver.resolve_lookup()
Expand Down
9 changes: 3 additions & 6 deletions metricflow/test/query/test_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,9 @@ def test_suggestions_for_defined_where_filter( # noqa: D
)

semantic_manifest_lookup = SemanticManifestLookup(modified_manifest)
column_association_resolver = DunderColumnAssociationResolver(modified_manifest)

query_parser = MetricFlowQueryParser(
column_association_resolver=column_association_resolver,
model=semantic_manifest_lookup,
semantic_manifest_lookup=semantic_manifest_lookup,
)
with pytest.raises(InvalidQueryException) as e:
query_parser.parse_and_validate_query(metric_names=("listings",), group_by_names=(METRIC_TIME_ELEMENT_NAME,))
Expand Down Expand Up @@ -132,11 +130,10 @@ def test_suggestions_for_defined_filters_in_multi_metric_query(
)

semantic_manifest_lookup = SemanticManifestLookup(modified_manifest)
column_association_resolver = DunderColumnAssociationResolver(modified_manifest)
DunderColumnAssociationResolver(modified_manifest)

query_parser = MetricFlowQueryParser(
column_association_resolver=column_association_resolver,
model=semantic_manifest_lookup,
semantic_manifest_lookup=semantic_manifest_lookup,
)
with pytest.raises(InvalidQueryException) as e:
query_parser.parse_and_validate_query(
Expand Down