Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 6, 2024
1 parent b2233f0 commit 7b1b227
Show file tree
Hide file tree
Showing 48 changed files with 3,358 additions and 632 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ def with_grain(self, time_granularity: ExpandedTimeGranularity) -> TimeDimension
aggregation_state=self.aggregation_state,
)

@property
def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
entity_links=self.entity_links,
time_granularity=ExpandedTimeGranularity.from_time_granularity(self.time_granularity.base_granularity),
date_part=self.date_part,
aggregation_state=self.aggregation_state,
)

def with_grain_and_date_part( # noqa: D102
self, time_granularity: ExpandedTimeGranularity, date_part: Optional[DatePart]
) -> TimeDimensionSpec:
Expand Down
36 changes: 32 additions & 4 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,8 +797,13 @@ def _build_plan_for_distinct_values(

for time_dimension_spec in required_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
include_base_time_dimension_column = (
time_dimension_spec.with_base_grain in required_linkable_specs.time_dimension_specs
)
output_node = JoinToCustomGranularityNode.create(
parent_node=output_node, time_dimension_spec=time_dimension_spec
parent_node=output_node,
time_dimension_spec=time_dimension_spec,
include_base_time_dimension_column=include_base_time_dimension_column,
)

if len(query_level_filter_specs) > 0:
Expand Down Expand Up @@ -891,11 +896,26 @@ def _select_source_nodes_with_linkable_specs(
"""Find source nodes with requested linkable specs and no measures."""
# Use a dictionary to dedupe for consistent ordering.
selected_nodes: Dict[DataflowPlanNode, None] = {}
requested_linkable_specs_set = set(linkable_specs.as_tuple)

# Find the source node that will satisfy the base granularity. Custom granularities will be joined in later.
linkable_specs_set_with_base_granularities: Set[LinkableInstanceSpec] = set()
for linkable_spec in linkable_specs.as_tuple:
if isinstance(linkable_spec, TimeDimensionSpec) and linkable_spec.time_granularity.is_custom_granularity:
# TODO: if you can satisfy the spec directly, do that. If not, try base grain.
# This would enable querying custom metric_time grains without metrics, without a join.
linkable_spec_with_base_grain = linkable_spec.with_grain(
ExpandedTimeGranularity.from_time_granularity(linkable_spec.time_granularity.base_granularity)
)
linkable_specs_set_with_base_granularities.add(linkable_spec_with_base_grain)
else:
linkable_specs_set_with_base_granularities.add(linkable_spec)

for source_node in source_nodes:
output_spec_set = self._node_data_set_resolver.get_output_data_set(source_node).instance_set.spec_set
all_linkable_specs_in_node = set(output_spec_set.linkable_specs)
requested_linkable_specs_in_node = requested_linkable_specs_set.intersection(all_linkable_specs_in_node)
requested_linkable_specs_in_node = linkable_specs_set_with_base_granularities.intersection(
all_linkable_specs_in_node
)
if requested_linkable_specs_in_node:
selected_nodes[source_node] = None

Expand Down Expand Up @@ -1016,10 +1036,12 @@ def _find_dataflow_recipe(
metric_time_dimension_reference=self._metric_time_dimension_reference,
time_spine_nodes=self._source_node_set.time_spine_nodes_tuple,
)

logger.info(
f"After removing unnecessary nodes, there are {len(candidate_nodes_for_right_side_of_join)} candidate "
f"nodes for the right side of the join"
)
# TODO: test multi-hop to ensure we handle custom grains correctly
if DataflowPlanBuilder._contains_multihop_linkables(linkable_specs):
candidate_nodes_for_right_side_of_join = list(
node_processor.add_multi_hop_joins(
Expand Down Expand Up @@ -1080,6 +1102,7 @@ def _find_dataflow_recipe(
)

start_time = time.time()
# Probably happening here
evaluation = node_evaluator.evaluate_node(
left_node=node,
required_linkable_specs=list(linkable_specs),
Expand Down Expand Up @@ -1542,8 +1565,13 @@ def _build_aggregated_measure_from_measure_source_node(

for time_dimension_spec in queried_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
include_base_time_dimension_column = (
time_dimension_spec.with_base_grain in required_linkable_specs.time_dimension_specs
)
unaggregated_measure_node = JoinToCustomGranularityNode.create(
parent_node=unaggregated_measure_node, time_dimension_spec=time_dimension_spec
parent_node=unaggregated_measure_node,
time_dimension_spec=time_dimension_spec,
include_base_time_dimension_column=include_base_time_dimension_column,
)

# If time constraint was previously adjusted for cumulative window or grain, apply original time constraint
Expand Down
10 changes: 7 additions & 3 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from metricflow_semantics.specs.entity_spec import LinklessEntitySpec
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.sql.sql_join_type import SqlJoinType

from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
Expand Down Expand Up @@ -406,19 +407,22 @@ def evaluate_node(
logger.debug(f"Candidate spec set is:\n{mf_pformat(candidate_spec_set)}")

data_set_linkable_specs = candidate_spec_set.linkable_specs
# Look for which nodes can satisfy the linkable specs at their base grains. Custom grains will be joined later.
required_linkable_specs_with_base_grains = [
spec.with_base_grain if isinstance(spec, TimeDimensionSpec) else spec for spec in required_linkable_specs
]

# These are linkable specs in the start node data set. Those are considered "local".
local_linkable_specs: List[LinkableInstanceSpec] = []

# These are linkable specs that aren't in the data set, but they might be able to be joined in.
possibly_joinable_linkable_specs: List[LinkableInstanceSpec] = []

# Group required_linkable_specs into local / un-joinable / or possibly joinable.
unjoinable_linkable_specs = []
for required_linkable_spec in required_linkable_specs:
for required_linkable_spec in required_linkable_specs_with_base_grains:
is_metric_time = required_linkable_spec.element_name == DataSet.metric_time_dimension_name()
is_local = required_linkable_spec in data_set_linkable_specs
is_unjoinable = not is_metric_time and (
is_unjoinable = (not is_metric_time) and (
len(required_linkable_spec.entity_links) == 0
or LinklessEntitySpec.from_reference(required_linkable_spec.entity_links[0])
not in data_set_linkable_specs
Expand Down
20 changes: 16 additions & 4 deletions metricflow/dataflow/nodes/join_to_custom_granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class JoinToCustomGranularityNode(DataflowPlanNode, ABC):
"""Join parent dataset to time spine dataset to convert time dimension to a custom granularity."""

time_dimension_spec: TimeDimensionSpec
include_base_time_dimension_column: bool

def __post_init__(self) -> None: # noqa: D105
assert (
Expand All @@ -26,9 +27,13 @@ def __post_init__(self) -> None: # noqa: D105

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec
parent_node: DataflowPlanNode, time_dimension_spec: TimeDimensionSpec, include_base_time_dimension_column: bool
) -> JoinToCustomGranularityNode:
return JoinToCustomGranularityNode(parent_nodes=(parent_node,), time_dimension_spec=time_dimension_spec)
return JoinToCustomGranularityNode(
parent_nodes=(parent_node,),
time_dimension_spec=time_dimension_spec,
include_base_time_dimension_column=include_base_time_dimension_column,
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
Expand All @@ -45,19 +50,26 @@ def description(self) -> str: # noqa: D102
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("time_dimension_spec", self.time_dimension_spec),
DisplayedProperty("include_base_time_dimension_column", self.include_base_time_dimension_column),
)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return isinstance(other_node, self.__class__) and other_node.time_dimension_spec == self.time_dimension_spec
return (
isinstance(other_node, self.__class__)
and other_node.time_dimension_spec == self.time_dimension_spec
and other_node.include_base_time_dimension_column == self.include_base_time_dimension_column
)

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> JoinToCustomGranularityNode:
assert len(new_parent_nodes) == 1, "JoinToCustomGranularity accepts exactly one parent node."
return JoinToCustomGranularityNode.create(
parent_node=new_parent_nodes[0], time_dimension_spec=self.time_dimension_spec
parent_node=new_parent_nodes[0],
time_dimension_spec=self.time_dimension_spec,
include_base_time_dimension_column=self.include_base_time_dimension_column,
)
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O
def visit_join_to_custom_granularity_node( # noqa: D102
self, node: JoinToCustomGranularityNode
) -> OptimizeBranchResult:
raise NotImplementedError
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult:
"""Handles pushdown state propagation for the standard join node type.
Expand Down
72 changes: 46 additions & 26 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def _create_time_dimension_instance(
self,
element_name: str,
entity_links: Tuple[EntityReference, ...],
time_granularity: TimeGranularity,
time_granularity: ExpandedTimeGranularity,
date_part: Optional[DatePart] = None,
semantic_model_name: Optional[str] = None,
) -> TimeDimensionInstance:
"""Create a time dimension instance from the dimension object from a semantic model in the model."""
time_dimension_spec = TimeDimensionSpec(
element_name=element_name,
entity_links=entity_links,
time_granularity=ExpandedTimeGranularity.from_time_granularity(time_granularity),
time_granularity=time_granularity,
date_part=date_part,
)

Expand Down Expand Up @@ -289,7 +289,7 @@ def _convert_time_dimension(
semantic_model_name=semantic_model_name,
element_name=dimension.reference.element_name,
entity_links=entity_links,
time_granularity=defined_time_granularity,
time_granularity=ExpandedTimeGranularity.from_time_granularity(defined_time_granularity),
)
time_dimension_instances.append(time_dimension_instance)

Expand All @@ -305,7 +305,7 @@ def _convert_time_dimension(
)
else:
select_columns.append(
self._build_column_for_time_granularity(
self._build_column_for_standard_time_granularity(
time_granularity=defined_time_granularity,
expr=dimension_select_expr,
column_alias=time_dimension_instance.associated_column.column_name,
Expand Down Expand Up @@ -340,12 +340,12 @@ def _build_time_dimension_instances_and_columns(
semantic_model_name=semantic_model_name,
element_name=element_name,
entity_links=entity_links,
time_granularity=time_granularity,
time_granularity=ExpandedTimeGranularity.from_time_granularity(time_granularity),
)
time_dimension_instances.append(time_dimension_instance)

select_columns.append(
self._build_column_for_time_granularity(
self._build_column_for_standard_time_granularity(
time_granularity=time_granularity,
expr=dimension_select_expr,
column_alias=time_dimension_instance.associated_column.column_name,
Expand All @@ -359,7 +359,7 @@ def _build_time_dimension_instances_and_columns(
semantic_model_name=semantic_model_name,
element_name=element_name,
entity_links=entity_links,
time_granularity=defined_time_granularity,
time_granularity=ExpandedTimeGranularity.from_time_granularity(defined_time_granularity),
date_part=date_part,
)
time_dimension_instances.append(time_dimension_instance)
Expand All @@ -373,7 +373,7 @@ def _build_time_dimension_instances_and_columns(

return (time_dimension_instances, select_columns)

def _build_column_for_time_granularity(
def _build_column_for_standard_time_granularity(
self, time_granularity: TimeGranularity, expr: SqlExpressionNode, column_alias: str
) -> SqlSelectColumn:
return SqlSelectColumn(
Expand Down Expand Up @@ -517,35 +517,55 @@ def create_sql_source_data_set(self, semantic_model: SemanticModel) -> SemanticM
def build_time_spine_source_data_set(self, time_spine_source: TimeSpineSource) -> SqlDataSet:
"""Build data set for time spine."""
from_source_alias = SequentialIdGenerator.create_next_id(StaticIdPrefix.TIME_SPINE_SOURCE).str_value
defined_time_granularity = time_spine_source.base_granularity
base_granularity = time_spine_source.base_granularity
time_column_name = time_spine_source.base_column

time_dimension_instances: List[TimeDimensionInstance] = []
select_columns: List[SqlSelectColumn] = []

time_dimension_instance = self._create_time_dimension_instance(
element_name=time_column_name, entity_links=(), time_granularity=defined_time_granularity
# Build base time dimension instances & columns
base_time_dimension_instance = self._create_time_dimension_instance(
element_name=time_column_name,
entity_links=(),
time_granularity=ExpandedTimeGranularity.from_time_granularity(base_granularity),
)
time_dimension_instances.append(time_dimension_instance)

dimension_select_expr = SemanticModelToDataSetConverter._make_element_sql_expr(
time_dimension_instances.append(base_time_dimension_instance)
base_dimension_select_expr = SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=from_source_alias, element_name=time_column_name
)
select_column = self._build_column_for_time_granularity(
time_granularity=defined_time_granularity,
expr=dimension_select_expr,
column_alias=time_dimension_instance.associated_column.column_name,
base_select_column = self._build_column_for_standard_time_granularity(
time_granularity=base_granularity,
expr=base_dimension_select_expr,
column_alias=base_time_dimension_instance.associated_column.column_name,
)
select_columns.append(select_column)

new_instances, new_columns = self._build_time_dimension_instances_and_columns(
defined_time_granularity=defined_time_granularity,
element_name=time_column_name,
select_columns.append(base_select_column)
new_base_instances, new_base_columns = self._build_time_dimension_instances_and_columns(
defined_time_granularity=base_granularity,
element_name=time_column_name, # is this right? should it be metric time instead?
entity_links=(),
dimension_select_expr=dimension_select_expr,
dimension_select_expr=base_dimension_select_expr,
)
time_dimension_instances.extend(new_instances)
select_columns.extend(new_columns)
time_dimension_instances.extend(new_base_instances)
select_columns.extend(new_base_columns)

# Build custom granularity time dimension instances & columns
for custom_granularity in time_spine_source.custom_granularities:
custom_time_dimension_instance = self._create_time_dimension_instance(
element_name=time_column_name, # is this right? should it be metric time instead? or gran name instead?
entity_links=(),
time_granularity=ExpandedTimeGranularity(
name=custom_granularity.name, base_granularity=base_granularity
),
)
time_dimension_instances.append(custom_time_dimension_instance)
custom_select_column = SqlSelectColumn(
expr=SemanticModelToDataSetConverter._make_element_sql_expr(
table_alias=from_source_alias,
element_name=custom_granularity.column_name or custom_granularity.name,
),
column_alias=custom_time_dimension_instance.associated_column.column_name,
)
select_columns.append(custom_select_column)

return SqlDataSet(
instance_set=InstanceSet(time_dimension_instances=tuple(time_dimension_instances)),
Expand Down
Loading

0 comments on commit 7b1b227

Please sign in to comment.