Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Dec 11, 2024
1 parent bd641ef commit ec58d8f
Show file tree
Hide file tree
Showing 21 changed files with 678 additions and 37 deletions.
2 changes: 2 additions & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX = "jce"
DATAFLOW_NODE_WINDOW_REAGGREGATION_ID_PREFIX = "wr"
DATAFLOW_NODE_ALIAS_SPECS_ID_PREFIX = "as"
DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX = "cgb"

SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX = "cr"
SQL_EXPR_COMPARISON_ID_PREFIX = "cmp"
Expand All @@ -75,6 +76,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
SQL_EXPR_BETWEEN_PREFIX = "betw"
SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc"
SQL_EXPR_GENERATE_UUID_PREFIX = "uuid"
SQL_EXPR_CASE_PREFIX = "case"

SQL_PLAN_SELECT_STATEMENT_ID_PREFIX = "ss"
SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc"
Expand Down
6 changes: 1 addition & 5 deletions metricflow-semantics/metricflow_semantics/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,7 @@ def with_entity_prefix(
) -> TimeDimensionInstance:
"""Returns a new instance with the entity prefix added to the entity links."""
transformed_spec = self.spec.with_entity_prefix(entity_prefix)
return TimeDimensionInstance(
associated_columns=(column_association_resolver.resolve_spec(transformed_spec),),
defined_from=self.defined_from,
spec=transformed_spec,
)
return self.with_new_spec(transformed_spec, column_association_resolver)

def with_new_defined_from(self, defined_from: Sequence[SemanticModelElementReference]) -> TimeDimensionInstance:
"""Returns a new instance with the defined_from field replaced."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.instance_spec import InstanceSpecVisitor
from metricflow_semantics.sql.sql_exprs import SqlWindowFunction
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
from metricflow_semantics.visitor import VisitorOutputT

Expand Down Expand Up @@ -91,6 +92,8 @@ class TimeDimensionSpec(DimensionSpec): # noqa: D101
# Used for semi-additive joins. Some more thought is needed, but this may be useful in InstanceSpec.
aggregation_state: Optional[AggregationState] = None

window_function: Optional[SqlWindowFunction] = None

@property
def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D102
assert len(self.entity_links) > 0, f"Spec does not have any entity links: {self}"
Expand All @@ -99,6 +102,8 @@ def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D102
entity_links=self.entity_links[1:],
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

@property
Expand All @@ -108,6 +113,8 @@ def without_entity_links(self) -> TimeDimensionSpec: # noqa: D102
time_granularity=self.time_granularity,
date_part=self.date_part,
entity_links=(),
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

@property
Expand Down Expand Up @@ -153,6 +160,7 @@ def with_grain(self, time_granularity: ExpandedTimeGranularity) -> TimeDimension
time_granularity=time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102
Expand All @@ -162,6 +170,7 @@ def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102
time_granularity=ExpandedTimeGranularity.from_time_granularity(self.time_granularity.base_granularity),
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

def with_grain_and_date_part( # noqa: D102
Expand All @@ -173,6 +182,7 @@ def with_grain_and_date_part( # noqa: D102
time_granularity=time_granularity,
date_part=date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDimensionSpec: # noqa: D102
Expand All @@ -182,6 +192,17 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=aggregation_state,
window_function=self.window_function,
)

def with_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
entity_links=self.entity_links,
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=window_function,
)

def comparison_key(self, exclude_fields: Sequence[TimeDimensionSpecField] = ()) -> TimeDimensionSpecComparisonKey:
Expand Down Expand Up @@ -243,6 +264,7 @@ def with_entity_prefix(self, entity_prefix: EntityReference) -> TimeDimensionSpe
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

@staticmethod
Expand Down
80 changes: 78 additions & 2 deletions metricflow-semantics/metricflow_semantics/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
from metricflow_semantics.visitor import Visitable, VisitorOutputT
from typing_extensions import override


@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -235,6 +236,10 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> Visit
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_case_expr(self, node: SqlCaseExpression) -> VisitorOutputT: # noqa: D102
pass


@dataclass(frozen=True, eq=False)
class SqlStringExpression(SqlExpressionNode):
Expand Down Expand Up @@ -948,11 +953,18 @@ class SqlWindowFunction(Enum):
FIRST_VALUE = "FIRST_VALUE"
LAST_VALUE = "LAST_VALUE"
AVERAGE = "AVG"
ROW_NUMBER = "ROW_NUMBER"
LAG = "LAG"

@property
def requires_ordering(self) -> bool:
"""Asserts whether or not ordering the window function will have an impact on the resulting value."""
if self is SqlWindowFunction.FIRST_VALUE or self is SqlWindowFunction.LAST_VALUE:
if (
self is SqlWindowFunction.FIRST_VALUE
or self is SqlWindowFunction.LAST_VALUE
or self is SqlWindowFunction.ROW_NUMBER
or self is SqlWindowFunction.LAG
):
return True
elif self is SqlWindowFunction.AVERAGE:
return False
Expand Down Expand Up @@ -1715,3 +1727,67 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False


@dataclass(frozen=True, eq=False)
class SqlCaseExpression(SqlExpressionNode):
"""Renders a CASE WHEN expression."""

when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode]
else_expr: Optional[SqlExpressionNode]

@staticmethod
def create( # noqa: D102
when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode], else_expr: Optional[SqlExpressionNode] = None
) -> SqlCaseExpression:
parent_nodes: Tuple[SqlExpressionNode, ...] = ()
for when, then in when_to_then_exprs.items():
parent_nodes += (when,)
parent_nodes += (then,)

if else_expr:
parent_nodes += (else_expr,)

return SqlCaseExpression(parent_nodes=parent_nodes, when_to_then_exprs=when_to_then_exprs, else_expr=else_expr)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_CASE_PREFIX

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_case_expr(self)

@property
def description(self) -> str: # noqa: D102
return "Case expression"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return super().displayed_properties

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return False

@property
def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102
return SqlBindParameterSet()

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(node_id={self.node_id})"

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return self

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage(other_exprs=(self,))

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlCaseExpression):
return False
return self.when_to_then_exprs == other.when_to_then_exprs and self.else_expr == other.else_expr
Original file line number Diff line number Diff line change
Expand Up @@ -860,3 +860,24 @@ metric:
- name: instant_bookings
alias: shared_alias
---
metric:
name: bookings_offset_one_martian_day
description: bookings offset by one martian_day
type: derived
type_params:
expr: bookings
metrics:
- name: bookings
offset_window: 1 martian_day
---
metric:
name: bookings_martian_day_over_martian_day
description: bookings growth martian day over martian day
type: derived
type_params:
expr: bookings - bookings_offset / NULLIF(bookings_offset, 0)
metrics:
- name: bookings
offset_window: 1 martian_day
alias: bookings_offset
- name: bookings
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_classes() -> None: # noqa: D103
time_granularity=ExpandedTimeGranularity(name='day', base_granularity=DAY),
date_part=None,
aggregation_state=None,
window_function=None,
)
"""
).rstrip()
Expand Down
Loading

0 comments on commit ec58d8f

Please sign in to comment.