diff --git a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py index 8c2a6d1b4..61fdcc5f7 100644 --- a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py +++ b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py @@ -75,6 +75,9 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper): SQL_EXPR_BETWEEN_PREFIX = "betw" SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc" SQL_EXPR_GENERATE_UUID_PREFIX = "uuid" + SQL_EXPR_CASE_PREFIX = "case" + SQL_EXPR_ARITHMETIC_PREFIX = "arit" + SQL_EXPR_INTEGER_PREFIX = "int" SQL_PLAN_SELECT_STATEMENT_ID_PREFIX = "ss" SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc" diff --git a/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py b/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py index ec7866f00..926ea5a39 100644 --- a/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py +++ b/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py @@ -14,12 +14,13 @@ from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from typing_extensions import override + from metricflow_semantics.collection_helpers.merger import Mergeable from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.visitor import Visitable, VisitorOutputT -from typing_extensions import override @dataclass(frozen=True, eq=False) @@ -237,6 +238,18 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> Visit def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOutputT: # noqa: D102 pass + @abstractmethod + def visit_case_expr(self, node: SqlCaseExpression) -> VisitorOutputT: # noqa: D102 + pass + + @abstractmethod + def visit_arithmetic_expr(self, node: SqlArithmeticExpression) -> VisitorOutputT: # noqa: D102 + pass + + @abstractmethod + def visit_integer_expr(self, node: SqlIntegerExpression) -> VisitorOutputT: # noqa: D102 + pass + @dataclass(frozen=True, eq=False) class SqlStringExpression(SqlExpressionNode): @@ -375,6 +388,59 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.literal_value == other.literal_value +@dataclass(frozen=True, eq=False) +class SqlIntegerExpression(SqlExpressionNode): + """An integer like 1.""" + + integer_value: int + + @staticmethod + def create(integer_value: int) -> SqlIntegerExpression: # noqa: D102 + return SqlIntegerExpression(parent_nodes=(), integer_value=integer_value) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.SQL_EXPR_INTEGER_PREFIX + + def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_integer_expr(self) + + @property + def description(self) -> str: # noqa: D102 + return f"Integer: {self.integer_value}" + + @property + def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + return tuple(super().displayed_properties) + (DisplayedProperty("value", self.integer_value),) + + @property + def requires_parenthesis(self) -> bool: # noqa: D102 + return False + + @property + def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102 + return SqlBindParameterSet() + + def __repr__(self) -> str: # noqa: D105 + return f"{self.__class__.__name__}(node_id={self.node_id}, integer_value={self.integer_value})" + + def rewrite( # noqa: D102 + self, + column_replacements: Optional[SqlColumnReplacements] = None, + should_render_table_alias: Optional[bool] = None, + ) -> SqlExpressionNode: + return self + + @property + def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 + return SqlExpressionTreeLineage(other_exprs=(self,)) + + def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 + if not isinstance(other, SqlIntegerExpression): + return False + return self.integer_value == other.integer_value + + @dataclass(frozen=True) class SqlColumnReference: """Used with string expressions to specify what columns are referred to in the string expression.""" @@ -950,11 +1016,18 @@ class SqlWindowFunction(Enum): FIRST_VALUE = "FIRST_VALUE" LAST_VALUE = "LAST_VALUE" AVERAGE = "AVG" + ROW_NUMBER = "ROW_NUMBER" + LAG = "LAG" @property def requires_ordering(self) -> bool: """Asserts whether or not ordering the window function will have an impact on the resulting value.""" - if self is SqlWindowFunction.FIRST_VALUE or self is SqlWindowFunction.LAST_VALUE: + if ( + self is SqlWindowFunction.FIRST_VALUE + or self is SqlWindowFunction.LAST_VALUE + or self is SqlWindowFunction.ROW_NUMBER + or self is SqlWindowFunction.LAG + ): return True elif self is SqlWindowFunction.AVERAGE: return False @@ -1106,7 +1179,8 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return ( self.sql_function == other.sql_function and self.order_by_args == other.order_by_args - and self._parents_match(other) + and self.partition_by_args == other.partition_by_args + and self.sql_function_args == other.sql_function_args ) @@ -1367,7 +1441,7 @@ def rewrite( # noqa: D102 ) -> SqlExpressionNode: return SqlAddTimeExpression.create( arg=self.arg.rewrite(column_replacements, should_render_table_alias), - count_expr=self.count_expr, + count_expr=self.count_expr.rewrite(column_replacements, should_render_table_alias), granularity=self.granularity, ) @@ -1719,3 +1793,158 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False + + +@dataclass(frozen=True, eq=False) +class SqlCaseExpression(SqlExpressionNode): + """Renders a CASE WHEN expression.""" + + when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode] + else_expr: Optional[SqlExpressionNode] + + @staticmethod + def create( # noqa: D102 + when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode], else_expr: Optional[SqlExpressionNode] = None + ) -> SqlCaseExpression: + parent_nodes: Tuple[SqlExpressionNode, ...] = () + for when, then in when_to_then_exprs.items(): + parent_nodes += (when,) + parent_nodes += (then,) + + if else_expr: + parent_nodes += (else_expr,) + + return SqlCaseExpression(parent_nodes=parent_nodes, when_to_then_exprs=when_to_then_exprs, else_expr=else_expr) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.SQL_EXPR_CASE_PREFIX + + def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_case_expr(self) + + @property + def description(self) -> str: # noqa: D102 + return "Case expression" + + @property + def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + return super().displayed_properties + + @property + def requires_parenthesis(self) -> bool: # noqa: D102 + return False + + @property + def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102 + return SqlBindParameterSet() + + def __repr__(self) -> str: # noqa: D105 + return f"{self.__class__.__name__}(node_id={self.node_id})" + + def rewrite( # noqa: D102 + self, + column_replacements: Optional[SqlColumnReplacements] = None, + should_render_table_alias: Optional[bool] = None, + ) -> SqlExpressionNode: + return SqlCaseExpression.create( + when_to_then_exprs={ + when.rewrite(column_replacements, should_render_table_alias): then.rewrite( + column_replacements, should_render_table_alias + ) + for when, then in self.when_to_then_exprs.items() + }, + else_expr=( + self.else_expr.rewrite(column_replacements, should_render_table_alias) if self.else_expr else None + ), + ) + + @property + def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 + return SqlExpressionTreeLineage.merge_iterable( + tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) + ) + + def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 + if not isinstance(other, SqlCaseExpression): + return False + return self.when_to_then_exprs == other.when_to_then_exprs and self.else_expr == other.else_expr + + +class SqlArithmeticOperator(Enum): + """Arithmetic operator used to do math in a SQL expression.""" + + ADD = "+" + SUBTRACT = "-" + MULTIPLY = "*" + DIVIDE = "/" + + +@dataclass(frozen=True, eq=False) +class SqlArithmeticExpression(SqlExpressionNode): + """An arithmetic expression using +, -, *, /. + + e.g. my_table.my_column + my_table.other_column + + Attributes: + left_expr: The expression on the left side of the operator + operator: The operator to use on the expressions + right_expr: The expression on the right side of the operator + """ + + left_expr: SqlExpressionNode + operator: SqlArithmeticOperator + right_expr: SqlExpressionNode + + @staticmethod + def create( # noqa: D102 + left_expr: SqlExpressionNode, operator: SqlArithmeticOperator, right_expr: SqlExpressionNode + ) -> SqlArithmeticExpression: + return SqlArithmeticExpression( + parent_nodes=(left_expr, right_expr), left_expr=left_expr, operator=operator, right_expr=right_expr + ) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.SQL_EXPR_ARITHMETIC_PREFIX + + def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_arithmetic_expr(self) + + @property + def description(self) -> str: # noqa: D102 + return "Arithmetic Expression" + + @property + def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + return tuple(super().displayed_properties) + ( + DisplayedProperty("left_expr", self.left_expr), + DisplayedProperty("operator", self.operator.value), + DisplayedProperty("right_expr", self.right_expr), + ) + + @property + def requires_parenthesis(self) -> bool: # noqa: D102 + return True + + def rewrite( # noqa: D102 + self, + column_replacements: Optional[SqlColumnReplacements] = None, + should_render_table_alias: Optional[bool] = None, + ) -> SqlExpressionNode: + return SqlArithmeticExpression.create( + left_expr=self.left_expr.rewrite(column_replacements, should_render_table_alias), + operator=self.operator, + right_expr=self.right_expr.rewrite(column_replacements, should_render_table_alias), + ) + + @property + def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 + return SqlExpressionTreeLineage.merge_iterable( + tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) + ) + + def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 + if not isinstance(other, SqlArithmeticExpression): + return False + return self.operator == other.operator and self._parents_match(other) diff --git a/metricflow/sql/render/duckdb_renderer.py b/metricflow/sql/render/duckdb_renderer.py index ecfca54f5..48d0c1672 100644 --- a/metricflow/sql/render/duckdb_renderer.py +++ b/metricflow/sql/render/duckdb_renderer.py @@ -7,7 +7,10 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.sql.sql_exprs import ( SqlAddTimeExpression, + SqlArithmeticExpression, + SqlArithmeticOperator, SqlGenerateUuidExpression, + SqlIntegerExpression, SqlPercentileExpression, SqlPercentileFunctionType, SqlSubtractTimeIntervalExpression, @@ -56,17 +59,25 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress @override def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: """Render time delta expression for DuckDB, which requires slightly different syntax from other engines.""" - arg_rendered = node.arg.accept(self) - count_rendered = node.count_expr.accept(self).sql - granularity = node.granularity + count_expr = node.count_expr if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"({count_rendered} * 3)" + count_expr = SqlArithmeticExpression.create( + left_expr=node.count_expr, + operator=SqlArithmeticOperator.MULTIPLY, + right_expr=SqlIntegerExpression.create(3), + ) + + arg_rendered = node.arg.accept(self) + count_rendered = count_expr.accept(self) + count_sql = f"({count_rendered.sql})" if count_expr.requires_parenthesis else count_rendered.sql return SqlExpressionRenderResult( - sql=f"{arg_rendered.sql} + INTERVAL {count_rendered} {granularity.value}", - bind_parameter_set=arg_rendered.bind_parameter_set, + sql=f"{arg_rendered.sql} + INTERVAL {count_sql} {granularity.value}", + bind_parameter_set=SqlBindParameterSet.merge_iterable( + (arg_rendered.bind_parameter_set, count_rendered.bind_parameter_set) + ), ) @override diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index 10e3d748b..a89dc2abb 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -15,7 +15,10 @@ from metricflow_semantics.sql.sql_exprs import ( SqlAddTimeExpression, SqlAggregateFunctionExpression, + SqlArithmeticExpression, + SqlArithmeticOperator, SqlBetweenExpression, + SqlCaseExpression, SqlCastToTimestampExpression, SqlColumnAliasReferenceExpression, SqlColumnReferenceExpression, @@ -26,6 +29,7 @@ SqlExtractExpression, SqlFunction, SqlGenerateUuidExpression, + SqlIntegerExpression, SqlIsNullExpression, SqlLogicalExpression, SqlNullExpression, @@ -320,17 +324,25 @@ def visit_subtract_time_interval_expr( # noqa: D102 ) def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: # noqa: D102 - arg_rendered = node.arg.accept(self) - count_rendered = node.count_expr.accept(self).sql - granularity = node.granularity + count_expr = node.count_expr if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"({count_rendered} * 3)" + count_expr = SqlArithmeticExpression.create( + left_expr=node.count_expr, + operator=SqlArithmeticOperator.MULTIPLY, + right_expr=SqlIntegerExpression.create(3), + ) + + arg_rendered = node.arg.accept(self) + count_rendered = count_expr.accept(self) + count_sql = f"({count_rendered.sql})" if count_expr.requires_parenthesis else count_rendered.sql return SqlExpressionRenderResult( - sql=f"DATEADD({granularity.value}, {count_rendered}, {arg_rendered.sql})", - bind_parameter_set=arg_rendered.bind_parameter_set, + sql=f"DATEADD({granularity.value}, {count_sql}, {arg_rendered.sql})", + bind_parameter_set=SqlBindParameterSet.merge_iterable( + (arg_rendered.bind_parameter_set, count_rendered.bind_parameter_set) + ), ) def visit_ratio_computation_expr(self, node: SqlRatioComputationExpression) -> SqlExpressionRenderResult: @@ -438,3 +450,27 @@ def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpres sql="UUID()", bind_parameter_set=SqlBindParameterSet(), ) + + def visit_case_expr(self, node: SqlCaseExpression) -> SqlExpressionRenderResult: # noqa: D102 + sql = "CASE\n" + for when, then in node.when_to_then_exprs.items(): + sql += indent( + f"WHEN {self.render_sql_expr(when).sql}\n", indent_prefix=SqlRenderingConstants.INDENT + ) + indent( + f"THEN {self.render_sql_expr(then).sql}\n", + indent_prefix=SqlRenderingConstants.INDENT * 2, + ) + if node.else_expr: + sql += indent( + f"ELSE {self.render_sql_expr(node.else_expr).sql}\n", + indent_prefix=SqlRenderingConstants.INDENT, + ) + sql += "END" + return SqlExpressionRenderResult(sql=sql, bind_parameter_set=SqlBindParameterSet()) + + def visit_arithmetic_expr(self, node: SqlArithmeticExpression) -> SqlExpressionRenderResult: # noqa: D102 + sql = f"{self.render_sql_expr(node.left_expr).sql} {node.operator.value} {self.render_sql_expr(node.right_expr).sql}" + return SqlExpressionRenderResult(sql=sql, bind_parameter_set=SqlBindParameterSet()) + + def visit_integer_expr(self, node: SqlIntegerExpression) -> SqlExpressionRenderResult: # noqa: D102 + return SqlExpressionRenderResult(sql=str(node.integer_value), bind_parameter_set=SqlBindParameterSet()) diff --git a/metricflow/sql/render/postgres.py b/metricflow/sql/render/postgres.py index 2509dfc24..92e910eb1 100644 --- a/metricflow/sql/render/postgres.py +++ b/metricflow/sql/render/postgres.py @@ -8,7 +8,10 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.sql.sql_exprs import ( SqlAddTimeExpression, + SqlArithmeticExpression, + SqlArithmeticOperator, SqlGenerateUuidExpression, + SqlIntegerExpression, SqlPercentileExpression, SqlPercentileFunctionType, SqlSubtractTimeIntervalExpression, @@ -58,17 +61,25 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress @override def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: """Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity.""" - arg_rendered = node.arg.accept(self) - count_rendered = node.count_expr.accept(self).sql - granularity = node.granularity + count_expr = node.count_expr if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"({count_rendered} * 3)" + SqlArithmeticExpression.create( + left_expr=node.count_expr, + operator=SqlArithmeticOperator.MULTIPLY, + right_expr=SqlIntegerExpression.create(3), + ) + + arg_rendered = node.arg.accept(self) + count_rendered = count_expr.accept(self) + count_sql = f"({count_rendered.sql})" if count_expr.requires_parenthesis else count_rendered.sql return SqlExpressionRenderResult( - sql=f"{arg_rendered.sql} + MAKE_INTERVAL({granularity.value}s => {count_rendered})", - bind_parameter_set=arg_rendered.bind_parameter_set, + sql=f"{arg_rendered.sql} + MAKE_INTERVAL({granularity.value}s => {count_sql})", + bind_parameter_set=SqlBindParameterSet.merge_iterable( + (arg_rendered.bind_parameter_set, count_rendered.bind_parameter_set) + ), ) @override diff --git a/metricflow/sql/render/trino.py b/metricflow/sql/render/trino.py index bd3a58159..f0a8ea2da 100644 --- a/metricflow/sql/render/trino.py +++ b/metricflow/sql/render/trino.py @@ -9,8 +9,11 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.sql.sql_exprs import ( SqlAddTimeExpression, + SqlArithmeticExpression, + SqlArithmeticOperator, SqlBetweenExpression, SqlGenerateUuidExpression, + SqlIntegerExpression, SqlPercentileExpression, SqlPercentileFunctionType, SqlSubtractTimeIntervalExpression, @@ -63,17 +66,25 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress @override def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: """Render time delta for Trino, require granularity in quotes and function name change.""" - arg_rendered = node.arg.accept(self) - count_rendered = node.count_expr.accept(self).sql - granularity = node.granularity + count_expr = node.count_expr if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"({count_rendered} * 3)" + SqlArithmeticExpression.create( + left_expr=node.count_expr, + operator=SqlArithmeticOperator.MULTIPLY, + right_expr=SqlIntegerExpression.create(3), + ) + + arg_rendered = node.arg.accept(self) + count_rendered = count_expr.accept(self) + count_sql = f"({count_rendered.sql})" if count_expr.requires_parenthesis else count_rendered.sql return SqlExpressionRenderResult( - sql=f"DATE_ADD('{granularity.value}', {count_rendered}, {arg_rendered.sql})", - bind_parameter_set=arg_rendered.bind_parameter_set, + sql=f"DATE_ADD('{granularity.value}', {count_sql}, {arg_rendered.sql})", + bind_parameter_set=SqlBindParameterSet.merge_iterable( + (arg_rendered.bind_parameter_set, count_rendered.bind_parameter_set) + ), ) @override