Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assorted Fixes for the sql Module #1081

Merged
merged 7 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions metricflow/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):

SQL_PLAN_SELECT_STATEMENT_ID_PREFIX = "ss"
SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc"
SQL_PLAN_QUERY_FROM_CLAUSE_ID_PREFIX = "qfc"

EXEC_NODE_READ_SQL_QUERY = "rsq"
EXEC_NODE_NOOP = "noop"
Expand Down
10 changes: 7 additions & 3 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from metricflow.dataflow.nodes.write_to_dataframe import WriteToResultDataframeNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer
from metricflow.dataflow.sql_table import SqlTable
from metricflow.dataset.dataset import DataSet
from metricflow.errors.errors import UnableToSatisfyQueryError
from metricflow.filters.time_constraint import TimeRangeConstraint
Expand Down Expand Up @@ -86,6 +85,7 @@
)
from metricflow.specs.where_filter_transform import WhereSpecFactory
from metricflow.sql.sql_plan import SqlJoinType
from metricflow.sql.sql_table import SqlTable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -422,7 +422,11 @@ def _build_base_metric_output_node(
descendent_filter_specs=metric_spec.filter_specs,
)

logger.info(f"For {metric_spec}, needed measure is:\n" f"{mf_pformat(metric_input_measure_spec)}")
logger.info(
f"For\n{indent(mf_pformat(metric_spec))}"
f"\nneeded measure is:"
f"\n{indent(mf_pformat(metric_input_measure_spec))}"
)

aggregated_measures_node = self.build_aggregated_measure(
metric_input_measure_spec=metric_input_measure_spec,
Expand Down Expand Up @@ -580,7 +584,7 @@ def _build_metrics_output_node(
output_nodes: List[BaseOutput] = []

for metric_spec in metric_specs:
logger.info(f"Generating compute metrics node for {metric_spec}")
logger.info(f"Generating compute metrics node for:\n{indent(mf_pformat(metric_spec))}")
self._metric_lookup.get_metric(metric_spec.reference)

output_nodes.append(
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/write_to_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
SinkNodeVisitor,
SinkOutput,
)
from metricflow.dataflow.sql_table import SqlTable
from metricflow.sql.sql_table import SqlTable
from metricflow.visitor import VisitorOutputT


Expand Down
4 changes: 1 addition & 3 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from metricflow.aggregation_properties import AggregationState
from metricflow.dag.id_prefix import DynamicIdPrefix, StaticIdPrefix
from metricflow.dag.sequential_id import SequentialIdGenerator
from metricflow.dataflow.sql_table import SqlTable
from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.instances import (
Expand Down Expand Up @@ -49,11 +48,11 @@
SqlStringExpression,
)
from metricflow.sql.sql_plan import (
SqlQueryPlanNode,
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableFromClauseNode,
)
from metricflow.sql.sql_table import SqlTable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -492,7 +491,6 @@ def create_sql_source_data_set(self, semantic_model: SemanticModel) -> SemanticM
all_select_columns.extend(select_columns)

# Generate the "from" clause depending on whether it's an SQL query or an SQL table.
from_source: Optional[SqlQueryPlanNode] = None
from_source = SqlTableFromClauseNode(sql_table=SqlTable.from_string(semantic_model.node_relation.relation_name))

select_statement_node = SqlSelectStatementNode(
Expand Down
12 changes: 0 additions & 12 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from typing import Sequence

import more_itertools

from metricflow.dataset.dataset import DataSet
from metricflow.instances import (
InstanceSet,
Expand Down Expand Up @@ -106,13 +104,3 @@ def column_association_for_time_dimension(
)

return column_associations_to_return[0]

@property
def groupable_column_associations(self) -> Sequence[ColumnAssociation]:
"""Return a flattened iterable with all groupable column associations for the current data set."""
instances = (
self.instance_set.entity_instances
+ self.instance_set.dimension_instances
+ self.instance_set.time_dimension_instances
)
return tuple(more_itertools.flatten([instance.associated_columns for instance in instances]))
2 changes: 1 addition & 1 deletion metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import (
SourceScanOptimizer,
)
from metricflow.dataflow.sql_table import SqlTable
from metricflow.dataset.convert_semantic_model import SemanticModelToDataSetConverter
from metricflow.dataset.dataset import DataSet
from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet
Expand Down Expand Up @@ -54,6 +53,7 @@
from metricflow.specs.query_param_implementations import SavedQueryParameter
from metricflow.specs.specs import InstanceSpecSet, MetricFlowQuerySpec
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from metricflow.sql.sql_table import SqlTable
from metricflow.telemetry.models import TelemetryLevel
from metricflow.telemetry.reporter import TelemetryReporter, log_call
from metricflow.time.time_source import TimeSource
Expand Down
2 changes: 1 addition & 1 deletion metricflow/execution/execution_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

from metricflow.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag, NodeId
from metricflow.dataflow.sql_table import SqlTable
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql.sql_table import SqlTable
from metricflow.visitor import Visitable

logger = logging.getLogger(__name__)
Expand Down
4 changes: 2 additions & 2 deletions metricflow/inference/context/data_warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from enum import Enum
from typing import Callable, ContextManager, Dict, Generic, Iterator, List, Optional, TypeVar

from metricflow.dataflow.sql_column import SqlColumn
from metricflow.dataflow.sql_table import SqlTable
from metricflow.inference.context.base import InferenceContext, InferenceContextProvider
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.sql_column import SqlColumn
from metricflow.sql.sql_table import SqlTable

T = TypeVar("T", str, int, float, date, datetime)

Expand Down
4 changes: 2 additions & 2 deletions metricflow/inference/context/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import json

from metricflow.dataflow.sql_column import SqlColumn
from metricflow.dataflow.sql_table import SqlTable
from metricflow.inference.context.data_warehouse import (
ColumnProperties,
DataWarehouseInferenceContextProvider,
InferenceColumnType,
TableProperties,
)
from metricflow.sql.sql_column import SqlColumn
from metricflow.sql.sql_table import SqlTable


class SnowflakeInferenceContextProvider(DataWarehouseInferenceContextProvider):
Expand Down
2 changes: 1 addition & 1 deletion metricflow/inference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import Enum
from typing import List, Optional

from metricflow.dataflow.sql_column import SqlColumn
from metricflow.sql.sql_column import SqlColumn


class InferenceSignalConfidence(Enum):
Expand Down
2 changes: 1 addition & 1 deletion metricflow/inference/renderer/config_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from ruamel.yaml.comments import CommentedMap
from typing_extensions import NotRequired

from metricflow.dataflow.sql_table import SqlTable
from metricflow.inference.models import InferenceResult, InferenceSignalType
from metricflow.inference.renderer.base import InferenceRenderer
from metricflow.sql.sql_table import SqlTable

yaml = YAML()

Expand Down
2 changes: 1 addition & 1 deletion metricflow/inference/renderer/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from collections import defaultdict
from typing import Dict, List, TextIO

from metricflow.dataflow.sql_table import SqlTable
from metricflow.inference.models import InferenceResult
from metricflow.inference.renderer.base import InferenceRenderer
from metricflow.sql.sql_table import SqlTable


class StreamInferenceRenderer(InferenceRenderer):
Expand Down
2 changes: 1 addition & 1 deletion metricflow/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

import more_itertools

from metricflow.dataflow.sql_table import SqlTable
from metricflow.inference.context.base import InferenceContextProvider
from metricflow.inference.context.data_warehouse import DataWarehouseInferenceContextProvider
from metricflow.inference.renderer.base import InferenceRenderer
from metricflow.inference.rule.base import InferenceRule
from metricflow.inference.solver.base import InferenceSolver
from metricflow.sql.sql_table import SqlTable

logger = logging.getLogger(__file__)

Expand Down
2 changes: 1 addition & 1 deletion metricflow/inference/solver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from abc import ABC, abstractmethod
from typing import List

from metricflow.dataflow.sql_column import SqlColumn
from metricflow.inference.models import InferenceResult, InferenceSignal
from metricflow.sql.sql_column import SqlColumn


class InferenceSolver(ABC):
Expand Down
2 changes: 1 addition & 1 deletion metricflow/inference/solver/weighted_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections import defaultdict
from typing import Callable, Dict, List, Optional

from metricflow.dataflow.sql_column import SqlColumn
from metricflow.inference.models import (
InferenceResult,
InferenceSignal,
Expand All @@ -12,6 +11,7 @@
InferenceSignalType,
)
from metricflow.inference.solver.base import InferenceSolver
from metricflow.sql.sql_column import SqlColumn

NodeWeighterFunction = Callable[[InferenceSignalConfidence], int]

Expand Down
2 changes: 1 addition & 1 deletion metricflow/model/semantic_manifest_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from dbt_semantic_interfaces.protocols.semantic_manifest import SemanticManifest
from dbt_semantic_interfaces.type_enums import TimeGranularity

from metricflow.dataflow.sql_table import SqlTable
from metricflow.mf_logging.pretty_print import mf_pformat
from metricflow.model.semantics.metric_lookup import MetricLookup
from metricflow.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow.plan_conversion.time_spine import TimeSpineSource
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.sql.sql_table import SqlTable

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from metricflow.dataflow.nodes.write_to_dataframe import WriteToResultDataframeNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.dataflow.sql_table import SqlTable
from metricflow.execution.execution_plan import (
ExecutionPlan,
ExecutionPlanTask,
Expand All @@ -21,6 +20,7 @@
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.sql_table import SqlTable

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/time_spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.dataflow.sql_table import SqlTable
from metricflow.sql.sql_table import SqlTable

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

from dbt_semantic_interfaces.implementations.base import FrozenBaseModel
from metricflow.sql.sql_table import SqlTable

from metricflow.dataflow.sql_table import SqlTable


class SqlColumn(FrozenBaseModel):
@dataclass(frozen=True, order=True)
class SqlColumn:
"""Represents a reference to a SQL column."""

table: SqlTable
Expand Down
16 changes: 8 additions & 8 deletions metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from metricflow.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag, NodeId
from metricflow.dataflow.sql_table import SqlTable
from metricflow.sql.sql_exprs import SqlExpressionNode
from metricflow.sql.sql_table import SqlTable
from metricflow.visitor import VisitorOutputT

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,19 +42,19 @@ def parent_nodes(self) -> List[SqlQueryPlanNode]: # noqa: D
@abstractmethod
def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
"""Called when a visitor needs to visit this node."""
pass
raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't matter, but I like the consistency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it seems like that's what the docs suggest, though I've seen some conflicting examples.


@property
@abstractmethod
def is_table(self) -> bool:
"""Returns whether this node resolves to a table (vs. a query)."""
pass
raise NotImplementedError

@property
@abstractmethod
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
"""If possible, return this as a select statement node."""
pass
raise NotImplementedError


class SqlQueryPlanNodeVisitor(Generic[VisitorOutputT], ABC):
Expand All @@ -65,15 +65,15 @@ class SqlQueryPlanNodeVisitor(Generic[VisitorOutputT], ABC):

@abstractmethod
def visit_select_statement_node(self, node: SqlSelectStatementNode) -> VisitorOutputT: # noqa: D
pass
raise NotImplementedError

@abstractmethod
def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> VisitorOutputT: # noqa: D
pass
raise NotImplementedError

@abstractmethod
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> VisitorOutputT: # noqa: D
pass
raise NotImplementedError


@dataclass(frozen=True)
Expand Down Expand Up @@ -267,7 +267,7 @@ def __init__(self, select_query: str) -> None: # noqa: D

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D
return StaticIdPrefix.SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX
return StaticIdPrefix.SQL_PLAN_QUERY_FROM_CLAUSE_ID_PREFIX

@property
def description(self) -> str: # noqa: D
Expand Down
27 changes: 4 additions & 23 deletions metricflow/dataflow/sql_table.py → metricflow/sql/sql_table.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,16 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple, Union

from dbt_semantic_interfaces.implementations.base import (
FrozenBaseModel,
PydanticCustomInputParser,
PydanticParseableValueType,
)


class SqlTable(PydanticCustomInputParser, FrozenBaseModel):
@dataclass(frozen=True, order=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice, thanks! There was some reason we couldn't do this originally and I guess we forgot to clean it up later.

class SqlTable:
"""Represents a reference to a SQL table."""

db_name: Optional[str] = None
schema_name: str
table_name: str

@classmethod
def _from_yaml_value(cls, input: PydanticParseableValueType) -> SqlTable:
"""Parses a SqlTable from string input found in a user-provided model specification.

Raises a ValueError on any non-string input, as all user-provided specifications of table entities
should be strings conforming to the expectations defined in the from_string method.
"""
if isinstance(input, str):
return SqlTable.from_string(input)
else:
raise ValueError(
f"SqlTable inputs from model configs are expected to always be of type string, but got type "
f"{type(input)} with value: {input}"
)
db_name: Optional[str] = None

@staticmethod
def from_string(sql_str: str) -> SqlTable: # noqa: D
Expand Down
Loading
Loading