Skip to content

Commit

Permalink
Added SqlWindowFunctionExpression and MetadataSpec/Instance (#355)
Browse files Browse the repository at this point in the history
* replace SqlFunctionExpression to SqlAggregateFunctionExpression

* added SqlWindowFunctionExpression

* added ExtraSpec/Instance

* added AppendRowNumberColumnNode

* added test for AppendRowNumberColumnNode

* updated snapshots

* address comments

* revert AppendRowNumberColumnNode
  • Loading branch information
WilliamDee authored Nov 30, 2022
1 parent b836679 commit ca2385c
Show file tree
Hide file tree
Showing 49 changed files with 609 additions and 251 deletions.
1 change: 1 addition & 0 deletions metricflow/dag/id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
SQL_EXPR_DATE_TRUNC = "dt"
SQL_EXPR_RATIO_COMPUTATION = "rc"
SQL_EXPR_BETWEEN_PREFIX = "betw"
SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc"

SQL_PLAN_SELECT_STATEMENT_ID_PREFIX = "ss"
SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc"
Expand Down
14 changes: 14 additions & 0 deletions metricflow/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from metricflow.dataclass_serialization import SerializableDataclass
from metricflow.references import ElementReference
from metricflow.specs import (
MetadataSpec,
MeasureSpec,
DimensionSpec,
IdentifierSpec,
Expand Down Expand Up @@ -171,6 +172,12 @@ class MetricInstance(MdoInstance[MetricSpec], SerializableDataclass): # noqa: D
defined_from: Tuple[MetricModelReference, ...]


@dataclass(frozen=True)
class MetadataInstance(MdoInstance[MetadataSpec], SerializableDataclass): # noqa: D
associated_columns: Tuple[ColumnAssociation, ...]
spec: MetadataSpec


# Output type of transform function
TransformOutputT = TypeVar("TransformOutputT")

Expand Down Expand Up @@ -200,6 +207,7 @@ class InstanceSet(SerializableDataclass):
time_dimension_instances: Tuple[TimeDimensionInstance, ...] = ()
identifier_instances: Tuple[IdentifierInstance, ...] = ()
metric_instances: Tuple[MetricInstance, ...] = ()
metadata_instances: Tuple[MetadataInstance, ...] = ()

def transform(self, transform_function: InstanceSetTransform[TransformOutputT]) -> TransformOutputT: # noqa: D
return transform_function.transform(self)
Expand All @@ -215,6 +223,7 @@ def merge(instance_sets: List[InstanceSet]) -> InstanceSet:
time_dimension_instances: List[TimeDimensionInstance] = []
identifier_instances: List[IdentifierInstance] = []
metric_instances: List[MetricInstance] = []
metadata_instances: List[MetadataInstance] = []

for instance_set in instance_sets:
for measure_instance in instance_set.measure_instances:
Expand All @@ -232,13 +241,17 @@ def merge(instance_sets: List[InstanceSet]) -> InstanceSet:
for metric_instance in instance_set.metric_instances:
if metric_instance.spec not in {x.spec for x in metric_instances}:
metric_instances.append(metric_instance)
for metadata_instance in instance_set.metadata_instances:
if metadata_instance.spec not in {x.spec for x in metadata_instances}:
metadata_instances.append(metadata_instance)

return InstanceSet(
measure_instances=tuple(measure_instances),
dimension_instances=tuple(dimension_instances),
time_dimension_instances=tuple(time_dimension_instances),
identifier_instances=tuple(identifier_instances),
metric_instances=tuple(metric_instances),
metadata_instances=tuple(metadata_instances),
)

@property
Expand All @@ -249,4 +262,5 @@ def spec_set(self) -> InstanceSpecSet: # noqa: D
time_dimension_specs=tuple(x.spec for x in self.time_dimension_instances),
identifier_specs=tuple(x.spec for x in self.identifier_instances),
metric_specs=tuple(x.spec for x in self.metric_instances),
metadata_specs=tuple(x.spec for x in self.metadata_instances),
)
7 changes: 7 additions & 0 deletions metricflow/plan_conversion/column_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.specs import (
MetadataSpec,
MetricSpec,
MeasureSpec,
DimensionSpec,
Expand Down Expand Up @@ -109,3 +110,9 @@ def resolve_identifier_spec(self, identifier_spec: IdentifierSpec) -> Tuple[Colu
single_column_correlation_key=SingleColumnCorrelationKey(),
),
)

def resolve_metadata_spec(self, metadata_spec: MetadataSpec) -> ColumnAssociation: # noqa: D
return ColumnAssociation(
column_name=metadata_spec.element_name,
single_column_correlation_key=SingleColumnCorrelationKey(),
)
4 changes: 2 additions & 2 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
SqlDateTruncExpression,
SqlStringLiteralExpression,
SqlBetweenExpression,
SqlFunctionExpression,
SqlAggregateFunctionExpression,
)
from metricflow.sql.sql_plan import (
SqlQueryPlan,
Expand Down Expand Up @@ -1168,7 +1168,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe
aggregation_state=AggregationState.COMPLETE,
).column_name
time_dimension_select_column = SqlSelectColumn(
expr=SqlFunctionExpression.from_aggregation_type(
expr=SqlAggregateFunctionExpression.from_aggregation_type(
node.agg_by_function,
SqlColumnReferenceExpression(
SqlColumnReference(
Expand Down
76 changes: 74 additions & 2 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MdoInstance,
DimensionInstance,
IdentifierInstance,
MetadataInstance,
MetricInstance,
MeasureInstance,
InstanceSet,
Expand All @@ -39,7 +40,7 @@
from metricflow.sql.sql_exprs import (
SqlColumnReferenceExpression,
SqlColumnReference,
SqlFunctionExpression,
SqlAggregateFunctionExpression,
)
from metricflow.sql.sql_plan import SqlSelectColumn
from metricflow.time.time_granularity import TimeGranularity
Expand Down Expand Up @@ -87,12 +88,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D
identifier_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.identifier_instances])
)
metadata_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances])
)
return SelectColumnSet(
metric_columns=metric_cols,
measure_columns=measure_cols,
dimension_columns=dimension_cols,
time_dimension_columns=time_dimension_cols,
identifier_columns=identifier_cols,
metadata_columns=metadata_cols,
)

def _make_sql_column_expression(
Expand Down Expand Up @@ -205,7 +210,7 @@ def _make_sql_column_expression_to_aggregate_measure( # noqa: D
SqlColumnReference(self._table_alias, column_name_in_table)
)

expression_to_aggregate_measure = SqlFunctionExpression.from_aggregation_type(
expression_to_aggregate_measure = SqlAggregateFunctionExpression.from_aggregation_type(
aggregation_type=aggregation_type, sql_column_expression=expression_to_get_measure
)

Expand Down Expand Up @@ -234,12 +239,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D
identifier_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.identifier_instances])
)
metadata_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances])
)
return SelectColumnSet(
metric_columns=metric_cols,
measure_columns=measure_cols,
dimension_columns=dimension_cols,
time_dimension_columns=time_dimension_cols,
identifier_columns=identifier_cols,
metadata_columns=metadata_cols,
)


Expand Down Expand Up @@ -436,6 +445,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=tuple(time_dimension_instances_with_additional_link),
identifier_instances=tuple(identifier_instances_with_additional_link),
metric_instances=(),
metadata_instances=(),
)


Expand Down Expand Up @@ -477,6 +487,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=filtered_time_dimension_instances,
identifier_instances=filtered_identifier_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
return output

Expand Down Expand Up @@ -535,6 +546,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
),
identifier_instances=tuple(x for x in instance_set.identifier_instances if self._should_pass(x.spec)),
metric_instances=tuple(x for x in instance_set.metric_instances if self._should_pass(x.spec)),
metadata_instances=tuple(x for x in instance_set.metadata_instances if self._should_pass(x.spec)),
)
return output

Expand Down Expand Up @@ -573,6 +585,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=instance_set.time_dimension_instances,
identifier_instances=instance_set.identifier_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)


Expand Down Expand Up @@ -626,6 +639,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=instance_set.time_dimension_instances,
identifier_instances=instance_set.identifier_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)


Expand All @@ -642,6 +656,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=instance_set.time_dimension_instances,
identifier_instances=instance_set.identifier_instances,
metric_instances=instance_set.metric_instances + tuple(self._metric_instances),
metadata_instances=instance_set.metadata_instances,
)


Expand All @@ -655,6 +670,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=instance_set.time_dimension_instances,
identifier_instances=instance_set.identifier_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)


Expand All @@ -668,6 +684,48 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=instance_set.time_dimension_instances,
identifier_instances=instance_set.identifier_instances,
metric_instances=(),
metadata_instances=instance_set.metadata_instances,
)


class CreateSqlColumnReferencesForInstances(InstanceSetTransform[Tuple[SqlColumnReferenceExpression, ...]]):
"""Create select column expressions that will express all instances in the set.
It assumes that the column names of the instances are represented by the supplied column association resolver and
come from the given table alias.
"""

def __init__(
self,
table_alias: str,
column_resolver: ColumnAssociationResolver,
) -> None:
"""Initializer.
Args:
table_alias: the table alias to select columns from
column_resolver: resolver to name columns.
"""
self._table_alias = table_alias
self._column_resolver = column_resolver

def transform(self, instance_set: InstanceSet) -> Tuple[SqlColumnReferenceExpression, ...]: # noqa: D
column_names = [
col.column_name
for col in (
chain.from_iterable(
[x.column_associations(self._column_resolver) for x in instance_set.spec_set.all_specs]
)
)
]
return tuple(
SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=self._table_alias,
column_name=column_name,
),
)
for column_name in column_names
)


Expand Down Expand Up @@ -745,12 +803,26 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
)
)

output_metadata_instances = []
for input_metadata_instance in instance_set.metadata_instances:
output_metadata_instances.append(
MetadataInstance(
associated_columns=(
self._column_association_resolver.resolve_metadata_spec(
metadata_spec=input_metadata_instance.spec
),
),
spec=input_metadata_instance.spec,
)
)

return InstanceSet(
measure_instances=tuple(output_measure_instances),
dimension_instances=tuple(output_dimension_instances),
time_dimension_instances=tuple(output_time_dimension_instances),
identifier_instances=tuple(output_identifier_instances),
metric_instances=tuple(output_metric_instances),
metadata_instances=tuple(output_metadata_instances),
)


Expand Down
4 changes: 4 additions & 0 deletions metricflow/plan_conversion/select_column_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class SelectColumnSet:
dimension_columns: List[SqlSelectColumn] = field(default_factory=list)
time_dimension_columns: List[SqlSelectColumn] = field(default_factory=list)
identifier_columns: List[SqlSelectColumn] = field(default_factory=list)
metadata_columns: List[SqlSelectColumn] = field(default_factory=list)

def merge(self, other_set: SelectColumnSet) -> SelectColumnSet:
"""Combine the select columns by type."""
Expand All @@ -27,6 +28,7 @@ def merge(self, other_set: SelectColumnSet) -> SelectColumnSet:
dimension_columns=self.dimension_columns + other_set.dimension_columns,
time_dimension_columns=self.time_dimension_columns + other_set.time_dimension_columns,
identifier_columns=self.identifier_columns + other_set.identifier_columns,
metadata_columns=self.metadata_columns + other_set.metadata_columns,
)

def as_tuple(self) -> Tuple[SqlSelectColumn, ...]:
Expand All @@ -38,6 +40,7 @@ def as_tuple(self) -> Tuple[SqlSelectColumn, ...]:
+ self.dimension_columns
+ self.metric_columns
+ self.measure_columns
+ self.metadata_columns
)

def without_measure_columns(self) -> SelectColumnSet:
Expand All @@ -47,4 +50,5 @@ def without_measure_columns(self) -> SelectColumnSet:
dimension_columns=self.dimension_columns,
time_dimension_columns=self.time_dimension_columns,
identifier_columns=self.identifier_columns,
metadata_columns=self.metadata_columns,
)
4 changes: 2 additions & 2 deletions metricflow/plan_conversion/spec_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SqlComparison,
SqlColumnReferenceExpression,
SqlColumnReference,
SqlFunctionExpression,
SqlAggregateFunctionExpression,
SqlFunction,
)
from metricflow.sql.sql_plan import SqlSelectColumn
Expand Down Expand Up @@ -50,7 +50,7 @@ def _make_coalesced_expr(table_aliases: Sequence[str], column_alias: str) -> Sql
)
)
)
return SqlFunctionExpression(
return SqlAggregateFunctionExpression(
sql_function=SqlFunction.COALESCE,
sql_function_args=columns_to_coalesce,
)
Expand Down
Loading

0 comments on commit ca2385c

Please sign in to comment.