From 19beaba1b942802f7aa1f3e795521b32900fc951 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Sun, 2 Jun 2024 22:29:29 -0700 Subject: [PATCH] Address comments. --- metricflow/data_table/mf_table.py | 7 +------ tests_metricflow/sql/compare_data_table.py | 10 ++++++---- tests_metricflow/sql_clients/test_sql_client.py | 9 ++++++--- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/metricflow/data_table/mf_table.py b/metricflow/data_table/mf_table.py index 5634d1680d..e7136a420f 100644 --- a/metricflow/data_table/mf_table.py +++ b/metricflow/data_table/mf_table.py @@ -95,8 +95,6 @@ def column_values_iterator(self, column_index: int) -> Iterator[CellValue]: return (row[column_index] for row in self.rows) def _sorted_by_column_name(self) -> MetricFlowDataTable: # noqa: D102 - # row_dict_by_row_index: Dict[int, Dict[str, CellType]] = defaultdict(dict) - new_rows: List[List[CellValue]] = [[] for _ in range(self.row_count)] sorted_column_names = sorted(self.column_names) for column_name in sorted_column_names: @@ -142,10 +140,7 @@ def text_format(self, float_decimals: int = 2) -> str: continue if isinstance(cell_value, datetime.datetime): - if cell_value.time() == datetime.time.min: - str_row.append(cell_value.date().isoformat()) - else: - str_row.append(cell_value.isoformat()) + str_row.append(cell_value.isoformat()) continue str_row.append(str(cell_value)) diff --git a/tests_metricflow/sql/compare_data_table.py b/tests_metricflow/sql/compare_data_table.py index c33272789a..88d058315e 100644 --- a/tests_metricflow/sql/compare_data_table.py +++ b/tests_metricflow/sql/compare_data_table.py @@ -94,14 +94,16 @@ def check_data_tables_are_equal( This was migrated from an existing implementation based on `pandas` data_tables. """ - if ignore_order: - expected_table = expected_table.sorted() - actual_table = actual_table.sorted() - if compare_column_names_using_lowercase: expected_table = expected_table.with_lower_case_column_names() actual_table = actual_table.with_lower_case_column_names() + # Sort after case changes since the order can change after a case change. e.g. underscore comes + # before lowercase. + if ignore_order: + expected_table = expected_table.sorted() + actual_table = actual_table.sorted() + if expected_table.column_names != actual_table.column_names: raise ValueError( mf_pformat_many( diff --git a/tests_metricflow/sql_clients/test_sql_client.py b/tests_metricflow/sql_clients/test_sql_client.py index 65bfe34507..a6a2ac23cb 100644 --- a/tests_metricflow/sql_clients/test_sql_client.py +++ b/tests_metricflow/sql_clients/test_sql_client.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Set, Union +from typing import Optional, Set, Union import pytest from dbt_semantic_interfaces.test_utils import as_datetime @@ -26,13 +26,16 @@ def _select_x_as_y(x: int = 1, y: str = "y") -> str: return f"SELECT {x} AS {y}" -def _check_1col(df: MetricFlowDataTable, col: str = "y", vals: Set[Union[int, str]] = {1}) -> None: +def _check_1col(df: MetricFlowDataTable, col: str = "y", vals: Optional[Set[Union[int, str]]] = None) -> None: """Helper to check that 1 column has the same value and a case-insensitive matching name. We lower-case the names due to snowflake's tendency to capitalize things. This isn't ideal but it'll do for now. """ + if vals is None: + vals = {1} + assert df.column_count == 1 - assert df.column_names == (col,) + assert tuple(column_name.lower() for column_name in df.column_names) == (col.lower(),) assert set(df.column_values_iterator(0)) == vals