From a81af7fbc2919c22e9d9bca478c9c022f1704c1f Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Mon, 27 May 2024 17:53:05 -0700 Subject: [PATCH] /* PR_START p--py312 03 */ Add replacement class for dataframes. --- metricflow/data_table/__init__.py | 0 metricflow/data_table/column_types.py | 22 ++ metricflow/data_table/mf_column.py | 39 +++ metricflow/data_table/mf_table.py | 279 +++++++++++++++++++++ tests_metricflow/sql/compare_data_table.py | 159 ++++++++++++ tests_metricflow/sql/test_data_table.py | 124 +++++++++ 6 files changed, 623 insertions(+) create mode 100644 metricflow/data_table/__init__.py create mode 100644 metricflow/data_table/column_types.py create mode 100644 metricflow/data_table/mf_column.py create mode 100644 metricflow/data_table/mf_table.py create mode 100644 tests_metricflow/sql/compare_data_table.py create mode 100644 tests_metricflow/sql/test_data_table.py diff --git a/metricflow/data_table/__init__.py b/metricflow/data_table/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metricflow/data_table/column_types.py b/metricflow/data_table/column_types.py new file mode 100644 index 0000000000..fb4bbf1b4f --- /dev/null +++ b/metricflow/data_table/column_types.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import datetime +import decimal +from typing import Sequence, Tuple, Type, Union + +# Types supported by `MetricFlowDataTable`. +CellValue = Union[float, str, datetime.datetime, bool, None] +# Types supported as inputs when building a `MetricFlowDataTable`. These inputs will get converted into +# one of the `CellValue` types. +InputCellValue = Union[int, float, str, datetime.datetime, bool, None, decimal.Decimal, datetime.date] +MetricflowNoneType = type(None) + + +def row_cell_types(row: Sequence[CellValue]) -> Tuple[Type[CellValue], ...]: + """Return the cell type / column type for the objects in the row.""" + return tuple(cell_type(cell) for cell in row) + + +def cell_type(cell_value: CellValue) -> Type[CellValue]: + """Return the cell type / column type for the object in the cell.""" + return type(cell_value) diff --git a/metricflow/data_table/mf_column.py b/metricflow/data_table/mf_column.py new file mode 100644 index 0000000000..2c7d61508b --- /dev/null +++ b/metricflow/data_table/mf_column.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from dataclasses import dataclass +from functools import cached_property +from typing import Iterator, Sequence, Tuple, Type + +from metricflow.data_table.column_types import CellValue + + +@dataclass(frozen=True) +class ColumnDescription: + """Describes a single column in a data table.""" + + column_name: str + column_type: Type[CellValue] + + def with_lower_case_column_name(self) -> ColumnDescription: # noqa: D102 + return ColumnDescription( + column_name=self.column_name.lower(), + column_type=self.column_type, + ) + + +@dataclass(frozen=True) +class ColumnDescriptionSet: + """Describes a collection of columns in a data table.""" + + column_descriptions: Tuple[ColumnDescription, ...] + + def __iter__(self) -> Iterator[ColumnDescription]: # noqa: D105 + return iter(self.column_descriptions) + + @cached_property + def column_names(self) -> Sequence[str]: # noqa: D102 + return tuple(column_description.column_name for column_description in self.column_descriptions) + + @cached_property + def column_types(self) -> Sequence[Type]: # noqa: D102 + return tuple(column_description.column_type for column_description in self.column_descriptions) diff --git a/metricflow/data_table/mf_table.py b/metricflow/data_table/mf_table.py new file mode 100644 index 0000000000..5634d1680d --- /dev/null +++ b/metricflow/data_table/mf_table.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +import datetime +import itertools +import logging +from dataclasses import dataclass +from decimal import Decimal +from typing import Iterable, Iterator, List, Optional, Sequence, Tuple, Type + +import tabulate +from metricflow_semantics.mf_logging.formatting import indent +from metricflow_semantics.mf_logging.pretty_print import mf_pformat, mf_pformat_many +from typing_extensions import Self + +from metricflow.data_table.column_types import CellValue, InputCellValue, row_cell_types +from metricflow.data_table.mf_column import ColumnDescription + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, eq=False) +class MetricFlowDataTable: + """Container for tabular data stored in memory. + + This is feature-limited and is used to pass tabular data for tests and the CLI. The only data types that are + supported are described by `CellValue`. + + When constructing the table, additional input types (as described by `InputCellValue`) can be used, but those + additional types will be converted into one of the `CellValue` types. + + Don't use `=` to compare tables as there many be NaNs. Instead, use `check_data_tables_are_equal`. + """ + + column_descriptions: Tuple[ColumnDescription, ...] + rows: Tuple[Tuple[CellValue, ...], ...] + + def __post_init__(self) -> None: # noqa: D105 + expected_column_count = self.column_count + for row_index, row in enumerate(self.rows): + # Check that the number of columns in the rows match. + row_column_count = len(row) + assert row_column_count == expected_column_count, ( + f"Row at index {row_index} has {row_column_count} columns instead of {expected_column_count}. " + f"Row is:" + f"\n{indent(mf_pformat(row))}" + ) + # Check that the type of the object in the rows match. + for column_index, cell_value in enumerate(row): + expected_cell_value_type = self.column_descriptions[column_index].column_type + assert cell_value is None or isinstance(cell_value, expected_cell_value_type), mf_pformat_many( + "Cell value type mismatch.", + { + "row_index": row_index, + "column_index": column_index, + "expected_cell_value_type": expected_cell_value_type, + "actual_cell_value_type": type(cell_value), + "cell_value": cell_value, + }, + ) + # Check that datetimes don't have a timezone set. + if isinstance(cell_value, datetime.datetime): + assert cell_value.tzinfo is None, mf_pformat_many( + "Time zone provided for datetime.", + { + "row_index": row_index, + "column_index": column_index, + "cell_value": cell_value, + }, + ) + + @property + def column_count(self) -> int: # noqa: D102 + return len(self.column_descriptions) + + @property + def row_count(self) -> int: # noqa: D102 + return len(self.rows) + + def column_name_index(self, column_name: str) -> int: + """Return the index of the column that matches the given name. Raises `ValueError` if the name is invalid.""" + for i, column_description in enumerate(self.column_descriptions): + if column_description.column_name == column_name: + return i + raise ValueError( + f"Unknown column name {repr(column_name)}. Known column names are:" + f"\n{indent(mf_pformat([column_name for column_name in self.column_descriptions]))}" + ) + + @property + def column_names(self) -> Sequence[str]: # noqa: D102 + return tuple(column_description.column_name for column_description in self.column_descriptions) + + def column_values_iterator(self, column_index: int) -> Iterator[CellValue]: + """Returns an iterator for values of the column at the tiven index.""" + return (row[column_index] for row in self.rows) + + def _sorted_by_column_name(self) -> MetricFlowDataTable: # noqa: D102 + # row_dict_by_row_index: Dict[int, Dict[str, CellType]] = defaultdict(dict) + + new_rows: List[List[CellValue]] = [[] for _ in range(self.row_count)] + sorted_column_names = sorted(self.column_names) + for column_name in sorted_column_names: + old_column_index = self.column_name_index(column_name) + for row_index, cell_value in enumerate(self.column_values_iterator(old_column_index)): + new_rows[row_index].append(cell_value) + + return MetricFlowDataTable( + column_descriptions=tuple( + self.column_descriptions[self.column_name_index(column_name)] for column_name in sorted_column_names + ), + rows=tuple( + tuple(row_dict[column_index] for column_index in range(self.column_count)) for row_dict in new_rows + ), + ) + + def _sorted_by_row(self) -> MetricFlowDataTable: # noqa: D102 + def _cell_sort_key(cell: CellValue) -> str: + if isinstance(cell, datetime.datetime): + return cell.isoformat() + return str(cell) + + def _row_sort_key(row: Tuple[CellValue, ...]) -> Tuple[str, ...]: + return tuple(_cell_sort_key(cell) for cell in row) + + return MetricFlowDataTable( + column_descriptions=self.column_descriptions, + rows=tuple(sorted((row for row in self.rows), key=_row_sort_key)), + ) + + def sorted(self) -> MetricFlowDataTable: + """Returns this but with the columns in order by name, and the rows in order by values.""" + return self._sorted_by_column_name()._sorted_by_row() + + def text_format(self, float_decimals: int = 2) -> str: + """Return a text version of this table that is suitable for printing.""" + str_rows: List[List[str]] = [] + for row in self.rows: + str_row: List[str] = [] + for cell_value in row: + if isinstance(cell_value, float): + str_row.append(f"{cell_value:.{float_decimals}f}") + continue + + if isinstance(cell_value, datetime.datetime): + if cell_value.time() == datetime.time.min: + str_row.append(cell_value.date().isoformat()) + else: + str_row.append(cell_value.isoformat()) + continue + + str_row.append(str(cell_value)) + str_rows.append(str_row) + return tabulate.tabulate( + tabular_data=tuple(tuple(str_row) for str_row in str_rows), + headers=tuple(column_description.column_name for column_description in self.column_descriptions), + ) + + def with_lower_case_column_names(self) -> MetricFlowDataTable: + """Return this but with columns names in lowercase.""" + return MetricFlowDataTable( + column_descriptions=tuple( + column_description.with_lower_case_column_name() for column_description in self.column_descriptions + ), + rows=self.rows, + ) + + def get_cell_value(self, row_index: int, column_index: int) -> CellValue: # noqa: D102 + return self.rows[row_index][column_index] + + @staticmethod + def create_from_rows( # noqa: D102 + column_names: Sequence[str], rows: Iterable[Sequence[InputCellValue]] + ) -> MetricFlowDataTable: + builder = _MetricFlowDataTableBuilder(column_names) + for row in rows: + builder.add_row(row) + return builder.build() + + +class _MetricFlowDataTableBuilder: + """Helps build `MetricFlowDataTable`, one row at a time. + + This validates each row as it is input to give better error messages. + """ + + def __init__(self, column_names: Sequence[str]) -> None: # noqa: D107 + self._rows: List[Tuple[CellValue, ...]] = [] + self._column_names = tuple(column_names) + + def _build_table_from_rows(self) -> MetricFlowDataTable: # noqa: D102 + # Figure out the type of the column based on the types of the values in the rows. + # Can't use the type of the columns in the first row because it might contain None, so iterate through the rows + # and use the first non-None type. + column_types_so_far: Optional[Tuple[Type[CellValue], ...]] = None + cell_value: CellValue + for row in self._rows: + if column_types_so_far is None: + column_types_so_far = row_cell_types(row) + continue + + # If the types of the objects in the row are the same, no need for updates. + row_column_types = row_cell_types(row) + if row_column_types == column_types_so_far: + continue + + # Types of objects in the row are different from what's known so far. + # They can only be different in that one can be None and other can be not None. + updated_column_types: List[Type[CellValue]] = [] + for column_type_so_far, cell_type, cell_value in itertools.zip_longest( + column_types_so_far, row_column_types, row + ): + if column_type_so_far is cell_type: + updated_column_types.append(column_type_so_far) + elif column_type_so_far is type(None) and cell_type is not None: + updated_column_types.append(cell_type) + elif column_type_so_far is not None and cell_type is type(None): + updated_column_types.append(column_type_so_far) + else: + raise ValueError(f"Expected cell type {column_type_so_far} but got: {cell_type}") + + column_types_so_far = tuple(updated_column_types) + + # Empty table case. + if column_types_so_far is None: + column_types_so_far = tuple(type(None) for _ in range(len(self._column_names))) + + final_column_types = column_types_so_far + + return MetricFlowDataTable( + column_descriptions=tuple( + ColumnDescription(column_name=column_name, column_type=column_type) + for column_name, column_type in itertools.zip_longest(self._column_names, final_column_types) + ), + rows=tuple(self._rows), + ) + + def _convert_row_to_supported_types(self, row: Sequence[InputCellValue]) -> Sequence[CellValue]: + """Since only a limited set of types are supported, convert the input type to the supported type.""" + updated_row: List[CellValue] = [] + for cell_value in row: + if ( + cell_value is None + or isinstance(cell_value, float) + or isinstance(cell_value, bool) + or isinstance(cell_value, int) + or isinstance(cell_value, str) + ): + updated_row.append(cell_value) + continue + + if isinstance(cell_value, datetime.datetime): + updated_row.append(cell_value.replace(tzinfo=None)) + continue + + if isinstance(cell_value, Decimal): + updated_row.append(float(cell_value)) + continue + + if isinstance(cell_value, datetime.date): + updated_row.append(datetime.datetime.combine(cell_value, datetime.datetime.min.time())) + continue + + raise ValueError(f"Row cell has unexpected type: {repr(cell_value)}") + return updated_row + + def add_row(self, row: Sequence[InputCellValue], parse_strings: bool = False) -> Self: # noqa: D102 + row = tuple(row) + expected_column_count = len(self._column_names) + actual_column_count = len(row) + if actual_column_count != expected_column_count: + raise ValueError( + f"Input row has {actual_column_count} columns, but expected {expected_column_count} columns. Row is:" + f"\n{indent(mf_pformat(row))}" + ) + self._rows.append(tuple(self._convert_row_to_supported_types(row))) + return self + + def build(self) -> MetricFlowDataTable: # noqa: D102 + return self._build_table_from_rows() diff --git a/tests_metricflow/sql/compare_data_table.py b/tests_metricflow/sql/compare_data_table.py new file mode 100644 index 0000000000..ece46f82a0 --- /dev/null +++ b/tests_metricflow/sql/compare_data_table.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import datetime +import difflib +import math +from dataclasses import dataclass +from typing import Dict, Optional, SupportsFloat + +from metricflow_semantics.mf_logging.pretty_print import mf_pformat_many + +from metricflow.data_table.column_types import CellValue +from metricflow.data_table.mf_table import MetricFlowDataTable + + +def _generate_table_diff_fields( + expected_table: MetricFlowDataTable, actual_table: MetricFlowDataTable +) -> Dict[str, str]: + differ = difflib.Differ() + expected_table_text = expected_table.text_format() + actual_table_text = actual_table.text_format() + diff = differ.compare(expected_table_text.splitlines(keepends=True), actual_table_text.splitlines(keepends=True)) + return { + "expected_table": expected_table_text, + "actual_table": actual_table_text, + "expected_table_to_actual_table_diff": "".join(diff), + } + + +@dataclass(frozen=True) +class DataTableMismatch: + """Describes a mismatch in a cell between two tables.""" + + message: str + row_index: int + column_index: int + expected_value: CellValue + actual_value: CellValue + + +def _check_table_cells_for_mismatch( + expected_table: MetricFlowDataTable, actual_table: MetricFlowDataTable +) -> Optional[DataTableMismatch]: + for row_index in range(expected_table.row_count): + for column_index in range(expected_table.column_count): + # NaNs can't be compared for equality. + expected_value = expected_table.get_cell_value(row_index, column_index) + actual_value = actual_table.get_cell_value(row_index, column_index) + if isinstance(expected_value, SupportsFloat) and isinstance(actual_value, SupportsFloat): + if math.isnan(expected_value) and math.isnan(actual_value): + continue + if not math.isclose(expected_value, actual_value, rel_tol=1e-6): + return DataTableMismatch( + message="`SupportsFloat` value mismatch", + row_index=row_index, + column_index=column_index, + expected_value=expected_value, + actual_value=actual_value, + ) + # It should be safe to remove this once `MetricFlowDataTable` is validated as it doesn't allow timezones. + elif ( + isinstance(expected_value, datetime.datetime) + and isinstance(actual_value, datetime.datetime) + # If expected has no tz but actual is UTC, remove timezone. Some engines add UTC by default. + and (actual_value.tzinfo == "UTC" and expected_value.tzinfo is None) + ): + if expected_value.replace(tzinfo=None) != actual_value.replace(tzinfo=None): + return DataTableMismatch( + message="`datetime` value mismatch", + row_index=row_index, + column_index=column_index, + expected_value=expected_value, + actual_value=actual_value, + ) + elif expected_value != actual_value: + return DataTableMismatch( + message="Value mismatch", + row_index=row_index, + column_index=column_index, + expected_value=expected_value, + actual_value=actual_value, + ) + + return None + + +def check_data_tables_are_equal( + expected_table: MetricFlowDataTable, + actual_table: MetricFlowDataTable, + ignore_order: bool = True, + allow_empty: bool = False, + compare_column_names_using_lowercase: bool = False, +) -> None: + """Check if this is equal to another table. If not, raise an exception. + + This was migrated from an existing implementation based on `pandas` dataframes. + """ + if ignore_order: + expected_table = expected_table.sorted() + actual_table = actual_table.sorted() + + if compare_column_names_using_lowercase: + expected_table = expected_table.with_lower_case_column_names() + actual_table = actual_table.with_lower_case_column_names() + + if expected_table.column_names != actual_table.column_names: + raise ValueError( + mf_pformat_many( + "Column descriptions do not match.", + { + "expected_table_column_names": expected_table.column_names, + "actual_table_column_names": actual_table.column_names, + }, + ) + ) + + if expected_table.row_count != actual_table.row_count: + raise ValueError( + mf_pformat_many( + "Row counts do not match.", + dict( + **{ + "expected_table_row_count": expected_table.row_count, + "actual_table_row_count": actual_table.row_count, + }, + **_generate_table_diff_fields(expected_table=expected_table, actual_table=actual_table), + ), + preserve_raw_strings=True, + ) + ) + + if not allow_empty and expected_table.row_count == 0: + raise ValueError( + mf_pformat_many( + f"Expected table is empty and {allow_empty=}. This may indicate an error in configuring the test.", + _generate_table_diff_fields(expected_table=expected_table, actual_table=actual_table), + preserve_raw_strings=True, + ) + ) + + mismatch = _check_table_cells_for_mismatch(expected_table=expected_table, actual_table=actual_table) + + if mismatch is not None: + raise ValueError( + mf_pformat_many( + mismatch.message, + dict( + **{ + "row_index": mismatch.row_index, + "column_index": mismatch.column_index, + "expected_value": mismatch.expected_value, + "actual_value": mismatch.actual_value, + }, + **_generate_table_diff_fields(expected_table=expected_table, actual_table=actual_table), + ), + preserve_raw_strings=True, + ) + ) + + return diff --git a/tests_metricflow/sql/test_data_table.py b/tests_metricflow/sql/test_data_table.py new file mode 100644 index 0000000000..c263e951eb --- /dev/null +++ b/tests_metricflow/sql/test_data_table.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import logging +from decimal import Decimal + +import pytest + +from metricflow.data_table.mf_table import MetricFlowDataTable +from tests_metricflow.sql.compare_data_table import check_data_tables_are_equal + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def example_table() -> MetricFlowDataTable: # noqa: D103 + return MetricFlowDataTable.create_from_rows( + column_names=["col_0", "col_1"], + rows=[ + (0, "a"), + (1, "b"), + ], + ) + + +def test_properties(example_table: MetricFlowDataTable) -> None: # noqa: D103 + assert example_table.column_count == 2 + assert example_table.row_count == 2 + assert tuple(example_table.column_names) == ("col_0", "col_1") + + +def test_input_type(example_table: MetricFlowDataTable) -> None: # noqa: D103 + table_from_decimals = MetricFlowDataTable.create_from_rows( + column_names=["col_0", "col_1"], + rows=[ + (Decimal(0), "a"), + (Decimal(1), "b"), + ], + ) + + table_from_floats = MetricFlowDataTable.create_from_rows( + column_names=["col_0", "col_1"], + rows=[ + (0.0, "a"), + (1.0, "b"), + ], + ) + + assert table_from_decimals.rows == table_from_floats.rows + + +def test_invalid_row_length() -> None: # noqa: D103 + with pytest.raises(ValueError): + MetricFlowDataTable.create_from_rows( + column_names=["col_0", "col_1"], + rows=( + (1, "a"), + (2,), + ), + ) + + +def test_invalid_cell_type(example_table: MetricFlowDataTable) -> None: # noqa: D103 + with pytest.raises(ValueError): + MetricFlowDataTable.create_from_rows( + column_names=["col_0", "col_1"], + rows=( + (1, "a"), + (2, 1.0), + ), + ) + + +def test_column_name_index(example_table: MetricFlowDataTable) -> None: # noqa: D103 + assert example_table.column_name_index("col_0") == 0 + assert example_table.column_name_index("col_1") == 1 + with pytest.raises(ValueError): + example_table.column_name_index("invalid_index") + + +def test_sorted() -> None: # noqa: D103: + expected_table = MetricFlowDataTable.create_from_rows( + column_names=["a", "b"], + rows=[(0, 1), (2, 3)], + ) + actual_table = MetricFlowDataTable.create_from_rows( + column_names=["b", "a"], + rows=[(3, 2), (1, 0)], + ).sorted() + + check_data_tables_are_equal( + expected_table=expected_table, + actual_table=actual_table, + ignore_order=False, + ) + + +def test_mismatch() -> None: # noqa: D103: + expected_table = MetricFlowDataTable.create_from_rows( + column_names=["a", "b"], + rows=[(0, 1), (2, 3)], + ) + actual_table = MetricFlowDataTable.create_from_rows( + column_names=["a", "b"], + rows=[(0, 1), (2, 4)], + ) + + with pytest.raises(ValueError): + check_data_tables_are_equal( + expected_table=expected_table, + actual_table=actual_table, + ignore_order=False, + ) + + +def test_get_cell_value(example_table: MetricFlowDataTable) -> None: # noqa: D103 + assert example_table.get_cell_value(0, 0) == 0 + assert example_table.get_cell_value(0, 1) == "a" + assert example_table.get_cell_value(1, 0) == 1 + assert example_table.get_cell_value(1, 1) == "b" + + +def test_column_values_iterator(example_table: MetricFlowDataTable) -> None: # noqa: D103 + assert tuple(example_table.column_values_iterator(0)) == (0, 1) + assert tuple(example_table.column_values_iterator(1)) == ("a", "b")