diff --git a/metricflow/dag/id_generation.py b/metricflow/dag/id_generation.py index 35d15ee302..349771b247 100644 --- a/metricflow/dag/id_generation.py +++ b/metricflow/dag/id_generation.py @@ -34,6 +34,7 @@ SQL_EXPR_DATE_TRUNC = "dt" SQL_EXPR_RATIO_COMPUTATION = "rc" SQL_EXPR_BETWEEN_PREFIX = "betw" +SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc" SQL_PLAN_SELECT_STATEMENT_ID_PREFIX = "ss" SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc" diff --git a/metricflow/instances.py b/metricflow/instances.py index b5b467f4c2..6fa3bf79ed 100644 --- a/metricflow/instances.py +++ b/metricflow/instances.py @@ -11,6 +11,7 @@ from metricflow.dataclass_serialization import SerializableDataclass from metricflow.references import ElementReference from metricflow.specs import ( + MetadataSpec, MeasureSpec, DimensionSpec, IdentifierSpec, @@ -171,6 +172,12 @@ class MetricInstance(MdoInstance[MetricSpec], SerializableDataclass): # noqa: D defined_from: Tuple[MetricModelReference, ...] +@dataclass(frozen=True) +class MetadataInstance(MdoInstance[MetadataSpec], SerializableDataclass): # noqa: D + associated_columns: Tuple[ColumnAssociation, ...] + spec: MetadataSpec + + # Output type of transform function TransformOutputT = TypeVar("TransformOutputT") @@ -200,6 +207,7 @@ class InstanceSet(SerializableDataclass): time_dimension_instances: Tuple[TimeDimensionInstance, ...] = () identifier_instances: Tuple[IdentifierInstance, ...] = () metric_instances: Tuple[MetricInstance, ...] = () + metadata_instances: Tuple[MetadataInstance, ...] = () def transform(self, transform_function: InstanceSetTransform[TransformOutputT]) -> TransformOutputT: # noqa: D return transform_function.transform(self) @@ -215,6 +223,7 @@ def merge(instance_sets: List[InstanceSet]) -> InstanceSet: time_dimension_instances: List[TimeDimensionInstance] = [] identifier_instances: List[IdentifierInstance] = [] metric_instances: List[MetricInstance] = [] + metadata_instances: List[MetadataInstance] = [] for instance_set in instance_sets: for measure_instance in instance_set.measure_instances: @@ -232,6 +241,9 @@ def merge(instance_sets: List[InstanceSet]) -> InstanceSet: for metric_instance in instance_set.metric_instances: if metric_instance.spec not in {x.spec for x in metric_instances}: metric_instances.append(metric_instance) + for metadata_instance in instance_set.metadata_instances: + if metadata_instance.spec not in {x.spec for x in metadata_instances}: + metadata_instances.append(metadata_instance) return InstanceSet( measure_instances=tuple(measure_instances), @@ -239,6 +251,7 @@ def merge(instance_sets: List[InstanceSet]) -> InstanceSet: time_dimension_instances=tuple(time_dimension_instances), identifier_instances=tuple(identifier_instances), metric_instances=tuple(metric_instances), + metadata_instances=tuple(metadata_instances), ) @property @@ -249,4 +262,5 @@ def spec_set(self) -> InstanceSpecSet: # noqa: D time_dimension_specs=tuple(x.spec for x in self.time_dimension_instances), identifier_specs=tuple(x.spec for x in self.identifier_instances), metric_specs=tuple(x.spec for x in self.metric_instances), + metadata_specs=tuple(x.spec for x in self.metadata_instances), ) diff --git a/metricflow/plan_conversion/column_resolver.py b/metricflow/plan_conversion/column_resolver.py index 07f82a6ac8..8f56afcdff 100644 --- a/metricflow/plan_conversion/column_resolver.py +++ b/metricflow/plan_conversion/column_resolver.py @@ -9,6 +9,7 @@ ) from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName from metricflow.specs import ( + MetadataSpec, MetricSpec, MeasureSpec, DimensionSpec, @@ -109,3 +110,9 @@ def resolve_identifier_spec(self, identifier_spec: IdentifierSpec) -> Tuple[Colu single_column_correlation_key=SingleColumnCorrelationKey(), ), ) + + def resolve_metadata_spec(self, metadata_spec: MetadataSpec) -> ColumnAssociation: # noqa: D + return ColumnAssociation( + column_name=metadata_spec.element_name, + single_column_correlation_key=SingleColumnCorrelationKey(), + ) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index c931edacb0..fa06d90214 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -89,7 +89,7 @@ SqlDateTruncExpression, SqlStringLiteralExpression, SqlBetweenExpression, - SqlFunctionExpression, + SqlAggregateFunctionExpression, ) from metricflow.sql.sql_plan import ( SqlQueryPlan, @@ -1168,7 +1168,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe aggregation_state=AggregationState.COMPLETE, ).column_name time_dimension_select_column = SqlSelectColumn( - expr=SqlFunctionExpression.from_aggregation_type( + expr=SqlAggregateFunctionExpression.from_aggregation_type( node.agg_by_function, SqlColumnReferenceExpression( SqlColumnReference( diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index 3bfb359da5..c023d79f65 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -15,6 +15,7 @@ MdoInstance, DimensionInstance, IdentifierInstance, + MetadataInstance, MetricInstance, MeasureInstance, InstanceSet, @@ -39,7 +40,7 @@ from metricflow.sql.sql_exprs import ( SqlColumnReferenceExpression, SqlColumnReference, - SqlFunctionExpression, + SqlAggregateFunctionExpression, ) from metricflow.sql.sql_plan import SqlSelectColumn from metricflow.time.time_granularity import TimeGranularity @@ -87,12 +88,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D identifier_cols = list( chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.identifier_instances]) ) + metadata_cols = list( + chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances]) + ) return SelectColumnSet( metric_columns=metric_cols, measure_columns=measure_cols, dimension_columns=dimension_cols, time_dimension_columns=time_dimension_cols, identifier_columns=identifier_cols, + metadata_columns=metadata_cols, ) def _make_sql_column_expression( @@ -205,7 +210,7 @@ def _make_sql_column_expression_to_aggregate_measure( # noqa: D SqlColumnReference(self._table_alias, column_name_in_table) ) - expression_to_aggregate_measure = SqlFunctionExpression.from_aggregation_type( + expression_to_aggregate_measure = SqlAggregateFunctionExpression.from_aggregation_type( aggregation_type=aggregation_type, sql_column_expression=expression_to_get_measure ) @@ -234,12 +239,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D identifier_cols = list( chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.identifier_instances]) ) + metadata_cols = list( + chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances]) + ) return SelectColumnSet( metric_columns=metric_cols, measure_columns=measure_cols, dimension_columns=dimension_cols, time_dimension_columns=time_dimension_cols, identifier_columns=identifier_cols, + metadata_columns=metadata_cols, ) @@ -436,6 +445,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D time_dimension_instances=tuple(time_dimension_instances_with_additional_link), identifier_instances=tuple(identifier_instances_with_additional_link), metric_instances=(), + metadata_instances=(), ) @@ -477,6 +487,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D time_dimension_instances=filtered_time_dimension_instances, identifier_instances=filtered_identifier_instances, metric_instances=instance_set.metric_instances, + metadata_instances=instance_set.metadata_instances, ) return output @@ -535,6 +546,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D ), identifier_instances=tuple(x for x in instance_set.identifier_instances if self._should_pass(x.spec)), metric_instances=tuple(x for x in instance_set.metric_instances if self._should_pass(x.spec)), + metadata_instances=tuple(x for x in instance_set.metadata_instances if self._should_pass(x.spec)), ) return output @@ -573,6 +585,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D time_dimension_instances=instance_set.time_dimension_instances, identifier_instances=instance_set.identifier_instances, metric_instances=instance_set.metric_instances, + metadata_instances=instance_set.metadata_instances, ) @@ -626,6 +639,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D time_dimension_instances=instance_set.time_dimension_instances, identifier_instances=instance_set.identifier_instances, metric_instances=instance_set.metric_instances, + metadata_instances=instance_set.metadata_instances, ) @@ -642,6 +656,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D time_dimension_instances=instance_set.time_dimension_instances, identifier_instances=instance_set.identifier_instances, metric_instances=instance_set.metric_instances + tuple(self._metric_instances), + metadata_instances=instance_set.metadata_instances, ) @@ -655,6 +670,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D time_dimension_instances=instance_set.time_dimension_instances, identifier_instances=instance_set.identifier_instances, metric_instances=instance_set.metric_instances, + metadata_instances=instance_set.metadata_instances, ) @@ -668,6 +684,48 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D time_dimension_instances=instance_set.time_dimension_instances, identifier_instances=instance_set.identifier_instances, metric_instances=(), + metadata_instances=instance_set.metadata_instances, + ) + + +class CreateSqlColumnReferencesForInstances(InstanceSetTransform[Tuple[SqlColumnReferenceExpression, ...]]): + """Create select column expressions that will express all instances in the set. + + It assumes that the column names of the instances are represented by the supplied column association resolver and + come from the given table alias. + """ + + def __init__( + self, + table_alias: str, + column_resolver: ColumnAssociationResolver, + ) -> None: + """Initializer. + + Args: + table_alias: the table alias to select columns from + column_resolver: resolver to name columns. + """ + self._table_alias = table_alias + self._column_resolver = column_resolver + + def transform(self, instance_set: InstanceSet) -> Tuple[SqlColumnReferenceExpression, ...]: # noqa: D + column_names = [ + col.column_name + for col in ( + chain.from_iterable( + [x.column_associations(self._column_resolver) for x in instance_set.spec_set.all_specs] + ) + ) + ] + return tuple( + SqlColumnReferenceExpression( + SqlColumnReference( + table_alias=self._table_alias, + column_name=column_name, + ), + ) + for column_name in column_names ) @@ -745,12 +803,26 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D ) ) + output_metadata_instances = [] + for input_metadata_instance in instance_set.metadata_instances: + output_metadata_instances.append( + MetadataInstance( + associated_columns=( + self._column_association_resolver.resolve_metadata_spec( + metadata_spec=input_metadata_instance.spec + ), + ), + spec=input_metadata_instance.spec, + ) + ) + return InstanceSet( measure_instances=tuple(output_measure_instances), dimension_instances=tuple(output_dimension_instances), time_dimension_instances=tuple(output_time_dimension_instances), identifier_instances=tuple(output_identifier_instances), metric_instances=tuple(output_metric_instances), + metadata_instances=tuple(output_metadata_instances), ) diff --git a/metricflow/plan_conversion/select_column_gen.py b/metricflow/plan_conversion/select_column_gen.py index 7486e7203a..1e313ed2ef 100644 --- a/metricflow/plan_conversion/select_column_gen.py +++ b/metricflow/plan_conversion/select_column_gen.py @@ -18,6 +18,7 @@ class SelectColumnSet: dimension_columns: List[SqlSelectColumn] = field(default_factory=list) time_dimension_columns: List[SqlSelectColumn] = field(default_factory=list) identifier_columns: List[SqlSelectColumn] = field(default_factory=list) + metadata_columns: List[SqlSelectColumn] = field(default_factory=list) def merge(self, other_set: SelectColumnSet) -> SelectColumnSet: """Combine the select columns by type.""" @@ -27,6 +28,7 @@ def merge(self, other_set: SelectColumnSet) -> SelectColumnSet: dimension_columns=self.dimension_columns + other_set.dimension_columns, time_dimension_columns=self.time_dimension_columns + other_set.time_dimension_columns, identifier_columns=self.identifier_columns + other_set.identifier_columns, + metadata_columns=self.metadata_columns + other_set.metadata_columns, ) def as_tuple(self) -> Tuple[SqlSelectColumn, ...]: @@ -38,6 +40,7 @@ def as_tuple(self) -> Tuple[SqlSelectColumn, ...]: + self.dimension_columns + self.metric_columns + self.measure_columns + + self.metadata_columns ) def without_measure_columns(self) -> SelectColumnSet: @@ -47,4 +50,5 @@ def without_measure_columns(self) -> SelectColumnSet: dimension_columns=self.dimension_columns, time_dimension_columns=self.time_dimension_columns, identifier_columns=self.identifier_columns, + metadata_columns=self.metadata_columns, ) diff --git a/metricflow/plan_conversion/spec_transforms.py b/metricflow/plan_conversion/spec_transforms.py index 0830be934a..d3175ab8c6 100644 --- a/metricflow/plan_conversion/spec_transforms.py +++ b/metricflow/plan_conversion/spec_transforms.py @@ -14,7 +14,7 @@ SqlComparison, SqlColumnReferenceExpression, SqlColumnReference, - SqlFunctionExpression, + SqlAggregateFunctionExpression, SqlFunction, ) from metricflow.sql.sql_plan import SqlSelectColumn @@ -50,7 +50,7 @@ def _make_coalesced_expr(table_aliases: Sequence[str], column_alias: str) -> Sql ) ) ) - return SqlFunctionExpression( + return SqlAggregateFunctionExpression( sql_function=SqlFunction.COALESCE, sql_function_args=columns_to_coalesce, ) diff --git a/metricflow/specs.py b/metricflow/specs.py index cfeae9dcf5..8701cb244f 100644 --- a/metricflow/specs.py +++ b/metricflow/specs.py @@ -71,6 +71,10 @@ def resolve_time_dimension_spec( # noqa: D def resolve_identifier_spec(self, identifier_spec: IdentifierSpec) -> Tuple[ColumnAssociation, ...]: # noqa: D pass + @abstractmethod + def resolve_metadata_spec(self, metadata_spec: MetadataSpec) -> ColumnAssociation: # noqa: D + pass + @dataclass(frozen=True) class InstanceSpec(SerializableDataclass): @@ -107,6 +111,24 @@ def qualified_name(self) -> str: raise NotImplementedError() +@dataclass(frozen=True) +class MetadataSpec(InstanceSpec): + """A specification for a specification that is built during the dataflow plan and not defined in config.""" + + element_name: str + + def column_associations(self, resolver: ColumnAssociationResolver) -> Tuple[ColumnAssociation, ...]: # noqa: D + return (resolver.resolve_metadata_spec(self),) + + @property + def qualified_name(self) -> str: # noqa: D + return self.element_name + + @staticmethod + def from_name(name: str) -> MetadataSpec: # noqa: D + return MetadataSpec(element_name=name) + + @dataclass(frozen=True) class LinkableInstanceSpec(InstanceSpec): """Generally a dimension or identifier that may be specified using identifier links. @@ -576,17 +598,21 @@ class InstanceSpecSet(SerializableDataclass): dimension_specs: Tuple[DimensionSpec, ...] = () identifier_specs: Tuple[IdentifierSpec, ...] = () time_dimension_specs: Tuple[TimeDimensionSpec, ...] = () + metadata_specs: Tuple[MetadataSpec, ...] = () def merge(self, others: Sequence[InstanceSpecSet]) -> InstanceSpecSet: """Merge all sets into one set, without de-duplication.""" return InstanceSpecSet( metric_specs=self.metric_specs + tuple(itertools.chain.from_iterable([x.metric_specs for x in others])), + measure_specs=self.measure_specs + tuple(itertools.chain.from_iterable([x.measure_specs for x in others])), dimension_specs=self.dimension_specs + tuple(itertools.chain.from_iterable([x.dimension_specs for x in others])), identifier_specs=self.identifier_specs + tuple(itertools.chain.from_iterable([x.identifier_specs for x in others])), time_dimension_specs=self.time_dimension_specs + tuple(itertools.chain.from_iterable([x.time_dimension_specs for x in others])), + metadata_specs=self.metadata_specs + + tuple(itertools.chain.from_iterable([x.metadata_specs for x in others])), ) @property @@ -603,6 +629,7 @@ def all_specs(self) -> Sequence[InstanceSpec]: # noqa: D self.time_dimension_specs, self.identifier_specs, self.metric_specs, + self.metadata_specs, ) ) diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index 0d1da4a819..55e550b15b 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -17,7 +17,7 @@ SqlComparisonExpression, SqlExpressionNode, SqlFunction, - SqlFunctionExpression, + SqlAggregateFunctionExpression, SqlNullExpression, SqlLogicalExpression, SqlStringLiteralExpression, @@ -28,6 +28,7 @@ SqlRatioComputationExpression, SqlColumnAliasReferenceExpression, SqlBetweenExpression, + SqlWindowFunctionExpression, ) from metricflow.time.time_granularity import TimeGranularity @@ -107,7 +108,7 @@ def visit_comparison_expr(self, node: SqlComparisonExpression) -> SqlExpressionR execution_parameters=combined_params, ) - def visit_function_expr(self, node: SqlFunctionExpression) -> SqlExpressionRenderResult: # noqa: D + def visit_function_expr(self, node: SqlAggregateFunctionExpression) -> SqlExpressionRenderResult: # noqa: D """Render a function call like CONCAT(a, b)""" args_rendered = [self.render_sql_expr(x) for x in node.sql_function_args] combined_params = SqlBindParameters() @@ -271,3 +272,45 @@ def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderR sql=f"{rendered_column_arg.sql} BETWEEN {rendered_start_expr.sql} AND {rendered_end_expr.sql}", execution_parameters=execution_parameters, ) + + def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> SqlExpressionRenderResult: # noqa: D + sql_function_args_rendered = [self.render_sql_expr(x) for x in node.sql_function_args] + partition_by_args_rendered = [self.render_sql_expr(x) for x in node.partition_by_args] + order_by_args_rendered = {self.render_sql_expr(x.expr): x for x in node.order_by_args} + + combined_params = SqlBindParameters() + args_rendered = [] + if sql_function_args_rendered: + args_rendered.extend(sql_function_args_rendered) + if partition_by_args_rendered: + args_rendered.extend(partition_by_args_rendered) + if order_by_args_rendered: + args_rendered.extend(list(order_by_args_rendered.keys())) + for arg_rendered in args_rendered: + combined_params.update(arg_rendered.execution_parameters) + + sql_function_args_string = ", ".join([x.sql for x in sql_function_args_rendered]) + partition_by_args_string = ( + ("PARTITION BY " + ", ".join([x.sql for x in partition_by_args_rendered])) + if partition_by_args_rendered + else "" + ) + order_by_args_string = ( + ( + "ORDER BY " + + ", ".join( + [ + rendered_result.sql + (f" {x.suffix}" if x.suffix else "") + for rendered_result, x in order_by_args_rendered.items() + ] + ) + ) + if order_by_args_rendered + else "" + ) + + window_string = " ".join(filter(bool, [partition_by_args_string, order_by_args_string])) + return SqlExpressionRenderResult( + sql=f"{node.sql_function.value}({sql_function_args_string}) OVER ({window_string})", + execution_parameters=combined_params, + ) diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 58ec4bbefa..df88d60f15 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -22,6 +22,7 @@ SQL_EXPR_DATE_TRUNC, SQL_EXPR_RATIO_COMPUTATION, SQL_EXPR_BETWEEN_PREFIX, + SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX, ) from metricflow.sql.sql_bind_parameters import SqlBindParameters from metricflow.visitor import Visitable, VisitorOutputT @@ -177,7 +178,7 @@ def visit_comparison_expr(self, node: SqlComparisonExpression) -> VisitorOutputT pass @abstractmethod - def visit_function_expr(self, node: SqlFunctionExpression) -> VisitorOutputT: # noqa: D + def visit_function_expr(self, node: SqlAggregateFunctionExpression) -> VisitorOutputT: # noqa: D pass @abstractmethod @@ -216,6 +217,10 @@ def visit_ratio_computation_expr(self, node: SqlRatioComputationExpression) -> V def visit_between_expr(self, node: SqlBetweenExpression) -> VisitorOutputT: # noqa: D pass + @abstractmethod + def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> VisitorOutputT: # noqa: D + pass + class SqlStringExpression(SqlExpressionNode): """An SQL expression in a string format, so it lacks information about the structure. @@ -688,14 +693,20 @@ def from_aggregation_type(aggregation_type: AggregationType) -> SqlFunction: class SqlFunctionExpression(SqlExpressionNode): - """A function expression like SUM(1).""" + """Denotes a function expression in SQL.""" + + pass + + +class SqlAggregateFunctionExpression(SqlFunctionExpression): + """An aggregate function expression like SUM(1).""" @staticmethod def from_aggregation_type( aggregation_type: AggregationType, sql_column_expression: SqlColumnReferenceExpression - ) -> SqlFunctionExpression: + ) -> SqlAggregateFunctionExpression: """Given the aggregation type, return an SQL function expression that does that aggregation on the given col.""" - return SqlFunctionExpression( + return SqlAggregateFunctionExpression( sql_function=SqlFunction.from_aggregation_type(aggregation_type=aggregation_type), sql_function_args=[sql_column_expression], ) @@ -751,7 +762,7 @@ def rewrite( # noqa: D column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlFunctionExpression( + return SqlAggregateFunctionExpression( sql_function=self.sql_function, sql_function_args=[ x.rewrite(column_replacements, should_render_table_alias) for x in self.sql_function_args @@ -765,11 +776,156 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D ) def matches(self, other: SqlExpressionNode) -> bool: # noqa: D - if not isinstance(other, SqlFunctionExpression): + if not isinstance(other, SqlAggregateFunctionExpression): return False return self.sql_function == other.sql_function and self._parents_match(other) +class SqlWindowFunction(Enum): + """Names of known SQL window functions like SUM(), RANK(), ROW_NUMBER() + + Values are the SQL string to be used in rendering. + """ + + FIRST_VALUE = "first_value" + ROW_NUMBER = "row_number" + + +@dataclass(frozen=True) +class SqlWindowOrderByArgument: + """In window functions, the ORDER BY clause can accept an expr, ordering, null ranking.""" + + expr: SqlExpressionNode + descending: Optional[bool] = None + nulls_last: Optional[bool] = None + + @property + def suffix(self) -> str: + """Helper to build suffix to append to {expr}{suffix}""" + result = [] + if self.descending is not None: + result.append("DESC" if self.descending else "ASC") + if self.nulls_last is not None: + result.append("NULLS LAST" if self.nulls_last else "NULLS FIRST") + return " ".join(result) + + +class SqlWindowFunctionExpression(SqlFunctionExpression): + """A window function expression like SUM(foo) OVER bar""" + + def __init__( + self, + sql_function: SqlWindowFunction, + sql_function_args: Optional[List[SqlExpressionNode]] = None, + partition_by_args: Optional[List[SqlExpressionNode]] = None, + order_by_args: Optional[List[SqlWindowOrderByArgument]] = None, + ) -> None: + """Constructor. + + Args: + sql_function: The function that this represents. + sql_function_args: The arguments that should go into the function. e.g. for "CONCAT(a, b)", the arg + expressions should be "a" and "b". + partition_by_args: The arguments to partition the rows. e.g. PARTITION BY expr1, expr2, + the args are "expr1", "expr2". + order_by_args: The expr to order the partitions by. + """ + self._sql_function = sql_function + self._sql_function_args = sql_function_args + self._partition_by_args = partition_by_args + self._order_by_args = order_by_args + parent_nodes = [] + if sql_function_args: + parent_nodes.extend(sql_function_args) + if partition_by_args: + parent_nodes.extend(partition_by_args) + if order_by_args: + parent_nodes.extend([x.expr for x in order_by_args]) + super().__init__(node_id=self.create_unique_id(), parent_nodes=parent_nodes) + + @classmethod + def id_prefix(cls) -> str: # noqa: D + return SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX + + @property + def requires_parenthesis(self) -> bool: # noqa: D + return False + + def accept(self, visitor: SqlExpressionNodeVisitor) -> VisitorOutputT: # noqa: D + return visitor.visit_window_function_expr(self) + + @property + def description(self) -> str: # noqa: D + return f"{self._sql_function.value} Window Function Expression" + + @property + def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D + return ( + super().displayed_properties + + [DisplayedProperty("function", self.sql_function)] + + [DisplayedProperty("argument", x) for x in self.sql_function_args] + + [DisplayedProperty("partition_by_argument", x) for x in self.partition_by_args] + + [DisplayedProperty("order_by_argument", x) for x in self.order_by_args] + ) + + @property + def sql_function(self) -> SqlWindowFunction: # noqa: D + return self._sql_function + + @property + def sql_function_args(self) -> List[SqlExpressionNode]: # noqa: D + return self._sql_function_args or [] + + @property + def partition_by_args(self) -> List[SqlExpressionNode]: # noqa: D + return self._partition_by_args or [] + + @property + def order_by_args(self) -> List[SqlWindowOrderByArgument]: # noqa: D + return self._order_by_args or [] + + def __repr__(self) -> str: # noqa: D + return f"{self.__class__.__name__}(node_id={self.node_id}, sql_function={self.sql_function.name})" + + def rewrite( # noqa: D + self, + column_replacements: Optional[SqlColumnReplacements] = None, + should_render_table_alias: Optional[bool] = None, + ) -> SqlExpressionNode: + return SqlWindowFunctionExpression( + sql_function=self.sql_function, + sql_function_args=[ + x.rewrite(column_replacements, should_render_table_alias) for x in self.sql_function_args + ], + partition_by_args=[ + x.rewrite(column_replacements, should_render_table_alias) for x in self.partition_by_args + ], + order_by_args=[ + SqlWindowOrderByArgument( + expr=x.expr.rewrite(column_replacements, should_render_table_alias), + descending=x.descending, + nulls_last=x.nulls_last, + ) + for x in self.order_by_args + ], + ) + + @property + def lineage(self) -> SqlExpressionTreeLineage: # noqa: D + return SqlExpressionTreeLineage.combine( + tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),) + ) + + def matches(self, other: SqlExpressionNode) -> bool: # noqa: D + if not isinstance(other, SqlWindowFunctionExpression): + return False + return ( + self.sql_function == other.sql_function + and self.order_by_args == other.order_by_args + and self._parents_match(other) + ) + + class SqlNullExpression(SqlExpressionNode): """Represents NULL.""" diff --git a/metricflow/test/plan_conversion/instance_converters/test_create_select_columns_with_measures_aggregated.py b/metricflow/test/plan_conversion/instance_converters/test_create_select_columns_with_measures_aggregated.py index 9bae6afd41..eb1b0baca8 100644 --- a/metricflow/test/plan_conversion/instance_converters/test_create_select_columns_with_measures_aggregated.py +++ b/metricflow/test/plan_conversion/instance_converters/test_create_select_columns_with_measures_aggregated.py @@ -9,7 +9,7 @@ from metricflow.specs import MeasureSpec, MetricInputMeasureSpec from metricflow.sql.sql_exprs import ( SqlFunction, - SqlFunctionExpression, + SqlAggregateFunctionExpression, ) from metricflow.test.fixtures.model_fixtures import ConsistentIdObjectRepository @@ -46,7 +46,7 @@ def test_sum_aggregation( assert len(select_column_set.measure_columns) == 1 measure_column = select_column_set.measure_columns[0] expr = measure_column.expr - assert isinstance(expr, SqlFunctionExpression) + assert isinstance(expr, SqlAggregateFunctionExpression) assert expr.sql_function == SqlFunction.SUM @@ -68,7 +68,7 @@ def test_sum_boolean_aggregation( assert len(select_column_set.measure_columns) == 1 measure_column = select_column_set.measure_columns[0] expr = measure_column.expr - assert isinstance(expr, SqlFunctionExpression) + assert isinstance(expr, SqlAggregateFunctionExpression) # The SUM_BOOLEAN aggregation type is transformed to SUM at model parsing time assert expr.sql_function == SqlFunction.SUM @@ -91,7 +91,7 @@ def test_avg_aggregation( assert len(select_column_set.measure_columns) == 1 measure_column = select_column_set.measure_columns[0] expr = measure_column.expr - assert isinstance(expr, SqlFunctionExpression) + assert isinstance(expr, SqlAggregateFunctionExpression) assert expr.sql_function == SqlFunction.AVERAGE @@ -113,7 +113,7 @@ def test_count_distinct_aggregation( assert len(select_column_set.measure_columns) == 1 measure_column = select_column_set.measure_columns[0] expr = measure_column.expr - assert isinstance(expr, SqlFunctionExpression) + assert isinstance(expr, SqlAggregateFunctionExpression) assert expr.sql_function == SqlFunction.COUNT_DISTINCT @@ -135,7 +135,7 @@ def test_max_aggregation( assert len(select_column_set.measure_columns) == 1 measure_column = select_column_set.measure_columns[0] expr = measure_column.expr - assert isinstance(expr, SqlFunctionExpression) + assert isinstance(expr, SqlAggregateFunctionExpression) assert expr.sql_function == SqlFunction.MAX @@ -157,7 +157,7 @@ def test_min_aggregation( assert len(select_column_set.measure_columns) == 1 measure_column = select_column_set.measure_columns[0] expr = measure_column.expr - assert isinstance(expr, SqlFunctionExpression) + assert isinstance(expr, SqlAggregateFunctionExpression) assert expr.sql_function == SqlFunction.MIN @@ -179,6 +179,6 @@ def test_aliased_sum( assert len(select_column_set.measure_columns) == 1 measure_column = select_column_set.measure_columns[0] expr = measure_column.expr - assert isinstance(expr, SqlFunctionExpression) + assert isinstance(expr, SqlAggregateFunctionExpression) assert expr.sql_function == SqlFunction.SUM assert measure_column.column_alias == "bvalue" diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier__plan0.xml index 402b988cf0..86cf3e8089 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier__plan0.xml @@ -27,10 +27,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier_with_join__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier_with_join__plan0.xml index ff20347ecc..9c15ccca15 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier_with_join__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier_with_join__plan0.xml @@ -35,10 +35,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier_with_order_by__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier_with_order_by__plan0.xml index 7956b51d64..13ec116169 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier_with_order_by__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_composite_identifier_with_order_by__plan0.xml @@ -52,10 +52,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node__plan0.xml index 7ed9c2dc81..d0dc520287 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node__plan0.xml @@ -27,10 +27,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_multiple_data_sources__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_multiple_data_sources__plan0.xml index 7cc0d5ab60..9d54117c57 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_multiple_data_sources__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_multiple_data_sources__plan0.xml @@ -77,10 +77,10 @@ - - - - + + + + @@ -896,10 +896,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_single_data_source__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_single_data_source__plan0.xml index bc560731e5..07db9cb500 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_single_data_source__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_single_data_source__plan0.xml @@ -27,14 +27,14 @@ - - - - - - - - + + + + + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_simple_expr__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_simple_expr__plan0.xml index 48b5ca8e0d..3d8e37ac19 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_simple_expr__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_simple_expr__plan0.xml @@ -27,10 +27,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric__plan0.xml index f56abed0fc..61ca539dfc 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric__plan0.xml @@ -19,10 +19,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_grain_to_date__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_grain_to_date__plan0.xml index e38d77eacf..4f0a3a64ca 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_grain_to_date__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_grain_to_date__plan0.xml @@ -19,10 +19,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_ds__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_ds__plan0.xml index 5c8337e0d3..524b3e9746 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_ds__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_ds__plan0.xml @@ -11,10 +11,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_window__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_window__plan0.xml index 09d1a4b721..8f23dad803 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_window__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_window__plan0.xml @@ -19,10 +19,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_window_with_time_constraint__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_window_with_time_constraint__plan0.xml index fb2fbeb53f..3a0eabc5d9 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_window_with_time_constraint__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_no_window_with_time_constraint__plan0.xml @@ -19,10 +19,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_with_time_constraint__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_with_time_constraint__plan0.xml index ada67eb503..c82fcc4aaa 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_with_time_constraint__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_cumulative_metric_with_time_constraint__plan0.xml @@ -19,10 +19,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_derived_metric__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_derived_metric__plan0.xml index 188df2fb57..4bea59c302 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_derived_metric__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_derived_metric__plan0.xml @@ -15,10 +15,10 @@ - - - - + + + + @@ -55,10 +55,10 @@ - - - - + + + + @@ -519,10 +519,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_distinct_values__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_distinct_values__plan0.xml index 8386b2eefd..39fb42a6bb 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_distinct_values__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_distinct_values__plan0.xml @@ -43,10 +43,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_with_where_constraint_on_join_dim__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_with_where_constraint_on_join_dim__plan0.xml index c0601b8d84..ee211f46c6 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_with_where_constraint_on_join_dim__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_with_where_constraint_on_join_dim__plan0.xml @@ -19,10 +19,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_scd_dimension__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_scd_dimension__plan0.xml index c1050efab2..c61673873d 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_scd_dimension__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_scd_dimension__plan0.xml @@ -19,10 +19,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_limit_rows__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_limit_rows__plan0.xml index c9d0e501fa..ca80767100 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_limit_rows__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_limit_rows__plan0.xml @@ -32,10 +32,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_local_dimension_using_local_identifier__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_local_dimension_using_local_identifier__plan0.xml index f20ff9c816..feeb1adfcb 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_local_dimension_using_local_identifier__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_local_dimension_using_local_identifier__plan0.xml @@ -19,10 +19,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_aggregation_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_aggregation_node__plan0.xml index 113a53fc1a..eed275d38c 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_aggregation_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_aggregation_node__plan0.xml @@ -2,22 +2,22 @@ - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint__plan0.xml index 125a88edb4..33b2a07792 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint__plan0.xml @@ -69,14 +69,14 @@ - - - - - - - - + + + + + + + + @@ -939,10 +939,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint_with_reused_measure__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint_with_reused_measure__plan0.xml index 530b1b3863..55cc2c4874 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint_with_reused_measure__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint_with_reused_measure__plan0.xml @@ -61,10 +61,10 @@ - - - - + + + + @@ -549,10 +549,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint_with_single_expr_and_alias__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint_with_single_expr_and_alias__plan0.xml index 6b55ed0962..23783414ab 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint_with_single_expr_and_alias__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_constraint_with_single_expr_and_alias__plan0.xml @@ -19,10 +19,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_metric_with_measures_from_multiple_sources_no_dimensions__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_metric_with_measures_from_multiple_sources_no_dimensions__plan0.xml index 70510c4e3c..7ab9db8261 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_metric_with_measures_from_multiple_sources_no_dimensions__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_metric_with_measures_from_multiple_sources_no_dimensions__plan0.xml @@ -45,10 +45,10 @@ - - - - + + + + @@ -483,10 +483,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_hop_through_scd_dimension__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_hop_through_scd_dimension__plan0.xml index f80313e649..98573544b7 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_hop_through_scd_dimension__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_hop_through_scd_dimension__plan0.xml @@ -27,10 +27,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_hop_to_scd_dimension__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_hop_to_scd_dimension__plan0.xml index 509456d435..8a44947f5f 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_hop_to_scd_dimension__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_hop_to_scd_dimension__plan0.xml @@ -27,10 +27,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multihop_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multihop_node__plan0.xml index 757bcac6fd..011532621e 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multihop_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multihop_node__plan0.xml @@ -19,10 +19,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multiple_metrics_no_dimensions__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multiple_metrics_no_dimensions__plan0.xml index 25b6863929..f4399656a0 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multiple_metrics_no_dimensions__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multiple_metrics_no_dimensions__plan0.xml @@ -30,10 +30,10 @@ - - - - + + + + @@ -693,10 +693,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_nested_derived_metric__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_nested_derived_metric__plan0.xml index fd6fea4b7b..131cb46293 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_nested_derived_metric__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_nested_derived_metric__plan0.xml @@ -15,10 +15,10 @@ - - - - + + + + @@ -61,10 +61,10 @@ - - - - + + + + @@ -101,10 +101,10 @@ - - - - + + + + @@ -565,10 +565,10 @@ - - - - + + + + @@ -1031,10 +1031,10 @@ - - - - + + + + @@ -1495,10 +1495,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_order_by_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_order_by_node__plan0.xml index 20008d0b7a..69bcc0baa3 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_order_by_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_order_by_node__plan0.xml @@ -52,10 +52,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_partitioned_join__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_partitioned_join__plan0.xml index 0fce36ff56..0fdfc2d73a 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_partitioned_join__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_partitioned_join__plan0.xml @@ -19,10 +19,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node__plan0.xml index f92ddd8159..57765accb3 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node__plan0.xml @@ -103,10 +103,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_grouping__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_grouping__plan0.xml index 9b66d6a477..1ef9811e05 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_grouping__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_grouping__plan0.xml @@ -107,10 +107,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_queried_group_by__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_queried_group_by__plan0.xml index 30fadde42a..95a0705568 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_queried_group_by__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_queried_group_by__plan0.xml @@ -103,10 +103,10 @@ - - - - + + + + diff --git a/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_simple_query_with_metric_time_dimension__plan0.xml b/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_simple_query_with_metric_time_dimension__plan0.xml index 4124c567aa..a09a871f53 100644 --- a/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_simple_query_with_metric_time_dimension__plan0.xml +++ b/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_simple_query_with_metric_time_dimension__plan0.xml @@ -2,10 +2,10 @@ - - - - + + + + @@ -42,10 +42,10 @@ - - - - + + + + @@ -506,10 +506,10 @@ - - - - + + + + diff --git a/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py b/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py index 9c88d3702d..4d0eb87a08 100644 --- a/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py +++ b/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py @@ -9,7 +9,7 @@ SqlComparisonExpression, SqlComparison, SqlStringLiteralExpression, - SqlFunctionExpression, + SqlAggregateFunctionExpression, SqlFunction, SqlStringExpression, ) @@ -56,7 +56,7 @@ def base_select_statement() -> SqlSelectStatementNode: description="src3", select_columns=( SqlSelectColumn( - expr=SqlFunctionExpression( + expr=SqlAggregateFunctionExpression( sql_function=SqlFunction.SUM, sql_function_args=[ SqlColumnReferenceExpression( @@ -220,7 +220,7 @@ def join_select_statement() -> SqlSelectStatementNode: description="query", select_columns=( SqlSelectColumn( - expr=SqlFunctionExpression( + expr=SqlAggregateFunctionExpression( sql_function=SqlFunction.SUM, sql_function_args=[ SqlColumnReferenceExpression( @@ -391,7 +391,7 @@ def colliding_select_statement() -> SqlSelectStatementNode: description="query", select_columns=( SqlSelectColumn( - expr=SqlFunctionExpression( + expr=SqlAggregateFunctionExpression( sql_function=SqlFunction.SUM, sql_function_args=[ SqlColumnReferenceExpression( @@ -572,7 +572,7 @@ def reduce_all_join_select_statement() -> SqlSelectStatementNode: description="query", select_columns=( SqlSelectColumn( - expr=SqlFunctionExpression( + expr=SqlAggregateFunctionExpression( sql_function=SqlFunction.SUM, sql_function_args=[ SqlColumnReferenceExpression( diff --git a/metricflow/test/sql/test_sql_expr_render.py b/metricflow/test/sql/test_sql_expr_render.py index 5c67335c0d..a2f7a6cc52 100644 --- a/metricflow/test/sql/test_sql_expr_render.py +++ b/metricflow/test/sql/test_sql_expr_render.py @@ -10,7 +10,7 @@ SqlColumnReference, SqlComparisonExpression, SqlComparison, - SqlFunctionExpression, + SqlAggregateFunctionExpression, SqlFunction, SqlNullExpression, SqlLogicalExpression, @@ -22,6 +22,9 @@ SqlColumnReplacements, SqlCastToTimestampExpression, SqlBetweenExpression, + SqlWindowFunctionExpression, + SqlWindowFunction, + SqlWindowOrderByArgument, ) from metricflow.time.time_granularity import TimeGranularity @@ -72,7 +75,7 @@ def test_require_parenthesis(default_expr_renderer: DefaultSqlExpressionRenderer def test_function_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D actual = default_expr_renderer.render_sql_expr( - SqlFunctionExpression( + SqlAggregateFunctionExpression( sql_function=SqlFunction.SUM, sql_function_args=[ SqlColumnReferenceExpression(SqlColumnReference("my_table", "a")), @@ -86,7 +89,7 @@ def test_function_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> N def test_distinct_agg_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: """Distinct aggregation functions require the insertion of the DISTINCT keyword in the rendered function expr""" actual = default_expr_renderer.render_sql_expr( - SqlFunctionExpression( + SqlAggregateFunctionExpression( sql_function=SqlFunction.COUNT_DISTINCT, sql_function_args=[ SqlColumnReferenceExpression(SqlColumnReference("my_table", "a")), @@ -100,11 +103,11 @@ def test_distinct_agg_expr(default_expr_renderer: DefaultSqlExpressionRenderer) def test_nested_function_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D actual = default_expr_renderer.render_sql_expr( - SqlFunctionExpression( + SqlAggregateFunctionExpression( sql_function=SqlFunction.CONCAT, sql_function_args=[ SqlColumnReferenceExpression(SqlColumnReference("my_table", "a")), - SqlFunctionExpression( + SqlAggregateFunctionExpression( sql_function=SqlFunction.CONCAT, sql_function_args=[ SqlColumnReferenceExpression(SqlColumnReference("my_table", "b")), @@ -191,7 +194,7 @@ def test_date_trunc_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> def test_ratio_computation_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D actual = default_expr_renderer.render_sql_expr( SqlRatioComputationExpression( - numerator=SqlFunctionExpression( + numerator=SqlAggregateFunctionExpression( SqlFunction.SUM, sql_function_args=[SqlStringExpression(sql_expr="1", requires_parenthesis=False)] ), denominator=SqlColumnReferenceExpression(SqlColumnReference(column_name="divide_by_me", table_alias="a")), @@ -236,3 +239,32 @@ def test_between_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> No ) ).sql assert actual == "a.col0 BETWEEN CAST('2020-01-01' AS TIMESTAMP) AND CAST('2020-01-10' AS TIMESTAMP)" + + +def test_window_function_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D + actual = default_expr_renderer.render_sql_expr( + SqlWindowFunctionExpression( + sql_function=SqlWindowFunction.FIRST_VALUE, + sql_function_args=[SqlColumnReferenceExpression(SqlColumnReference("a", "col0"))], + partition_by_args=[ + SqlColumnReferenceExpression(SqlColumnReference("b", "col0")), + SqlColumnReferenceExpression(SqlColumnReference("b", "col1")), + ], + order_by_args=[ + SqlWindowOrderByArgument( + expr=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), + descending=True, + nulls_last=False, + ), + SqlWindowOrderByArgument( + expr=SqlColumnReferenceExpression(SqlColumnReference("b", "col0")), + descending=False, + nulls_last=True, + ), + ], + ) + ).sql + assert ( + actual + == "first_value(a.col0) OVER (PARTITION BY b.col0, b.col1 ORDER BY a.col0 DESC NULLS FIRST, b.col0 ASC NULLS LAST)" + ) diff --git a/metricflow/test/sql/test_sql_plan_render.py b/metricflow/test/sql/test_sql_plan_render.py index db3dce05ab..748e2e1576 100644 --- a/metricflow/test/sql/test_sql_plan_render.py +++ b/metricflow/test/sql/test_sql_plan_render.py @@ -12,7 +12,7 @@ SqlColumnReference, SqlComparisonExpression, SqlComparison, - SqlFunctionExpression, + SqlAggregateFunctionExpression, SqlFunction, ) from metricflow.sql.sql_plan import ( @@ -44,7 +44,9 @@ def test_component_rendering( # Test single SELECT column select_columns = [ SqlSelectColumn( - expr=SqlFunctionExpression(sql_function=SqlFunction.SUM, sql_function_args=[SqlStringExpression("1")]), + expr=SqlAggregateFunctionExpression( + sql_function=SqlFunction.SUM, sql_function_args=[SqlStringExpression("1")] + ), column_alias="bookings", ), ]