Skip to content

Commit

Permalink
Update SnowflakeInferenceContextProvider to use correct types.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed May 31, 2024
1 parent b6a313a commit a3c2110
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions metricflow/inference/context/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.inference.context.data_warehouse import (
ColumnProperties,
DataWarehouseInferenceContextProvider,
Expand Down Expand Up @@ -73,6 +74,22 @@ def _get_select_list_for_column_name(self, name: str, count_nulls: bool) -> str:

return ", ".join(statements)

def _get_one_int(self, data_table: MetricFlowDataTable, column_name: str) -> int:
if len(data_table.rows) == 0:
raise ValueError("No rows in the data table")

return_value = data_table.get_cell_value(0, data_table.column_name_index(column_name))

if isinstance(return_value, int):
return int(return_value)
raise RuntimeError(f"Unhandled case {return_value=}")

def _get_one_str(self, data_table: MetricFlowDataTable, column_name: str) -> str:
if len(data_table.rows) == 0:
raise ValueError("No rows in the data table")

return str(data_table.get_cell_value(0, data_table.column_name_index(column_name)))

def _get_table_properties(self, table: SqlTable) -> TableProperties:
all_columns_query = f"SHOW COLUMNS IN TABLE {table.sql}"
all_columns = self._client.query(all_columns_query)
Expand All @@ -82,16 +99,16 @@ def _get_table_properties(self, table: SqlTable) -> TableProperties:
col_nullable = {}
select_lists = []

for row in all_columns.itertuples():
for row in all_columns.rows:
column = SqlColumn.from_names(
db_name=row.database_name.lower(),
schema_name=row.schema_name.lower(),
table_name=row.table_name.lower(),
column_name=row.column_name.lower(),
db_name=str(row[all_columns.column_name_index("database_name")]).lower(),
schema_name=str(row[all_columns.column_name_index("schema_name")]).lower(),
table_name=str(row[all_columns.column_name_index("table_name")]).lower(),
column_name=str(row[all_columns.column_name_index("column_name")]).lower(),
)
sql_column_list.append(column)

type_dict = json.loads(row.data_type)
type_dict = json.loads(str(row[all_columns.column_name_index("data_type")]))
col_types[column] = self._column_type_from_show_columns_data_type(type_dict["type"])
col_nullable[column] = type_dict["nullable"]
select_lists.append(
Expand All @@ -104,22 +121,30 @@ def _get_table_properties(self, table: SqlTable) -> TableProperties:
select_lists.append("COUNT(*) AS rowcount")
select_list = ", ".join(select_lists)
statistics_query = f"SELECT {select_list} FROM {table.sql} SAMPLE ({self.max_sample_size} ROWS)"
statistics_df = self._client.query(statistics_query)
statistics_data_table = self._client.query(statistics_query)

column_props = [
ColumnProperties(
column=column,
type=col_types[column],
is_nullable=col_nullable[column],
null_count=statistics_df[f"{column.column_name}_{SnowflakeInferenceContextProvider.COUNT_NULL_SUFFIX}"][
0
],
row_count=statistics_df["rowcount"][0],
distinct_row_count=statistics_df[
f"{column.column_name}_{SnowflakeInferenceContextProvider.COUNT_DISTINCT_SUFFIX}"
][0],
min_value=statistics_df[f"{column.column_name}_{SnowflakeInferenceContextProvider.MIN_SUFFIX}"][0],
max_value=statistics_df[f"{column.column_name}_{SnowflakeInferenceContextProvider.MAX_SUFFIX}"][0],
null_count=self._get_one_int(
statistics_data_table,
f"{column.column_name}_{SnowflakeInferenceContextProvider.COUNT_NULL_SUFFIX}",
),
row_count=self._get_one_int(statistics_data_table, "rowcount"),
distinct_row_count=self._get_one_int(
statistics_data_table,
f"{column.column_name}_{SnowflakeInferenceContextProvider.COUNT_DISTINCT_SUFFIX}",
),
min_value=self._get_one_str(
statistics_data_table,
f"{column.column_name}_{SnowflakeInferenceContextProvider.MIN_SUFFIX}",
),
max_value=self._get_one_str(
statistics_data_table,
f"{column.column_name}_{SnowflakeInferenceContextProvider.MAX_SUFFIX}",
),
)
for column in sql_column_list
]
Expand Down

0 comments on commit a3c2110

Please sign in to comment.