Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename SqlBindParameters to SqlBindParameterSet #1459

Merged
merged 2 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metricflow_semantics.errors.error_classes import SqlBindParametersNotSupportedError
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.random_id import random_id
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet

from metricflow.data_table.mf_table import MetricFlowDataTable
from metricflow.protocols.sql_client import SqlEngine
Expand Down Expand Up @@ -127,23 +127,25 @@ def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer:
def query(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(),
) -> MetricFlowDataTable:
"""Query statement; result expected to be data which will be returned as a DataTable.

Args:
stmt: The SQL query statement to run. This should produce output via a SELECT
sql_bind_parameters: The parameter replacement mapping for filling in concrete values for SQL query
sql_bind_parameter_set: The parameter replacement mapping for filling in concrete values for SQL query
parameters.
"""
start = time.time()
request_id = SqlRequestId(f"mf_rid__{random_id()}")
if sql_bind_parameters.param_dict:
if sql_bind_parameter_set.param_dict:
raise SqlBindParametersNotSupportedError(
f"Invalid query statement - we do not support queries with bind parameters through dbt adapters! "
f"Bind params: {sql_bind_parameters.param_dict}"
f"Bind params: {sql_bind_parameter_set.param_dict}"
)
logger.info(LazyFormat("Running query() statement", statement=stmt, param_dict=sql_bind_parameters.param_dict))
logger.info(
LazyFormat("Running query() statement", statement=stmt, param_dict=sql_bind_parameter_set.param_dict)
)
with self._adapter.connection_named(f"MetricFlow_request_{request_id}"):
# returns a Tuple[AdapterResponse, agate.Table] but the decorator converts it to Any
result = self._adapter.execute(sql=stmt, auto_begin=True, fetch=True)
Expand All @@ -167,24 +169,24 @@ def query(
def execute(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(),
) -> None:
"""Execute a SQL statement. No result will be returned.

Args:
stmt: The SQL query statement to run. This should not produce output.
sql_bind_parameters: The parameter replacement mapping for filling in
sql_bind_parameter_set: The parameter replacement mapping for filling in
concrete values for SQL query parameters.
"""
if sql_bind_parameters.param_dict:
if sql_bind_parameter_set.param_dict:
raise SqlBindParametersNotSupportedError(
f"Invalid execute statement - we do not support execute commands with bind parameters through dbt "
f"adapters! Bind params: {SqlBindParameters.param_dict}"
f"adapters! Bind params: {SqlBindParameterSet.param_dict}"
)
start = time.time()
request_id = SqlRequestId(f"mf_rid__{random_id()}")
logger.info(
LazyFormat("Running execute() statement", statement=stmt, param_dict=sql_bind_parameters.param_dict)
LazyFormat("Running execute() statement", statement=stmt, param_dict=sql_bind_parameter_set.param_dict)
)
with self._adapter.connection_named(f"MetricFlow_request_{request_id}"):
result = self._adapter.execute(stmt, auto_begin=True, fetch=False)
Expand All @@ -199,19 +201,19 @@ def execute(
def dry_run(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(),
) -> None:
"""Dry run statement; checks that the 'stmt' is queryable. Returns None on success.

Raises an exception if the 'stmt' isn't queryable.

Args:
stmt: The SQL query statement to dry run.
sql_bind_parameters: The parameter replacement mapping for filling in
sql_bind_parameter_set: The parameter replacement mapping for filling in
concrete values for SQL query parameters.
"""
start = time.time()
logger.info(LazyFormat("Running dry run", statement=stmt, param_dict=sql_bind_parameters.param_dict))
logger.info(LazyFormat("Running dry run", statement=stmt, param_dict=sql_bind_parameter_set.param_dict))
request_id = SqlRequestId(f"mf_rid__{random_id()}")
connection_name = f"MetricFlow_dry_run_request_{request_id}"
# TODO - consolidate to self._adapter.validate_sql() when all implementations will work from within MetricFlow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from metricflow_semantics.model.semantics.linkable_element import LinkableElement, LinkableElementUnion
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet


@dataclass(frozen=True)
Expand All @@ -26,7 +26,7 @@ class WhereFilterSpec(Mergeable, SerializableDataclass):

WhereFilterSpec(
where_sql="listing__country == 'US'",
bind_parameters: SqlBindParameters(),
bind_parameter_set: SqlBindParameters(),
linkable_specs: (
DimensionSpec(
element_name='country',
Expand All @@ -42,10 +42,10 @@ class WhereFilterSpec(Mergeable, SerializableDataclass):
)
"""

# Debating whether where_sql / bind_parameters belongs here. where_sql may become dialect specific if we introduce
# Debating whether where_sql / bind_parameter_set belongs here. where_sql may become dialect specific if we introduce
# quoted identifiers later.
where_sql: str
bind_parameters: SqlBindParameters
bind_parameters: SqlBindParameterSet
linkable_element_unions: Tuple[LinkableElementUnion, ...]
linkable_spec_set: LinkableSpecSet

Expand Down Expand Up @@ -83,7 +83,7 @@ def empty_instance(cls) -> WhereFilterSpec:
# line with other cases of Mergeable.
return WhereFilterSpec(
where_sql="TRUE",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_spec_set=LinkableSpecSet(),
linkable_element_unions=(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from metricflow_semantics.specs.where_filter.where_filter_metric import WhereFilterMetricFactory
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.specs.where_filter.where_filter_time_dimension import WhereFilterTimeDimensionFactory
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,7 +108,7 @@ def create_from_where_filter_intersection( # noqa: D102
filter_specs.append(
WhereFilterSpec(
where_sql=where_sql,
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_spec_set=LinkableSpecSet.create_from_specs(rendered_specs),
linkable_element_unions=tuple(linkable_element.as_union for linkable_element in linkable_elements),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class SqlBindParameter(SerializableDataclass): # noqa: D101


@dataclass(frozen=True)
class SqlBindParameters(SerializableDataclass):
class SqlBindParameterSet(SerializableDataclass):
"""Helps to build execution parameters during SQL query rendering.

These can be used as per https://docs.sqlalchemy.org/en/14/core/tutorial.html#using-textual-sql
Expand All @@ -83,7 +83,7 @@ class SqlBindParameters(SerializableDataclass):
# Using tuples for immutability as dicts are not.
param_items: Tuple[SqlBindParameter, ...] = ()

def combine(self, additional_params: SqlBindParameters) -> SqlBindParameters:
def combine(self, additional_params: SqlBindParameterSet) -> SqlBindParameterSet:
"""Create a new set of bind parameters that includes parameters from this and additional_params."""
if len(self.param_items) == 0:
return additional_params
Expand All @@ -108,7 +108,7 @@ def combine(self, additional_params: SqlBindParameters) -> SqlBindParameters:
new_items.append(item)
included_keys.add(item.key)

return SqlBindParameters(param_items=tuple(new_items))
return SqlBindParameterSet(param_items=tuple(new_items))

@property
def param_dict(self) -> OrderedDict[str, SqlColumnType]:
Expand All @@ -119,13 +119,13 @@ def param_dict(self) -> OrderedDict[str, SqlColumnType]:
return param_dict

@staticmethod
def create_from_dict(param_dict: Mapping[str, SqlColumnType]) -> SqlBindParameters: # noqa: D102
return SqlBindParameters(
def create_from_dict(param_dict: Mapping[str, SqlColumnType]) -> SqlBindParameterSet: # noqa: D102
return SqlBindParameterSet(
tuple(
SqlBindParameter(key=key, value=SqlBindParameterValue.create_from_sql_column_type(value))
for key, value in param_dict.items()
)
)

def __eq__(self, other: Any) -> bool: # type: ignore # noqa: D105
return isinstance(other, SqlBindParameters) and self.param_dict == other.param_dict
return isinstance(other, SqlBindParameterSet) and self.param_dict == other.param_dict
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameters, SqlBindParameterValue
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameterSet, SqlBindParameterValue
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


def test_where_filter_spec_serialization() -> None: # noqa: D103
where_filter_spec = WhereFilterSpec(
where_sql="where_sql",
bind_parameters=SqlBindParameters(
bind_parameters=SqlBindParameterSet(
param_items=(SqlBindParameter(key="key", value=SqlBindParameterValue(str_value="str_value")),)
),
linkable_element_unions=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from dbt_semantic_interfaces.dataclass_serialization import DataClassDeserializer, DataclassSerializer
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameters, SqlBindParameterValue
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameterSet, SqlBindParameterValue


@pytest.fixture
Expand All @@ -19,14 +19,14 @@ def test_serialization( # noqa: D103
serializer: DataclassSerializer,
deserializer: DataClassDeserializer,
) -> None:
bind_parameters = SqlBindParameters(
bind_parameter_set = SqlBindParameterSet(
param_items=(
SqlBindParameter("key0", SqlBindParameterValue.create_from_sql_column_type("value0")),
SqlBindParameter("key1", SqlBindParameterValue.create_from_sql_column_type("value1")),
)
)
serialized_obj = serializer.pydantic_serialize(bind_parameters)
serialized_obj = serializer.pydantic_serialize(bind_parameter_set)
deserialized_obj = deserializer.pydantic_deserialize(
dataclass_type=SqlBindParameters, serialized_obj=serialized_obj
dataclass_type=SqlBindParameterSet, serialized_obj=serialized_obj
)
assert bind_parameters == deserialized_obj
assert bind_parameter_set == deserialized_obj
26 changes: 13 additions & 13 deletions metricflow-semantics/tests_metricflow_semantics/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.specs.where_filter.where_filter_spec_set import WhereFilterSpecSet
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


Expand Down Expand Up @@ -224,23 +224,23 @@ def where_filter_spec_set() -> WhereFilterSpecSet: # noqa: D103
measure_level_filter_specs=(
WhereFilterSpec(
where_sql="measure is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
),
metric_level_filter_specs=(
WhereFilterSpec(
where_sql="metric is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
),
query_level_filter_specs=(
WhereFilterSpec(
where_sql="query is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
Expand All @@ -252,19 +252,19 @@ def test_where_filter_spec_set_all_specs(where_filter_spec_set: WhereFilterSpecS
assert set(where_filter_spec_set.all_filter_specs) == {
WhereFilterSpec(
where_sql="measure is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
WhereFilterSpec(
where_sql="metric is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
WhereFilterSpec(
where_sql="query is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
Expand All @@ -275,13 +275,13 @@ def test_where_filter_spec_set_post_aggregation_specs(where_filter_spec_set: Whe
assert set(where_filter_spec_set.after_measure_aggregation_filter_specs) == {
WhereFilterSpec(
where_sql="metric is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
WhereFilterSpec(
where_sql="query is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
Expand All @@ -293,7 +293,7 @@ def test_where_filter_spec_set_merge(where_filter_spec_set: WhereFilterSpecSet)
measure_level_filter_specs=(
WhereFilterSpec(
where_sql="measure is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
Expand All @@ -303,7 +303,7 @@ def test_where_filter_spec_set_merge(where_filter_spec_set: WhereFilterSpecSet)
metric_level_filter_specs=(
WhereFilterSpec(
where_sql="metric is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
Expand All @@ -314,15 +314,15 @@ def test_where_filter_spec_set_merge(where_filter_spec_set: WhereFilterSpecSet)
measure_level_filter_specs=(
WhereFilterSpec(
where_sql="measure is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
),
metric_level_filter_specs=(
WhereFilterSpec(
where_sql="metric is true",
bind_parameters=SqlBindParameters(),
bind_parameters=SqlBindParameterSet(),
linkable_element_unions=(),
linkable_spec_set=LinkableSpecSet(),
),
Expand Down
2 changes: 1 addition & 1 deletion metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def rendered_sql_without_descriptions(self) -> SqlQuery:
sql_query.sql_query.split("\n"),
)
),
bind_parameters=sql_query.bind_parameters,
bind_parameter_set=sql_query.bind_parameter_set,
)

@property
Expand Down
4 changes: 2 additions & 2 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode
leaf_tasks=(
SelectSqlQueryToDataTableTask.create(
sql_client=self._sql_client,
sql_query=SqlQuery(render_sql_result.sql, render_sql_result.bind_parameters),
sql_query=SqlQuery(render_sql_result.sql, render_sql_result.bind_parameter_set),
),
)
)
Expand All @@ -105,7 +105,7 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> Conv
sql_client=self._sql_client,
sql_query=SqlQuery(
sql_query=render_sql_result.sql,
bind_parameters=render_sql_result.bind_parameters,
bind_parameter_set=render_sql_result.bind_parameter_set,
),
output_table=node.output_sql_table,
),
Expand Down
Loading
Loading