Skip to content

Commit

Permalink
added JoinConversionEventsNode
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamDee committed Dec 5, 2022
1 parent 07c3363 commit 8787fa0
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 30 deletions.
1 change: 1 addition & 0 deletions metricflow/dag/id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DATAFLOW_NODE_SET_MEASURE_AGGREGATION_TIME = "sma"
DATAFLOW_NODE_SEMI_ADDITIVE_JOIN_ID_PREFIX = "saj"
DATAFLOW_NODE_ADD_UUID_COLUMN_PREFIX = "auid"
DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX = "jce"

SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX = "cr"
SQL_EXPR_COMPARISON_ID_PREFIX = "cmp"
Expand Down
10 changes: 10 additions & 0 deletions metricflow/dataflow/builder/costing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ComputeMetricsNode,
AggregateMeasuresNode,
JoinAggregatedMeasuresByGroupByColumnsNode,
JoinConversionEventsNode,
JoinOverTimeRangeNode,
JoinToBaseOutputNode,
ReadSqlSourceNode,
Expand Down Expand Up @@ -167,3 +168,12 @@ def visit_add_generated_uuid_column_node( # noqa: D
self, node: AddGeneratedUuidColumnNode[SourceDataSetT]
) -> DefaultCost:
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])

def visit_join_conversion_events_node( # noqa: D
self, node: JoinConversionEventsNode[SourceDataSetT]
) -> DefaultCost:
parent_costs = [x.accept(self) for x in node.parent_nodes]

# Add number of joins to the cost.
node_cost = DefaultCost(num_joins=1, num_aggregations=1)
return DefaultCost.sum(parent_costs + [node_cost])
90 changes: 90 additions & 0 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DATAFLOW_NODE_COMBINE_METRICS_ID_PREFIX,
DATAFLOW_NODE_CONSTRAIN_TIME_RANGE_ID_PREFIX,
DATAFLOW_NODE_SET_MEASURE_AGGREGATION_TIME,
DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX,
)
from metricflow.dag.mf_dag import DagNode, DisplayedProperty, MetricFlowDag, NodeId
from metricflow.dataflow.builder.partitions import (
Expand All @@ -41,6 +42,7 @@
from metricflow.specs import (
MetricInputMeasureSpec,
OrderBySpec,
IdentifierSpec,
InstanceSpec,
MetricSpec,
LinklessIdentifierSpec,
Expand Down Expand Up @@ -167,6 +169,12 @@ def visit_add_generated_uuid_column_node( # noqa: D
) -> VisitorOutputT:
pass

@abstractmethod
def visit_join_conversion_events_node( # noqa: D
self, node: JoinConversionEventsNode[SourceDataSetT]
) -> VisitorOutputT:
pass


class BaseOutput(Generic[SourceDataSetT], DataflowPlanNode[SourceDataSetT], ABC):
"""A node that outputs data in a "base" format.
Expand Down Expand Up @@ -1011,6 +1019,88 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D
return super().displayed_properties


class JoinConversionEventsNode(Generic[SourceDataSetT], BaseOutput[SourceDataSetT]):
"""Builds a data set containing successful conversion events."""

def __init__(
self,
base_node: BaseOutput[SourceDataSetT],
base_time_dimension_spec: TimeDimensionSpec,
conversion_node: BaseOutput[SourceDataSetT],
conversion_time_dimension_spec: TimeDimensionSpec,
conversion_primary_key_specs: Tuple[InstanceSpec],
entity_spec: IdentifierSpec,
window: Optional[CumulativeMetricWindow] = None,
) -> None:
"""Constructor.
Args:
base_node: node containing dataset for computing base events.
base_time_dimension_spec: time dimension for the base events to compute against.
conversion_node: node containing dataset to join base node for computing conversion events.
conversion_time_dimension_spec: time dimension for the conversion events to compute against.
conversion_primary_key_specs: primary_key(s) to uniquely identify each conversion event.
entity_spec: the specific entity in which the conversion is happening for.
window: time range bound for when a conversion is still considered valid (default: INF).
"""
self._base_node = base_node
self._conversion_node = conversion_node
self._base_time_dimension_spec = base_time_dimension_spec
self._conversion_time_dimension_spec = conversion_time_dimension_spec
self._conversion_primary_key_specs = conversion_primary_key_specs
self._entity_spec = entity_spec
self._window = window
super().__init__(node_id=self.create_unique_id(), parent_nodes=[base_node, conversion_node])

@classmethod
def id_prefix(cls) -> str: # noqa: D
return DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D
return visitor.visit_join_conversion_events_node(self)

@property
def base_node(self) -> DataflowPlanNode: # noqa: D
return self._base_node

@property
def conversion_node(self) -> DataflowPlanNode: # noqa: D
return self._conversion_node

@property
def base_time_dimension_spec(self) -> TimeDimensionSpec: # noqa: D
return self._base_time_dimension_spec

@property
def conversion_time_dimension_spec(self) -> TimeDimensionSpec: # noqa: D
return self._conversion_time_dimension_spec

@property
def conversion_primary_key_specs(self) -> Tuple[InstanceSpec]: # noqa: D
return self._conversion_primary_key_specs

@property
def entity_spec(self) -> IdentifierSpec: # noqa: D
return self._entity_spec

@property
def window(self) -> Optional[CumulativeMetricWindow]: # noqa: D
return self._window

@property
def description(self) -> str: # noqa: D
return f"""Find conversions for {self.entity_spec} within the range of {self.window.to_string() if self.window else 'INF'}"""

@property
def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D
return super().displayed_properties + [
DisplayedProperty("base_time_dimension_spec", self.base_time_dimension_spec),
DisplayedProperty("conversion_time_dimension_spec", self.conversion_time_dimension_spec),
DisplayedProperty("entity_spec", self.entity_spec),
DisplayedProperty("window", self.window),
]


class DataflowPlan(Generic[SourceDataSetT], MetricFlowDag[SinkOutput[SourceDataSetT]]):
"""Describes the flow of metric data as it goes from source nodes to sink nodes in the graph."""

Expand Down
139 changes: 139 additions & 0 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ComputeMetricsNode,
AggregateMeasuresNode,
JoinAggregatedMeasuresByGroupByColumnsNode,
JoinConversionEventsNode,
JoinToBaseOutputNode,
ReadSqlSourceNode,
BaseOutput,
Expand Down Expand Up @@ -50,6 +51,7 @@
AddMetrics,
CreateSelectColumnsForInstances,
CreateSelectColumnsWithMeasuresAggregated,
CreateSqlColumnReferencesForInstances,
create_select_columns_for_instance_sets,
AddLinkToLinkableElements,
FilterElements,
Expand Down Expand Up @@ -96,6 +98,9 @@
SqlStringLiteralExpression,
SqlBetweenExpression,
SqlAggregateFunctionExpression,
SqlWindowFunctionExpression,
SqlWindowFunction,
SqlWindowOrderByArgument,
)
from metricflow.sql.sql_plan import (
SqlQueryPlan,
Expand Down Expand Up @@ -1312,3 +1317,137 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode)
order_bys=(),
),
)

def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> SqlDataSet:
"""Builds a resulting data set with all valid conversion events.
This node takes the conversion and base data set and joins them against an entity and
a valid time range to get successful conversions. It then deduplicates opportunities
via the window function `first_value` to take the closest opportunity to the
corresponding conversion. Then it returns a data set with each row representing a
successful conversion and. Duplication may exist in the result due to a single base event
being able to link to multiple conversion events.
"""
base_data_set: SqlDataSet = node.base_node.accept(self)
base_data_set_alias = self._next_unique_table_alias()

conversion_data_set: SqlDataSet = node.conversion_node.accept(self)
conversion_data_set_alias = self._next_unique_table_alias()

base_time_dimension_column_name = self._column_association_resolver.resolve_time_dimension_spec(
node.base_time_dimension_spec
).column_name
conversion_time_dimension_column_name = self._column_association_resolver.resolve_time_dimension_spec(
node.conversion_time_dimension_spec
).column_name
entity_column_name = self._column_association_resolver.resolve_identifier_spec(node.entity_spec)[0].column_name

# Builds the join conditions that is required for a successful conversion
sql_join_description = SqlQueryPlanJoinBuilder.make_join_conversion_join_description(
node=node,
base_data_set=AnnotatedSqlDataSet(
data_set=base_data_set,
alias=base_data_set_alias,
_metric_time_column_name=base_time_dimension_column_name,
),
conversion_data_set=AnnotatedSqlDataSet(
data_set=conversion_data_set,
alias=conversion_data_set_alias,
_metric_time_column_name=conversion_time_dimension_column_name,
),
column_equality_descriptions=(
ColumnEqualityDescription(
left_column_alias=entity_column_name,
right_column_alias=entity_column_name,
), # add constant property here
),
)

# Builds the first_value window function columns
base_sql_column_references = base_data_set.instance_set.transform(
CreateSqlColumnReferencesForInstances(base_data_set_alias, self._column_association_resolver)
)
partition_by_columns = (entity_column_name, conversion_time_dimension_column_name) # add constant property here
base_sql_select_columns = tuple(
SqlSelectColumn(
expr=SqlWindowFunctionExpression(
sql_function=SqlWindowFunction.FIRST_VALUE,
sql_function_args=[
SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=base_data_set_alias,
column_name=base_sql_column_reference.col_ref.column_name,
),
)
],
partition_by_args=[
SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=conversion_data_set_alias,
column_name=column,
),
)
for column in partition_by_columns
],
order_by_args=[
SqlWindowOrderByArgument(
expr=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=base_data_set_alias,
column_name=base_time_dimension_column_name,
),
),
descending=True,
)
],
),
column_alias=base_sql_column_reference.col_ref.column_name,
)
for base_sql_column_reference in base_sql_column_references
)

# Deduplicate the fanout results
conversion_primary_key_select_columns = tuple(
SqlSelectColumn(
expr=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=conversion_data_set_alias,
column_name=spec.column_associations(self._column_association_resolver)[0].column_name,
),
),
column_alias=spec.column_associations(self._column_association_resolver)[0].column_name,
)
for spec in node.conversion_primary_key_specs
)
deduped_sql_select_node = SqlSelectStatementNode(
description=f"Dedupe the fanout on {node.conversion_primary_key_specs} in the conversion data set",
select_columns=base_sql_select_columns + conversion_primary_key_select_columns,
from_source=base_data_set.sql_select_node,
from_source_alias=base_data_set_alias,
joins_descs=(sql_join_description,),
group_bys=(),
where=None,
order_bys=(),
distinct=True,
)

# Returns the original dataset with all the successful conversion
output_data_set_alias = self._next_unique_table_alias()
output_instance_set = ChangeAssociatedColumns(self._column_association_resolver).transform(
base_data_set.instance_set
)
return SqlDataSet(
instance_set=output_instance_set,
sql_select_node=SqlSelectStatementNode(
description=node.description,
select_columns=output_instance_set.transform(
CreateSelectColumnsForInstances(output_data_set_alias, self._column_association_resolver)
).as_tuple(),
from_source=deduped_sql_select_node,
from_source_alias=output_data_set_alias,
joins_descs=(),
group_bys=(),
where=None,
order_bys=(),
),
)
Loading

0 comments on commit 8787fa0

Please sign in to comment.