diff --git a/metricflow/inference/__init__.py b/metricflow/inference/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/metricflow/inference/context/__init__.py b/metricflow/inference/context/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/metricflow/inference/context/base.py b/metricflow/inference/context/base.py deleted file mode 100644 index c0541d5e23..0000000000 --- a/metricflow/inference/context/base.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Generic, TypeVar - - -class InferenceContext(ABC): - """Encapsulates information that can be used by an inference signaling rule or inference policy.""" - - pass - - -TContext = TypeVar("TContext", bound=InferenceContext) - - -class InferenceContextProvider(Generic[TContext], ABC): - """Provides a populated inference context from some SemanticModel.""" - - @abstractmethod - def get_context(self) -> TContext: - """Fetch inference context data and return it.""" - pass diff --git a/metricflow/inference/context/data_warehouse.py b/metricflow/inference/context/data_warehouse.py deleted file mode 100644 index 2e6a30d609..0000000000 --- a/metricflow/inference/context/data_warehouse.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations - -import contextlib -from abc import ABC, abstractmethod -from dataclasses import InitVar, dataclass, field -from datetime import date, datetime -from enum import Enum -from typing import Callable, ContextManager, Dict, Generic, Iterator, List, Optional, TypeVar - -from metricflow.inference.context.base import InferenceContext, InferenceContextProvider -from metricflow.protocols.sql_client import SqlClient -from metricflow.sql.sql_column import SqlColumn -from metricflow.sql.sql_table import SqlTable - -T = TypeVar("T", str, int, float, date, datetime) - - -class InferenceColumnType(str, Enum): - """Represents a column type that can be used for inference. - - This does not provide a 1 to 1 mapping between SQL types and enum values. For example, - all possible floating point types (FLOAT, DOUBLE etc) are mapped to the same FLOAT - value. Same for datetimes and others. - """ - - STRING = "string" - BOOLEAN = "boolean" - INTEGER = "integer" - FLOAT = "float" - DATETIME = "datetime" - UNKNOWN = "unknown" - - -@dataclass(frozen=True) -class ColumnProperties(Generic[T]): - """Holds properties about a column that were extracted from the data warehouse.""" - - column: SqlColumn - - type: InferenceColumnType - row_count: int - distinct_row_count: int - is_nullable: bool - null_count: int - min_value: Optional[T] - max_value: Optional[T] - - @property - def is_empty(self) -> bool: - """Whether the column has any rows.""" - return self.row_count == 0 - - -@dataclass(frozen=True) -class TableProperties: - """Holds properties of a table and its columns that were extracted from the data warehouse.""" - - column_props: InitVar[List[ColumnProperties]] - - table: SqlTable - columns: Dict[SqlColumn, ColumnProperties] = field(default_factory=lambda: {}, init=False) - - def __post_init__(self, column_props: List[ColumnProperties]) -> None: # noqa: D105 - for col in column_props: - self.columns[col.column] = col - - -@dataclass(frozen=True) -class DataWarehouseInferenceContext(InferenceContext): - """The inference context for a data warehouse. Holds statistics and metadata about each table and column.""" - - table_props: InitVar[List[TableProperties]] - - tables: Dict[SqlTable, TableProperties] = field(default_factory=lambda: {}, init=False) - columns: Dict[SqlColumn, ColumnProperties] = field(default_factory=lambda: {}, init=False) - - def __post_init__(self, table_props: List[TableProperties]) -> None: # noqa: D105 - for stats in table_props: - self.tables[stats.table] = stats - for column in stats.columns.values(): - self.columns[column.column] = column - - -@contextlib.contextmanager -def _default_table_progress(table: SqlTable, index: int, total: int) -> Iterator[None]: - yield - - -class DataWarehouseInferenceContextProvider(InferenceContextProvider[DataWarehouseInferenceContext], ABC): - """Provides inference context from a data warehouse by querying data from its tables.""" - - def __init__(self, client: SqlClient, tables: List[SqlTable], max_sample_size: int = 10000) -> None: - """Initialize the class. - - client: the underlying SQL engine client that will be used for querying table data. - tables: an exhaustive list of all tables that should be queried. - max_sample_size: max number of rows to sample from each table - """ - self._client = client - self.tables = tables - self.max_sample_size = max_sample_size - - @abstractmethod - def _get_table_properties(self, table: SqlTable) -> TableProperties: - """Fetch properties about a single table by querying the warehouse.""" - raise NotImplementedError - - def get_context( - self, - table_progress: Callable[[SqlTable, int, int], ContextManager[None]] = _default_table_progress, - ) -> DataWarehouseInferenceContext: - """Query the data warehouse for statistics about all tables and populate a context with it.""" - table_props_list: List[TableProperties] = [] - for i, table in enumerate(self.tables): - with table_progress(table, i, len(self.tables)): - table_props = self._get_table_properties(table) - table_props_list.append(table_props) - return DataWarehouseInferenceContext(table_props=table_props_list) diff --git a/metricflow/inference/context/snowflake.py b/metricflow/inference/context/snowflake.py deleted file mode 100644 index a245718a10..0000000000 --- a/metricflow/inference/context/snowflake.py +++ /dev/null @@ -1,152 +0,0 @@ -from __future__ import annotations - -import json - -from metricflow.data_table.mf_table import MetricFlowDataTable -from metricflow.inference.context.data_warehouse import ( - ColumnProperties, - DataWarehouseInferenceContextProvider, - InferenceColumnType, - TableProperties, -) -from metricflow.sql.sql_column import SqlColumn -from metricflow.sql.sql_table import SqlTable - - -class SnowflakeInferenceContextProvider(DataWarehouseInferenceContextProvider): - """The snowflake implementation for a DataWarehouseInferenceContextProvider.""" - - COUNT_DISTINCT_SUFFIX = "countdistinct" - COUNT_NULL_SUFFIX = "countnull" - MIN_SUFFIX = "min" - MAX_SUFFIX = "max" - - def _column_type_from_show_columns_data_type(self, type_str: str) -> InferenceColumnType: - """Get the correspondent InferenceColumnType from Snowflake's returned type string. - - See for reference: https://docs.snowflake.com/en/sql-reference/sql/show-columns.html - For string types: https://docs.snowflake.com/en/sql-reference/data-types-text.html#data-types-for-text-strings - """ - type_str = type_str.upper() - type_mapping = { - "FIXED": InferenceColumnType.INTEGER, - "REAL": InferenceColumnType.FLOAT, - "BOOLEAN": InferenceColumnType.BOOLEAN, - "DATE": InferenceColumnType.DATETIME, - "TIMESTAMP_TZ": InferenceColumnType.DATETIME, - "TIMESTAMP_LTZ": InferenceColumnType.DATETIME, - "TIMESTAMP_NTZ": InferenceColumnType.DATETIME, - } - string_prefixes = { - "VARCHAR", - "CHAR", - "CHARACTER", - "NCHAR", - "STRING", - "TEXT", - "NVARCHAR", - "NVARCHAR2", - "CHAR VARYING", - "NCHAR VARYING", - } - - if type_str in type_mapping: - return type_mapping[type_str] - - # This might be a string type, which can either be something like "TEXT" or "VARCHAR(256)" - for prefix in string_prefixes: - if type_str.startswith(prefix): - return InferenceColumnType.STRING - - return InferenceColumnType.UNKNOWN - - def _get_select_list_for_column_name(self, name: str, count_nulls: bool) -> str: - statements = [ - f"COUNT(DISTINCT '{name}') AS {name}_{SnowflakeInferenceContextProvider.COUNT_DISTINCT_SUFFIX}", - f"MIN('{name}') AS {name}_{SnowflakeInferenceContextProvider.MIN_SUFFIX}", - f"MAX('{name}') AS {name}_{SnowflakeInferenceContextProvider.MAX_SUFFIX}", - ( - f"SUM(CASE WHEN '{name}' IS NULL THEN 1 ELSE 0 END) AS {name}_{SnowflakeInferenceContextProvider.COUNT_NULL_SUFFIX}" - if count_nulls - else f"0 AS {name}_{SnowflakeInferenceContextProvider.COUNT_NULL_SUFFIX}" - ), - ] - - return ", ".join(statements) - - def _get_one_int(self, data_table: MetricFlowDataTable, column_name: str) -> int: - if len(data_table.rows) == 0: - raise ValueError("No rows in the data table") - - return_value = data_table.get_cell_value(0, data_table.column_name_index(column_name)) - - if isinstance(return_value, int): - return int(return_value) - raise RuntimeError(f"Unhandled case {return_value=}") - - def _get_one_str(self, data_table: MetricFlowDataTable, column_name: str) -> str: - if len(data_table.rows) == 0: - raise ValueError("No rows in the data table") - - return str(data_table.get_cell_value(0, data_table.column_name_index(column_name))) - - def _get_table_properties(self, table: SqlTable) -> TableProperties: - all_columns_query = f"SHOW COLUMNS IN TABLE {table.sql}" - all_columns = self._client.query(all_columns_query) - - sql_column_list = [] - col_types = {} - col_nullable = {} - select_lists = [] - - for row in all_columns.rows: - column = SqlColumn.from_names( - db_name=str(row[all_columns.column_name_index("database_name")]).lower(), - schema_name=str(row[all_columns.column_name_index("schema_name")]).lower(), - table_name=str(row[all_columns.column_name_index("table_name")]).lower(), - column_name=str(row[all_columns.column_name_index("column_name")]).lower(), - ) - sql_column_list.append(column) - - type_dict = json.loads(str(row[all_columns.column_name_index("data_type")])) - col_types[column] = self._column_type_from_show_columns_data_type(type_dict["type"]) - col_nullable[column] = type_dict["nullable"] - select_lists.append( - self._get_select_list_for_column_name( - name=column.column_name, - count_nulls=col_nullable[column], - ) - ) - - select_lists.append("COUNT(*) AS rowcount") - select_list = ", ".join(select_lists) - statistics_query = f"SELECT {select_list} FROM {table.sql} SAMPLE ({self.max_sample_size} ROWS)" - statistics_data_table = self._client.query(statistics_query) - - column_props = [ - ColumnProperties( - column=column, - type=col_types[column], - is_nullable=col_nullable[column], - null_count=self._get_one_int( - statistics_data_table, - f"{column.column_name}_{SnowflakeInferenceContextProvider.COUNT_NULL_SUFFIX}", - ), - row_count=self._get_one_int(statistics_data_table, "rowcount"), - distinct_row_count=self._get_one_int( - statistics_data_table, - f"{column.column_name}_{SnowflakeInferenceContextProvider.COUNT_DISTINCT_SUFFIX}", - ), - min_value=self._get_one_str( - statistics_data_table, - f"{column.column_name}_{SnowflakeInferenceContextProvider.MIN_SUFFIX}", - ), - max_value=self._get_one_str( - statistics_data_table, - f"{column.column_name}_{SnowflakeInferenceContextProvider.MAX_SUFFIX}", - ), - ) - for column in sql_column_list - ] - - return TableProperties(table=table, column_props=column_props) diff --git a/metricflow/inference/models.py b/metricflow/inference/models.py deleted file mode 100644 index 4d1e11870a..0000000000 --- a/metricflow/inference/models.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Module with common models that are used across multiple classes/modules in inference.""" - -from __future__ import annotations - -from abc import ABC -from dataclasses import dataclass -from enum import Enum -from typing import List, Optional - -from metricflow.sql.sql_column import SqlColumn - - -class InferenceSignalConfidence(Enum): - """A discrete enumeration of possible confidence values for an inference signal. - - We chose discrete confidence values instead of a continuous range (e.g a float between 0 and 1) - to standardize confidence outputs between heuristic rules. We want to avoid different rules - assuming different values for, say, "medium" or "high" confidence, since this could skew - results in favor of/against certain rules. - """ - - VERY_HIGH = 3 - HIGH = 2 - MEDIUM = 1 - LOW = 0 - - -class InferenceSignalNode(ABC): - """A node in the inference signal type hierarchy. - - This class can be used to assemble a type hierarchy tree. It can be used by heuristics - to check whether signals produced by rules are conflicting or complementary, relying - on the property that sibling nodes are mutually exclusive in the hierarchy. - """ - - def __init__(self, parent: Optional[InferenceSignalNode], name: str) -> None: # noqa: D107 - self.name = name - - self.parent = parent - self.children: List[InferenceSignalNode] = [] - - if parent is not None: - parent.children.append(self) - - def __str__(self) -> str: # noqa: D105 - return f"InferenceSignalNode: {self.name}" - - @property - def supertypes(self) -> List[InferenceSignalNode]: - """The list of all supertypes for this type node, from the root to the direct parent.""" - if self.parent is None: - return [] - - return self.parent.supertypes + [self.parent] - - def is_subtype_of(self, other: InferenceSignalNode) -> bool: - """Whether self is a subtype of other. - - A subtype can always be used where its supertype is expected, although the reverse - may not be true - """ - return other == self or other in self.supertypes - - -# This is kinda horrible but there's no way of instancing the tree with type safety without -# hardcoding it. Having some magic that dynamically assigns attributes could work, but then -# we lose IDE autocompletion and static checking -class _TreeNodes: - root = InferenceSignalNode(None, "UNKNOWN") - id = InferenceSignalNode(root, "IDENTIFIER") - foreign_id = InferenceSignalNode(id, "FOREIGN_IDENTIFIER") - unique_id = InferenceSignalNode(id, "UNIQUE_IDENTIFIER") - primary_id = InferenceSignalNode(unique_id, "PRIMARY_IDENTIFIER") - dimension = InferenceSignalNode(root, "DIMENSION") - time_dimension = InferenceSignalNode(dimension, "TIME_DIMENSION") - primary_time_dimension = InferenceSignalNode(time_dimension, "PRIMARY_TIME_DIMENSION") - categorical_dimension = InferenceSignalNode(dimension, "CATEGORICAL_DIMENSION") - measure = InferenceSignalNode(root, "MEASURE") - - -class InferenceSignalType: - """All possible inference signal types.""" - - UNKNOWN = _TreeNodes.root - - class ID: - """Indicates a column was inferred to be an ID.""" - - UNKNOWN = _TreeNodes.id - FOREIGN = _TreeNodes.foreign_id - UNIQUE = _TreeNodes.unique_id - PRIMARY = _TreeNodes.primary_id - - class DIMENSION: - """Indicates a column was inferred to be a dimension.""" - - UNKNOWN = _TreeNodes.dimension - TIME = _TreeNodes.time_dimension - PRIMARY_TIME = _TreeNodes.primary_time_dimension - CATEGORICAL = _TreeNodes.categorical_dimension - - class MEASURE: - """Indicates a column was inferred to be a measure.""" - - UNKNOWN = _TreeNodes.measure - - -@dataclass(frozen=True) -class InferenceSignal: - """Encapsulates a piece of evidence about a column produced by an inference rule. - - column: the target column for this signal - type_node: the type node of the signal - reason: a human-readable string that explains why this signal was produced. It may - eventually reach the user's eyeballs. - confidence: the confidence that this signal is correct. - only_applies_to_parent: whether a solver should only consider this signal if the parent type - node is also present. In other words, this signal is _complementary_ to its parent. - - About the usage of `only_applies_to_parent`: - This can be used to produce signals which don't affect each other, are not - individually indicative of a specific outcome, and as such are only used to - further specify parent signals. - - Example: columns with unique values can indicate both unique keys and categorical - dimensions, which are contradictory outcomes. As such, a unique values signal requires - the parent signal (is ID or is DIMENSION) to be present to be useful, and so anything - tagged with this flag will only apply to the parent. So you can safely designate a - signal as having a HIGH confidence for a given base column type while ignoring - it for any other column type. - - - """ - - column: SqlColumn - type_node: InferenceSignalNode - reason: str - confidence: InferenceSignalConfidence - only_applies_to_parent: bool - - -@dataclass(frozen=True) -class InferenceResult: - """Encapsulates a final decision about a column. - - column: the target column for this result - type_node: the type node of the column - reason: a list of human-readable strings that explain positive reasons why this result - was produced. They may eventually reach the user's eyeballs. - problems: a list of human-readable strings that explain why a more specific - result was not produced or any other confusion/errors it might have encountered - while solving. They may eventually reach the user's eyeballs. - """ - - column: SqlColumn - type_node: InferenceSignalNode - reasons: List[str] - problems: List[str] diff --git a/metricflow/inference/renderer/__init__.py b/metricflow/inference/renderer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/metricflow/inference/renderer/base.py b/metricflow/inference/renderer/base.py deleted file mode 100644 index 4aea95f256..0000000000 --- a/metricflow/inference/renderer/base.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from typing import List - -from metricflow.inference.models import InferenceResult - - -class InferenceRenderer(ABC): - """Render inference results into some format.""" - - def render(self, results: List[InferenceResult]) -> None: - """Render a set of inference results into the screen, some file, the network or whatever.""" - raise NotImplementedError diff --git a/metricflow/inference/renderer/config_file.py b/metricflow/inference/renderer/config_file.py deleted file mode 100644 index d4d5461349..0000000000 --- a/metricflow/inference/renderer/config_file.py +++ /dev/null @@ -1,162 +0,0 @@ -from __future__ import annotations - -import os -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Literal, TypedDict, Union - -from ruamel.yaml import YAML -from ruamel.yaml.comments import CommentedMap -from typing_extensions import NotRequired - -from metricflow.inference.models import InferenceResult, InferenceSignalType -from metricflow.inference.renderer.base import InferenceRenderer -from metricflow.sql.sql_table import SqlTable - -yaml = YAML() - - -class RenderedTimeColumnConfigTypeParams(TypedDict): # noqa: D101 - time_granularity: Literal["day"] - - -class RenderedColumnConfig(TypedDict): # noqa: D101 - name: str - type: str - type_params: NotRequired[RenderedTimeColumnConfigTypeParams] - - -class ConfigFileRenderer(InferenceRenderer): - """Writes inference results to a set of config files.""" - - UNKNOWN_FIELD_VALUE = "FIXME" - - def __init__(self, dir_path: Union[str, Path], overwrite: bool = False) -> None: - """Initializes the renderer. - - dir_path: The path to the config directory - overwrite: If set to False, will raise error if the directory exists - """ - dir_path = os.path.abspath(dir_path) - - if not overwrite and os.path.exists(dir_path): - raise ValueError("ConfigFileRender.overwrite is False but path exists.") - - if os.path.isfile(dir_path): - raise ValueError("ConfigFileRenderer `dir_path` is a file.") - - self.dir_path = dir_path - - def _get_filename_for_table(self, table: SqlTable) -> str: - return os.path.abspath(os.path.join(self.dir_path, f"{table.sql}.yaml")) - - def _fixme(self, comment: str) -> str: - return f"FIXME: {comment}" - - def _render_entity_columns(self, results: List[InferenceResult]) -> List[CommentedMap]: - type_map = { - InferenceSignalType.ID.PRIMARY: "primary", - InferenceSignalType.ID.FOREIGN: "foreign", - InferenceSignalType.ID.UNIQUE: "unique", - } - - rendered: List[CommentedMap] = [ - CommentedMap( - { - "name": result.column.column_name, - "type": type_map.get(result.type_node, ConfigFileRenderer.UNKNOWN_FIELD_VALUE), - } - ) - for result in results - if result.type_node.is_subtype_of(InferenceSignalType.ID.UNKNOWN) - ] - - return rendered - - def _render_dimension_columns(self, results: List[InferenceResult]) -> List[CommentedMap]: - type_map = { - InferenceSignalType.DIMENSION.TIME: "time", - InferenceSignalType.DIMENSION.PRIMARY_TIME: "time", - InferenceSignalType.DIMENSION.CATEGORICAL: "categorical", - } - - rendered: List[CommentedMap] = [] - for result in results: - if not result.type_node.is_subtype_of(InferenceSignalType.DIMENSION.UNKNOWN): - continue - - result_data: CommentedMap = CommentedMap( - { - "name": result.column.column_name, - "type": type_map.get(result.type_node, ConfigFileRenderer.UNKNOWN_FIELD_VALUE), - } - ) - - if result_data["type"] == ConfigFileRenderer.UNKNOWN_FIELD_VALUE: - result_data.yaml_add_eol_comment(self._fixme("unknown field value"), "type") - - if result.type_node.is_subtype_of(InferenceSignalType.DIMENSION.TIME): - type_params: CommentedMap = CommentedMap({"time_granularity": "day"}) - if result.type_node.is_subtype_of(InferenceSignalType.DIMENSION.PRIMARY_TIME): - type_params["is_primary"] = True - result_data["type_params"] = type_params - - if len(result.problems) > 0: - result_data.yaml_set_comment_before_after_key( - key="name", - before=f"{ConfigFileRenderer.UNKNOWN_FIELD_VALUE}: " + ", ".join(result.problems), - ) - - rendered.append(result_data) - - return rendered - - def _get_comments_for_unknown_columns(self, results: List[InferenceResult]) -> List[str]: - delim = "\n - " - return [ - self._fixme(result.column.column_name + delim + delim.join(result.problems)) - for result in results - if result.type_node == InferenceSignalType.UNKNOWN - ] - - def _get_configuration_data_for_table(self, table: SqlTable, results: List[InferenceResult]) -> Dict: - data = CommentedMap( - { - "semantic_model": { - "name": table.table_name, - "node_relation": { - "alias": table.table_name, - "schema_name": table.schema_name, - "database": table.db_name, - "relation_name": table.sql, - }, - "entities": self._render_entity_columns(results), - "dimensions": self._render_dimension_columns(results), - "measures": [], - } - } - ) - - header_comments = [self._fixme("Unreviewed inferred config file")] + self._get_comments_for_unknown_columns( - results - ) - data.yaml_set_comment_before_after_key( - key="semantic_model", - before="\n".join(header_comments), - ) - - return data - - def render(self, results: List[InferenceResult]) -> None: - """Render the inference results to files in the configured directory.""" - os.makedirs(self.dir_path, exist_ok=True) - - results_by_table: Dict[SqlTable, List[InferenceResult]] = defaultdict(list) - for result in results: - results_by_table[result.column.table].append(result) - - for table, results in results_by_table.items(): - table_data = self._get_configuration_data_for_table(table, results) - table_path = self._get_filename_for_table(table) - with open(table_path, "w") as table_file: - yaml.dump(table_data, table_file) diff --git a/metricflow/inference/renderer/stream.py b/metricflow/inference/renderer/stream.py deleted file mode 100644 index 41b26e5f74..0000000000 --- a/metricflow/inference/renderer/stream.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -import sys -from collections import defaultdict -from typing import Dict, List, TextIO - -from metricflow.inference.models import InferenceResult -from metricflow.inference.renderer.base import InferenceRenderer -from metricflow.sql.sql_table import SqlTable - - -class StreamInferenceRenderer(InferenceRenderer): - """Writes inference results to a TextIO as human-readable text.""" - - def __init__(self, stream: TextIO, close_after_render: bool) -> None: - """Initializes the renderer. - - stream: the `TextIO` to write outputs to. - """ - self.stream = stream - self.close_after_render = close_after_render - - @staticmethod - def stdout() -> StreamInferenceRenderer: - """Factory method to create a `StreamInferenceRenderer` that writes to stdout.""" - return StreamInferenceRenderer(stream=sys.stdout, close_after_render=False) - - @staticmethod - def file(path: str) -> StreamInferenceRenderer: - """Factory method to create a `StreamInferenceRenderer` that writes to a file.""" - file_stream = open(path, "w") - return StreamInferenceRenderer(stream=file_stream, close_after_render=True) - - def render(self, results: List[InferenceResult]) -> None: - """Write the results to the configured TextIO as human-readable text.""" - list_delim = "\n - " - - results_by_table: Dict[SqlTable, List[InferenceResult]] = defaultdict(list) - for result in results: - results_by_table[result.column.table].append(result) - - for table, results in results_by_table.items(): - lines: List[str] = [table.sql + "\n"] - for result in results: - reasons_str = " -- " - if len(result.reasons) > 0: - reasons_str = list_delim + list_delim.join(result.reasons) - - problems_str = " -- " - if len(result.problems) > 0: - problems_str = list_delim + list_delim.join(result.problems) - - lines += [ - f" {result.column.column_name}\n", - f" type: {result.type_node.name}\n", - " reasons: ", - reasons_str, - "\n problems: ", - problems_str, - "\n\n", - ] - self.stream.writelines(lines) - - if self.close_after_render: - self.stream.close() diff --git a/metricflow/inference/rule/__init__.py b/metricflow/inference/rule/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/metricflow/inference/rule/base.py b/metricflow/inference/rule/base.py deleted file mode 100644 index 57a1dac610..0000000000 --- a/metricflow/inference/rule/base.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import List - -from metricflow.inference.context.data_warehouse import DataWarehouseInferenceContext -from metricflow.inference.models import InferenceSignal - - -class InferenceRule(ABC): - """Implements some sort of heuristic that produces signals about columns. - - An inference rule produces zero or more `InferenceSignal` instances about whatever - columns it thinks it should, based on input `InferenceContext`s. - - Concrete implementations should aim to be short and modularized. It is preferred to - compose multiple small rules that each produce a signal type than to make one large - rule with complex logic to produce a bunch of signals. - - It currently only accepts DataWarehouseInferenceContext as input, but this will probably - be generalized later. - """ - - @abstractmethod - def process(self, warehouse: DataWarehouseInferenceContext) -> List[InferenceSignal]: - """The actual rule implementation that returns a list of signals based on the input contexts.""" - raise NotImplementedError diff --git a/metricflow/inference/rule/defaults.py b/metricflow/inference/rule/defaults.py deleted file mode 100644 index df92492c18..0000000000 --- a/metricflow/inference/rule/defaults.py +++ /dev/null @@ -1,275 +0,0 @@ -from __future__ import annotations - -from typing import List - -from metricflow.inference.context.data_warehouse import ( - ColumnProperties, - DataWarehouseInferenceContext, - InferenceColumnType, -) -from metricflow.inference.models import InferenceSignal, InferenceSignalConfidence, InferenceSignalType -from metricflow.inference.rule.base import InferenceRule -from metricflow.inference.rule.rules import ColumnMatcherRule, LowCardinalityRatioRule - -# ------------- -# ENTITIES -# ------------- - - -class AnyEntityByNameRule(ColumnMatcherRule): - """Inference rule that checks for columns ending with `id`. - - We searched for words ending with "id" just to assess the chance of this resulting in a - false positive. Our guess is most of those words would rarely, if ever, be used as column names. - Therefore, not adding a mandatory "_" before "id" would benefit the product by matching names - like "userid", despite the rare "squid", "mermaid" or "android" matches. - - See: https://www.thefreedictionary.com/words-that-end-in-id - - It will always produce ID.UNKNOWN signal with VERY_HIGH confidence. - """ - - type_node = InferenceSignalType.ID.UNKNOWN - confidence = InferenceSignalConfidence.VERY_HIGH - only_applies_to_parent_signal = False - match_reason = "Column name ends with `id`" - - def match_column(self, props: ColumnProperties) -> bool: # noqa: D102 - return props.column.column_name.lower().endswith("id") - - -class PrimaryEntityByNameRule(ColumnMatcherRule): - """Inference rule that matches primary entities by their names. - - It will match columns such as `db.schema.mytable.mytable_id`, - `db.schema.mytable.mytableid` and `db.schema.mytable.id`. - - It will always produce a ID.PRIMARY signal with VERY_HIGH confidence. - """ - - type_node = InferenceSignalType.ID.PRIMARY - confidence = InferenceSignalConfidence.VERY_HIGH - only_applies_to_parent_signal = False - match_reason = "Column name matches `(table_name?)(_?)id`" - - def match_column(self, props: ColumnProperties) -> bool: # noqa: D102 - col_lower = props.column.column_name.lower() - table_lower = props.column.table_name.lower().rstrip("s") - - if col_lower == "id": - return True - - return col_lower == f"{table_lower}_id" or col_lower == f"{table_lower}id" - - -class UniqueEntityByDistinctCountRule(ColumnMatcherRule): - """Inference rule that matches unique entities by their COUNT DISTINCT. - - It will always produce a ID.UNIQUE complementary signal with VERY_HIGH confidence. - """ - - type_node = InferenceSignalType.ID.UNIQUE - confidence = InferenceSignalConfidence.VERY_HIGH - only_applies_to_parent_signal = True - match_reason = "The values in the column are unique" - - def match_column(self, props: ColumnProperties) -> bool: # noqa: D102 - return props.distinct_row_count == props.row_count - - -class ForeignEntityByCardinalityRatioRule(LowCardinalityRatioRule): - """Inference rule that checks for low cardinality columns. - - It will always produce ID.FOREIGN with MEDIUM confidence (complementary). - """ - - type_node = InferenceSignalType.ID.FOREIGN - confidence = InferenceSignalConfidence.MEDIUM - only_applies_to_parent_signal = True - - -# ------------- -# DIMENSIONS -# ------------- - - -class TimeDimensionByTimeTypeRule(ColumnMatcherRule): - """Inference rule that checks for time (time, date, datetime, timestamp) columns. - - It will always produce DIMENSION.TIME with VERY_HIGH confidence. - """ - - type_node = InferenceSignalType.DIMENSION.TIME - confidence = InferenceSignalConfidence.VERY_HIGH - only_applies_to_parent_signal = False - match_reason = "Column type is time (TIME, DATE, DATETIME, TIMESTAMP)" - - def match_column(self, props: ColumnProperties) -> bool: # noqa: D102 - return props.type == InferenceColumnType.DATETIME - - -class PrimaryTimeDimensionByNameRule(ColumnMatcherRule): - """Inference rule that checks if the column name is one of `ds` or `created_at`. - - It will always produce DIMENSION.PRIMARY_TIME with VERY_HIGH confidence. - """ - - type_node = InferenceSignalType.DIMENSION.PRIMARY_TIME - confidence = InferenceSignalConfidence.VERY_HIGH - only_applies_to_parent_signal = False - match_reason = "Column name is either of 'ds', 'created_at', 'created_date' or 'created_time'" - - def match_column(self, props: ColumnProperties) -> bool: # noqa: D102 - return props.column.column_name in ["ds", "created_at", "created_date", "created_time"] - - -class PrimaryTimeDimensionIfOnlyTimeRule(InferenceRule): - """Inference rule for checking if the column is the only time column in the table. - - It will always produce DIMENSION.PRIMARY_TIME signal with VERY_HIGH confidence. - """ - - def process(self, warehouse: DataWarehouseInferenceContext) -> List[InferenceSignal]: # noqa: D102 - signals: List[InferenceSignal] = [] - for table_props in warehouse.tables.values(): - time_cols = [ - col for col, col_props in table_props.columns.items() if col_props.type == InferenceColumnType.DATETIME - ] - if len(time_cols) == 1: - signals.append( - InferenceSignal( - column=time_cols[0], - type_node=InferenceSignalType.DIMENSION.PRIMARY_TIME, - only_applies_to_parent=False, - reason="The column is the only time column in its table", - confidence=InferenceSignalConfidence.VERY_HIGH, - ) - ) - return signals - - -class CategoricalDimensionByBooleanTypeRule(ColumnMatcherRule): - """Inference rule that checks for boolean columns. - - It will always produce DIMENSION.CATEGORICAL with VERY_HIGH confidence. - """ - - type_node = InferenceSignalType.DIMENSION.CATEGORICAL - confidence = InferenceSignalConfidence.VERY_HIGH - only_applies_to_parent_signal = False - match_reason = "Column type is BOOLEAN" - - def match_column(self, props: ColumnProperties) -> bool: # noqa: D102 - return props.type == InferenceColumnType.BOOLEAN - - -class CategoricalDimensionByStringTypeAndLowCardinalityRule(LowCardinalityRatioRule): - """Inference rule that checks for string typed columns with cardinality below the specified threshold. - - It will always produce DIMENSION.CATEGORICAL with HIGH confidence - """ - - type_node = InferenceSignalType.DIMENSION.CATEGORICAL - confidence = InferenceSignalConfidence.HIGH - only_applies_to_parent_signal = False - match_reason = "Column type is STRING and cardinality ratio is below 0.4" - - def match_column(self, props: ColumnProperties) -> bool: - """This is a bit of a hack for composing rules by invoking one directly here.""" - if props.type != InferenceColumnType.STRING: - return False - return super().match_column(props=props) - - -class CategoricalDimensionByStringTypeRule(ColumnMatcherRule): - """Inference rule that checks for string columns. - - It will always produce DIMENSION.CATEGORICAL with MEDIUM confidence (complementary). - """ - - type_node = InferenceSignalType.DIMENSION.CATEGORICAL - confidence = InferenceSignalConfidence.MEDIUM - only_applies_to_parent_signal = True - match_reason = "Column type is STRING" - - def match_column(self, props: ColumnProperties) -> bool: # noqa: D102 - return props.type == InferenceColumnType.STRING - - -class CategoricalDimensionByIntegerTypeRule(ColumnMatcherRule): - """Inference rule that checks for integer columns. - - It will always produce DIMENSION.CATEGORICAL with MEDIUM confidence (complementary). - """ - - type_node = InferenceSignalType.DIMENSION.CATEGORICAL - confidence = InferenceSignalConfidence.MEDIUM - only_applies_to_parent_signal = True - match_reason = "Column type is INTEGER" - - def match_column(self, props: ColumnProperties) -> bool: # noqa: D102 - return props.type == InferenceColumnType.INTEGER - - -class CategoricalDimensionByCardinalityRatioRule(LowCardinalityRatioRule): - """Inference rule that checks for low cardinality columns. - - It will always produce DIMENSION.CATEGORICAL with MEDIUM confidence (complementary). - """ - - type_node = InferenceSignalType.DIMENSION.CATEGORICAL - confidence = InferenceSignalConfidence.MEDIUM - only_applies_to_parent_signal = True - - -# ------------- -# MEASURES -# ------------- - - -class MeasureByRealTypeRule(ColumnMatcherRule): - """Inference rule that checks for real (float, double) columns. - - It will always produce MEASURE with VERY_HIGH confidence. - """ - - type_node = InferenceSignalType.MEASURE.UNKNOWN - confidence = InferenceSignalConfidence.VERY_HIGH - only_applies_to_parent_signal = False - match_reason = "Column type is real (FLOAT, DOUBLE, DOUBLE PRECISION)" - - def match_column(self, props: ColumnProperties) -> bool: # noqa: D102 - return props.type == InferenceColumnType.FLOAT - - -class MeasureByIntegerTypeRule(ColumnMatcherRule): - """Inference rule that checks for integer columns. - - It will always produce MEASURE with MEDIUM confidence (complementary). - """ - - type_node = InferenceSignalType.MEASURE.UNKNOWN - confidence = InferenceSignalConfidence.MEDIUM - only_applies_to_parent_signal = True - match_reason = "Column type is INTEGER" - - def match_column(self, props: ColumnProperties) -> bool: # noqa: D102 - return props.type == InferenceColumnType.INTEGER - - -DEFAULT_RULESET = [ - AnyEntityByNameRule(), - PrimaryEntityByNameRule(), - UniqueEntityByDistinctCountRule(), - ForeignEntityByCardinalityRatioRule(0.6), - TimeDimensionByTimeTypeRule(), - PrimaryTimeDimensionByNameRule(), - PrimaryTimeDimensionIfOnlyTimeRule(), - CategoricalDimensionByBooleanTypeRule(), - CategoricalDimensionByStringTypeRule(), - CategoricalDimensionByStringTypeAndLowCardinalityRule(0.4), - CategoricalDimensionByIntegerTypeRule(), - CategoricalDimensionByCardinalityRatioRule(0.2), - MeasureByRealTypeRule(), - MeasureByIntegerTypeRule(), -] diff --git a/metricflow/inference/rule/rules.py b/metricflow/inference/rule/rules.py deleted file mode 100644 index 8829235a34..0000000000 --- a/metricflow/inference/rule/rules.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import Callable, List, TypeVar - -from metricflow.inference.context.data_warehouse import ( - ColumnProperties, - DataWarehouseInferenceContext, -) -from metricflow.inference.models import ( - InferenceSignal, - InferenceSignalConfidence, - InferenceSignalNode, -) -from metricflow.inference.rule.base import InferenceRule - -T = TypeVar("T", bound="ColumnMatcherRule") - -ColumnMatcher = Callable[[T, ColumnProperties], bool] - - -class ColumnMatcherRule(InferenceRule): - """Inference rule that checks for matches across all columns. - - This is a useful shortcut for making rules that match columns one by one with preset confidence - values, types and match reasons. - - If you need a more specific rule with varying confidence, column cross-checking and that outputs - multiple types, inherit from `InferenceRule` directly. - - type_node: the `InferenceSignalNode` to produce whenever the pattern is matched - confidence: the `InferenceSignalConfidence` to produce whenever the pattern is matched - only_applies_to_parent_signal: whether the produced signal should be only taken into - consideration by the solver if the parent is present in the tree. - match_reason: a human-readable string of the reason why this was matched - """ - - type_node: InferenceSignalNode - confidence: InferenceSignalConfidence - only_applies_to_parent_signal: bool - match_reason: str - - @abstractmethod - def match_column(self, props: ColumnProperties) -> bool: - """A function to determine whether `ColumnProperties` matches. If it does, produce the signal.""" - raise NotImplementedError - - def process(self, warehouse: DataWarehouseInferenceContext) -> List[InferenceSignal]: # type: ignore - """Try to match all columns' properties with the matching function. - - If they do match, produce a signal with the configured type and confidence. - """ - matching_columns = [column for column, props in warehouse.columns.items() if self.match_column(props)] - signals = [ - InferenceSignal( - column=column, - type_node=self.type_node, - reason=self.match_reason, - confidence=self.confidence, - only_applies_to_parent=self.only_applies_to_parent_signal, - ) - for column in matching_columns - ] - return signals - - -class LowCardinalityRatioRule(ColumnMatcherRule): - """Inference rule that checks for string columns with low cardinality to count ratio. - - The ratio is calculated as `distinct_count/(count - null_count)`. - """ - - def __init__(self, cardinality_count_ratio_threshold: float) -> None: - """Initialize the rule. - - cardinality_count_ratio_threshold: rations below this threshold will match. - """ - assert cardinality_count_ratio_threshold >= 0 and cardinality_count_ratio_threshold <= 1 - self.threshold = cardinality_count_ratio_threshold - super().__init__() - - match_reason = "Column has low cardinality" - - def match_column(self, props: ColumnProperties) -> bool: # noqa: D102 - denom = props.row_count - props.null_count - - # undefined ratio - if denom == 0: - return False - - ratio = props.distinct_row_count / denom - return ratio < self.threshold diff --git a/metricflow/inference/runner.py b/metricflow/inference/runner.py deleted file mode 100644 index ed03ee90af..0000000000 --- a/metricflow/inference/runner.py +++ /dev/null @@ -1,156 +0,0 @@ -from __future__ import annotations - -import contextlib -import logging -from abc import ABC, abstractmethod -from collections import defaultdict -from typing import Iterator, List - -import more_itertools - -from metricflow.inference.context.base import InferenceContextProvider -from metricflow.inference.context.data_warehouse import DataWarehouseInferenceContextProvider -from metricflow.inference.renderer.base import InferenceRenderer -from metricflow.inference.rule.base import InferenceRule -from metricflow.inference.solver.base import InferenceSolver -from metricflow.sql.sql_table import SqlTable - -logger = logging.getLogger(__file__) - - -class InferenceProgressReporter(ABC): - """Base class for reporting progress while running inference.""" - - @staticmethod - @abstractmethod - @contextlib.contextmanager - def warehouse() -> Iterator[None]: - """Context manager that wraps the warehouse context fetching step.""" - raise NotImplementedError - - @staticmethod - @abstractmethod - @contextlib.contextmanager - def table(table: SqlTable, index: int, total: int) -> Iterator[None]: - """Context manager that wraps context fetching for a single table.""" - raise NotImplementedError - - @staticmethod - @abstractmethod - @contextlib.contextmanager - def rules() -> Iterator[None]: - """Context manager that wraps all rule-processing.""" - raise NotImplementedError - - @staticmethod - @abstractmethod - @contextlib.contextmanager - def solver() -> Iterator[None]: - """Context manager that wraps the solving step.""" - raise NotImplementedError - - @staticmethod - @abstractmethod - @contextlib.contextmanager - def renderers() -> Iterator[None]: - """Context manager that wraps the rendering step.""" - raise NotImplementedError - - -class NoOpInferenceProgressReporter(InferenceProgressReporter): - """Pass-through implementation of `InferenceProgressReporter`.""" - - @staticmethod - @contextlib.contextmanager - def warehouse() -> Iterator[None]: # noqa: D102 - yield - - @staticmethod - @contextlib.contextmanager - def table(table: SqlTable, index: int, total: int) -> Iterator[None]: # noqa: D102 - yield - - @staticmethod - @contextlib.contextmanager - def rules() -> Iterator[None]: # noqa: D102 - yield - - @staticmethod - @contextlib.contextmanager - def solver() -> Iterator[None]: # noqa: D102 - yield - - @staticmethod - @contextlib.contextmanager - def renderers() -> Iterator[None]: # noqa: D102 - yield - - -# TODO: we still need to add input/output context validations and optimizations. -# Case 1: Rule 1 requires Context of type A, but no provider privides it. Should fail before running -# Case 2: ContextProvider 1 provides Context of type A, but no rule uses it. Should proceed without actually fetching that context. -# Case 3: ContextProviders 1 and 2 both provide Contexts of type A, which is ambiguous. Should fail before running - - -class InferenceRunner: - """Glues together all other inference classes in a sequence that actually runs inference.""" - - def __init__( - self, - context_providers: List[InferenceContextProvider], - ruleset: List[InferenceRule], - solver: InferenceSolver, - renderers: List[InferenceRenderer], - progress_reporter: InferenceProgressReporter = NoOpInferenceProgressReporter(), - ) -> None: - """Initialize the inference runner. - - context_providers: a list of context providers to be used - ruleset: the set of rules that will produce signals - solver: the inference solver to be used - renderers: the renderers that will write inference results as meaningful output - progress_reporter: `InferenceProgressReporter` to report progress - """ - logger.warning( - "Semantic Model Inference is still in Beta. " - "As such, you should not expect it to be 100% stable or be free of bugs. Any public CLI or Python interfaces may change without prior notice." - " If you find any bugs or feel like something is not behaving as it should, feel free to open an issue on the Metricflow Github repo." - ) - - if len(context_providers) != 1 or not isinstance(context_providers[0], DataWarehouseInferenceContextProvider): - raise ValueError("Currently, InferenceRunner requires exactly one DataWarehouseInferenceContextProvider.") - - if len(ruleset) == 0: - raise ValueError("Running inference with an empty ruleset would produce no result.") - - if len(renderers) == 0: - raise ValueError("Running inference with no renderer would discard the results.") - - self.context_providers = context_providers - self.ruleset = ruleset - self.solver = solver - self.renderers = renderers - self._progress = progress_reporter - - def run(self) -> None: - """Runs inference with the given configs.""" - # FIXME: currently we only accept DataWarehouseContextProvider - provider: DataWarehouseInferenceContextProvider = self.context_providers[0] # type: ignore - - with self._progress.warehouse(): - warehouse = provider.get_context( - table_progress=self._progress.table, - ) - - with self._progress.rules(): - signals_by_column = defaultdict(list) - signals = [rule.process(warehouse) for rule in self.ruleset] - for rule_signal in tuple(more_itertools.flatten(signals)): - signals_by_column[rule_signal.column].append(rule_signal) - - with self._progress.solver(): - results = [self.solver.solve_column(column, signals) for column, signals in signals_by_column.items()] - - with self._progress.renderers(): - for renderer in self.renderers: - renderer.render(results) diff --git a/metricflow/inference/solver/__init__.py b/metricflow/inference/solver/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/metricflow/inference/solver/base.py b/metricflow/inference/solver/base.py deleted file mode 100644 index ddb122caac..0000000000 --- a/metricflow/inference/solver/base.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import List - -from metricflow.inference.models import InferenceResult, InferenceSignal -from metricflow.sql.sql_column import SqlColumn - - -class InferenceSolver(ABC): - """Base class for inference solvers. - - Inference solvers implement algorithms for making a final decision over what a column is or isn't, based on - signals produced by inference rules for that column. - """ - - @abstractmethod - def solve_column(self, column: SqlColumn, signals: List[InferenceSignal]) -> InferenceResult: - """Make a decision about a column based on its input signals.""" - raise NotImplementedError diff --git a/metricflow/inference/solver/weighted_tree.py b/metricflow/inference/solver/weighted_tree.py deleted file mode 100644 index 7b637e2ecc..0000000000 --- a/metricflow/inference/solver/weighted_tree.py +++ /dev/null @@ -1,166 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict -from typing import Callable, Dict, List, Optional - -from metricflow.inference.models import ( - InferenceResult, - InferenceSignal, - InferenceSignalConfidence, - InferenceSignalNode, - InferenceSignalType, -) -from metricflow.inference.solver.base import InferenceSolver -from metricflow.sql.sql_column import SqlColumn - -NodeWeighterFunction = Callable[[InferenceSignalConfidence], int] - - -class WeightedTypeTreeInferenceSolver(InferenceSolver): - """Assigns weights to each type in the column type tree and attempts to traverse it using a weight percentage threshold.""" - - @staticmethod - def default_weighter_function(confidence: InferenceSignalConfidence) -> int: - """The default weighter function. - - It assigns weights 1, 2, 4 and 8 for LOW, MEDIUM, HIGH and VERY_HIGH confidences, respectively. It then sums - the weights of all provided confidence scores. - """ - confidence_weight_map = { - InferenceSignalConfidence.LOW: 1, - InferenceSignalConfidence.MEDIUM: 2, - InferenceSignalConfidence.HIGH: 4, - InferenceSignalConfidence.VERY_HIGH: 8, - } - - return confidence_weight_map[confidence] - - def __init__( - self, weight_percent_threshold: float = 0.6, weighter_function: Optional[NodeWeighterFunction] = None - ) -> None: - """Initialize the solver. - - weight_percent_threshold: a number in (0.5, 1]. If a node's weight corresponds to a percentage - above this threshold with respect to its siblings' total weight sum, the solver will progress deeper - into the type tree, entering that node. If not, it stops at the parent. - weighter_function: a function that returns a weight given a confidence score. It will be used - to assign integer weights to each node in the type tree based on its input signals. - """ - assert ( - weight_percent_threshold > 0.5 and weight_percent_threshold <= 1 - ), f"weight_percent_threshold is {weight_percent_threshold}, but it must be > 0.5 and <= 1!" - self._weight_percent_threshold = weight_percent_threshold - - self._weighter_function = ( - weighter_function - if weighter_function is not None - else WeightedTypeTreeInferenceSolver.default_weighter_function - ) - - def _get_cumulative_weights_for_root( - self, - root: InferenceSignalNode, - output_weights: Dict[InferenceSignalNode, int], - output_parent_only_weights: Dict[InferenceSignalNode, int], - signals_by_type: Dict[InferenceSignalNode, List[InferenceSignal]], - ) -> Dict[InferenceSignalNode, int]: - """Get a dict of cumulative weights, starting at `root`. - - A parent node's weight is the sum of all its children plus its own weight. Children tagged as only - applying to their parent are excluded from their grand-parent's sum. - - root: the root to start assigning cumulative weights from. - output_weights: the output dictionary to assign the cumulative weights to - output_parent_only_weights: similar to output_weights, but the weight of each node excludes the weights - of all of its parent-only children - signals_by_type: a dictionary that maps signal type nodes to signals - """ - for child in root.children: - self._get_cumulative_weights_for_root( - root=child, - output_weights=output_weights, - output_parent_only_weights=output_parent_only_weights, - signals_by_type=signals_by_type, - ) - - output_weights[root] = sum(self._weighter_function(signal.confidence) for signal in signals_by_type[root]) - output_weights[root] += sum(output_parent_only_weights[child] for child in root.children) - - output_parent_only_weights[root] = sum( - self._weighter_function(signal.confidence) - for signal in signals_by_type[root] - if not signal.only_applies_to_parent - ) - - return output_weights - - def _get_cumulative_weights(self, signals: List[InferenceSignal]) -> Dict[InferenceSignalNode, int]: - """Get the cumulative weights dict for a list of signals.""" - signals_by_type: Dict[InferenceSignalNode, List[InferenceSignal]] = defaultdict(list) - for signal in signals: - signals_by_type[signal.type_node].append(signal) - - return self._get_cumulative_weights_for_root( - root=InferenceSignalType.UNKNOWN, - output_weights=defaultdict(int), - output_parent_only_weights=defaultdict(int), - signals_by_type=signals_by_type, - ) - - def solve_column(self, column: SqlColumn, signals: List[InferenceSignal]) -> InferenceResult: - """Find the appropriate type for a column by traversing through the type tree. - - It traverses the tree by giving weights to all nodes and greedily finding the path with the most - weight until it either finds a leaf or there is a "weight bifurcation" in the path with respect - to the provided `weight_percent_threshold`. - """ - if len(signals) == 0: - return InferenceResult( - column=column, - type_node=InferenceSignalType.UNKNOWN, - reasons=[], - problems=[ - "No signals were extracted for this column", - "Inference solver could not determine if column was an entity, a dimension, or a measure", - ], - ) - - reasons_by_type: Dict[InferenceSignalNode, List[str]] = defaultdict(list) - for signal in signals: - reasons_by_type[signal.type_node].append(f"{signal.reason} ({signal.type_node.name})") - - node_weights = self._get_cumulative_weights(signals) - - reasons: List[str] = [] - problems: List[str] = [] - node = InferenceSignalType.UNKNOWN - while len(node.children) > 0: - children_weight_total = sum(node_weights[child] for child in node.children) - - if children_weight_total == 0: - break - - next_node = None - for child in node.children: - if node_weights[child] / children_weight_total >= self._weight_percent_threshold: - next_node = child - reasons += reasons_by_type[child] - break - - if next_node is None: - if len(node.children) > 0: # there was confusion - children_weight_strings = [ - f"{child.name} (weight {node_weights[child]})" - for child in node.children - if node_weights[child] != 0 - ] - child_str = " / ".join(children_weight_strings) - problems.append(f"Solver is confused between {child_str}") - break - - node = next_node - - if node == InferenceSignalType.UNKNOWN: - problems.append("Inference solver could not determine if column was an entity, a dimension, or a measure") - - return InferenceResult(column=column, type_node=node, reasons=reasons, problems=problems) diff --git a/tests_metricflow/inference/__init__.py b/tests_metricflow/inference/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests_metricflow/inference/context/__init__.py b/tests_metricflow/inference/context/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests_metricflow/inference/context/test_data_warehouse.py b/tests_metricflow/inference/context/test_data_warehouse.py deleted file mode 100644 index b481bfd213..0000000000 --- a/tests_metricflow/inference/context/test_data_warehouse.py +++ /dev/null @@ -1,134 +0,0 @@ -from __future__ import annotations - -from metricflow.inference.context.data_warehouse import ( - ColumnProperties, - DataWarehouseInferenceContext, - InferenceColumnType, - TableProperties, -) -from metricflow.sql.sql_column import SqlColumn -from metricflow.sql.sql_table import SqlTable - - -def test_column_properties_is_empty() -> None: - """Just some easy assertions to test is_empty works as intended.""" - props = ColumnProperties( - column=SqlColumn.from_string("db.schema.table.column"), - type=InferenceColumnType.INTEGER, - row_count=10000, - distinct_row_count=1000, - is_nullable=True, - null_count=1, - min_value=0, - max_value=9999, - ) - assert not props.is_empty - - empty_props = ColumnProperties( - column=SqlColumn.from_string("db.schema.table.column"), - type=InferenceColumnType.UNKNOWN, - row_count=0, - distinct_row_count=0, - is_nullable=False, - null_count=0, - min_value=None, - max_value=None, - ) - assert empty_props.is_empty - - -def test_table_properties() -> None: - """Test `TableProperties` initialization. - - This test case asserts that the conversion from the `column_props` argument (which is a list) to - `self.columns` (which is a dict) implemented by `TableProperties.__post_init__` works as intended. - """ - table = SqlTable.from_string("db.schema.table") - col_props = [ - ColumnProperties( - column=SqlColumn(table=table, column_name="column1"), - type=InferenceColumnType.INTEGER, - row_count=1000, - distinct_row_count=1000, - is_nullable=False, - null_count=0, - min_value=0, - max_value=999, - ), - ColumnProperties( - column=SqlColumn(table=table, column_name="column2"), - type=InferenceColumnType.FLOAT, - row_count=2000, - distinct_row_count=1000, - is_nullable=True, - null_count=10, - min_value=0, - max_value=1000, - ), - ] - - table_props = TableProperties(table=table, column_props=col_props) - - assert table_props.columns == { - col_props[0].column: col_props[0], - col_props[1].column: col_props[1], - } - - -def test_data_warehouse_inference_context() -> None: - """Test `DataWarehouseInferenceContext` initialization. - - This test case asserts that the conversion from the `table_props` argument - (which is a list) to `self.tables` and `self.columns` (which are dicts) - implemented by `DataWarehouseInferenceContext.__post_init__` works as intended. - """ - t1 = SqlTable.from_string("db.schema1.table1") - t2 = SqlTable.from_string("db.schema2.table1") - - t1_cols = [ - ColumnProperties( - column=SqlColumn(table=t1, column_name="column1"), - type=InferenceColumnType.INTEGER, - row_count=1000, - distinct_row_count=1000, - is_nullable=False, - null_count=0, - min_value=0, - max_value=999, - ), - ColumnProperties( - column=SqlColumn(table=t1, column_name="column2"), - type=InferenceColumnType.FLOAT, - row_count=2000, - distinct_row_count=1000, - is_nullable=True, - null_count=10, - min_value=0, - max_value=1000, - ), - ] - t1_props = TableProperties(table=t1, column_props=t1_cols) - - t2_cols = [ - ColumnProperties( - column=SqlColumn(table=t2, column_name="column_a"), - type=InferenceColumnType.FLOAT, - row_count=1000, - distinct_row_count=1000, - is_nullable=False, - null_count=0, - min_value=0, - max_value=999, - ), - ] - t2_props = TableProperties(table=t2, column_props=t2_cols) - - ctx = DataWarehouseInferenceContext(table_props=[t1_props, t2_props]) - - assert ctx.tables == {t1: t1_props, t2: t2_props} - - assert ctx.columns == { - t1_cols[0].column: t1_cols[0], - t1_cols[1].column: t1_cols[1], - t2_cols[0].column: t2_cols[0], - } diff --git a/tests_metricflow/inference/context/test_snowflake.py b/tests_metricflow/inference/context/test_snowflake.py deleted file mode 100644 index 7a1fd00c0a..0000000000 --- a/tests_metricflow/inference/context/test_snowflake.py +++ /dev/null @@ -1,131 +0,0 @@ -from __future__ import annotations - -import itertools -import json -from typing import Dict, List, Union -from unittest.mock import MagicMock - -from metricflow.data_table.mf_table import MetricFlowDataTable -from metricflow.inference.context.data_warehouse import ( - ColumnProperties, - DataWarehouseInferenceContext, - InferenceColumnType, - TableProperties, -) -from metricflow.inference.context.snowflake import SnowflakeInferenceContextProvider -from metricflow.sql.sql_column import SqlColumn -from metricflow.sql.sql_table import SqlTable - - -def test_column_type_conversion() -> None: # noqa: D103 - ctx_provider = SnowflakeInferenceContextProvider(client=MagicMock(), tables=[]) - - # known snowflake types - assert ctx_provider._column_type_from_show_columns_data_type("FIXED") == InferenceColumnType.INTEGER - assert ctx_provider._column_type_from_show_columns_data_type("REAL") == InferenceColumnType.FLOAT - assert ctx_provider._column_type_from_show_columns_data_type("BOOLEAN") == InferenceColumnType.BOOLEAN - assert ctx_provider._column_type_from_show_columns_data_type("DATE") == InferenceColumnType.DATETIME - assert ctx_provider._column_type_from_show_columns_data_type("TIMESTAMP_TZ") == InferenceColumnType.DATETIME - assert ctx_provider._column_type_from_show_columns_data_type("TIMESTAMP_LTZ") == InferenceColumnType.DATETIME - assert ctx_provider._column_type_from_show_columns_data_type("TIMESTAMP_NTZ") == InferenceColumnType.DATETIME - - # String types - assert ctx_provider._column_type_from_show_columns_data_type("VARCHAR") == InferenceColumnType.STRING - assert ctx_provider._column_type_from_show_columns_data_type("VARCHAR(256)") == InferenceColumnType.STRING - assert ctx_provider._column_type_from_show_columns_data_type("CHAR") == InferenceColumnType.STRING - assert ctx_provider._column_type_from_show_columns_data_type("CHAR(8)") == InferenceColumnType.STRING - assert ctx_provider._column_type_from_show_columns_data_type("CHARACTER(8)") == InferenceColumnType.STRING - assert ctx_provider._column_type_from_show_columns_data_type("NCHAR(8)") == InferenceColumnType.STRING - assert ctx_provider._column_type_from_show_columns_data_type("STRING") == InferenceColumnType.STRING - assert ctx_provider._column_type_from_show_columns_data_type("TEXT") == InferenceColumnType.STRING - assert ctx_provider._column_type_from_show_columns_data_type("NVARCHAR(16777216)") == InferenceColumnType.STRING - assert ctx_provider._column_type_from_show_columns_data_type("NVARCHAR2(16777216)") == InferenceColumnType.STRING - assert ctx_provider._column_type_from_show_columns_data_type("CHAR VARYING(16777216)") == InferenceColumnType.STRING - assert ( - ctx_provider._column_type_from_show_columns_data_type("NCHAR VARYING(16777216)") == InferenceColumnType.STRING - ) - - # unknowns - assert ctx_provider._column_type_from_show_columns_data_type("BINARY") == InferenceColumnType.UNKNOWN - assert ctx_provider._column_type_from_show_columns_data_type("TIME") == InferenceColumnType.UNKNOWN - - -def test_context_provider() -> None: - """Test `SnowflakeInferenceContextProvider` implementation. - - This test case currently mocks the Snowflake response with a `MagicMock`. This is not ideal - and should probably be replaced by integration tests in the future. - """ - # See for SHOW COLUMNS result data_table spec: - # https://docs.snowflake.com/en/sql-reference/sql/show-columns.html - - show_columns_result_dict: Dict[str, List[Union[int, str]]] = { - "column_name": ["INTCOL", "STRCOL"], - "schema_name": ["SCHEMA", "SCHEMA"], - "table_name": ["TABLE", "TABLE"], - "database_name": ["DB", "DB"], - "data_type": [ - json.dumps({"type": "FIXED", "nullable": False}), - json.dumps({"type": "TEXT", "nullable": True}), - ], - } - show_columns_result = MetricFlowDataTable.create_from_rows( - column_names=tuple(show_columns_result_dict.keys()), - rows=tuple(itertools.zip_longest(*show_columns_result_dict.values())), - ) - stats_result_dict: Dict[str, List[Union[int, str]]] = { - "intcol_countdistinct": [10], - "intcol_min": [0], - "intcol_max": [10], - "intcol_countnull": [0], - "strcol_countdistinct": [40], - "strcol_min": ["aaaa"], - "strcol_max": ["zzzz"], - "strcol_countnull": [10], - "rowcount": [50], - } - stats_result = MetricFlowDataTable.create_from_rows( - column_names=tuple(stats_result_dict.keys()), - rows=tuple(itertools.zip_longest(*stats_result_dict.values())), - ) - client = MagicMock() - client.query = MagicMock() - client.query.side_effect = [show_columns_result, stats_result] - - ctx_provider = SnowflakeInferenceContextProvider( - client=client, - tables=[SqlTable.from_string("db.schema.table")], - max_sample_size=50, - ) - - ctx = ctx_provider.get_context() - - assert ctx == DataWarehouseInferenceContext( - [ - TableProperties( - table=SqlTable.from_string("db.schema.table"), - column_props=[ - ColumnProperties( - column=SqlColumn.from_string("db.schema.table.intcol"), - type=InferenceColumnType.INTEGER, - row_count=50, - distinct_row_count=10, - is_nullable=False, - null_count=0, - min_value="0", - max_value="10", - ), - ColumnProperties( - column=SqlColumn.from_string("db.schema.table.strcol"), - type=InferenceColumnType.STRING, - row_count=50, - distinct_row_count=40, - is_nullable=True, - null_count=10, - min_value="aaaa", - max_value="zzzz", - ), - ], - ), - ] - ) diff --git a/tests_metricflow/inference/renderer/__init__.py b/tests_metricflow/inference/renderer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests_metricflow/inference/renderer/test_config_file.py b/tests_metricflow/inference/renderer/test_config_file.py deleted file mode 100644 index ee44840244..0000000000 --- a/tests_metricflow/inference/renderer/test_config_file.py +++ /dev/null @@ -1,84 +0,0 @@ -"""These tests rely on the pytest tmp_path factory fixture.""" - -from __future__ import annotations - -import os -from pathlib import Path - -import pytest -from ruamel.yaml import YAML - -from metricflow.inference.models import InferenceResult, InferenceSignalType -from metricflow.inference.renderer.config_file import ConfigFileRenderer -from metricflow.sql.sql_column import SqlColumn - -yaml = YAML() - - -def test_no_overwrite_with_existing_dir(tmpdir: Path) -> None: # noqa: D103 - with pytest.raises(ValueError): - ConfigFileRenderer(tmpdir, False) - - -def test_dir_path_is_file(tmpdir: Path) -> None: # noqa: D103 - file_path = os.path.join(tmpdir, "file.txt") - with open(file_path, "w") as f: - f.write("file contents!") - - with pytest.raises(ValueError): - ConfigFileRenderer(file_path, False) - - -def test_render_configs(tmpdir: Path) -> None: # noqa: D103 - inference_results = [ - InferenceResult( - column=SqlColumn.from_string("db.schema.test_table.id"), - type_node=InferenceSignalType.ID.PRIMARY, - reasons=[], - problems=[], - ), - InferenceResult( - column=SqlColumn.from_string("db.schema.test_table.time_dim"), - type_node=InferenceSignalType.DIMENSION.TIME, - reasons=[], - problems=[], - ), - InferenceResult( - column=SqlColumn.from_string("db.schema.test_table.primary_time_dim"), - type_node=InferenceSignalType.DIMENSION.PRIMARY_TIME, - reasons=[], - problems=[], - ), - ] - - renderer = ConfigFileRenderer(tmpdir, True) - - renderer.render(inference_results) - - table_file_path = os.path.join(tmpdir, "db.schema.test_table.yaml") - assert os.path.isfile(table_file_path) - - with open(table_file_path, "r") as f: - file_contents = yaml.load(f) - - assert file_contents == { - "semantic_model": { - "name": "test_table", - "node_relation": { - "alias": "test_table", - "schema_name": "schema", - "database": "db", - "relation_name": "db.schema.test_table", - }, - "entities": [{"type": "primary", "name": "id"}], - "dimensions": [ - {"type": "time", "name": "time_dim", "type_params": {"time_granularity": "day"}}, - { - "type": "time", - "name": "primary_time_dim", - "type_params": {"is_primary": True, "time_granularity": "day"}, - }, - ], - "measures": [], - } - } diff --git a/tests_metricflow/inference/rule/__init__.py b/tests_metricflow/inference/rule/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests_metricflow/inference/rule/conftest.py b/tests_metricflow/inference/rule/conftest.py deleted file mode 100644 index 2f2da5cb96..0000000000 --- a/tests_metricflow/inference/rule/conftest.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -import pytest - -from metricflow.inference.context.data_warehouse import ( - ColumnProperties, - DataWarehouseInferenceContext, - InferenceColumnType, - TableProperties, -) -from metricflow.sql.sql_column import SqlColumn -from metricflow.sql.sql_table import SqlTable - - -@pytest.fixture -def warehouse_ctx() -> DataWarehouseInferenceContext: - """A dummy DataWarehouseInferenceContext to be used as a fixture.""" - return DataWarehouseInferenceContext( - table_props=[ - TableProperties( - table=SqlTable.from_string("db.schema.table"), - column_props=[ - ColumnProperties( - column=SqlColumn.from_string("db.schema.table.id"), - type=InferenceColumnType.INTEGER, - row_count=1, - distinct_row_count=1, - is_nullable=True, - null_count=0, - min_value=0, - max_value=1, - ), - ColumnProperties( - column=SqlColumn.from_string("db.schema.table.othertable_id"), - type=InferenceColumnType.INTEGER, - row_count=2, - distinct_row_count=2, - is_nullable=True, - null_count=0, - min_value=0, - max_value=1, - ), - ColumnProperties( - column=SqlColumn.from_string("db.schema.table.test_column"), - type=InferenceColumnType.INTEGER, - row_count=1, - distinct_row_count=1, - is_nullable=True, - null_count=0, - min_value=0, - max_value=1, - ), - ColumnProperties( - column=SqlColumn.from_string("db.schema.othertable.othertable_id"), - type=InferenceColumnType.INTEGER, - row_count=2, - distinct_row_count=2, - is_nullable=True, - null_count=0, - min_value=0, - max_value=1, - ), - ], - ) - ] - ) diff --git a/tests_metricflow/inference/rule/test_defaults.py b/tests_metricflow/inference/rule/test_defaults.py deleted file mode 100644 index cdcc0357fe..0000000000 --- a/tests_metricflow/inference/rule/test_defaults.py +++ /dev/null @@ -1,226 +0,0 @@ -from __future__ import annotations - -import metricflow.inference.rule.defaults as defaults -from metricflow.inference.context.data_warehouse import ( - ColumnProperties, - DataWarehouseInferenceContext, - InferenceColumnType, - TableProperties, -) -from metricflow.inference.models import InferenceSignalType -from metricflow.sql.sql_column import SqlColumn -from metricflow.sql.sql_table import SqlTable - - -def get_column_properties(column_str: str, type: InferenceColumnType, unique: bool) -> ColumnProperties: # noqa: D103 - return ColumnProperties( - column=SqlColumn.from_string(column_str), - type=type, - row_count=10000, - distinct_row_count=10000 if unique else 9000, - is_nullable=False, - null_count=0, - min_value=0, - max_value=9999, - ) - - -def test_any_entity_by_name_matcher() -> None: # noqa: D103 - assert defaults.AnyEntityByNameRule().match_column( - get_column_properties("db.schema.table.id", InferenceColumnType.INTEGER, True) - ) - assert defaults.AnyEntityByNameRule().match_column( - get_column_properties("db.schema.table.tableid", InferenceColumnType.INTEGER, True) - ) - assert defaults.AnyEntityByNameRule().match_column( - get_column_properties("db.schema.table.table_id", InferenceColumnType.INTEGER, True) - ) - assert defaults.AnyEntityByNameRule().match_column( - get_column_properties("db.schema.table.othertable_id", InferenceColumnType.INTEGER, True) - ) - assert not defaults.AnyEntityByNameRule().match_column( - get_column_properties("db.schema.table.whatever", InferenceColumnType.INTEGER, True) - ) - - -def test_primary_entity_by_name_matcher() -> None: # noqa: D103 - assert defaults.PrimaryEntityByNameRule().match_column( - get_column_properties("db.schema.table.id", InferenceColumnType.INTEGER, True) - ) - assert defaults.PrimaryEntityByNameRule().match_column( - get_column_properties("db.schema.table.tableid", InferenceColumnType.INTEGER, True) - ) - assert defaults.PrimaryEntityByNameRule().match_column( - get_column_properties("db.schema.table.table_id", InferenceColumnType.INTEGER, True) - ) - assert defaults.PrimaryEntityByNameRule().match_column( - get_column_properties("db.schema.tables.table_id", InferenceColumnType.INTEGER, True) - ) - assert defaults.PrimaryEntityByNameRule().match_column( - get_column_properties("db.schema.tables.tableid", InferenceColumnType.INTEGER, True) - ) - assert not defaults.PrimaryEntityByNameRule().match_column( - get_column_properties("db.schema.table.othertable_id", InferenceColumnType.INTEGER, True) - ) - assert not defaults.PrimaryEntityByNameRule().match_column( - get_column_properties("db.schema.table.othertableid", InferenceColumnType.INTEGER, True) - ) - assert not defaults.PrimaryEntityByNameRule().match_column( - get_column_properties("db.schema.table.whatever", InferenceColumnType.INTEGER, True) - ) - - -def test_unique_entity_by_distinct_count_matcher() -> None: # noqa: D103 - assert defaults.UniqueEntityByDistinctCountRule().match_column( - get_column_properties("db.schema.table.unique_id", InferenceColumnType.INTEGER, True) - ) - assert not defaults.UniqueEntityByDistinctCountRule().match_column( - get_column_properties("db.schema.table.unique_id", InferenceColumnType.STRING, False) - ) - - -def test_time_dimension_by_time_type_matcher() -> None: # noqa: D103 - assert defaults.TimeDimensionByTimeTypeRule().match_column( - get_column_properties("db.schema.table.time", InferenceColumnType.DATETIME, True) - ) - - assert not defaults.TimeDimensionByTimeTypeRule().match_column( - get_column_properties("db.schema.table.time", InferenceColumnType.INTEGER, True) - ) - assert not defaults.TimeDimensionByTimeTypeRule().match_column( - get_column_properties("db.schema.table.time", InferenceColumnType.FLOAT, True) - ) - assert not defaults.TimeDimensionByTimeTypeRule().match_column( - get_column_properties("db.schema.table.time", InferenceColumnType.BOOLEAN, True) - ) - assert not defaults.TimeDimensionByTimeTypeRule().match_column( - get_column_properties("db.schema.table.time", InferenceColumnType.STRING, True) - ) - assert not defaults.TimeDimensionByTimeTypeRule().match_column( - get_column_properties("db.schema.table.time", InferenceColumnType.UNKNOWN, True) - ) - - -def test_primary_time_dimension_by_name_matcher() -> None: # noqa: D103 - assert defaults.PrimaryTimeDimensionByNameRule().match_column( - get_column_properties("db.schema.table.ds", InferenceColumnType.DATETIME, True) - ) - assert defaults.PrimaryTimeDimensionByNameRule().match_column( - get_column_properties("db.schema.table.created_at", InferenceColumnType.DATETIME, True) - ) - assert not defaults.PrimaryTimeDimensionByNameRule().match_column( - get_column_properties("db.schema.table.bla", InferenceColumnType.DATETIME, True) - ) - assert not defaults.PrimaryTimeDimensionByNameRule().match_column( - get_column_properties("db.schema.table.time", InferenceColumnType.DATETIME, True) - ) - - -def test_primary_time_dimension_if_only_time_rule() -> None: # noqa: D103 - table = SqlTable.from_string("db.schema.table") - single_time_col_warehouse = DataWarehouseInferenceContext( - table_props=[ - TableProperties( - table=table, - column_props=[ - get_column_properties("db.schema.table.id", InferenceColumnType.INTEGER, True), - get_column_properties("db.schema.table.time", InferenceColumnType.DATETIME, True), - ], - ) - ] - ) - single_time_col_signals = defaults.PrimaryTimeDimensionIfOnlyTimeRule().process(single_time_col_warehouse) - assert len(single_time_col_signals) == 1 - assert single_time_col_signals[0].column == SqlColumn.from_string("db.schema.table.time") - assert single_time_col_signals[0].type_node == InferenceSignalType.DIMENSION.PRIMARY_TIME - - many_time_col_warehouse = DataWarehouseInferenceContext( - table_props=[ - TableProperties( - table=table, - column_props=[ - get_column_properties("db.schema.table.id", InferenceColumnType.INTEGER, True), - get_column_properties("db.schema.table.time", InferenceColumnType.DATETIME, True), - get_column_properties("db.schema.table.othertime", InferenceColumnType.DATETIME, True), - ], - ) - ] - ) - many_time_col_signals = defaults.PrimaryTimeDimensionIfOnlyTimeRule().process(many_time_col_warehouse) - assert len(many_time_col_signals) == 0 - - -def test_categorical_dimension_by_boolean_type_matcher() -> None: # noqa: D103 - assert defaults.CategoricalDimensionByBooleanTypeRule().match_column( - get_column_properties("db.schema.table.dim", InferenceColumnType.BOOLEAN, True) - ) - assert not defaults.CategoricalDimensionByBooleanTypeRule().match_column( - get_column_properties("db.schema.table.bla", InferenceColumnType.FLOAT, True) - ) - - -def test_categorical_dimension_by_string_type_matcher() -> None: # noqa: D103 - assert defaults.CategoricalDimensionByStringTypeRule().match_column( - get_column_properties("db.schema.table.dim", InferenceColumnType.STRING, True) - ) - assert not defaults.CategoricalDimensionByStringTypeRule().match_column( - get_column_properties("db.schema.table.bla", InferenceColumnType.FLOAT, True) - ) - - -def test_categorical_dimension_by_string__and_cardinality_type_matcher() -> None: - """Tests the composite of string type and cardinality below supplied threshold. - - Since the helper cardinality ratio is always either 1 or 0.9, the cardinality thresholds are set to either above - 0.9 (for checks which should match) or below 0.9 (for checks which should not match, or where the match does - not matter) - """ - assert defaults.CategoricalDimensionByStringTypeAndLowCardinalityRule(0.99).match_column( - get_column_properties("db.schema.table.low_cardinality_string_col", InferenceColumnType.STRING, unique=False) - ) - # INTEGER type columns never match this rule - assert not defaults.CategoricalDimensionByStringTypeAndLowCardinalityRule(0.99).match_column( - get_column_properties("db.schema.table.int_col", InferenceColumnType.INTEGER, unique=False) - ) - assert not defaults.CategoricalDimensionByCardinalityRatioRule(0.40).match_column( - get_column_properties("db.schema.table.high_cardinality_string_col", InferenceColumnType.STRING, unique=False) - ) - - -def test_categorical_dimension_by_integer_type_matcher() -> None: # noqa: D103 - assert defaults.CategoricalDimensionByIntegerTypeRule().match_column( - get_column_properties("db.schema.table.dim", InferenceColumnType.INTEGER, True) - ) - assert not defaults.CategoricalDimensionByIntegerTypeRule().match_column( - get_column_properties("db.schema.table.bla", InferenceColumnType.FLOAT, True) - ) - - -def test_measure_by_real_type_matcher() -> None: # noqa: D103 - assert defaults.MeasureByRealTypeRule().match_column( - get_column_properties("db.schema.table.measure", InferenceColumnType.FLOAT, True) - ) - assert not defaults.MeasureByRealTypeRule().match_column( - get_column_properties("db.schema.table.measure", InferenceColumnType.INTEGER, True) - ) - assert not defaults.MeasureByRealTypeRule().match_column( - get_column_properties("db.schema.table.measure", InferenceColumnType.DATETIME, True) - ) - assert not defaults.MeasureByRealTypeRule().match_column( - get_column_properties("db.schema.table.measure", InferenceColumnType.BOOLEAN, True) - ) - assert not defaults.MeasureByRealTypeRule().match_column( - get_column_properties("db.schema.table.measure", InferenceColumnType.STRING, True) - ) - assert not defaults.MeasureByRealTypeRule().match_column( - get_column_properties("db.schema.table.measure", InferenceColumnType.UNKNOWN, True) - ) - - -def test_measure_by_integer_type_matcher() -> None: # noqa: D103 - assert defaults.MeasureByIntegerTypeRule().match_column( - get_column_properties("db.schema.table.measure", InferenceColumnType.INTEGER, True) - ) - assert not defaults.MeasureByRealTypeRule().match_column( - get_column_properties("db.schema.table.measure", InferenceColumnType.BOOLEAN, True) - ) diff --git a/tests_metricflow/inference/rule/test_rules.py b/tests_metricflow/inference/rule/test_rules.py deleted file mode 100644 index ffc479d67b..0000000000 --- a/tests_metricflow/inference/rule/test_rules.py +++ /dev/null @@ -1,93 +0,0 @@ -from __future__ import annotations - -from metricflow.inference.context.data_warehouse import ( - ColumnProperties, - DataWarehouseInferenceContext, - InferenceColumnType, - TableProperties, -) -from metricflow.inference.models import InferenceSignalConfidence, InferenceSignalType -from metricflow.inference.rule.rules import ColumnMatcherRule, LowCardinalityRatioRule -from metricflow.sql.sql_column import SqlColumn -from metricflow.sql.sql_table import SqlTable - - -def create_context_with_counts(rows: int, distinct: int, nulls: int) -> DataWarehouseInferenceContext: - """Get a `DataWarehouseInferenceContext` with the designated counts.""" - return DataWarehouseInferenceContext( - table_props=[ - TableProperties( - table=SqlTable.from_string("db.schema.table"), - column_props=[ - ColumnProperties( - column=SqlColumn.from_string("db.schema.table.column"), - type=InferenceColumnType.INTEGER, - row_count=rows, - distinct_row_count=distinct, - null_count=nulls, - is_nullable=nulls != 0, - min_value=0, - max_value=rows - 1, - ) - ], - ) - ] - ) - - -class ExampleLowCardinalityRule(LowCardinalityRatioRule): # noqa: D101 - type_node = InferenceSignalType.DIMENSION.CATEGORICAL - confidence = InferenceSignalConfidence.MEDIUM - only_applies_to_parent_signal = False - - -def test_column_matcher(warehouse_ctx: DataWarehouseInferenceContext) -> None: # noqa: D103 - class TestRule(ColumnMatcherRule): - type_node = InferenceSignalType.DIMENSION.UNKNOWN - confidence = InferenceSignalConfidence.MEDIUM - match_reason = "test reason" - only_applies_to_parent_signal = False - - def match_column(self, props: ColumnProperties) -> bool: - return props.column.column_name.endswith("test_column") - - signals = TestRule().process(warehouse_ctx) - - assert len(signals) == 1 - assert signals[0].confidence == InferenceSignalConfidence.MEDIUM - assert signals[0].type_node == InferenceSignalType.DIMENSION.UNKNOWN - assert signals[0].column == SqlColumn.from_string("db.schema.table.test_column") - assert signals[0].reason == "test reason" - assert not signals[0].only_applies_to_parent - - -def test_low_cardinality_ratio_rule_high_cardinality_doesnt_match() -> None: # noqa: D103 - rule = ExampleLowCardinalityRule(0.1) - ctx = create_context_with_counts(100, 100, 0) - - signals = rule.process(ctx) - assert len(signals) == 0 - - -def test_low_cardinality_ratio_rule_low_cardinality_lots_of_nulls_doesnt_match() -> None: # noqa: D103 - rule = ExampleLowCardinalityRule(0.1) - ctx = create_context_with_counts(100, 2, 99) - - signals = rule.process(ctx) - assert len(signals) == 0 - - -def test_low_cardinality_ratio_rule_low_cardinality_all_nulls_doesnt_match() -> None: # noqa: D103 - rule = ExampleLowCardinalityRule(0.1) - ctx = create_context_with_counts(100, 1, 100) - - signals = rule.process(ctx) - assert len(signals) == 0 - - -def test_low_cardinality_ratio_rule_low_cardinality_matches() -> None: # noqa: D103 - rule = ExampleLowCardinalityRule(0.1) - ctx = create_context_with_counts(100, 1, 0) - - signals = rule.process(ctx) - assert len(signals) == 1 diff --git a/tests_metricflow/inference/solver/__init__.py b/tests_metricflow/inference/solver/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests_metricflow/inference/solver/test_weighted_tree.py b/tests_metricflow/inference/solver/test_weighted_tree.py deleted file mode 100644 index 89ce225354..0000000000 --- a/tests_metricflow/inference/solver/test_weighted_tree.py +++ /dev/null @@ -1,155 +0,0 @@ -from __future__ import annotations - -from metricflow.inference.models import InferenceSignal, InferenceSignalConfidence, InferenceSignalType -from metricflow.inference.solver.weighted_tree import WeightedTypeTreeInferenceSolver -from metricflow.sql.sql_column import SqlColumn - -column = SqlColumn.from_string("db.schema.table.col") -solver = WeightedTypeTreeInferenceSolver() - - -def test_empty_signals_return_unknown() -> None: # noqa: D103 - result = solver.solve_column(column, []) - - assert result.type_node == InferenceSignalType.UNKNOWN - assert len(result.reasons) == 0 - assert len(result.problems) == 2 - - -def test_follow_signal_path() -> None: - """Test that the solver will return the deepest (most specific) node if it finds a path with multiple signals.""" - signals = [ - InferenceSignal( - column=column, - type_node=InferenceSignalType.ID.UNIQUE, - reason="UNIQUE", - confidence=InferenceSignalConfidence.HIGH, - only_applies_to_parent=False, - ), - InferenceSignal( - column=column, - type_node=InferenceSignalType.ID.PRIMARY, - reason="PRIMARY", - confidence=InferenceSignalConfidence.VERY_HIGH, - only_applies_to_parent=False, - ), - ] - - result = solver.solve_column(column, signals) - - assert result.type_node == InferenceSignalType.ID.PRIMARY - assert "UNIQUE" in result.reasons[0] and "PRIMARY" in result.reasons[1] - - -def test_complementary_signal_with_parent_trail() -> None: - """Test that the solver will follow the weight trail and take complementary signals into account if parent has weight.""" - signals = [ - InferenceSignal( - column=column, - type_node=InferenceSignalType.ID.UNKNOWN, - reason="ID", - confidence=InferenceSignalConfidence.HIGH, - only_applies_to_parent=False, - ), - InferenceSignal( - column=column, - type_node=InferenceSignalType.ID.UNIQUE, - reason="UNIQUE", - confidence=InferenceSignalConfidence.VERY_HIGH, - only_applies_to_parent=True, - ), - InferenceSignal( - column=column, - type_node=InferenceSignalType.ID.PRIMARY, - reason="PRIMARY", - confidence=InferenceSignalConfidence.VERY_HIGH, - only_applies_to_parent=True, - ), - ] - - result = solver.solve_column(column, signals) - - assert result.type_node == InferenceSignalType.ID.PRIMARY - assert "ID" in result.reasons[0] and "UNIQUE" in result.reasons[1] and "PRIMARY" in result.reasons[2] - - -def test_complementary_signals_without_parent_signal() -> None: - """Test that the solver won't follow the weight trail and take complementary signals into account if parent has no weight.""" - signals = [ - InferenceSignal( - column=column, - type_node=InferenceSignalType.DIMENSION.CATEGORICAL, - reason="CATEG_DIM", - confidence=InferenceSignalConfidence.MEDIUM, - only_applies_to_parent=False, - ), - InferenceSignal( - column=column, - type_node=InferenceSignalType.ID.UNIQUE, - reason="UNIQUE", - confidence=InferenceSignalConfidence.VERY_HIGH, - only_applies_to_parent=True, - ), - InferenceSignal( - column=column, - type_node=InferenceSignalType.ID.PRIMARY, - reason="PRIMARY", - confidence=InferenceSignalConfidence.VERY_HIGH, - only_applies_to_parent=True, - ), - ] - - result = solver.solve_column(column, signals) - - assert result.type_node == InferenceSignalType.DIMENSION.CATEGORICAL - assert "CATEG_DIM" in result.reasons[0] - - -def test_contradicting_signals() -> None: - """Test that the solver will return the deepest common ancestor if it finds conflicting signals.""" - signals = [ - InferenceSignal( - column=column, - type_node=InferenceSignalType.ID.FOREIGN, - reason="FOREIGN", - confidence=InferenceSignalConfidence.HIGH, - only_applies_to_parent=False, - ), - InferenceSignal( - column=column, - type_node=InferenceSignalType.ID.PRIMARY, - reason="PRIMARY", - confidence=InferenceSignalConfidence.HIGH, - only_applies_to_parent=False, - ), - ] - - result = solver.solve_column(column, signals) - - assert result.type_node == InferenceSignalType.ID.UNKNOWN - - -def test_stop_at_internal_node_if_trail_stops() -> None: - """Test that if the signal trail stops at an internal node the solver will return that node instead of going deeper.""" - signals = [ - InferenceSignal( - column=column, - type_node=InferenceSignalType.ID.UNKNOWN, - reason="KEY", - confidence=InferenceSignalConfidence.HIGH, - only_applies_to_parent=False, - ), - InferenceSignal( - column=column, - type_node=InferenceSignalType.ID.UNIQUE, - reason="UNIQUE", - confidence=InferenceSignalConfidence.HIGH, - only_applies_to_parent=False, - ), - ] - - result = solver.solve_column(column, signals) - - # should not progress further into the tree and assume it's PRIMARY - assert result.type_node == InferenceSignalType.ID.UNIQUE - assert "KEY" in result.reasons[0] and "UNIQUE" in result.reasons[1] diff --git a/tests_metricflow/inference/test_models.py b/tests_metricflow/inference/test_models.py deleted file mode 100644 index ae59f24679..0000000000 --- a/tests_metricflow/inference/test_models.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -from metricflow.inference.models import InferenceSignalType - - -def test_inference_type_node_conflict() -> None: - """Make sure the inference signal type hierarchy is correctly configured.""" - # IDENTIFIER - assert InferenceSignalType.ID.UNKNOWN.is_subtype_of(InferenceSignalType.UNKNOWN) - assert not InferenceSignalType.ID.UNKNOWN.is_subtype_of(InferenceSignalType.DIMENSION.UNKNOWN) - assert not InferenceSignalType.ID.UNKNOWN.is_subtype_of(InferenceSignalType.MEASURE.UNKNOWN) - - assert InferenceSignalType.ID.UNIQUE.is_subtype_of(InferenceSignalType.ID.UNKNOWN) - assert InferenceSignalType.ID.FOREIGN.is_subtype_of(InferenceSignalType.ID.UNKNOWN) - assert InferenceSignalType.ID.PRIMARY.is_subtype_of(InferenceSignalType.ID.UNIQUE) - - # DIMENSION - assert InferenceSignalType.DIMENSION.UNKNOWN.is_subtype_of(InferenceSignalType.UNKNOWN) - assert not InferenceSignalType.DIMENSION.UNKNOWN.is_subtype_of(InferenceSignalType.ID.UNKNOWN) - assert not InferenceSignalType.DIMENSION.UNKNOWN.is_subtype_of(InferenceSignalType.MEASURE.UNKNOWN) - - assert InferenceSignalType.DIMENSION.CATEGORICAL.is_subtype_of(InferenceSignalType.DIMENSION.UNKNOWN) - assert InferenceSignalType.DIMENSION.TIME.is_subtype_of(InferenceSignalType.DIMENSION.UNKNOWN) - - # MEASURE - assert InferenceSignalType.MEASURE.UNKNOWN.is_subtype_of(InferenceSignalType.UNKNOWN) - assert not InferenceSignalType.MEASURE.UNKNOWN.is_subtype_of(InferenceSignalType.ID.UNKNOWN) - assert not InferenceSignalType.MEASURE.UNKNOWN.is_subtype_of(InferenceSignalType.DIMENSION.UNKNOWN)