Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace uses of DataFrame with MetricflowDataTable #1235

Merged
merged 15 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import time

import pandas as pd
from dbt.adapters.base.impl import BaseAdapter
from dbt.exceptions import DbtDatabaseError
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
Expand All @@ -14,6 +13,7 @@
from metricflow_semantics.random_id import random_id
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.render.big_query import BigQuerySqlQueryPlanRenderer
from metricflow.sql.render.databricks import DatabricksSqlQueryPlanRenderer
Expand Down Expand Up @@ -127,7 +127,7 @@ def query(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
) -> pd.DataFrame:
) -> MetricFlowDataTable:
"""Query statement; result expected to be data which will be returned as a DataFrame.

Args:
Expand All @@ -150,10 +150,14 @@ def query(
logger.info(f"Query returned from dbt Adapter with response {result[0]}")

agate_data = result[1]
df = pd.DataFrame([row.values() for row in agate_data.rows], columns=agate_data.column_names)
rows = [row.values() for row in agate_data.rows]
data_table = MetricFlowDataTable.create_from_rows(
column_names=agate_data.column_names,
rows=rows,
)
stop = time.time()
logger.info(f"Finished running the query in {stop - start:.2f}s with {df.shape[0]} row(s) returned")
return df
logger.info(f"Finished running the query in {stop - start:.2f}s with {data_table.row_count} row(s) returned")
return data_table

def execute(
self,
Expand Down
22 changes: 5 additions & 17 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from enum import Enum
from typing import List, Optional, Sequence, Tuple

import pandas as pd
from dbt_semantic_interfaces.implementations.elements.dimension import PydanticDimensionTypeParams
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from dbt_semantic_interfaces.references import EntityReference, MeasureReference, MetricReference
Expand Down Expand Up @@ -35,6 +34,7 @@
from metricflow_semantics.specs.spec_set import InstanceSpecSet
from metricflow_semantics.time.time_source import TimeSource

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.builder.source_node import SourceNodeBuilder
Expand Down Expand Up @@ -160,7 +160,7 @@ class MetricFlowQueryResult:
query_spec: MetricFlowQuerySpec
dataflow_plan: DataflowPlan
sql: str
result_df: Optional[pd.DataFrame] = None
result_df: Optional[MetricFlowDataTable] = None
result_table: Optional[SqlTable] = None


Expand Down Expand Up @@ -703,7 +703,7 @@ def get_dimension_values( # noqa: D102
time_constraint_end: Optional[datetime.datetime] = None,
) -> List[str]:
# Run query
query_result = self.query(
query_result: MetricFlowQueryResult = self.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=metric_names,
group_by_names=[get_group_by_values],
Expand All @@ -712,22 +712,10 @@ def get_dimension_values( # noqa: D102
query_type=MetricFlowQueryType.DIMENSION_VALUES,
)
)
result_dataframe = query_result.result_df
if result_dataframe is None:
if query_result.result_df is None:
return []

# Snowflake likes upper-casing things in result output, so we lower-case all names
# before operating on the dataframe.
metric_names = [metric_name.lower() for metric_name in metric_names]
result_dataframe.columns = result_dataframe.columns.str.lower()

# Get dimension values regardless of input name -> output dimension mapping. This is necessary befcause
# granularity adjustments on time dimensions produce different output names for dimension values.
# Note: this only works as long as we have exactly one column of group by values
# and no other extraneous output columns
dim_vals = result_dataframe[result_dataframe.columns[~result_dataframe.columns.isin(metric_names)]].iloc[:, 0]

return sorted([str(val) for val in dim_vals])
return sorted([str(val) for val in query_result.result_df.column_values_iterator(0)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice. Do we know for sure the dimension will always come first, or should we get rid of the magic number and do something like query_result.result_df.column_values_iterator(query_result.result_df.column_name_index(get_group_by_values))?

The pandas operation was skipping all of the metric columns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there's a specified order when the SQL is rendered to have the dimension values first

def as_tuple(self) -> Tuple[SqlSelectColumn, ...]:

It would be better to do a lookup, but mapping the name to the output column is not as straightforward as it should be since the output column name can be different from the input.


@log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter)
def explain_get_dimension_values( # noqa: D102
Expand Down
4 changes: 2 additions & 2 deletions metricflow/execution/execution_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple

import pandas as pd
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag, NodeId
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.visitor import Visitable

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.sql_table import SqlTable

Expand Down Expand Up @@ -83,7 +83,7 @@ class TaskExecutionResult:
sql: Optional[str] = None
bind_params: Optional[SqlBindParameters] = None
# If the task produces a dataframe as a result, it's stored here.
df: Optional[pd.DataFrame] = None
df: Optional[MetricFlowDataTable] = None


class SelectSqlQueryToDataFrameTask(ExecutionPlanTask):
Expand Down
4 changes: 2 additions & 2 deletions metricflow/protocols/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Protocol

from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from pandas import DataFrame

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer


Expand Down Expand Up @@ -52,7 +52,7 @@ def query(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
) -> DataFrame:
) -> MetricFlowDataTable:
"""Base query method, upon execution will run a query that returns a pandas DataFrame."""
raise NotImplementedError

Expand Down
41 changes: 12 additions & 29 deletions tests_metricflow/compare_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import pandas as pd

from metricflow.data_table.mf_table import MetricFlowDataTable
from tests_metricflow.sql.compare_data_table import check_data_tables_are_equal

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -43,8 +46,8 @@ def _dataframes_contain_same_data(


def assert_dataframes_equal(
actual: pd.DataFrame,
expected: pd.DataFrame,
actual: MetricFlowDataTable,
expected: MetricFlowDataTable,
sort_columns: bool = True,
allow_empty: bool = False,
compare_names_using_lowercase: bool = False,
Expand All @@ -55,30 +58,10 @@ def assert_dataframes_equal(
If compare_names_using_lowercase is set to True, we copy the dataframes and lower-case their names.
This is useful for Snowflake query output comparisons.
"""
if compare_names_using_lowercase:
actual = actual.copy()
expected = expected.copy()
actual.columns = actual.columns.str.lower()
expected.columns = expected.columns.str.lower()

if set(actual.columns) != set(expected.columns):
raise ValueError(
f"DataFrames do not contain the same columns. actual: {set(actual.columns)}, "
f"expected: {set(expected.columns)}"
)

if not allow_empty and actual.shape[0] == 0 and expected.shape[0] == 0:
raise AssertionError("Both dataframes have no rows; likely there is a mistake with the test")

if sort_columns:
sort_by = list(sorted(actual.columns.tolist()))
expected = expected.loc[:, sort_by].sort_values(sort_by).reset_index(drop=True)
actual = actual.loc[:, sort_by].sort_values(sort_by).reset_index(drop=True)

if not _dataframes_contain_same_data(actual=actual, expected=expected):
raise ValueError(
f"Dataframes not equal.\n"
f"Expected:\n{expected.to_markdown(index=False)}"
"\n---\n"
f"Actual:\n{actual.to_markdown(index=False)}"
)
check_data_tables_are_equal(
expected_table=expected,
actual_table=actual,
ignore_order=sort_columns,
allow_empty=allow_empty,
compare_column_names_using_lowercase=compare_names_using_lowercase,
)
14 changes: 7 additions & 7 deletions tests_metricflow/execution/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import pandas as pd
from metricflow_semantics.dag.mf_dag import DagId
from metricflow_semantics.random_id import random_id
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.execution.execution_plan import (
ExecutionPlan,
SelectSqlQueryToDataFrameTask,
Expand All @@ -29,9 +29,9 @@ def test_read_sql_task(sql_client: SqlClient) -> None: # noqa: D103

assert_dataframes_equal(
actual=task_result.df,
expected=pd.DataFrame(
columns=["foo"],
data=[(1,)],
expected=MetricFlowDataTable.create_from_rows(
column_names=["foo"],
rows=[(1,)],
),
compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE,
)
Expand All @@ -55,9 +55,9 @@ def test_write_table_task( # noqa: D103

assert_dataframes_equal(
actual=sql_client.query(f"SELECT * FROM {output_table.sql}"),
expected=pd.DataFrame(
columns=["foo"],
data=[(1,)],
expected=MetricFlowDataTable.create_from_rows(
column_names=["foo"],
rows=[(1,)],
),
compare_names_using_lowercase=sql_client.sql_engine_type is SqlEngine.SNOWFLAKE,
)
Expand Down
47 changes: 22 additions & 25 deletions tests_metricflow/fixtures/sql_clients/adapter_backed_ddl_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import datetime
import logging
import time
from typing import Optional

import pandas as pd

from dbt_metricflow.cli.dbt_connectors.adapter_backed_client import AdapterBackedSqlClient
from metricflow.data_table.mf_column import ColumnDescription
from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.sql_table import SqlTable

Expand All @@ -19,7 +20,7 @@ class AdapterBackedDDLSqlClient(AdapterBackedSqlClient):
def create_table_from_dataframe(
self,
sql_table: SqlTable,
df: pd.DataFrame,
df: MetricFlowDataTable,
chunk_size: Optional[int] = None,
) -> None:
"""Create a table in the data warehouse containing the contents of the dataframe.
Expand All @@ -31,21 +32,16 @@ def create_table_from_dataframe(
df: The Pandas DataFrame object containing the column schema and data to load
chunk_size: The number of rows to insert per transaction
"""
logger.info(f"Creating table '{sql_table.sql}' from a DataFrame with {df.shape[0]} row(s)")
logger.info(f"Creating table '{sql_table.sql}' from a DataFrame with {df.row_count} row(s)")
start_time = time.time()

with self._adapter.connection_named("MetricFlow_create_from_dataframe"):
# Create table
# update dtypes to convert None to NA in boolean columns.
# This mirrors the SQLAlchemy schema detection logic in pandas.io.sql
df = df.convert_dtypes()
columns = df.columns

columns_to_insert = []
for i in range(len(df.columns)):
for column_description in df.column_descriptions:
# Format as "column_name column_type"
columns_to_insert.append(
f"{columns[i]} {self._get_type_from_pandas_dtype(str(df[columns[i]].dtype).lower())}"
)
columns_to_insert.append(f"{column_description.column_name} {self._get_sql_type(column_description)}")

self._adapter.execute(
f"CREATE TABLE IF NOT EXISTS {sql_table.sql} ({', '.join(columns_to_insert)})",
auto_begin=True,
Expand All @@ -55,18 +51,18 @@ def create_table_from_dataframe(

# Insert rows
values = []
for row in df.itertuples(index=False, name=None):
for row in df.rows:
cells = []
for cell in row:
if pd.isnull(cell):
if cell is None:
# use null keyword instead of isNA/None/etc.
cells.append("null")
elif type(cell) in [str, pd.Timestamp]:
elif type(cell) in [str, datetime.datetime]:
# Wrap cell in quotes & escape existing single quotes
escaped_cell = self._quote_escape_value(str(cell))
# Trino requires timestamp literals to be wrapped in a timestamp() function.
# There is probably a better way to handle this.
if self.sql_engine_type is SqlEngine.TRINO and type(cell) is pd.Timestamp:
if self.sql_engine_type is SqlEngine.TRINO and type(cell) is datetime.datetime:
cells.append(f"timestamp '{escaped_cell}'")
else:
cells.append(f"'{escaped_cell}'")
Expand All @@ -88,30 +84,31 @@ def create_table_from_dataframe(
# Commit all insert transaction at once
self._adapter.commit_if_has_connection()

logger.info(f"Created table '{sql_table.sql}' from a DataFrame in {time.time() - start_time:.2f}s")
logger.info(f"Created SQL table '{sql_table.sql}' from an in-memory table in {time.time() - start_time:.2f}s")

def _get_type_from_pandas_dtype(self, dtype: str) -> str:
def _get_sql_type(self, column_description: ColumnDescription) -> str:
"""Helper method to get the engine-specific type value.

The dtype dict here is non-exhaustive but should be adequate for our needs.
"""
# TODO: add type handling for string/bool/bigint types for all engines
if dtype == "string" or dtype == "object":
column_type = column_description.column_type
if column_type is str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is so much better than the magic string comparisons....

if self.sql_engine_type is SqlEngine.DATABRICKS or self.sql_engine_type is SqlEngine.BIGQUERY:
return "string"
if self.sql_engine_type is SqlEngine.TRINO:
return "varchar"
return "text"
elif dtype == "boolean" or dtype == "bool":
elif column_type is bool:
return "boolean"
elif dtype == "int64":
elif column_type is int:
return "bigint"
elif dtype == "float64":
elif column_type is float:
return self._sql_query_plan_renderer.expr_renderer.double_data_type
elif dtype == "datetime64[ns]":
elif column_type is datetime.datetime:
return self._sql_query_plan_renderer.expr_renderer.timestamp_data_type
else:
raise ValueError(f"Encountered unexpected Pandas dtype ({dtype})!")
raise ValueError(f"Encountered unexpected {column_type=}!")

def _quote_escape_value(self, value: str) -> str:
"""Escape single quotes in string-like values for create_table_from_dataframe.
Expand Down
5 changes: 2 additions & 3 deletions tests_metricflow/fixtures/sql_clients/ddl_sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from abc import abstractmethod
from typing import Optional, Protocol

from pandas import DataFrame

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.sql_table import SqlTable

Expand All @@ -24,7 +23,7 @@ class SqlClientWithDDLMethods(SqlClient, Protocol):
def create_table_from_dataframe(
self,
sql_table: SqlTable,
df: DataFrame,
df: MetricFlowDataTable,
chunk_size: Optional[int] = None,
) -> None:
"""Creates a table and populates it with the contents of the dataframe.
Expand Down
Loading