diff --git a/metricflow/inference/context/snowflake.py b/metricflow/inference/context/snowflake.py index 97c3684b71..a245718a10 100644 --- a/metricflow/inference/context/snowflake.py +++ b/metricflow/inference/context/snowflake.py @@ -2,6 +2,7 @@ import json +from metricflow.data_table.mf_table import MetricFlowDataTable from metricflow.inference.context.data_warehouse import ( ColumnProperties, DataWarehouseInferenceContextProvider, @@ -73,6 +74,22 @@ def _get_select_list_for_column_name(self, name: str, count_nulls: bool) -> str: return ", ".join(statements) + def _get_one_int(self, data_table: MetricFlowDataTable, column_name: str) -> int: + if len(data_table.rows) == 0: + raise ValueError("No rows in the data table") + + return_value = data_table.get_cell_value(0, data_table.column_name_index(column_name)) + + if isinstance(return_value, int): + return int(return_value) + raise RuntimeError(f"Unhandled case {return_value=}") + + def _get_one_str(self, data_table: MetricFlowDataTable, column_name: str) -> str: + if len(data_table.rows) == 0: + raise ValueError("No rows in the data table") + + return str(data_table.get_cell_value(0, data_table.column_name_index(column_name))) + def _get_table_properties(self, table: SqlTable) -> TableProperties: all_columns_query = f"SHOW COLUMNS IN TABLE {table.sql}" all_columns = self._client.query(all_columns_query) @@ -82,16 +99,16 @@ def _get_table_properties(self, table: SqlTable) -> TableProperties: col_nullable = {} select_lists = [] - for row in all_columns.itertuples(): + for row in all_columns.rows: column = SqlColumn.from_names( - db_name=row.database_name.lower(), - schema_name=row.schema_name.lower(), - table_name=row.table_name.lower(), - column_name=row.column_name.lower(), + db_name=str(row[all_columns.column_name_index("database_name")]).lower(), + schema_name=str(row[all_columns.column_name_index("schema_name")]).lower(), + table_name=str(row[all_columns.column_name_index("table_name")]).lower(), + column_name=str(row[all_columns.column_name_index("column_name")]).lower(), ) sql_column_list.append(column) - type_dict = json.loads(row.data_type) + type_dict = json.loads(str(row[all_columns.column_name_index("data_type")])) col_types[column] = self._column_type_from_show_columns_data_type(type_dict["type"]) col_nullable[column] = type_dict["nullable"] select_lists.append( @@ -104,22 +121,30 @@ def _get_table_properties(self, table: SqlTable) -> TableProperties: select_lists.append("COUNT(*) AS rowcount") select_list = ", ".join(select_lists) statistics_query = f"SELECT {select_list} FROM {table.sql} SAMPLE ({self.max_sample_size} ROWS)" - statistics_df = self._client.query(statistics_query) + statistics_data_table = self._client.query(statistics_query) column_props = [ ColumnProperties( column=column, type=col_types[column], is_nullable=col_nullable[column], - null_count=statistics_df[f"{column.column_name}_{SnowflakeInferenceContextProvider.COUNT_NULL_SUFFIX}"][ - 0 - ], - row_count=statistics_df["rowcount"][0], - distinct_row_count=statistics_df[ - f"{column.column_name}_{SnowflakeInferenceContextProvider.COUNT_DISTINCT_SUFFIX}" - ][0], - min_value=statistics_df[f"{column.column_name}_{SnowflakeInferenceContextProvider.MIN_SUFFIX}"][0], - max_value=statistics_df[f"{column.column_name}_{SnowflakeInferenceContextProvider.MAX_SUFFIX}"][0], + null_count=self._get_one_int( + statistics_data_table, + f"{column.column_name}_{SnowflakeInferenceContextProvider.COUNT_NULL_SUFFIX}", + ), + row_count=self._get_one_int(statistics_data_table, "rowcount"), + distinct_row_count=self._get_one_int( + statistics_data_table, + f"{column.column_name}_{SnowflakeInferenceContextProvider.COUNT_DISTINCT_SUFFIX}", + ), + min_value=self._get_one_str( + statistics_data_table, + f"{column.column_name}_{SnowflakeInferenceContextProvider.MIN_SUFFIX}", + ), + max_value=self._get_one_str( + statistics_data_table, + f"{column.column_name}_{SnowflakeInferenceContextProvider.MAX_SUFFIX}", + ), ) for column in sql_column_list ]