Skip to content

Commit

Permalink
Replace uses of DataFrame with MetricflowDataTable (#1235)
Browse files Browse the repository at this point in the history
### Description

This replacement is needed for the later removal of the `pandas`
dependency.
<!--- 
  Before requesting review, please make sure you have:
1. read [the contributing
guide](https://github.com/dbt-labs/metricflow/blob/main/CONTRIBUTING.md),
2. signed the
[CLA](https://docs.getdbt.com/docs/contributor-license-agreements)
3. run `changie new` to [create a changelog
entry](https://github.com/dbt-labs/metricflow/blob/main/CONTRIBUTING.md#adding-a-changelog-entry)
-->
  • Loading branch information
plypaul authored Jun 3, 2024
1 parent 3d26700 commit 8ecf93a
Show file tree
Hide file tree
Showing 258 changed files with 4,119 additions and 4,027 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import time

import pandas as pd
from dbt.adapters.base.impl import BaseAdapter
from dbt.exceptions import DbtDatabaseError
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
Expand All @@ -14,6 +13,7 @@
from metricflow_semantics.random_id import random_id
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.render.big_query import BigQuerySqlQueryPlanRenderer
from metricflow.sql.render.databricks import DatabricksSqlQueryPlanRenderer
Expand Down Expand Up @@ -127,8 +127,8 @@ def query(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
) -> pd.DataFrame:
"""Query statement; result expected to be data which will be returned as a DataFrame.
) -> MetricFlowDataTable:
"""Query statement; result expected to be data which will be returned as a DataTable.
Args:
stmt: The SQL query statement to run. This should produce output via a SELECT
Expand All @@ -150,10 +150,14 @@ def query(
logger.info(f"Query returned from dbt Adapter with response {result[0]}")

agate_data = result[1]
df = pd.DataFrame([row.values() for row in agate_data.rows], columns=agate_data.column_names)
rows = [row.values() for row in agate_data.rows]
data_table = MetricFlowDataTable.create_from_rows(
column_names=agate_data.column_names,
rows=rows,
)
stop = time.time()
logger.info(f"Finished running the query in {stop - start:.2f}s with {df.shape[0]} row(s) returned")
return df
logger.info(f"Finished running the query in {stop - start:.2f}s with {data_table.row_count} row(s) returned")
return data_table

def execute(
self,
Expand Down
6 changes: 3 additions & 3 deletions dbt-metricflow/dbt_metricflow/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def tutorial(ctx: click.core.Context, cfg: CLIContext, msg: bool, clean: bool) -
"--csv",
type=click.File("wb"),
required=False,
help="Provide filepath for dataframe output to csv",
help="Provide filepath for data_table output to csv",
)
@click.option(
"--explain",
Expand Down Expand Up @@ -334,14 +334,14 @@ def query(
df = query_result.result_df
# Show the data if returned successfully
if df is not None:
if df.empty:
if df.row_count == 0:
click.echo("🕳 Successful MQL query returned an empty result set.")
elif csv is not None:
# csv is a LazyFile that is file-like that works in this case.
df.to_csv(csv, index=False) # type: ignore
click.echo(f"🖨 Successfully written query output to {csv.name}")
else:
click.echo(df.to_markdown(index=False, floatfmt=f".{decimals}f"))
click.echo(df.text_format(decimals))

if display_plans:
temp_path = tempfile.mkdtemp()
Expand Down
2 changes: 1 addition & 1 deletion metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
DATAFLOW_NODE_PASS_FILTER_ELEMENTS_ID_PREFIX = "pfe"
DATAFLOW_NODE_READ_SQL_SOURCE_ID_PREFIX = "rss"
DATAFLOW_NODE_WHERE_CONSTRAINT_ID_PREFIX = "wcc"
DATAFLOW_NODE_WRITE_TO_RESULT_DATAFRAME_ID_PREFIX = "wrd"
DATAFLOW_NODE_WRITE_TO_RESULT_DATA_TABLE_ID_PREFIX = "wrd"
DATAFLOW_NODE_WRITE_TO_RESULT_TABLE_ID_PREFIX = "wrt"
DATAFLOW_NODE_COMBINE_AGGREGATED_OUTPUTS_ID_PREFIX = "cao"
DATAFLOW_NODE_CONSTRAIN_TIME_RANGE_ID_PREFIX = "ctr"
Expand Down
7 changes: 1 addition & 6 deletions metricflow/data_table/mf_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ def column_values_iterator(self, column_index: int) -> Iterator[CellValue]:
return (row[column_index] for row in self.rows)

def _sorted_by_column_name(self) -> MetricFlowDataTable: # noqa: D102
# row_dict_by_row_index: Dict[int, Dict[str, CellType]] = defaultdict(dict)

new_rows: List[List[CellValue]] = [[] for _ in range(self.row_count)]
sorted_column_names = sorted(self.column_names)
for column_name in sorted_column_names:
Expand Down Expand Up @@ -142,10 +140,7 @@ def text_format(self, float_decimals: int = 2) -> str:
continue

if isinstance(cell_value, datetime.datetime):
if cell_value.time() == datetime.time.min:
str_row.append(cell_value.date().isoformat())
else:
str_row.append(cell_value.isoformat())
str_row.append(cell_value.isoformat())
continue

str_row.append(str(cell_value))
Expand Down
6 changes: 3 additions & 3 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.write_to_dataframe import WriteToResultDataframeNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer
from metricflow.dataset.dataset_classes import DataSet
Expand Down Expand Up @@ -144,7 +144,7 @@ def build_plan(
output_selection_specs: Optional[InstanceSpecSet] = None,
optimizers: Sequence[DataflowPlanOptimizer] = (),
) -> DataflowPlan:
"""Generate a plan for reading the results of a query with the given spec into a dataframe or table."""
"""Generate a plan for reading the results of a query with the given spec into a data_table or table."""
# Workaround for a Pycharm type inspection issue with decorators.
# noinspection PyArgumentList
return self._build_plan(
Expand Down Expand Up @@ -738,7 +738,7 @@ def build_sink_node(

write_result_node: DataflowPlanNode
if not output_sql_table:
write_result_node = WriteToResultDataframeNode(parent_node=pre_result_node or parent_node)
write_result_node = WriteToResultDataTableNode(parent_node=pre_result_node or parent_node)
else:
write_result_node = WriteToResultTableNode(
parent_node=pre_result_node or parent_node, output_sql_table=output_sql_table
Expand Down
5 changes: 2 additions & 3 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import more_itertools
from metricflow_semantics.dag.id_prefix import StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId, DagNode, MetricFlowDag, NodeId
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec
from metricflow_semantics.visitor import Visitable, VisitorOutputT

if typing.TYPE_CHECKING:
Expand All @@ -33,7 +32,7 @@
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.write_to_dataframe import WriteToResultDataframeNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode


Expand Down Expand Up @@ -147,7 +146,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> VisitorOutpu
pass

@abstractmethod
def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode) -> VisitorOutputT: # noqa: D102
def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@
)


class WriteToResultDataframeNode(DataflowPlanNode):
"""A node where incoming data gets written to a dataframe."""
class WriteToResultDataTableNode(DataflowPlanNode):
"""A node where incoming data gets written to a data_table."""

def __init__(self, parent_node: DataflowPlanNode) -> None: # noqa: D107
self._parent_node = parent_node
super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,))

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_WRITE_TO_RESULT_DATAFRAME_ID_PREFIX
return StaticIdPrefix.DATAFLOW_NODE_WRITE_TO_RESULT_DATA_TABLE_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_write_to_result_dataframe_node(self)
return visitor.visit_write_to_result_data_table_node(self)

@property
def description(self) -> str: # noqa: D102
return """Write to Dataframe"""
return """Write to DataTable"""

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
Expand All @@ -39,6 +39,6 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa:

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> WriteToResultDataframeNode:
) -> WriteToResultDataTableNode:
assert len(new_parent_nodes) == 1
return WriteToResultDataframeNode(parent_node=new_parent_nodes[0])
return WriteToResultDataTableNode(parent_node=new_parent_nodes[0])
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/write_to_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_WRITE_TO_RESULT_DATAFRAME_ID_PREFIX
return StaticIdPrefix.DATAFLOW_NODE_WRITE_TO_RESULT_DATA_TABLE_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_write_to_result_table_node(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.write_to_dataframe import WriteToResultDataframeNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.dataflow.optimizer.source_scan.matching_linkable_specs import MatchingLinkableSpecsTransform

Expand Down Expand Up @@ -337,8 +337,8 @@ def visit_where_constraint_node( # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_write_to_result_dataframe_node( # noqa: D102
self, node: WriteToResultDataframeNode
def visit_write_to_result_data_table_node( # noqa: D102
self, node: WriteToResultDataTableNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._handle_unsupported_node(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.write_to_dataframe import WriteToResultDataframeNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer
from metricflow.dataflow.optimizer.source_scan.cm_branch_combiner import (
Expand Down Expand Up @@ -161,8 +161,8 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_write_to_result_dataframe_node( # noqa: D102
self, node: WriteToResultDataframeNode
def visit_write_to_result_data_table_node( # noqa: D102
self, node: WriteToResultDataTableNode
) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_base_output_handler(node)
Expand Down
24 changes: 6 additions & 18 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from enum import Enum
from typing import List, Optional, Sequence, Tuple

import pandas as pd
from dbt_semantic_interfaces.implementations.elements.dimension import PydanticDimensionTypeParams
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from dbt_semantic_interfaces.references import EntityReference, MeasureReference, MetricReference
Expand Down Expand Up @@ -35,6 +34,7 @@
from metricflow_semantics.specs.spec_set import InstanceSpecSet
from metricflow_semantics.time.time_source import TimeSource

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.builder.source_node import SourceNodeBuilder
Expand Down Expand Up @@ -96,7 +96,7 @@ class MetricFlowQueryRequest:
where_constraint: A SQL string using group by names that can be used like a where clause on the output data.
order_by_names: metric and group by names to order by. A "-" can be used to specify reverse order e.g. "-ds".
order_by: metric, dimension, or entity objects to order by.
output_table: If specified, output the result data to this table instead of a result dataframe.
output_table: If specified, output the result data to this table instead of a result data_table.
sql_optimization_level: The level of optimization for the generated SQL.
query_type: Type of MetricFlow query.
"""
Expand Down Expand Up @@ -160,7 +160,7 @@ class MetricFlowQueryResult:
query_spec: MetricFlowQuerySpec
dataflow_plan: DataflowPlan
sql: str
result_df: Optional[pd.DataFrame] = None
result_df: Optional[MetricFlowDataTable] = None
result_table: Optional[SqlTable] = None


Expand Down Expand Up @@ -703,7 +703,7 @@ def get_dimension_values( # noqa: D102
time_constraint_end: Optional[datetime.datetime] = None,
) -> List[str]:
# Run query
query_result = self.query(
query_result: MetricFlowQueryResult = self.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=metric_names,
group_by_names=[get_group_by_values],
Expand All @@ -712,22 +712,10 @@ def get_dimension_values( # noqa: D102
query_type=MetricFlowQueryType.DIMENSION_VALUES,
)
)
result_dataframe = query_result.result_df
if result_dataframe is None:
if query_result.result_df is None:
return []

# Snowflake likes upper-casing things in result output, so we lower-case all names
# before operating on the dataframe.
metric_names = [metric_name.lower() for metric_name in metric_names]
result_dataframe.columns = result_dataframe.columns.str.lower()

# Get dimension values regardless of input name -> output dimension mapping. This is necessary befcause
# granularity adjustments on time dimensions produce different output names for dimension values.
# Note: this only works as long as we have exactly one column of group by values
# and no other extraneous output columns
dim_vals = result_dataframe[result_dataframe.columns[~result_dataframe.columns.isin(metric_names)]].iloc[:, 0]

return sorted([str(val) for val in dim_vals])
return sorted([str(val) for val in query_result.result_df.column_values_iterator(0)])

@log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter)
def explain_get_dimension_values( # noqa: D102
Expand Down
8 changes: 4 additions & 4 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.write_to_dataframe import WriteToResultDataframeNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.execution.convert_to_execution_plan import ConvertToExecutionPlanResult
from metricflow.execution.execution_plan import (
ExecutionPlan,
SelectSqlQueryToDataFrameTask,
SelectSqlQueryToDataTableTask,
SelectSqlQueryToTableTask,
)
from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult
Expand Down Expand Up @@ -74,12 +74,12 @@ def _render_sql(self, convert_to_sql_plan_result: ConvertToSqlPlanResult) -> Sql
return self._sql_plan_renderer.render_sql_query_plan(convert_to_sql_plan_result.sql_plan)

@override
def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode) -> ConvertToExecutionPlanResult:
def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode) -> ConvertToExecutionPlanResult:
convert_to_sql_plan_result = self._convert_to_sql_plan(node)
render_sql_result = self._render_sql(convert_to_sql_plan_result)
execution_plan = ExecutionPlan(
leaf_tasks=(
SelectSqlQueryToDataFrameTask(
SelectSqlQueryToDataTableTask(
sql_client=self._sql_client,
sql_query=render_sql_result.sql,
bind_parameters=render_sql_result.bind_parameters,
Expand Down
10 changes: 5 additions & 5 deletions metricflow/execution/execution_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple

import pandas as pd
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag, NodeId
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.visitor import Visitable

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.sql_table import SqlTable

Expand Down Expand Up @@ -82,12 +82,12 @@ class TaskExecutionResult:
# If the task was an SQL query, it's stored here
sql: Optional[str] = None
bind_params: Optional[SqlBindParameters] = None
# If the task produces a dataframe as a result, it's stored here.
df: Optional[pd.DataFrame] = None
# If the task produces a data_table as a result, it's stored here.
df: Optional[MetricFlowDataTable] = None


class SelectSqlQueryToDataFrameTask(ExecutionPlanTask):
"""A task that runs a SELECT and puts that result into a dataframe."""
class SelectSqlQueryToDataTableTask(ExecutionPlanTask):
"""A task that runs a SELECT and puts that result into a data_table."""

def __init__( # noqa: D107
self,
Expand Down
Loading

0 comments on commit 8ecf93a

Please sign in to comment.