Skip to content

Commit

Permalink
Add SQL exprs needed for custom offset window (#1575)
Browse files Browse the repository at this point in the history
This includes a `CASE` expression, an integer expression, an arithmetic
expression, and some updates to the add time expression & window
function expressions. These are all needed to build the following SQL
column:
```
CASE
  WHEN ds__martian_day__first_value__offset + INTERVAL (ds__day__row_number - 1) day <= ds__martian_day__last_value__offset
    THEN ds__martian_day__first_value__offset + INTERVAL (ds__day__row_number - 1) day
  ELSE ds__martian_day__last_value__offset
  END AS metric_time__day
```

Note that the first commit is moving the `sql_exprs` file into
`metricflow-semantics`, which is needed for a commit further up the
stack.
  • Loading branch information
courtneyholcomb authored Dec 21, 2024
1 parent 72f692f commit 48e85f5
Show file tree
Hide file tree
Showing 39 changed files with 559 additions and 243 deletions.
3 changes: 3 additions & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
SQL_EXPR_BETWEEN_PREFIX = "betw"
SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc"
SQL_EXPR_GENERATE_UUID_PREFIX = "uuid"
SQL_EXPR_CASE_PREFIX = "case"
SQL_EXPR_ARITHMETIC_PREFIX = "arit"
SQL_EXPR_INTEGER_PREFIX = "int"

SQL_PLAN_SELECT_STATEMENT_ID_PREFIX = "ss"
SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
from metricflow_semantics.visitor import Visitable, VisitorOutputT
from typing_extensions import override


@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -237,6 +238,18 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> Visit
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_case_expr(self, node: SqlCaseExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_arithmetic_expr(self, node: SqlArithmeticExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_integer_expr(self, node: SqlIntegerExpression) -> VisitorOutputT: # noqa: D102
pass


@dataclass(frozen=True, eq=False)
class SqlStringExpression(SqlExpressionNode):
Expand Down Expand Up @@ -375,6 +388,59 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.literal_value == other.literal_value


@dataclass(frozen=True, eq=False)
class SqlIntegerExpression(SqlExpressionNode):
"""An integer like 1."""

integer_value: int

@staticmethod
def create(integer_value: int) -> SqlIntegerExpression: # noqa: D102
return SqlIntegerExpression(parent_nodes=(), integer_value=integer_value)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_INTEGER_PREFIX

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_integer_expr(self)

@property
def description(self) -> str: # noqa: D102
return f"Integer: {self.integer_value}"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (DisplayedProperty("value", self.integer_value),)

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return False

@property
def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102
return SqlBindParameterSet()

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(node_id={self.node_id}, integer_value={self.integer_value})"

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return self

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage(other_exprs=(self,))

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlIntegerExpression):
return False
return self.integer_value == other.integer_value


@dataclass(frozen=True)
class SqlColumnReference:
"""Used with string expressions to specify what columns are referred to in the string expression."""
Expand Down Expand Up @@ -950,17 +1016,38 @@ class SqlWindowFunction(Enum):
FIRST_VALUE = "FIRST_VALUE"
LAST_VALUE = "LAST_VALUE"
AVERAGE = "AVG"
ROW_NUMBER = "ROW_NUMBER"
LEAD = "LEAD"

@property
def requires_ordering(self) -> bool:
"""Asserts whether or not ordering the window function will have an impact on the resulting value."""
if self is SqlWindowFunction.FIRST_VALUE or self is SqlWindowFunction.LAST_VALUE:
if (
self is SqlWindowFunction.FIRST_VALUE
or self is SqlWindowFunction.LAST_VALUE
or self is SqlWindowFunction.ROW_NUMBER
or self is SqlWindowFunction.LEAD
):
return True
elif self is SqlWindowFunction.AVERAGE:
return False
else:
assert_values_exhausted(self)

@property
def allows_frame_clause(self) -> bool:
"""Whether the function allows a frame clause, e.g., 'ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING'."""
if (
self is SqlWindowFunction.FIRST_VALUE
or self is SqlWindowFunction.LAST_VALUE
or self is SqlWindowFunction.AVERAGE
):
return True
if self is SqlWindowFunction.ROW_NUMBER or self is SqlWindowFunction.LEAD:
return False
else:
assert_values_exhausted(self)

@classmethod
def get_window_function_for_period_agg(cls, period_agg: PeriodAggregation) -> SqlWindowFunction:
"""Get the window function to use for given period agg option."""
Expand Down Expand Up @@ -1106,7 +1193,8 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return (
self.sql_function == other.sql_function
and self.order_by_args == other.order_by_args
and self._parents_match(other)
and self.partition_by_args == other.partition_by_args
and self.sql_function_args == other.sql_function_args
)


Expand Down Expand Up @@ -1367,7 +1455,7 @@ def rewrite( # noqa: D102
) -> SqlExpressionNode:
return SqlAddTimeExpression.create(
arg=self.arg.rewrite(column_replacements, should_render_table_alias),
count_expr=self.count_expr,
count_expr=self.count_expr.rewrite(column_replacements, should_render_table_alias),
granularity=self.granularity,
)

Expand Down Expand Up @@ -1719,3 +1807,157 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False


@dataclass(frozen=True, eq=False)
class SqlCaseExpression(SqlExpressionNode):
"""Renders a CASE WHEN expression."""

when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode]
else_expr: Optional[SqlExpressionNode]

@staticmethod
def create( # noqa: D102
when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode], else_expr: Optional[SqlExpressionNode] = None
) -> SqlCaseExpression:
parent_nodes: Tuple[SqlExpressionNode, ...] = ()
for when, then in when_to_then_exprs.items():
parent_nodes += (when, then)

if else_expr:
parent_nodes += (else_expr,)

return SqlCaseExpression(parent_nodes=parent_nodes, when_to_then_exprs=when_to_then_exprs, else_expr=else_expr)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_CASE_PREFIX

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_case_expr(self)

@property
def description(self) -> str: # noqa: D102
return "Case expression"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return super().displayed_properties

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return False

@property
def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102
return SqlBindParameterSet()

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(node_id={self.node_id})"

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return SqlCaseExpression.create(
when_to_then_exprs={
when.rewrite(column_replacements, should_render_table_alias): then.rewrite(
column_replacements, should_render_table_alias
)
for when, then in self.when_to_then_exprs.items()
},
else_expr=(
self.else_expr.rewrite(column_replacements, should_render_table_alias) if self.else_expr else None
),
)

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlCaseExpression):
return False
return self.when_to_then_exprs == other.when_to_then_exprs and self.else_expr == other.else_expr


class SqlArithmeticOperator(Enum):
"""Arithmetic operator used to do math in a SQL expression."""

ADD = "+"
SUBTRACT = "-"
MULTIPLY = "*"
DIVIDE = "/"


@dataclass(frozen=True, eq=False)
class SqlArithmeticExpression(SqlExpressionNode):
"""An arithmetic expression using +, -, *, /.
e.g. my_table.my_column + my_table.other_column
Attributes:
left_expr: The expression on the left side of the operator
operator: The operator to use on the expressions
right_expr: The expression on the right side of the operator
"""

left_expr: SqlExpressionNode
operator: SqlArithmeticOperator
right_expr: SqlExpressionNode

@staticmethod
def create( # noqa: D102
left_expr: SqlExpressionNode, operator: SqlArithmeticOperator, right_expr: SqlExpressionNode
) -> SqlArithmeticExpression:
return SqlArithmeticExpression(
parent_nodes=(left_expr, right_expr), left_expr=left_expr, operator=operator, right_expr=right_expr
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_ARITHMETIC_PREFIX

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_arithmetic_expr(self)

@property
def description(self) -> str: # noqa: D102
return "Arithmetic Expression"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("left_expr", self.left_expr),
DisplayedProperty("operator", self.operator.value),
DisplayedProperty("right_expr", self.right_expr),
)

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return True

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return SqlArithmeticExpression.create(
left_expr=self.left_expr.rewrite(column_replacements, should_render_table_alias),
operator=self.operator,
right_expr=self.right_expr.rewrite(column_replacements, should_render_table_alias),
)

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlArithmeticExpression):
return False
return self.operator == other.operator and self._parents_match(other)
14 changes: 7 additions & 7 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.time_dimension_spec import DEFAULT_TIME_GRANULARITY, TimeDimensionSpec
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
from metricflow_semantics.time.time_spine_source import TimeSpineSource

from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.sql.sql_exprs import (
from metricflow_semantics.sql.sql_exprs import (
SqlColumnReference,
SqlColumnReferenceExpression,
SqlDateTruncExpression,
SqlExpressionNode,
SqlExtractExpression,
SqlStringExpression,
)
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
from metricflow_semantics.time.time_spine_source import TimeSpineSource

from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.sql.sql_plan import (
SqlSelectColumn,
SqlSelectStatementNode,
Expand Down
Loading

0 comments on commit 48e85f5

Please sign in to comment.