Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unneeded protocols for lookup classes #1137

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions metricflow/dataflow/builder/node_data_set.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

from typing import Dict, Optional, Sequence
from typing import TYPE_CHECKING, Dict, Optional, Sequence

from metricflow.dataflow.dataflow_plan import (
DataflowPlanNode,
)
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.mf_logging.runtime import log_block_runtime
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.specs.column_assoc import ColumnAssociationResolver

if TYPE_CHECKING:
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup


class DataflowPlanNodeOutputDataSetResolver(DataflowToSqlQueryPlanConverter):
"""Given a node in a dataflow plan, figure out what is the data set output by that node.
Expand Down
4 changes: 2 additions & 2 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from metricflow.instances import InstanceSet
from metricflow.mf_logging.pretty_print import mf_pformat
from metricflow.model.semantics.semantic_model_join_evaluator import SemanticModelJoinEvaluator
from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow.plan_conversion.instance_converters import CreateValidityWindowJoinDescription
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.specs.specs import InstanceSpecSet, LinkableInstanceSpec, LinkableSpecSet, LinklessEntitySpec
from metricflow.sql.sql_plan import SqlJoinType

Expand Down Expand Up @@ -159,7 +159,7 @@ class NodeEvaluatorForLinkableInstances:

def __init__(
self,
semantic_model_lookup: SemanticModelAccessor,
semantic_model_lookup: SemanticModelLookup,
nodes_available_for_joins: Sequence[BaseOutput],
node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver,
time_spine_node: MetricTimeDimensionTransformNode,
Expand Down
4 changes: 2 additions & 2 deletions metricflow/dataflow/builder/partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Sequence, Tuple

from metricflow.dataset.dataset import DataSet
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow.specs.specs import (
DimensionSpec,
InstanceSpecSet,
Expand Down Expand Up @@ -33,7 +33,7 @@ class PartitionTimeDimensionJoinDescription:
class PartitionJoinResolver:
"""When joining data sets, this class helps to figure out the necessary partition specs to join on."""

def __init__(self, semantic_model_lookup: SemanticModelAccessor) -> None: # noqa: D107
def __init__(self, semantic_model_lookup: SemanticModelLookup) -> None: # noqa: D107
self._semantic_model_lookup = semantic_model_lookup

def _get_partitions(self, spec_set: InstanceSpecSet) -> PartitionSpecSet:
Expand Down
3 changes: 1 addition & 2 deletions metricflow/model/semantic_manifest_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from metricflow.model.semantics.metric_lookup import MetricLookup
from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow.plan_conversion.time_spine import TimeSpineSource
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.sql.sql_table import SqlTable

logger = logging.getLogger(__name__)
Expand All @@ -28,7 +27,7 @@ def semantic_manifest(self) -> SemanticManifest: # noqa: D102
return self._semantic_manifest

@property
def semantic_model_lookup(self) -> SemanticModelAccessor: # noqa: D102
def semantic_model_lookup(self) -> SemanticModelLookup: # noqa: D102
return self._semantic_model_lookup

@property
Expand Down
9 changes: 6 additions & 3 deletions metricflow/model/semantics/linkable_spec_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, FrozenSet, List, Optional, Sequence, Set, Tuple
from typing import TYPE_CHECKING, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.dimension import Dimension, DimensionType
Expand All @@ -26,7 +26,6 @@
from metricflow.mf_logging.pretty_print import mf_pformat
from metricflow.model.semantics.linkable_element_properties import LinkableElementProperty
from metricflow.model.semantics.semantic_model_join_evaluator import SemanticModelJoinEvaluator
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.specs.specs import (
DEFAULT_TIME_GRANULARITY,
DimensionSpec,
Expand All @@ -37,6 +36,10 @@
TimeDimensionSpec,
)

if TYPE_CHECKING:
from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -462,7 +465,7 @@ class ValidLinkableSpecResolver:
def __init__(
self,
semantic_manifest: SemanticManifest,
semantic_model_lookup: SemanticModelAccessor,
semantic_model_lookup: SemanticModelLookup,
max_entity_links: int,
) -> None:
"""Constructor.
Expand Down
7 changes: 3 additions & 4 deletions metricflow/model/semantics/metric_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
)
from metricflow.model.semantics.semantic_model_join_evaluator import MAX_JOIN_HOPS
from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow.protocols.semantics import MetricAccessor
from metricflow.specs.specs import TimeDimensionSpec

logger = logging.getLogger(__name__)


class MetricLookup(MetricAccessor):
class MetricLookup:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the add_metric method here is the only "write" method in this class, and it's not called anywhere. Mind just renaming it to _add_metric so we don't have any "write" methods in the public interface?

"""Tracks semantic information for metrics by linking them to semantic models."""

def __init__(self, semantic_manifest: SemanticManifest, semantic_model_lookup: SemanticModelLookup) -> None:
Expand All @@ -37,7 +36,7 @@ def __init__(self, semantic_manifest: SemanticManifest, semantic_model_lookup: S
self._semantic_model_lookup = semantic_model_lookup

for metric in semantic_manifest.metrics:
self.add_metric(metric)
self._add_metric(metric)

self._linkable_spec_resolver = ValidLinkableSpecResolver(
semantic_manifest=semantic_manifest,
Expand Down Expand Up @@ -108,7 +107,7 @@ def get_metric(self, metric_reference: MetricReference) -> Metric: # noqa: D102
raise MetricNotFoundError(f"Unable to find metric `{metric_reference}`. Perhaps it has not been registered")
return self._metrics[metric_reference]

def add_metric(self, metric: Metric) -> None:
def _add_metric(self, metric: Metric) -> None:
"""Add metric, validating presence of required measures."""
metric_reference = MetricReference(element_name=metric.name)
if metric_reference in self._metrics:
Expand Down
8 changes: 5 additions & 3 deletions metricflow/model/semantics/semantic_model_join_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional

from dbt_semantic_interfaces.protocols.entity import EntityType
from dbt_semantic_interfaces.references import (
Expand All @@ -12,7 +12,9 @@

from metricflow.instances import EntityInstance, InstanceSet
from metricflow.mf_logging.pretty_print import mf_pformat
from metricflow.protocols.semantics import SemanticModelAccessor

if TYPE_CHECKING:
from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup

MAX_JOIN_HOPS = 2

Expand Down Expand Up @@ -70,7 +72,7 @@ class SemanticModelJoinEvaluator:
SemanticModelEntityJoinType(left_entity_type=EntityType.NATURAL, right_entity_type=EntityType.NATURAL),
)

def __init__(self, semantic_model_lookup: SemanticModelAccessor) -> None: # noqa: D107
def __init__(self, semantic_model_lookup: SemanticModelLookup) -> None: # noqa: D107
self._semantic_model_lookup = semantic_model_lookup

def get_joinable_semantic_models(
Expand Down
55 changes: 30 additions & 25 deletions metricflow/model/semantics/semantic_model_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@
)
from dbt_semantic_interfaces.type_enums import DimensionType, EntityType
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from typing_extensions import override

from metricflow.errors.errors import InvalidSemanticModelError
from metricflow.mf_logging.pretty_print import mf_pformat
from metricflow.model.semantics.element_group import ElementGrouper
from metricflow.model.semantics.linkable_spec_resolver import ElementPathKey
from metricflow.model.spec_converters import MeasureConverter
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.specs.specs import (
DimensionSpec,
EntitySpec,
Expand All @@ -40,12 +38,8 @@
logger = logging.getLogger(__name__)


class SemanticModelLookup(SemanticModelAccessor):
"""Tracks semantic information for semantic models held in a set of SemanticModelContainers.

This implements the SemanticModelAccessors protocol, the interface type we use throughout the codebase.
That interface prevents unwanted calls to methods for adding semantic models to the container.
"""
class SemanticModelLookup:
"""Tracks semantic information for semantic models held in a set of SemanticModelContainers."""

def __init__(
self,
Expand Down Expand Up @@ -74,7 +68,8 @@ def __init__(
for semantic_model in sorted(model.semantic_models, key=lambda semantic_model: semantic_model.name):
self._add_semantic_model(semantic_model)

def get_dimension_references(self) -> Sequence[DimensionReference]: # noqa: D102
def get_dimension_references(self) -> Sequence[DimensionReference]:
"""Retrieve all dimension references from the collection of semantic models."""
return tuple(self._dimension_index.keys())

@staticmethod
Expand Down Expand Up @@ -110,11 +105,16 @@ def get_time_dimension(self, time_dimension_reference: TimeDimensionReference) -
return self.get_dimension(dimension_reference=time_dimension_reference.dimension_reference())

@property
def measure_references(self) -> Sequence[MeasureReference]: # noqa: D102
def measure_references(self) -> Sequence[MeasureReference]:
"""Return all measure references from the collection of semantic models."""
return list(self._measure_index.keys())

@property
def non_additive_dimension_specs_by_measure(self) -> Dict[MeasureReference, NonAdditiveDimensionSpec]: # noqa: D102
def non_additive_dimension_specs_by_measure(self) -> Dict[MeasureReference, NonAdditiveDimensionSpec]:
"""Return a mapping from all semi-additive measures to their corresponding non additive dimension parameters.

This includes all measures with non-additive dimension parameters, if any, from the collection of semantic models.
"""
return self._measure_non_additive_dimension_specs

@staticmethod
Expand All @@ -128,7 +128,8 @@ def get_measure_from_semantic_model(semantic_model: SemanticModel, measure_refer
f"No dimension with name ({measure_reference.element_name}) in semantic_model with name ({semantic_model.name})"
)

def get_measure(self, measure_reference: MeasureReference) -> Measure: # noqa: D102
def get_measure(self, measure_reference: MeasureReference) -> Measure:
"""Retrieve the measure model object associated with the measure reference."""
if measure_reference not in self._measure_index:
raise ValueError(f"Could not find measure with name ({measure_reference}) in configured semantic models")

Expand All @@ -142,20 +143,24 @@ def get_measure(self, measure_reference: MeasureReference) -> Measure: # noqa:
semantic_model=semantic_models[0], measure_reference=measure_reference
)

def get_entity_references(self) -> Sequence[EntityReference]: # noqa: D102
def get_entity_references(self) -> Sequence[EntityReference]:
"""Retrieve all entity references from the collection of semantic models."""
return list(self._entity_index.keys())

def get_semantic_models_for_measure( # noqa: D102
self, measure_reference: MeasureReference
) -> Sequence[SemanticModel]: # noqa: D102
def get_semantic_models_for_measure(self, measure_reference: MeasureReference) -> Sequence[SemanticModel]:
"""Retrieve semantic model where the measure is defined."""
return self._measure_index.get(measure_reference, [])

def get_agg_time_dimension_for_measure( # noqa: D102
self, measure_reference: MeasureReference
) -> TimeDimensionReference: # noqa: D102
def get_agg_time_dimension_for_measure(self, measure_reference: MeasureReference) -> TimeDimensionReference:
"""Retrieves the aggregate time dimension that is associated with the measure reference.

This is the time dimension along which the measure will be aggregated when a metric built on this measure
is queried with metric_time.
"""
return self._measure_agg_time_dimension[measure_reference]

def get_entity_in_semantic_model(self, ref: SemanticModelElementReference) -> Optional[Entity]: # noqa: D102
def get_entity_in_semantic_model(self, ref: SemanticModelElementReference) -> Optional[Entity]:
"""Retrieve the entity matching the element -> semantic model mapping, if any."""
semantic_model = self.get_by_reference(ref.semantic_model_reference)
if not semantic_model:
return None
Expand All @@ -166,9 +171,8 @@ def get_entity_in_semantic_model(self, ref: SemanticModelElementReference) -> Op

return None

def get_by_reference( # noqa: D102
self, semantic_model_reference: SemanticModelReference
) -> Optional[SemanticModel]: # noqa: D102
def get_by_reference(self, semantic_model_reference: SemanticModelReference) -> Optional[SemanticModel]:
"""Retrieve the semantic model object matching the input semantic model reference, if any."""
return self._semantic_model_reference_to_semantic_model.get(semantic_model_reference)

def _add_semantic_model(self, semantic_model: SemanticModel) -> None:
Expand Down Expand Up @@ -309,8 +313,8 @@ def resolved_primary_entity(semantic_model: SemanticModel) -> Optional[EntityRef
return None

@staticmethod
@override
def entity_links_for_local_elements(semantic_model: SemanticModel) -> Sequence[EntityReference]:
"""Return the entity prefix that can be used to access dimensions defined in the semantic model."""
primary_entity_reference = semantic_model.primary_entity_reference

possible_entity_links = set()
Expand All @@ -323,7 +327,8 @@ def entity_links_for_local_elements(semantic_model: SemanticModel) -> Sequence[E

return sorted(possible_entity_links, key=lambda entity_reference: entity_reference.element_name)

def get_element_spec_for_name(self, element_name: str) -> LinkableInstanceSpec: # noqa: D102
def get_element_spec_for_name(self, element_name: str) -> LinkableInstanceSpec:
"""Returns the spec for the given name of a linkable element (dimension or entity)."""
if TimeDimensionReference(element_name=element_name) in self._dimension_ref_to_spec:
return self._dimension_ref_to_spec[TimeDimensionReference(element_name=element_name)]
elif DimensionReference(element_name=element_name) in self._dimension_ref_to_spec:
Expand Down
11 changes: 6 additions & 5 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
MetricInstance,
TimeDimensionInstance,
)
from metricflow.model.semantics.metric_lookup import MetricLookup
from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow.plan_conversion.select_column_gen import SelectColumnSet
from metricflow.protocols.semantics import MetricAccessor, SemanticModelAccessor
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import (
DimensionSpec,
Expand Down Expand Up @@ -181,7 +182,7 @@ def __init__( # noqa: D107
self,
table_alias: str,
column_resolver: ColumnAssociationResolver,
semantic_model_lookup: SemanticModelAccessor,
semantic_model_lookup: SemanticModelLookup,
metric_input_measure_specs: Sequence[MetricInputMeasureSpec],
) -> None:
self._semantic_model_lookup = semantic_model_lookup
Expand Down Expand Up @@ -292,8 +293,8 @@ class CreateValidityWindowJoinDescription(InstanceSetTransform[Optional[Validity
an SCD source, and extracting validity window information accordingly.
"""

def __init__(self, semantic_model_lookup: SemanticModelAccessor) -> None:
"""Initializer. The SemanticModelAccessor is needed for getting the original model definition."""
def __init__(self, semantic_model_lookup: SemanticModelLookup) -> None:
"""Initializer. The SemanticModelLookup is needed for getting the original model definition."""
self._semantic_model_lookup = semantic_model_lookup

def _get_validity_window_dimensions_for_semantic_model(
Expand Down Expand Up @@ -838,7 +839,7 @@ def __init__( # noqa: D107
self,
table_alias: str,
column_resolver: ColumnAssociationResolver,
metric_lookup: MetricAccessor,
metric_lookup: MetricLookup,
) -> None:
self._table_alias = table_alias
self._column_resolver = column_resolver
Expand Down
4 changes: 2 additions & 2 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.mf_logging.pretty_print import mf_pformat
from metricflow.model.semantics.semantic_model_join_evaluator import MAX_JOIN_HOPS, SemanticModelJoinEvaluator
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow.specs.spec_set_transforms import ToElementNameSet
from metricflow.specs.specs import InstanceSpecSet, LinkableInstanceSpec, LinklessEntitySpec
from metricflow.sql.sql_plan import SqlJoinType
Expand Down Expand Up @@ -78,7 +78,7 @@ class PreJoinNodeProcessor:

def __init__( # noqa: D107
self,
semantic_model_lookup: SemanticModelAccessor,
semantic_model_lookup: SemanticModelLookup,
node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver,
):
self._node_data_set_resolver = node_data_set_resolver
Expand Down
Loading
Loading