Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added SqlWindowFunctionExpression and MetadataSpec/Instance #355

Merged
merged 8 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions metricflow/dag/id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
SQL_EXPR_DATE_TRUNC = "dt"
SQL_EXPR_RATIO_COMPUTATION = "rc"
SQL_EXPR_BETWEEN_PREFIX = "betw"
SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc"

SQL_PLAN_SELECT_STATEMENT_ID_PREFIX = "ss"
SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc"
Expand Down
14 changes: 14 additions & 0 deletions metricflow/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from metricflow.dataclass_serialization import SerializableDataclass
from metricflow.references import ElementReference
from metricflow.specs import (
MetadataSpec,
MeasureSpec,
DimensionSpec,
IdentifierSpec,
Expand Down Expand Up @@ -171,6 +172,12 @@ class MetricInstance(MdoInstance[MetricSpec], SerializableDataclass): # noqa: D
defined_from: Tuple[MetricModelReference, ...]


@dataclass(frozen=True)
class MetadataInstance(MdoInstance[MetadataSpec], SerializableDataclass): # noqa: D
associated_columns: Tuple[ColumnAssociation, ...]
spec: MetadataSpec


# Output type of transform function
TransformOutputT = TypeVar("TransformOutputT")

Expand Down Expand Up @@ -200,6 +207,7 @@ class InstanceSet(SerializableDataclass):
time_dimension_instances: Tuple[TimeDimensionInstance, ...] = ()
identifier_instances: Tuple[IdentifierInstance, ...] = ()
metric_instances: Tuple[MetricInstance, ...] = ()
metadata_instances: Tuple[MetadataInstance, ...] = ()

def transform(self, transform_function: InstanceSetTransform[TransformOutputT]) -> TransformOutputT: # noqa: D
return transform_function.transform(self)
Expand All @@ -215,6 +223,7 @@ def merge(instance_sets: List[InstanceSet]) -> InstanceSet:
time_dimension_instances: List[TimeDimensionInstance] = []
identifier_instances: List[IdentifierInstance] = []
metric_instances: List[MetricInstance] = []
metadata_instances: List[MetadataInstance] = []

for instance_set in instance_sets:
for measure_instance in instance_set.measure_instances:
Expand All @@ -232,13 +241,17 @@ def merge(instance_sets: List[InstanceSet]) -> InstanceSet:
for metric_instance in instance_set.metric_instances:
if metric_instance.spec not in {x.spec for x in metric_instances}:
metric_instances.append(metric_instance)
for metadata_instance in instance_set.metadata_instances:
if metadata_instance.spec not in {x.spec for x in metadata_instances}:
metadata_instances.append(metadata_instance)

return InstanceSet(
measure_instances=tuple(measure_instances),
dimension_instances=tuple(dimension_instances),
time_dimension_instances=tuple(time_dimension_instances),
identifier_instances=tuple(identifier_instances),
metric_instances=tuple(metric_instances),
metadata_instances=tuple(metadata_instances),
)

@property
Expand All @@ -249,4 +262,5 @@ def spec_set(self) -> InstanceSpecSet: # noqa: D
time_dimension_specs=tuple(x.spec for x in self.time_dimension_instances),
identifier_specs=tuple(x.spec for x in self.identifier_instances),
metric_specs=tuple(x.spec for x in self.metric_instances),
metadata_specs=tuple(x.spec for x in self.metadata_instances),
)
7 changes: 7 additions & 0 deletions metricflow/plan_conversion/column_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.specs import (
MetadataSpec,
MetricSpec,
MeasureSpec,
DimensionSpec,
Expand Down Expand Up @@ -109,3 +110,9 @@ def resolve_identifier_spec(self, identifier_spec: IdentifierSpec) -> Tuple[Colu
single_column_correlation_key=SingleColumnCorrelationKey(),
),
)

def resolve_metadata_spec(self, metadata_spec: MetadataSpec) -> ColumnAssociation: # noqa: D
return ColumnAssociation(
column_name=metadata_spec.element_name,
single_column_correlation_key=SingleColumnCorrelationKey(),
)
4 changes: 2 additions & 2 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
SqlDateTruncExpression,
SqlStringLiteralExpression,
SqlBetweenExpression,
SqlFunctionExpression,
SqlAggregateFunctionExpression,
)
from metricflow.sql.sql_plan import (
SqlQueryPlan,
Expand Down Expand Up @@ -1168,7 +1168,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe
aggregation_state=AggregationState.COMPLETE,
).column_name
time_dimension_select_column = SqlSelectColumn(
expr=SqlFunctionExpression.from_aggregation_type(
expr=SqlAggregateFunctionExpression.from_aggregation_type(
node.agg_by_function,
SqlColumnReferenceExpression(
SqlColumnReference(
Expand Down
76 changes: 74 additions & 2 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MdoInstance,
DimensionInstance,
IdentifierInstance,
MetadataInstance,
MetricInstance,
MeasureInstance,
InstanceSet,
Expand All @@ -39,7 +40,7 @@
from metricflow.sql.sql_exprs import (
SqlColumnReferenceExpression,
SqlColumnReference,
SqlFunctionExpression,
SqlAggregateFunctionExpression,
)
from metricflow.sql.sql_plan import SqlSelectColumn
from metricflow.time.time_granularity import TimeGranularity
Expand Down Expand Up @@ -87,12 +88,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D
identifier_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.identifier_instances])
)
metadata_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances])
)
return SelectColumnSet(
metric_columns=metric_cols,
measure_columns=measure_cols,
dimension_columns=dimension_cols,
time_dimension_columns=time_dimension_cols,
identifier_columns=identifier_cols,
metadata_columns=metadata_cols,
)

def _make_sql_column_expression(
Expand Down Expand Up @@ -205,7 +210,7 @@ def _make_sql_column_expression_to_aggregate_measure( # noqa: D
SqlColumnReference(self._table_alias, column_name_in_table)
)

expression_to_aggregate_measure = SqlFunctionExpression.from_aggregation_type(
expression_to_aggregate_measure = SqlAggregateFunctionExpression.from_aggregation_type(
aggregation_type=aggregation_type, sql_column_expression=expression_to_get_measure
)

Expand Down Expand Up @@ -234,12 +239,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D
identifier_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.identifier_instances])
)
metadata_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances])
)
return SelectColumnSet(
metric_columns=metric_cols,
measure_columns=measure_cols,
dimension_columns=dimension_cols,
time_dimension_columns=time_dimension_cols,
identifier_columns=identifier_cols,
metadata_columns=metadata_cols,
)


Expand Down Expand Up @@ -436,6 +445,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=tuple(time_dimension_instances_with_additional_link),
identifier_instances=tuple(identifier_instances_with_additional_link),
metric_instances=(),
metadata_instances=(),
)


Expand Down Expand Up @@ -477,6 +487,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=filtered_time_dimension_instances,
identifier_instances=filtered_identifier_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
return output

Expand Down Expand Up @@ -535,6 +546,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
),
identifier_instances=tuple(x for x in instance_set.identifier_instances if self._should_pass(x.spec)),
metric_instances=tuple(x for x in instance_set.metric_instances if self._should_pass(x.spec)),
metadata_instances=tuple(x for x in instance_set.metadata_instances if self._should_pass(x.spec)),
)
return output

Expand Down Expand Up @@ -573,6 +585,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=instance_set.time_dimension_instances,
identifier_instances=instance_set.identifier_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)


Expand Down Expand Up @@ -626,6 +639,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=instance_set.time_dimension_instances,
identifier_instances=instance_set.identifier_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)


Expand All @@ -642,6 +656,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=instance_set.time_dimension_instances,
identifier_instances=instance_set.identifier_instances,
metric_instances=instance_set.metric_instances + tuple(self._metric_instances),
metadata_instances=instance_set.metadata_instances,
)


Expand All @@ -655,6 +670,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=instance_set.time_dimension_instances,
identifier_instances=instance_set.identifier_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)


Expand All @@ -668,6 +684,48 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
time_dimension_instances=instance_set.time_dimension_instances,
identifier_instances=instance_set.identifier_instances,
metric_instances=(),
metadata_instances=instance_set.metadata_instances,
)


class CreateSqlColumnReferencesForInstances(InstanceSetTransform[Tuple[SqlColumnReferenceExpression, ...]]):
"""Create select column expressions that will express all instances in the set.

It assumes that the column names of the instances are represented by the supplied column association resolver and
come from the given table alias.
"""

def __init__(
self,
table_alias: str,
column_resolver: ColumnAssociationResolver,
) -> None:
"""Initializer.

Args:
table_alias: the table alias to select columns from
column_resolver: resolver to name columns.
"""
self._table_alias = table_alias
self._column_resolver = column_resolver

def transform(self, instance_set: InstanceSet) -> Tuple[SqlColumnReferenceExpression, ...]: # noqa: D
column_names = [
col.column_name
for col in (
chain.from_iterable(
[x.column_associations(self._column_resolver) for x in instance_set.spec_set.all_specs]
)
)
]
return tuple(
SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=self._table_alias,
column_name=column_name,
),
)
for column_name in column_names
)


Expand Down Expand Up @@ -745,12 +803,26 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
)
)

output_metadata_instances = []
for input_metadata_instance in instance_set.metadata_instances:
output_metadata_instances.append(
MetadataInstance(
associated_columns=(
self._column_association_resolver.resolve_metadata_spec(
metadata_spec=input_metadata_instance.spec
),
),
spec=input_metadata_instance.spec,
)
)

return InstanceSet(
measure_instances=tuple(output_measure_instances),
dimension_instances=tuple(output_dimension_instances),
time_dimension_instances=tuple(output_time_dimension_instances),
identifier_instances=tuple(output_identifier_instances),
metric_instances=tuple(output_metric_instances),
metadata_instances=tuple(output_metadata_instances),
)


Expand Down
4 changes: 4 additions & 0 deletions metricflow/plan_conversion/select_column_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class SelectColumnSet:
dimension_columns: List[SqlSelectColumn] = field(default_factory=list)
time_dimension_columns: List[SqlSelectColumn] = field(default_factory=list)
identifier_columns: List[SqlSelectColumn] = field(default_factory=list)
metadata_columns: List[SqlSelectColumn] = field(default_factory=list)

def merge(self, other_set: SelectColumnSet) -> SelectColumnSet:
"""Combine the select columns by type."""
Expand All @@ -27,6 +28,7 @@ def merge(self, other_set: SelectColumnSet) -> SelectColumnSet:
dimension_columns=self.dimension_columns + other_set.dimension_columns,
time_dimension_columns=self.time_dimension_columns + other_set.time_dimension_columns,
identifier_columns=self.identifier_columns + other_set.identifier_columns,
metadata_columns=self.metadata_columns + other_set.metadata_columns,
)

def as_tuple(self) -> Tuple[SqlSelectColumn, ...]:
Expand All @@ -38,6 +40,7 @@ def as_tuple(self) -> Tuple[SqlSelectColumn, ...]:
+ self.dimension_columns
+ self.metric_columns
+ self.measure_columns
+ self.metadata_columns
)

def without_measure_columns(self) -> SelectColumnSet:
Expand All @@ -47,4 +50,5 @@ def without_measure_columns(self) -> SelectColumnSet:
dimension_columns=self.dimension_columns,
time_dimension_columns=self.time_dimension_columns,
identifier_columns=self.identifier_columns,
metadata_columns=self.metadata_columns,
)
4 changes: 2 additions & 2 deletions metricflow/plan_conversion/spec_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SqlComparison,
SqlColumnReferenceExpression,
SqlColumnReference,
SqlFunctionExpression,
SqlAggregateFunctionExpression,
SqlFunction,
)
from metricflow.sql.sql_plan import SqlSelectColumn
Expand Down Expand Up @@ -50,7 +50,7 @@ def _make_coalesced_expr(table_aliases: Sequence[str], column_alias: str) -> Sql
)
)
)
return SqlFunctionExpression(
return SqlAggregateFunctionExpression(
sql_function=SqlFunction.COALESCE,
sql_function_args=columns_to_coalesce,
)
Expand Down
Loading