Skip to content

Commit

Permalink
/* PR_START p--cte 01 */ Rename SqlBindParameters to `SqlBindParame…
Browse files Browse the repository at this point in the history
…terSet`.
  • Loading branch information
plypaul committed Oct 15, 2024
1 parent a497234 commit 648d343
Show file tree
Hide file tree
Showing 27 changed files with 191 additions and 189 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metricflow_semantics.errors.error_classes import SqlBindParametersNotSupportedError
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.random_id import random_id
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.protocols.sql_client import SqlEngine
Expand Down Expand Up @@ -127,23 +127,25 @@ def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer:
def query(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(),
) -> MetricFlowDataTable:
"""Query statement; result expected to be data which will be returned as a DataTable.
Args:
stmt: The SQL query statement to run. This should produce output via a SELECT
sql_bind_parameters: The parameter replacement mapping for filling in concrete values for SQL query
sql_bind_parameter_set: The parameter replacement mapping for filling in concrete values for SQL query
parameters.
"""
start = time.time()
request_id = SqlRequestId(f"mf_rid__{random_id()}")
if sql_bind_parameters.param_dict:
if sql_bind_parameter_set.param_dict:
raise SqlBindParametersNotSupportedError(
f"Invalid query statement - we do not support queries with bind parameters through dbt adapters! "
f"Bind params: {sql_bind_parameters.param_dict}"
f"Bind params: {sql_bind_parameter_set.param_dict}"
)
logger.info(LazyFormat("Running query() statement", statement=stmt, param_dict=sql_bind_parameters.param_dict))
logger.info(
LazyFormat("Running query() statement", statement=stmt, param_dict=sql_bind_parameter_set.param_dict)
)
with self._adapter.connection_named(f"MetricFlow_request_{request_id}"):
# returns a Tuple[AdapterResponse, agate.Table] but the decorator converts it to Any
result = self._adapter.execute(sql=stmt, auto_begin=True, fetch=True)
Expand All @@ -167,24 +169,24 @@ def query(
def execute(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(),
) -> None:
"""Execute a SQL statement. No result will be returned.
Args:
stmt: The SQL query statement to run. This should not produce output.
sql_bind_parameters: The parameter replacement mapping for filling in
sql_bind_parameter_set: The parameter replacement mapping for filling in
concrete values for SQL query parameters.
"""
if sql_bind_parameters.param_dict:
if sql_bind_parameter_set.param_dict:
raise SqlBindParametersNotSupportedError(
f"Invalid execute statement - we do not support execute commands with bind parameters through dbt "
f"adapters! Bind params: {SqlBindParameters.param_dict}"
f"adapters! Bind params: {SqlBindParameterSet.param_dict}"
)
start = time.time()
request_id = SqlRequestId(f"mf_rid__{random_id()}")
logger.info(
LazyFormat("Running execute() statement", statement=stmt, param_dict=sql_bind_parameters.param_dict)
LazyFormat("Running execute() statement", statement=stmt, param_dict=sql_bind_parameter_set.param_dict)
)
with self._adapter.connection_named(f"MetricFlow_request_{request_id}"):
result = self._adapter.execute(stmt, auto_begin=True, fetch=False)
Expand All @@ -199,19 +201,19 @@ def execute(
def dry_run(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(),
) -> None:
"""Dry run statement; checks that the 'stmt' is queryable. Returns None on success.
Raises an exception if the 'stmt' isn't queryable.
Args:
stmt: The SQL query statement to dry run.
sql_bind_parameters: The parameter replacement mapping for filling in
sql_bind_parameter_set: The parameter replacement mapping for filling in
concrete values for SQL query parameters.
"""
start = time.time()
logger.info(LazyFormat("Running dry run", statement=stmt, param_dict=sql_bind_parameters.param_dict))
logger.info(LazyFormat("Running dry run", statement=stmt, param_dict=sql_bind_parameter_set.param_dict))
request_id = SqlRequestId(f"mf_rid__{random_id()}")
connection_name = f"MetricFlow_dry_run_request_{request_id}"
# TODO - consolidate to self._adapter.validate_sql() when all implementations will work from within MetricFlow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from metricflow_semantics.model.semantics.linkable_element import LinkableElement, LinkableElementUnion
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet


@dataclass(frozen=True)
Expand All @@ -26,7 +26,7 @@ class WhereFilterSpec(Mergeable, SerializableDataclass):
WhereFilterSpec(
where_sql="listing__country == 'US'",
bind_parameters: SqlBindParameters(),
bind_parameter_set: SqlBindParameters(),
linkable_specs: (
DimensionSpec(
element_name='country',
Expand All @@ -42,10 +42,10 @@ class WhereFilterSpec(Mergeable, SerializableDataclass):
)
"""

# Debating whether where_sql / bind_parameters belongs here. where_sql may become dialect specific if we introduce
# Debating whether where_sql / bind_parameter_set belongs here. where_sql may become dialect specific if we introduce
# quoted identifiers later.
where_sql: str
bind_parameters: SqlBindParameters
bind_parameters: SqlBindParameterSet
linkable_element_unions: Tuple[LinkableElementUnion, ...]
linkable_spec_set: LinkableSpecSet

Expand Down Expand Up @@ -83,7 +83,7 @@ def empty_instance(cls) -> WhereFilterSpec:
# line with other cases of Mergeable.
return WhereFilterSpec(
where_sql="TRUE",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_spec_set=LinkableSpecSet(),
linkable_element_unions=(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from metricflow_semantics.specs.where_filter.where_filter_metric import WhereFilterMetricFactory
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.specs.where_filter.where_filter_time_dimension import WhereFilterTimeDimensionFactory
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,7 +108,7 @@ def create_from_where_filter_intersection( # noqa: D102
filter_specs.append(
WhereFilterSpec(
where_sql=where_sql,
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_spec_set=LinkableSpecSet.create_from_specs(rendered_specs),
linkable_element_unions=tuple(linkable_element.as_union for linkable_element in linkable_elements),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class SqlBindParameter(SerializableDataclass): # noqa: D101


@dataclass(frozen=True)
class SqlBindParameters(SerializableDataclass):
class SqlBindParameterSet(SerializableDataclass):
"""Helps to build execution parameters during SQL query rendering.
These can be used as per https://docs.sqlalchemy.org/en/14/core/tutorial.html#using-textual-sql
Expand All @@ -83,7 +83,7 @@ class SqlBindParameters(SerializableDataclass):
# Using tuples for immutability as dicts are not.
param_items: Tuple[SqlBindParameter, ...] = ()

def combine(self, additional_params: SqlBindParameters) -> SqlBindParameters:
def combine(self, additional_params: SqlBindParameterSet) -> SqlBindParameterSet:
"""Create a new set of bind parameters that includes parameters from this and additional_params."""
if len(self.param_items) == 0:
return additional_params
Expand All @@ -108,7 +108,7 @@ def combine(self, additional_params: SqlBindParameters) -> SqlBindParameters:
new_items.append(item)
included_keys.add(item.key)

return SqlBindParameters(param_items=tuple(new_items))
return SqlBindParameterSet(param_items=tuple(new_items))

@property
def param_dict(self) -> OrderedDict[str, SqlColumnType]:
Expand All @@ -119,13 +119,13 @@ def param_dict(self) -> OrderedDict[str, SqlColumnType]:
return param_dict

@staticmethod
def create_from_dict(param_dict: Mapping[str, SqlColumnType]) -> SqlBindParameters: # noqa: D102
return SqlBindParameters(
def create_from_dict(param_dict: Mapping[str, SqlColumnType]) -> SqlBindParameterSet: # noqa: D102
return SqlBindParameterSet(
tuple(
SqlBindParameter(key=key, value=SqlBindParameterValue.create_from_sql_column_type(value))
for key, value in param_dict.items()
)
)

def __eq__(self, other: Any) -> bool: # type: ignore # noqa: D105
return isinstance(other, SqlBindParameters) and self.param_dict == other.param_dict
return isinstance(other, SqlBindParameterSet) and self.param_dict == other.param_dict
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameters, SqlBindParameterValue
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameterSet, SqlBindParameterValue
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


def test_where_filter_spec_serialization() -> None: # noqa: D103
where_filter_spec = WhereFilterSpec(
where_sql="where_sql",
bind_parameters=SqlBindParameters(
bind_parameters=SqlBindParameterSet(
param_items=(SqlBindParameter(key="key", value=SqlBindParameterValue(str_value="str_value")),)
),
linkable_element_unions=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from dbt_semantic_interfaces.dataclass_serialization import DataClassDeserializer, DataclassSerializer
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameters, SqlBindParameterValue
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameterSet, SqlBindParameterValue


@pytest.fixture
Expand All @@ -19,14 +19,14 @@ def test_serialization( # noqa: D103
serializer: DataclassSerializer,
deserializer: DataClassDeserializer,
) -> None:
bind_parameters = SqlBindParameters(
bind_parameter_set = SqlBindParameterSet(
param_items=(
SqlBindParameter("key0", SqlBindParameterValue.create_from_sql_column_type("value0")),
SqlBindParameter("key1", SqlBindParameterValue.create_from_sql_column_type("value1")),
)
)
serialized_obj = serializer.pydantic_serialize(bind_parameters)
serialized_obj = serializer.pydantic_serialize(bind_parameter_set)
deserialized_obj = deserializer.pydantic_deserialize(
dataclass_type=SqlBindParameters, serialized_obj=serialized_obj
dataclass_type=SqlBindParameterSet, serialized_obj=serialized_obj
)
assert bind_parameters == deserialized_obj
assert bind_parameter_set == deserialized_obj
2 changes: 1 addition & 1 deletion metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def rendered_sql_without_descriptions(self) -> SqlQuery:
sql_query.sql_query.split("\n"),
)
),
bind_parameters=sql_query.bind_parameters,
bind_parameter_set=sql_query.bind_parameter_set,
)

@property
Expand Down
4 changes: 2 additions & 2 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode
leaf_tasks=(
SelectSqlQueryToDataTableTask.create(
sql_client=self._sql_client,
sql_query=SqlQuery(render_sql_result.sql, render_sql_result.bind_parameters),
sql_query=SqlQuery(render_sql_result.sql, render_sql_result.bind_parameter_set),
),
)
)
Expand All @@ -105,7 +105,7 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> Conv
sql_client=self._sql_client,
sql_query=SqlQuery(
sql_query=render_sql_result.sql,
bind_parameters=render_sql_result.bind_parameters,
bind_parameter_set=render_sql_result.bind_parameter_set,
),
output_table=node.output_sql_table,
),
Expand Down
14 changes: 7 additions & 7 deletions metricflow/execution/execution_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag, NodeId
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.visitor import Visitable

Expand Down Expand Up @@ -49,7 +49,7 @@ class SqlQuery:

# This field will be renamed as it is confusing given the class name.
sql_query: str
bind_parameters: SqlBindParameters
bind_parameter_set: SqlBindParameterSet


@dataclass(frozen=True)
Expand All @@ -69,7 +69,7 @@ class TaskExecutionResult:

# If the task was an SQL query, it's stored here
sql: Optional[str] = None
bind_params: Optional[SqlBindParameters] = None
bind_params: Optional[SqlBindParameterSet] = None
# If the task produces a data_table as a result, it's stored here.
df: Optional[MetricFlowDataTable] = None

Expand Down Expand Up @@ -120,15 +120,15 @@ def execute(self) -> TaskExecutionResult: # noqa: D102

df = self.sql_client.query(
sql_query.sql_query,
sql_bind_parameters=sql_query.bind_parameters,
sql_bind_parameter_set=sql_query.bind_parameter_set,
)

end_time = time.time()
return TaskExecutionResult(
start_time=start_time,
end_time=end_time,
sql=sql_query.sql_query,
bind_params=sql_query.bind_parameters,
bind_params=sql_query.bind_parameter_set,
df=df,
)

Expand Down Expand Up @@ -180,7 +180,7 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty(key="sql_query", value=sql_query.sql_query),
DisplayedProperty(key="output_table", value=self.output_table),
DisplayedProperty(key="bind_parameters", value=sql_query.bind_parameters),
DisplayedProperty(key="bind_parameter_set", value=sql_query.bind_parameter_set),
)

def execute(self) -> TaskExecutionResult: # noqa: D102
Expand All @@ -192,7 +192,7 @@ def execute(self) -> TaskExecutionResult: # noqa: D102
logger.debug(LazyFormat(lambda: f"Creating table {self.output_table} using a query"))
self.sql_client.execute(
sql_query.sql_query,
sql_bind_parameters=sql_query.bind_parameters,
sql_bind_parameter_set=sql_query.bind_parameter_set,
)

end_time = time.time()
Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet:
used_columns=tuple(
column_association.column_name for column_association in column_associations_in_where_sql
),
bind_parameters=node.where.bind_parameters,
bind_parameter_set=node.where.bind_parameters,
),
),
)
Expand Down
8 changes: 4 additions & 4 deletions metricflow/protocols/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
Expand Down Expand Up @@ -79,7 +79,7 @@ def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer:
def query(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(),
) -> MetricFlowDataTable:
"""Base query method, upon execution will run a query that returns a pandas DataTable."""
raise NotImplementedError
Expand All @@ -88,7 +88,7 @@ def query(
def execute(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(),
) -> None:
"""Base execute method."""
raise NotImplementedError
Expand All @@ -97,7 +97,7 @@ def execute(
def dry_run(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(),
) -> None:
"""Base dry_run method."""
raise NotImplementedError
Expand Down
Loading

0 comments on commit 648d343

Please sign in to comment.