Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Dec 19, 2024
1 parent 33fe95d commit 8f8f2ca
Show file tree
Hide file tree
Showing 52 changed files with 585 additions and 257 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def query(
sql_bind_parameter_set: The parameter replacement mapping for filling in concrete values for SQL query
parameters.
"""
print("sql:", stmt)
start = time.time()
request_id = SqlRequestId(f"mf_rid__{random_id()}")
if sql_bind_parameter_set.param_dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim
window_function=self.window_function,
)

def with_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionSpec: # noqa: D102
def with_window_function(self, window_function: Optional[SqlWindowFunction]) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
entity_links=self.entity_links,
Expand Down
20 changes: 18 additions & 2 deletions metricflow-semantics/metricflow_semantics/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def as_window_function_expression(self) -> Optional[SqlWindowFunctionExpression]
"""If this is a window function expression, return self."""
return None

@property
def is_verbose(self) -> bool:
"""Denotes if the statement is typically verbose, and therefore can be hard to read when optimized.
This is helpful in determining if statements will be harder to read when collapsed.
"""
return False

@abstractmethod
def rewrite(
self,
Expand Down Expand Up @@ -1017,7 +1025,7 @@ class SqlWindowFunction(Enum):
LAST_VALUE = "LAST_VALUE"
AVERAGE = "AVG"
ROW_NUMBER = "ROW_NUMBER"
LAG = "LAG"
LEAD = "LEAD"

@property
def requires_ordering(self) -> bool:
Expand All @@ -1026,7 +1034,7 @@ def requires_ordering(self) -> bool:
self is SqlWindowFunction.FIRST_VALUE
or self is SqlWindowFunction.LAST_VALUE
or self is SqlWindowFunction.ROW_NUMBER
or self is SqlWindowFunction.LAG
or self is SqlWindowFunction.LEAD
):
return True
elif self is SqlWindowFunction.AVERAGE:
Expand Down Expand Up @@ -1183,6 +1191,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
and self.sql_function_args == other.sql_function_args
)

@property
def is_verbose(self) -> bool: # noqa: D102
return True


@dataclass(frozen=True, eq=False)
class SqlNullExpression(SqlExpressionNode):
Expand Down Expand Up @@ -1870,6 +1882,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self.when_to_then_exprs == other.when_to_then_exprs and self.else_expr == other.else_expr

@property
def is_verbose(self) -> bool: # noqa: D102
return True


class SqlArithmeticOperator(Enum):
"""Arithmetic operator used to do math in a SQL expression."""
Expand Down
6 changes: 5 additions & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,10 +1889,14 @@ def _build_time_spine_node(
required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs

should_dedupe = False
filter_to_specs = tuple(queried_time_spine_specs)
if offset_window and self._offset_window_is_custom(offset_window):
time_spine_node = self._build_custom_offset_time_spine_node(
offset_window=offset_window, required_time_spine_specs=required_time_spine_specs
)
filter_to_specs = self._node_data_set_resolver.get_output_data_set(
time_spine_node
).instance_set.spec_set.time_dimension_specs
else:
# For simpler time spine queries, choose the appropriate time spine node and apply requested aliases.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
Expand Down Expand Up @@ -1920,7 +1924,7 @@ def _build_time_spine_node(

return self._build_pre_aggregation_plan(
source_node=time_spine_node,
filter_to_specs=InstanceSpecSet(time_dimension_specs=tuple(queried_time_spine_specs)),
filter_to_specs=InstanceSpecSet(time_dimension_specs=filter_to_specs),
time_range_constraint=time_range_constraint,
where_filter_specs=where_filter_specs,
distinct=should_dedupe,
Expand Down
103 changes: 68 additions & 35 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,47 +1447,74 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
time_spine_alias = self._next_unique_table_alias()

required_agg_time_dimension_specs = tuple(node.requested_agg_time_dimension_specs)
if node.join_on_time_dimension_spec not in node.requested_agg_time_dimension_specs:
join_spec_was_requested = node.join_on_time_dimension_spec in node.requested_agg_time_dimension_specs
if not join_spec_was_requested:
# Why is this necessary? Really getting in the way here!! Can we remove it?
required_agg_time_dimension_specs += (node.join_on_time_dimension_spec,)

# Build join expression.
join_column_name = self._column_association_resolver.resolve_spec(node.join_on_time_dimension_spec).column_name
parent_join_column_name = self._column_association_resolver.resolve_spec(
node.join_on_time_dimension_spec
).column_name
time_spine_jon_column_name = time_spine_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=node.join_on_time_dimension_spec.time_granularity.name, date_part=None
).associated_column.column_name
join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description(
node=node,
time_spine_alias=time_spine_alias,
agg_time_dimension_column_name=join_column_name,
time_spine_column_name=time_spine_jon_column_name,
parent_column_name=parent_join_column_name,
parent_sql_select_node=parent_data_set.checked_sql_select_node,
parent_alias=parent_alias,
)

# Build combined instance set.
# Build new instances and columns.
time_spine_required_spec_set = InstanceSpecSet(time_dimension_specs=required_agg_time_dimension_specs)
parent_instance_set = parent_data_set.instance_set.transform(
output_parent_instance_set = parent_data_set.instance_set.transform(
FilterElements(exclude_specs=time_spine_required_spec_set)
)
time_spine_instance_set = time_spine_data_set.instance_set.transform(
FilterElements(include_specs=time_spine_required_spec_set)
)
output_instance_set = InstanceSet.merge([parent_instance_set, time_spine_instance_set])
output_time_spine_instances: Tuple[TimeDimensionInstance, ...] = ()
output_time_spine_columns: Tuple[SqlSelectColumn, ...] = ()
for old_instance in time_spine_data_set.instance_set.time_dimension_instances:
new_spec = old_instance.spec.with_window_function(None)
if new_spec not in required_agg_time_dimension_specs:
continue
if old_instance.spec.window_function:
new_instance = old_instance.with_new_spec(
new_spec=new_spec, column_association_resolver=self._column_association_resolver
)
column = SqlSelectColumn(
expr=SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=old_instance.associated_column.column_name
),
column_alias=new_instance.associated_column.column_name,
)
else:
new_instance = old_instance
column = SqlSelectColumn.from_table_and_column_names(
table_alias=time_spine_alias, column_name=old_instance.associated_column.column_name
)
output_time_spine_instances += (new_instance,)
output_time_spine_columns += (column,)

# Build new simple select columns.
select_columns = create_simple_select_columns_for_instance_sets(
output_instance_set = InstanceSet.merge(
[output_parent_instance_set, InstanceSet(time_dimension_instances=output_time_spine_instances)]
)
select_columns = output_time_spine_columns + create_simple_select_columns_for_instance_sets(
self._column_association_resolver,
OrderedDict({parent_alias: parent_instance_set, time_spine_alias: time_spine_instance_set}),
OrderedDict({parent_alias: output_parent_instance_set}),
)

# If offset_to_grain is used, will need to filter down to rows that match selected granularities.
# Does not apply if one of the granularities selected matches the time spine column granularity.
where_filter: Optional[SqlExpressionNode] = None
need_where_filter = (
node.offset_to_grain and node.join_on_time_dimension_spec not in node.requested_agg_time_dimension_specs
)
need_where_filter = node.offset_to_grain and not join_spec_was_requested

# Filter down to one row per granularity period requested in the group by. Any other granularities
# included here will be filtered out before aggregation and so should not be included in where filter.
if need_where_filter:
join_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=join_column_name
table_alias=time_spine_alias, column_name=parent_join_column_name
)
for requested_spec in node.requested_agg_time_dimension_specs:
column_name = self._column_association_resolver.resolve_spec(requested_spec).column_name
Expand Down Expand Up @@ -2178,7 +2205,7 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit
first_value_offset_column, last_value_offset_column = tuple(
SqlSelectColumn(
expr=SqlWindowFunctionExpression.create(
sql_function=SqlWindowFunction.LAG,
sql_function=SqlWindowFunction.LEAD,
sql_function_args=(
SqlColumnReferenceExpression.from_table_and_column_names(
column_name=instance.associated_column.column_name,
Expand All @@ -2203,9 +2230,6 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit

# Offset the base column by the requested window. If the offset date is not within the offset custom grain period,
# default to the last value in that period.
new_custom_grain_column = SqlSelectColumn.from_table_and_column_names(
column_name=custom_grain_column_name, table_alias=bounds_data_set_alias
)
first_value_offset_expr, last_value_offset_expr = [
SqlColumnReferenceExpression.from_table_and_column_names(
column_name=offset_column.column_alias, table_alias=offset_bounds_subquery_alias
Expand All @@ -2228,12 +2252,20 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit
comparison=SqlComparison.LESS_THAN_OR_EQUALS,
right_expr=last_value_offset_expr,
)
offset_base_instance = base_grain_instance.with_new_spec(
# LAG isn't quite accurate here, but this will differentiate the offset instance (and column) from the original one.
new_spec=base_grain_instance.spec.with_window_function(SqlWindowFunction.LEAD),
column_association_resolver=self._column_association_resolver,
)
offset_base_column = SqlSelectColumn(
expr=SqlCaseExpression.create(
when_to_then_exprs={is_below_last_value_expr: offset_base_grain_expr},
else_expr=last_value_offset_expr,
),
column_alias=base_grain_instance.associated_column.column_name,
column_alias=offset_base_instance.associated_column.column_name,
)
original_base_grain_column = SqlSelectColumn.from_table_and_column_names(
column_name=base_grain_instance.associated_column.column_name, table_alias=bounds_data_set_alias
)
join_desc = SqlJoinDescription(
right_source=offset_bounds_subquery,
Expand All @@ -2251,7 +2283,7 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit
)
offset_base_grain_subquery = SqlSelectStatementNode.create(
description=node.description,
select_columns=(new_custom_grain_column, offset_base_column),
select_columns=(original_base_grain_column, offset_base_column),
from_source=bounds_data_set.checked_sql_select_node,
from_source_alias=bounds_data_set_alias,
join_descs=(join_desc,),
Expand All @@ -2261,41 +2293,42 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit
# Apply standard grains & date parts requested in the query. Use base grain for any custom grains.
standard_grain_instances: Tuple[TimeDimensionInstance, ...] = ()
standard_grain_columns: Tuple[SqlSelectColumn, ...] = ()
base_column = SqlSelectColumn(
offset_base_column_ref = SqlSelectColumn(
expr=SqlColumnReferenceExpression.from_table_and_column_names(
column_name=base_grain_instance.associated_column.column_name,
column_name=offset_base_instance.associated_column.column_name,
table_alias=offset_base_grain_subquery_alias,
),
column_alias=base_grain_instance.associated_column.column_name,
)
base_grain_requested = False
for spec in node.required_time_spine_specs:
new_instance = base_grain_instance.with_new_spec(
new_spec=spec, column_association_resolver=self._column_association_resolver
)
standard_grain_instances += (new_instance,)
if spec.date_part:
expr: SqlExpressionNode = SqlExtractExpression.create(date_part=spec.date_part, arg=base_column.expr)
expr: SqlExpressionNode = SqlExtractExpression.create(
date_part=spec.date_part, arg=offset_base_column_ref.expr
)
elif spec.time_granularity.base_granularity == base_grain.base_granularity:
expr = base_column.expr
base_grain_requested = True
expr = offset_base_column_ref.expr
else:
expr = SqlDateTruncExpression.create(
time_granularity=spec.time_granularity.base_granularity, arg=base_column.expr
time_granularity=spec.time_granularity.base_granularity, arg=offset_base_column_ref.expr
)
standard_grain_columns += (
SqlSelectColumn(expr=expr, column_alias=new_instance.associated_column.column_name),
)
if not base_grain_requested:
assert 0
standard_grain_instances = (base_grain_instance,) + standard_grain_instances
standard_grain_columns = (base_column,) + standard_grain_columns

# Need to keep the non-offset base grain column in the output. This will be used to join to the source data set.
non_offset_base_grain_column = SqlSelectColumn.from_table_and_column_names(
column_name=base_grain_instance.associated_column.column_name, table_alias=offset_base_grain_subquery_alias
)

return SqlDataSet(
instance_set=InstanceSet(time_dimension_instances=standard_grain_instances),
instance_set=InstanceSet(time_dimension_instances=(base_grain_instance,) + standard_grain_instances),
sql_select_node=SqlSelectStatementNode.create(
description="Apply Requested Granularities",
select_columns=standard_grain_columns,
select_columns=(non_offset_base_grain_column,) + standard_grain_columns,
from_source=offset_base_grain_subquery,
from_source_alias=offset_base_grain_subquery_alias,
),
Expand Down
7 changes: 4 additions & 3 deletions metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,14 @@ def make_cumulative_metric_time_range_join_description(
def make_join_to_time_spine_join_description(
node: JoinToTimeSpineNode,
time_spine_alias: str,
agg_time_dimension_column_name: str,
time_spine_column_name: str,
parent_column_name: str,
parent_sql_select_node: SqlSelectStatementNode,
parent_alias: str,
) -> SqlJoinDescription:
"""Build join expression used to join a metric to a time spine dataset."""
left_expr: SqlExpressionNode = SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=agg_time_dimension_column_name)
col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=time_spine_column_name)
)
if node.offset_window:
left_expr = SqlSubtractTimeIntervalExpression.create(
Expand All @@ -551,7 +552,7 @@ def make_join_to_time_spine_join_description(
left_expr=left_expr,
comparison=SqlComparison.EQUALS,
right_expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias=parent_alias, column_name=agg_time_dimension_column_name)
col_ref=SqlColumnReference(table_alias=parent_alias, column_name=parent_column_name)
),
),
join_type=node.join_type,
Expand Down
Loading

0 comments on commit 8f8f2ca

Please sign in to comment.