diff --git a/dbt-metricflow/dbt_metricflow/cli/dbt_connectors/adapter_backed_client.py b/dbt-metricflow/dbt_metricflow/cli/dbt_connectors/adapter_backed_client.py index f77b68fc95..4650f94f30 100644 --- a/dbt-metricflow/dbt_metricflow/cli/dbt_connectors/adapter_backed_client.py +++ b/dbt-metricflow/dbt_metricflow/cli/dbt_connectors/adapter_backed_client.py @@ -10,7 +10,7 @@ from metricflow_semantics.errors.error_classes import SqlBindParametersNotSupportedError from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from metricflow_semantics.random_id import random_id -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow.data_table.mf_table import MetricFlowDataTable from metricflow.protocols.sql_client import SqlEngine @@ -127,23 +127,25 @@ def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer: def query( self, stmt: str, - sql_bind_parameters: SqlBindParameters = SqlBindParameters(), + sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(), ) -> MetricFlowDataTable: """Query statement; result expected to be data which will be returned as a DataTable. Args: stmt: The SQL query statement to run. This should produce output via a SELECT - sql_bind_parameters: The parameter replacement mapping for filling in concrete values for SQL query + sql_bind_parameter_set: The parameter replacement mapping for filling in concrete values for SQL query parameters. """ start = time.time() request_id = SqlRequestId(f"mf_rid__{random_id()}") - if sql_bind_parameters.param_dict: + if sql_bind_parameter_set.param_dict: raise SqlBindParametersNotSupportedError( f"Invalid query statement - we do not support queries with bind parameters through dbt adapters! " - f"Bind params: {sql_bind_parameters.param_dict}" + f"Bind params: {sql_bind_parameter_set.param_dict}" ) - logger.info(LazyFormat("Running query() statement", statement=stmt, param_dict=sql_bind_parameters.param_dict)) + logger.info( + LazyFormat("Running query() statement", statement=stmt, param_dict=sql_bind_parameter_set.param_dict) + ) with self._adapter.connection_named(f"MetricFlow_request_{request_id}"): # returns a Tuple[AdapterResponse, agate.Table] but the decorator converts it to Any result = self._adapter.execute(sql=stmt, auto_begin=True, fetch=True) @@ -167,24 +169,24 @@ def query( def execute( self, stmt: str, - sql_bind_parameters: SqlBindParameters = SqlBindParameters(), + sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(), ) -> None: """Execute a SQL statement. No result will be returned. Args: stmt: The SQL query statement to run. This should not produce output. - sql_bind_parameters: The parameter replacement mapping for filling in + sql_bind_parameter_set: The parameter replacement mapping for filling in concrete values for SQL query parameters. """ - if sql_bind_parameters.param_dict: + if sql_bind_parameter_set.param_dict: raise SqlBindParametersNotSupportedError( f"Invalid execute statement - we do not support execute commands with bind parameters through dbt " - f"adapters! Bind params: {SqlBindParameters.param_dict}" + f"adapters! Bind params: {SqlBindParameterSet.param_dict}" ) start = time.time() request_id = SqlRequestId(f"mf_rid__{random_id()}") logger.info( - LazyFormat("Running execute() statement", statement=stmt, param_dict=sql_bind_parameters.param_dict) + LazyFormat("Running execute() statement", statement=stmt, param_dict=sql_bind_parameter_set.param_dict) ) with self._adapter.connection_named(f"MetricFlow_request_{request_id}"): result = self._adapter.execute(stmt, auto_begin=True, fetch=False) @@ -199,7 +201,7 @@ def execute( def dry_run( self, stmt: str, - sql_bind_parameters: SqlBindParameters = SqlBindParameters(), + sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(), ) -> None: """Dry run statement; checks that the 'stmt' is queryable. Returns None on success. @@ -207,11 +209,11 @@ def dry_run( Args: stmt: The SQL query statement to dry run. - sql_bind_parameters: The parameter replacement mapping for filling in + sql_bind_parameter_set: The parameter replacement mapping for filling in concrete values for SQL query parameters. """ start = time.time() - logger.info(LazyFormat("Running dry run", statement=stmt, param_dict=sql_bind_parameters.param_dict)) + logger.info(LazyFormat("Running dry run", statement=stmt, param_dict=sql_bind_parameter_set.param_dict)) request_id = SqlRequestId(f"mf_rid__{random_id()}") connection_name = f"MetricFlow_dry_run_request_{request_id}" # TODO - consolidate to self._adapter.validate_sql() when all implementations will work from within MetricFlow diff --git a/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_spec.py b/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_spec.py index f449bc2695..7b36234c07 100644 --- a/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_spec.py +++ b/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_spec.py @@ -11,7 +11,7 @@ from metricflow_semantics.model.semantics.linkable_element import LinkableElement, LinkableElementUnion from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet @dataclass(frozen=True) @@ -26,7 +26,7 @@ class WhereFilterSpec(Mergeable, SerializableDataclass): WhereFilterSpec( where_sql="listing__country == 'US'", - bind_parameters: SqlBindParameters(), + bind_parameter_set: SqlBindParameters(), linkable_specs: ( DimensionSpec( element_name='country', @@ -42,10 +42,10 @@ class WhereFilterSpec(Mergeable, SerializableDataclass): ) """ - # Debating whether where_sql / bind_parameters belongs here. where_sql may become dialect specific if we introduce + # Debating whether where_sql / bind_parameter_set belongs here. where_sql may become dialect specific if we introduce # quoted identifiers later. where_sql: str - bind_parameters: SqlBindParameters + bind_parameters: SqlBindParameterSet linkable_element_unions: Tuple[LinkableElementUnion, ...] linkable_spec_set: LinkableSpecSet @@ -83,7 +83,7 @@ def empty_instance(cls) -> WhereFilterSpec: # line with other cases of Mergeable. return WhereFilterSpec( where_sql="TRUE", - bind_parameters=SqlBindParameters(), + bind_parameters=SqlBindParameterSet(), linkable_spec_set=LinkableSpecSet(), linkable_element_unions=(), ) diff --git a/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_transform.py b/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_transform.py index e1e45eb675..2a9f303628 100644 --- a/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_transform.py +++ b/metricflow-semantics/metricflow_semantics/specs/where_filter/where_filter_transform.py @@ -20,7 +20,7 @@ from metricflow_semantics.specs.where_filter.where_filter_metric import WhereFilterMetricFactory from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec from metricflow_semantics.specs.where_filter.where_filter_time_dimension import WhereFilterTimeDimensionFactory -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet logger = logging.getLogger(__name__) @@ -108,7 +108,7 @@ def create_from_where_filter_intersection( # noqa: D102 filter_specs.append( WhereFilterSpec( where_sql=where_sql, - bind_parameters=SqlBindParameters(), + bind_parameters=SqlBindParameterSet(), linkable_spec_set=LinkableSpecSet.create_from_specs(rendered_specs), linkable_element_unions=tuple(linkable_element.as_union for linkable_element in linkable_elements), ) diff --git a/metricflow-semantics/metricflow_semantics/sql/sql_bind_parameters.py b/metricflow-semantics/metricflow_semantics/sql/sql_bind_parameters.py index 05b7ee7e2d..346a8ca826 100644 --- a/metricflow-semantics/metricflow_semantics/sql/sql_bind_parameters.py +++ b/metricflow-semantics/metricflow_semantics/sql/sql_bind_parameters.py @@ -74,7 +74,7 @@ class SqlBindParameter(SerializableDataclass): # noqa: D101 @dataclass(frozen=True) -class SqlBindParameters(SerializableDataclass): +class SqlBindParameterSet(SerializableDataclass): """Helps to build execution parameters during SQL query rendering. These can be used as per https://docs.sqlalchemy.org/en/14/core/tutorial.html#using-textual-sql @@ -83,7 +83,7 @@ class SqlBindParameters(SerializableDataclass): # Using tuples for immutability as dicts are not. param_items: Tuple[SqlBindParameter, ...] = () - def combine(self, additional_params: SqlBindParameters) -> SqlBindParameters: + def combine(self, additional_params: SqlBindParameterSet) -> SqlBindParameterSet: """Create a new set of bind parameters that includes parameters from this and additional_params.""" if len(self.param_items) == 0: return additional_params @@ -108,7 +108,7 @@ def combine(self, additional_params: SqlBindParameters) -> SqlBindParameters: new_items.append(item) included_keys.add(item.key) - return SqlBindParameters(param_items=tuple(new_items)) + return SqlBindParameterSet(param_items=tuple(new_items)) @property def param_dict(self) -> OrderedDict[str, SqlColumnType]: @@ -119,8 +119,8 @@ def param_dict(self) -> OrderedDict[str, SqlColumnType]: return param_dict @staticmethod - def create_from_dict(param_dict: Mapping[str, SqlColumnType]) -> SqlBindParameters: # noqa: D102 - return SqlBindParameters( + def create_from_dict(param_dict: Mapping[str, SqlColumnType]) -> SqlBindParameterSet: # noqa: D102 + return SqlBindParameterSet( tuple( SqlBindParameter(key=key, value=SqlBindParameterValue.create_from_sql_column_type(value)) for key, value in param_dict.items() @@ -128,4 +128,4 @@ def create_from_dict(param_dict: Mapping[str, SqlColumnType]) -> SqlBindParamete ) def __eq__(self, other: Any) -> bool: # type: ignore # noqa: D105 - return isinstance(other, SqlBindParameters) and self.param_dict == other.param_dict + return isinstance(other, SqlBindParameterSet) and self.param_dict == other.param_dict diff --git a/metricflow-semantics/tests_metricflow_semantics/specs/test_spec_serialization.py b/metricflow-semantics/tests_metricflow_semantics/specs/test_spec_serialization.py index 294cbc1c2b..ad1d56bb01 100644 --- a/metricflow-semantics/tests_metricflow_semantics/specs/test_spec_serialization.py +++ b/metricflow-semantics/tests_metricflow_semantics/specs/test_spec_serialization.py @@ -16,14 +16,14 @@ from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameters, SqlBindParameterValue +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameterSet, SqlBindParameterValue from metricflow_semantics.time.granularity import ExpandedTimeGranularity def test_where_filter_spec_serialization() -> None: # noqa: D103 where_filter_spec = WhereFilterSpec( where_sql="where_sql", - bind_parameters=SqlBindParameters( + bind_parameters=SqlBindParameterSet( param_items=(SqlBindParameter(key="key", value=SqlBindParameterValue(str_value="str_value")),) ), linkable_element_unions=( diff --git a/metricflow-semantics/tests_metricflow_semantics/sql/test_bind_parameter_serialization.py b/metricflow-semantics/tests_metricflow_semantics/sql/test_bind_parameter_serialization.py index 7d4d8b3953..d4ccb8e34b 100644 --- a/metricflow-semantics/tests_metricflow_semantics/sql/test_bind_parameter_serialization.py +++ b/metricflow-semantics/tests_metricflow_semantics/sql/test_bind_parameter_serialization.py @@ -2,7 +2,7 @@ import pytest from dbt_semantic_interfaces.dataclass_serialization import DataClassDeserializer, DataclassSerializer -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameters, SqlBindParameterValue +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameter, SqlBindParameterSet, SqlBindParameterValue @pytest.fixture @@ -19,14 +19,14 @@ def test_serialization( # noqa: D103 serializer: DataclassSerializer, deserializer: DataClassDeserializer, ) -> None: - bind_parameters = SqlBindParameters( + bind_parameter_set = SqlBindParameterSet( param_items=( SqlBindParameter("key0", SqlBindParameterValue.create_from_sql_column_type("value0")), SqlBindParameter("key1", SqlBindParameterValue.create_from_sql_column_type("value1")), ) ) - serialized_obj = serializer.pydantic_serialize(bind_parameters) + serialized_obj = serializer.pydantic_serialize(bind_parameter_set) deserialized_obj = deserializer.pydantic_deserialize( - dataclass_type=SqlBindParameters, serialized_obj=serialized_obj + dataclass_type=SqlBindParameterSet, serialized_obj=serialized_obj ) - assert bind_parameters == deserialized_obj + assert bind_parameter_set == deserialized_obj diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index fa572f7215..83099dd28f 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -204,7 +204,7 @@ def rendered_sql_without_descriptions(self) -> SqlQuery: sql_query.sql_query.split("\n"), ) ), - bind_parameters=sql_query.bind_parameters, + bind_parameter_set=sql_query.bind_parameter_set, ) @property diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index 473a809d5d..388b0ee1c8 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -85,7 +85,7 @@ def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode leaf_tasks=( SelectSqlQueryToDataTableTask.create( sql_client=self._sql_client, - sql_query=SqlQuery(render_sql_result.sql, render_sql_result.bind_parameters), + sql_query=SqlQuery(render_sql_result.sql, render_sql_result.bind_parameter_set), ), ) ) @@ -105,7 +105,7 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> Conv sql_client=self._sql_client, sql_query=SqlQuery( sql_query=render_sql_result.sql, - bind_parameters=render_sql_result.bind_parameters, + bind_parameter_set=render_sql_result.bind_parameter_set, ), output_table=node.output_sql_table, ), diff --git a/metricflow/execution/execution_plan.py b/metricflow/execution/execution_plan.py index fa3dd2c106..cae16bad9f 100644 --- a/metricflow/execution/execution_plan.py +++ b/metricflow/execution/execution_plan.py @@ -9,7 +9,7 @@ from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag, NodeId from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.visitor import Visitable @@ -49,7 +49,7 @@ class SqlQuery: # This field will be renamed as it is confusing given the class name. sql_query: str - bind_parameters: SqlBindParameters + bind_parameter_set: SqlBindParameterSet @dataclass(frozen=True) @@ -69,7 +69,7 @@ class TaskExecutionResult: # If the task was an SQL query, it's stored here sql: Optional[str] = None - bind_params: Optional[SqlBindParameters] = None + bind_params: Optional[SqlBindParameterSet] = None # If the task produces a data_table as a result, it's stored here. df: Optional[MetricFlowDataTable] = None @@ -120,7 +120,7 @@ def execute(self) -> TaskExecutionResult: # noqa: D102 df = self.sql_client.query( sql_query.sql_query, - sql_bind_parameters=sql_query.bind_parameters, + sql_bind_parameter_set=sql_query.bind_parameter_set, ) end_time = time.time() @@ -128,7 +128,7 @@ def execute(self) -> TaskExecutionResult: # noqa: D102 start_time=start_time, end_time=end_time, sql=sql_query.sql_query, - bind_params=sql_query.bind_parameters, + bind_params=sql_query.bind_parameter_set, df=df, ) @@ -180,7 +180,7 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 return tuple(super().displayed_properties) + ( DisplayedProperty(key="sql_query", value=sql_query.sql_query), DisplayedProperty(key="output_table", value=self.output_table), - DisplayedProperty(key="bind_parameters", value=sql_query.bind_parameters), + DisplayedProperty(key="bind_parameter_set", value=sql_query.bind_parameter_set), ) def execute(self) -> TaskExecutionResult: # noqa: D102 @@ -192,7 +192,7 @@ def execute(self) -> TaskExecutionResult: # noqa: D102 logger.debug(LazyFormat(lambda: f"Creating table {self.output_table} using a query")) self.sql_client.execute( sql_query.sql_query, - sql_bind_parameters=sql_query.bind_parameters, + sql_bind_parameter_set=sql_query.bind_parameter_set, ) end_time = time.time() diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index d0c8a97546..fa990f071f 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -855,7 +855,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet: used_columns=tuple( column_association.column_name for column_association in column_associations_in_where_sql ), - bind_parameters=node.where.bind_parameters, + bind_parameter_set=node.where.bind_parameters, ), ), ) diff --git a/metricflow/protocols/sql_client.py b/metricflow/protocols/sql_client.py index 5eababde8e..0feb327d16 100644 --- a/metricflow/protocols/sql_client.py +++ b/metricflow/protocols/sql_client.py @@ -6,7 +6,7 @@ from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow.data_table.mf_table import MetricFlowDataTable from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer @@ -79,7 +79,7 @@ def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer: def query( self, stmt: str, - sql_bind_parameters: SqlBindParameters = SqlBindParameters(), + sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(), ) -> MetricFlowDataTable: """Base query method, upon execution will run a query that returns a pandas DataTable.""" raise NotImplementedError @@ -88,7 +88,7 @@ def query( def execute( self, stmt: str, - sql_bind_parameters: SqlBindParameters = SqlBindParameters(), + sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(), ) -> None: """Base execute method.""" raise NotImplementedError @@ -97,7 +97,7 @@ def execute( def dry_run( self, stmt: str, - sql_bind_parameters: SqlBindParameters = SqlBindParameters(), + sql_bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(), ) -> None: """Base dry_run method.""" raise NotImplementedError diff --git a/metricflow/sql/render/big_query.py b/metricflow/sql/render/big_query.py index cd08a06722..bef39e99fc 100644 --- a/metricflow/sql/render/big_query.py +++ b/metricflow/sql/render/big_query.py @@ -7,7 +7,7 @@ from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.errors.error_classes import UnsupportedEngineFeatureError -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -71,7 +71,7 @@ def render_group_by_expr(self, group_by_column: SqlSelectColumn) -> SqlExpressio """ return SqlExpressionRenderResult( sql=group_by_column.column_alias, - bind_parameters=group_by_column.expr.bind_parameters, + bind_parameter_set=group_by_column.expr.bind_parameter_set, ) @override @@ -84,14 +84,14 @@ def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionR """ if node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS: arg_rendered = self.render_sql_expr(node.order_by_arg) - params = arg_rendered.bind_parameters + params = arg_rendered.bind_parameter_set percentile = node.percentile_args.percentile fraction = Fraction(percentile).limit_denominator() return SqlExpressionRenderResult( sql=f"APPROX_QUANTILES({arg_rendered.sql}, {fraction.denominator})[OFFSET({fraction.numerator})]", - bind_parameters=params, + bind_parameter_set=params, ) elif ( node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_DISCRETE @@ -117,7 +117,7 @@ def visit_cast_to_timestamp_expr(self, node: SqlCastToTimestampExpression) -> Sq arg_rendered = self.render_sql_expr(node.arg) return SqlExpressionRenderResult( sql=f"CAST({arg_rendered.sql} AS {self.timestamp_data_type})", - bind_parameters=arg_rendered.bind_parameters, + bind_parameter_set=arg_rendered.bind_parameter_set, ) @override @@ -133,7 +133,7 @@ def visit_date_trunc_expr(self, node: SqlDateTruncExpression) -> SqlExpressionRe return SqlExpressionRenderResult( sql=f"DATETIME_TRUNC({arg_rendered.sql}, {prefix}{node.time_granularity.value})", - bind_parameters=arg_rendered.bind_parameters, + bind_parameter_set=arg_rendered.bind_parameter_set, ) @override @@ -163,7 +163,7 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> SqlExpressionRenderR return SqlExpressionRenderResult( sql=case_expr, - bind_parameters=extract_rendering_result.bind_parameters, + bind_parameter_set=extract_rendering_result.bind_parameter_set, ) @override @@ -173,14 +173,14 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE return SqlExpressionRenderResult( sql=f"DATE_SUB(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {node.count} {node.granularity.value})", - bind_parameters=column.bind_parameters, + bind_parameter_set=column.bind_parameter_set, ) @override def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: return SqlExpressionRenderResult( sql="GENERATE_UUID()", - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) diff --git a/metricflow/sql/render/databricks.py b/metricflow/sql/render/databricks.py index a4f0510062..5aa98000a7 100644 --- a/metricflow/sql/render/databricks.py +++ b/metricflow/sql/render/databricks.py @@ -38,7 +38,7 @@ def render_date_part(self, date_part: DatePart) -> str: def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult: """Render a percentile expression for Databricks.""" arg_rendered = self.render_sql_expr(node.order_by_arg) - params = arg_rendered.bind_parameters + params = arg_rendered.bind_parameter_set percentile = node.percentile_args.percentile if node.percentile_args.function_type is SqlPercentileFunctionType.CONTINUOUS: @@ -57,14 +57,14 @@ def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionR elif node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_DISCRETE: return SqlExpressionRenderResult( sql=f"APPROX_PERCENTILE({arg_rendered.sql}, {percentile})", - bind_parameters=params, + bind_parameter_set=params, ) else: assert_values_exhausted(node.percentile_args.function_type) return SqlExpressionRenderResult( sql=f"{function_str}({percentile}) WITHIN GROUP (ORDER BY ({arg_rendered.sql}))", - bind_parameters=params, + bind_parameter_set=params, ) diff --git a/metricflow/sql/render/duckdb_renderer.py b/metricflow/sql/render/duckdb_renderer.py index 53094de945..6ff7fb9e7d 100644 --- a/metricflow/sql/render/duckdb_renderer.py +++ b/metricflow/sql/render/duckdb_renderer.py @@ -4,7 +4,7 @@ from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -49,21 +49,21 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE return SqlExpressionRenderResult( sql=f"{arg_rendered.sql} - INTERVAL {count} {granularity.value}", - bind_parameters=arg_rendered.bind_parameters, + bind_parameter_set=arg_rendered.bind_parameter_set, ) @override def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: return SqlExpressionRenderResult( sql="GEN_RANDOM_UUID()", - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) @override def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult: """Render a percentile expression for DuckDB.""" arg_rendered = self.render_sql_expr(node.order_by_arg) - params = arg_rendered.bind_parameters + params = arg_rendered.bind_parameter_set percentile = node.percentile_args.percentile if node.percentile_args.function_type is SqlPercentileFunctionType.CONTINUOUS: @@ -73,7 +73,7 @@ def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionR elif node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS: return SqlExpressionRenderResult( sql=f"approx_quantile({arg_rendered.sql}, {percentile})", - bind_parameters=params, + bind_parameter_set=params, ) elif node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_DISCRETE: raise RuntimeError( @@ -85,7 +85,7 @@ def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionR return SqlExpressionRenderResult( sql=f"{function_str}({percentile}) WITHIN GROUP (ORDER BY ({arg_rendered.sql}))", - bind_parameters=params, + bind_parameter_set=params, ) diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index 71abb10b9d..c6cd709360 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -11,7 +11,7 @@ from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.mf_logging.formatting import indent -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from typing_extensions import override from metricflow.sql.render.rendering_constants import SqlRenderingConstants @@ -53,7 +53,7 @@ class SqlExpressionRenderResult: """The result of rendering an SQL expression tree to a string.""" sql: str - bind_parameters: SqlBindParameters + bind_parameter_set: SqlBindParameterSet class SqlExpressionRenderer(SqlExpressionNodeVisitor[SqlExpressionRenderResult], ABC): @@ -118,7 +118,7 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio def visit_string_expr(self, node: SqlStringExpression) -> SqlExpressionRenderResult: """Renders an arbitrary string expression like 1+1=2.""" - return SqlExpressionRenderResult(sql=node.sql_expr, bind_parameters=node.bind_parameters) + return SqlExpressionRenderResult(sql=node.sql_expr, bind_parameter_set=node.bind_parameter_set) def visit_column_reference_expr(self, node: SqlColumnReferenceExpression) -> SqlExpressionRenderResult: """Render a reference to a column in a table like my_table.my_col.""" @@ -128,25 +128,25 @@ def visit_column_reference_expr(self, node: SqlColumnReferenceExpression) -> Sql if node.should_render_table_alias else node.col_ref.column_name ), - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) def visit_column_alias_reference_expr(self, node: SqlColumnAliasReferenceExpression) -> SqlExpressionRenderResult: """Render a reference to a column without a known table alias. e.g. foo.bar vs bar.""" return SqlExpressionRenderResult( sql=node.column_alias, - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) def visit_comparison_expr(self, node: SqlComparisonExpression) -> SqlExpressionRenderResult: """Render a comparison expression like 1 = 2.""" - combined_params = SqlBindParameters() + combined_params = SqlBindParameterSet() left_expr_rendered = self.render_sql_expr(node.left_expr) - combined_params = combined_params.combine(left_expr_rendered.bind_parameters) + combined_params = combined_params.combine(left_expr_rendered.bind_parameter_set) right_expr_rendered = self.render_sql_expr(node.right_expr) - combined_params = combined_params.combine(right_expr_rendered.bind_parameters) + combined_params = combined_params.combine(right_expr_rendered.bind_parameter_set) # To avoid issues with operator precedence, use parenthesis to group the left / right expressions if they # contain operators. @@ -157,22 +157,22 @@ def visit_comparison_expr(self, node: SqlComparisonExpression) -> SqlExpressionR + f" {node.comparison.value} " + (f"({right_expr_rendered.sql})" if node.right_expr.requires_parenthesis else right_expr_rendered.sql) ), - bind_parameters=combined_params, + bind_parameter_set=combined_params, ) def visit_function_expr(self, node: SqlAggregateFunctionExpression) -> SqlExpressionRenderResult: """Render a function call like CONCAT(a, b).""" args_rendered = [self.render_sql_expr(x) for x in node.sql_function_args] - combined_params = SqlBindParameters() + combined_params = SqlBindParameterSet() for arg_rendered in args_rendered: - combined_params = combined_params.combine(arg_rendered.bind_parameters) + combined_params = combined_params.combine(arg_rendered.bind_parameter_set) distinct_prefix = "DISTINCT " if SqlFunction.is_distinct_aggregation(node.sql_function) else "" args_string = ", ".join([x.sql for x in args_rendered]) return SqlExpressionRenderResult( sql=f"{node.sql_function.value}({distinct_prefix}{args_string})", - bind_parameters=combined_params, + bind_parameter_set=combined_params, ) def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult: @@ -184,25 +184,25 @@ def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionR def visit_null_expr(self, node: SqlNullExpression) -> SqlExpressionRenderResult: # noqa: D102 return SqlExpressionRenderResult( sql="NULL", - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) def visit_string_literal_expr(self, node: SqlStringLiteralExpression) -> SqlExpressionRenderResult: # noqa: D102 return SqlExpressionRenderResult( sql=f"'{node.literal_value}'", - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) def visit_logical_expr(self, node: SqlLogicalExpression) -> SqlExpressionRenderResult: # noqa: D102 RenderedExpr = namedtuple("RenderedExpr", ["expr", "requires_parenthesis"]) args_rendered = [RenderedExpr(self.render_sql_expr(x), x.requires_parenthesis) for x in node.args] - combined_parameters = SqlBindParameters() + combined_parameters = SqlBindParameterSet() args_sql: List[str] = [] can_be_rendered_in_one_line = sum(len(x.expr.sql) for x in args_rendered) < 60 for arg_rendered in args_rendered: - combined_parameters.combine(arg_rendered.expr.bind_parameters) + combined_parameters.combine(arg_rendered.expr.bind_parameter_set) arg_sql = self._render_logical_arg( arg_rendered.expr, arg_rendered.requires_parenthesis, render_in_one_line=can_be_rendered_in_one_line ) @@ -212,7 +212,7 @@ def visit_logical_expr(self, node: SqlLogicalExpression) -> SqlExpressionRenderR return SqlExpressionRenderResult( sql=sql, - bind_parameters=combined_parameters, + bind_parameter_set=combined_parameters, ) @staticmethod @@ -259,7 +259,7 @@ def visit_is_null_expr(self, node: SqlIsNullExpression) -> SqlExpressionRenderRe return SqlExpressionRenderResult( sql=f"{arg_rendered.sql} IS NULL" if not node.arg.requires_parenthesis else f"({arg_rendered.sql}) IS NULL", - bind_parameters=arg_rendered.bind_parameters, + bind_parameter_set=arg_rendered.bind_parameter_set, ) def visit_cast_to_timestamp_expr( # noqa: D102 @@ -268,7 +268,7 @@ def visit_cast_to_timestamp_expr( # noqa: D102 arg_rendered = self.render_sql_expr(node.arg) return SqlExpressionRenderResult( sql=f"CAST({arg_rendered.sql} AS {self.timestamp_data_type})", - bind_parameters=arg_rendered.bind_parameters, + bind_parameter_set=arg_rendered.bind_parameter_set, ) def _validate_granularity_for_engine(self, time_granularity: TimeGranularity) -> None: @@ -282,7 +282,7 @@ def visit_date_trunc_expr(self, node: SqlDateTruncExpression) -> SqlExpressionRe return SqlExpressionRenderResult( sql=f"DATE_TRUNC('{node.time_granularity.value}', {arg_rendered.sql})", - bind_parameters=arg_rendered.bind_parameters, + bind_parameter_set=arg_rendered.bind_parameter_set, ) def visit_extract_expr(self, node: SqlExtractExpression) -> SqlExpressionRenderResult: # noqa: D102 @@ -290,7 +290,7 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> SqlExpressionRenderR return SqlExpressionRenderResult( sql=f"EXTRACT({self.render_date_part(node.date_part)} FROM {arg_rendered.sql})", - bind_parameters=arg_rendered.bind_parameters, + bind_parameter_set=arg_rendered.bind_parameter_set, ) def render_date_part(self, date_part: DatePart) -> str: @@ -313,7 +313,7 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE count *= 3 return SqlExpressionRenderResult( sql=f"DATEADD({granularity.value}, -{count}, {arg_rendered.sql})", - bind_parameters=arg_rendered.bind_parameters, + bind_parameter_set=arg_rendered.bind_parameter_set, ) def visit_ratio_computation_expr(self, node: SqlRatioComputationExpression) -> SqlExpressionRenderResult: @@ -328,13 +328,13 @@ def visit_ratio_computation_expr(self, node: SqlRatioComputationExpression) -> S numerator_sql = f"CAST({rendered_numerator.sql} AS {self.double_data_type})" denominator_sql = f"CAST(NULLIF({rendered_denominator.sql}, 0) AS {self.double_data_type})" - bind_parameters = SqlBindParameters() - bind_parameters = bind_parameters.combine(rendered_numerator.bind_parameters) - bind_parameters = bind_parameters.combine(rendered_denominator.bind_parameters) + bind_parameter_set = SqlBindParameterSet() + bind_parameter_set = bind_parameter_set.combine(rendered_numerator.bind_parameter_set) + bind_parameter_set = bind_parameter_set.combine(rendered_denominator.bind_parameter_set) return SqlExpressionRenderResult( sql=f"{numerator_sql} / {denominator_sql}", - bind_parameters=bind_parameters, + bind_parameter_set=bind_parameter_set, ) def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderResult: # noqa: D102 @@ -342,14 +342,14 @@ def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderR rendered_start_expr = self.render_sql_expr(node.start_expr) rendered_end_expr = self.render_sql_expr(node.end_expr) - bind_parameters = SqlBindParameters() - bind_parameters = bind_parameters.combine(rendered_column_arg.bind_parameters) - bind_parameters = bind_parameters.combine(rendered_start_expr.bind_parameters) - bind_parameters = bind_parameters.combine(rendered_end_expr.bind_parameters) + bind_parameter_set = SqlBindParameterSet() + bind_parameter_set = bind_parameter_set.combine(rendered_column_arg.bind_parameter_set) + bind_parameter_set = bind_parameter_set.combine(rendered_start_expr.bind_parameter_set) + bind_parameter_set = bind_parameter_set.combine(rendered_end_expr.bind_parameter_set) return SqlExpressionRenderResult( sql=f"{rendered_column_arg.sql} BETWEEN {rendered_start_expr.sql} AND {rendered_end_expr.sql}", - bind_parameters=bind_parameters, + bind_parameter_set=bind_parameter_set, ) def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> SqlExpressionRenderResult: # noqa: D102 @@ -357,7 +357,7 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> SqlEx partition_by_args_rendered = [self.render_sql_expr(x) for x in node.partition_by_args] order_by_args_rendered = {self.render_sql_expr(x.expr): x for x in node.order_by_args} - combined_params = SqlBindParameters() + combined_params = SqlBindParameterSet() args_rendered = [] if sql_function_args_rendered: args_rendered.extend(sql_function_args_rendered) @@ -366,7 +366,7 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> SqlEx if order_by_args_rendered: args_rendered.extend(list(order_by_args_rendered.keys())) for arg_rendered in args_rendered: - combined_params = combined_params.combine(arg_rendered.bind_parameters) + combined_params = combined_params.combine(arg_rendered.bind_parameter_set) sql_function_args_string = ", ".join([x.sql for x in sql_function_args_rendered]) window_string_lines: List[str] = [] @@ -407,17 +407,17 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> SqlEx if len(window_string_lines) <= 1: return SqlExpressionRenderResult( sql=f"{node.sql_function.value}({sql_function_args_string}) OVER ({window_string})", - bind_parameters=combined_params, + bind_parameter_set=combined_params, ) else: indented_window_string = indent(window_string, indent_prefix=SqlRenderingConstants.INDENT) return SqlExpressionRenderResult( sql=f"{node.sql_function.value}({sql_function_args_string}) OVER (\n{indented_window_string}\n)", - bind_parameters=combined_params, + bind_parameter_set=combined_params, ) def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: # noqa: D102 return SqlExpressionRenderResult( sql="UUID()", - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) diff --git a/metricflow/sql/render/postgres.py b/metricflow/sql/render/postgres.py index 04eb2b9bba..36a1b687e9 100644 --- a/metricflow/sql/render/postgres.py +++ b/metricflow/sql/render/postgres.py @@ -5,7 +5,7 @@ from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.errors.error_classes import UnsupportedEngineFeatureError -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -51,21 +51,21 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE count *= 3 return SqlExpressionRenderResult( sql=f"{arg_rendered.sql} - MAKE_INTERVAL({granularity.value}s => {count})", - bind_parameters=arg_rendered.bind_parameters, + bind_parameter_set=arg_rendered.bind_parameter_set, ) @override def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: return SqlExpressionRenderResult( sql="GEN_RANDOM_UUID()", - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) @override def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult: """Render a percentile expression for Postgres.""" arg_rendered = self.render_sql_expr(node.order_by_arg) - params = arg_rendered.bind_parameters + params = arg_rendered.bind_parameter_set percentile = node.percentile_args.percentile if node.percentile_args.function_type is SqlPercentileFunctionType.CONTINUOUS: @@ -87,7 +87,7 @@ def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionR return SqlExpressionRenderResult( sql=f"{function_str}({percentile}) WITHIN GROUP (ORDER BY ({arg_rendered.sql}))", - bind_parameters=params, + bind_parameter_set=params, ) diff --git a/metricflow/sql/render/redshift.py b/metricflow/sql/render/redshift.py index 8074871169..c916cd3a1c 100644 --- a/metricflow/sql/render/redshift.py +++ b/metricflow/sql/render/redshift.py @@ -5,7 +5,7 @@ from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.type_enums.date_part import DatePart from metricflow_semantics.errors.error_classes import UnsupportedEngineFeatureError -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -43,7 +43,7 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult: """Render a percentile expression for Redshift.""" arg_rendered = self.render_sql_expr(node.order_by_arg) - params = arg_rendered.bind_parameters + params = arg_rendered.bind_parameter_set percentile = node.percentile_args.percentile if node.percentile_args.function_type is SqlPercentileFunctionType.CONTINUOUS: @@ -65,7 +65,7 @@ def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionR return SqlExpressionRenderResult( sql=f"{function_str}({percentile}) WITHIN GROUP (ORDER BY ({arg_rendered.sql}))", - bind_parameters=params, + bind_parameter_set=params, ) @override @@ -90,7 +90,7 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> SqlExpressionRenderR return SqlExpressionRenderResult( sql=case_expr, - bind_parameters=extract_rendering_result.bind_parameters, + bind_parameter_set=extract_rendering_result.bind_parameter_set, ) @override @@ -104,7 +104,7 @@ def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpres """ return SqlExpressionRenderResult( sql="CONCAT(CAST(RANDOM()*100000000 AS INT)::VARCHAR,CAST(RANDOM()*100000000 AS INT)::VARCHAR)", - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) diff --git a/metricflow/sql/render/snowflake.py b/metricflow/sql/render/snowflake.py index f623d89691..bc125087e8 100644 --- a/metricflow/sql/render/snowflake.py +++ b/metricflow/sql/render/snowflake.py @@ -5,7 +5,7 @@ from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.type_enums.date_part import DatePart from metricflow_semantics.errors.error_classes import UnsupportedEngineFeatureError -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -47,14 +47,14 @@ def render_date_part(self, date_part: DatePart) -> str: def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: return SqlExpressionRenderResult( sql="UUID_STRING()", - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) @override def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult: """Render a percentile expression for Snowflake.""" arg_rendered = self.render_sql_expr(node.order_by_arg) - params = arg_rendered.bind_parameters + params = arg_rendered.bind_parameter_set percentile = node.percentile_args.percentile if node.percentile_args.function_type is SqlPercentileFunctionType.CONTINUOUS: @@ -64,7 +64,7 @@ def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionR elif node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS: return SqlExpressionRenderResult( sql=f"APPROX_PERCENTILE({arg_rendered.sql}, {percentile})", - bind_parameters=params, + bind_parameter_set=params, ) elif node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_DISCRETE: raise UnsupportedEngineFeatureError( @@ -76,7 +76,7 @@ def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionR return SqlExpressionRenderResult( sql=f"{function_str}({percentile}) WITHIN GROUP (ORDER BY ({arg_rendered.sql}))", - bind_parameters=params, + bind_parameter_set=params, ) diff --git a/metricflow/sql/render/sql_plan_renderer.py b/metricflow/sql/render/sql_plan_renderer.py index 7907541c1c..39523a7f8a 100644 --- a/metricflow/sql/render/sql_plan_renderer.py +++ b/metricflow/sql/render/sql_plan_renderer.py @@ -8,7 +8,7 @@ from typing import List, Optional, Sequence, Tuple from metricflow_semantics.mf_logging.formatting import indent -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow.sql.render.expr_renderer import ( DefaultSqlExpressionRenderer, @@ -36,7 +36,7 @@ class SqlPlanRenderResult: # noqa: D101 # The SQL string that could be run. sql: str # The execution parameters that should be specified along with the SQL str to execute() - bind_parameters: SqlBindParameters + bind_parameter_set: SqlBindParameterSet class SqlQueryPlanRenderer(SqlQueryPlanNodeVisitor[SqlPlanRenderResult], ABC): @@ -76,7 +76,7 @@ def _render_select_columns_section( select_columns: Sequence[SqlSelectColumn], num_parents: int, distinct: bool, - ) -> Tuple[str, SqlBindParameters]: + ) -> Tuple[str, SqlBindParameterSet]: """Convert the select columns into a "SELECT" section. e.g. @@ -87,13 +87,13 @@ def _render_select_columns_section( Returns a tuple of the "SELECT" section as a string and the associated execution parameters. """ - params = SqlBindParameters() + params = SqlBindParameterSet() select_section_lines = ["SELECT DISTINCT" if distinct else "SELECT"] first_column = True for select_column in select_columns: expr_rendered = self.EXPR_RENDERER.render_sql_expr(select_column.expr) # Merge all execution parameters together. Similar pattern follows below. - params = params.combine(expr_rendered.bind_parameters) + params = params.combine(expr_rendered.bind_parameter_set) column_select_str = f"{expr_rendered.sql} AS {select_column.column_alias}" @@ -123,7 +123,7 @@ def _render_select_columns_section( def _render_from_section( self, from_source: SqlQueryPlanNode, from_source_alias: str - ) -> Tuple[str, SqlBindParameters]: + ) -> Tuple[str, SqlBindParameterSet]: """Convert the node into a "FROM" section. e.g. @@ -146,9 +146,9 @@ def _render_from_section( from_section_lines.append(f") {from_source_alias}") from_section = "\n".join(from_section_lines) - return from_section, from_render_result.bind_parameters + return from_section, from_render_result.bind_parameter_set - def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) -> Tuple[str, SqlBindParameters]: + def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) -> Tuple[str, SqlBindParameterSet]: """Convert the join descriptions into a "JOIN" section. e.g. @@ -160,18 +160,18 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) Returns a tuple of the "JOIN" section as a string and the associated execution parameters. """ - params = SqlBindParameters() + params = SqlBindParameterSet() join_section_lines = [] for join_description in join_descriptions: # Render the source for the join right_source_rendered = self._render_node(join_description.right_source) - params = params.combine(right_source_rendered.bind_parameters) + params = params.combine(right_source_rendered.bind_parameter_set) # Render the on condition for the join on_condition_rendered: Optional[SqlExpressionRenderResult] = None if join_description.on_condition: on_condition_rendered = self.EXPR_RENDERER.render_sql_expr(join_description.on_condition) - params = params.combine(on_condition_rendered.bind_parameters) + params = params.combine(on_condition_rendered.bind_parameter_set) if join_description.right_source.is_table: join_section_lines.append(join_description.join_type.value) @@ -196,7 +196,7 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) return "\n".join(join_section_lines), params - def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) -> Tuple[str, SqlBindParameters]: + def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) -> Tuple[str, SqlBindParameterSet]: """Convert the group by columns into a "GROUP BY" section. e.g. @@ -206,11 +206,11 @@ def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) Returns a tuple of the "GROUP BY" section as a string and the associated execution parameters. """ group_by_section_lines: List[str] = [] - params = SqlBindParameters() + params = SqlBindParameterSet() first = True for group_by_column in group_by_columns: group_by_expr_rendered = self.EXPR_RENDERER.render_group_by_expr(group_by_column) - params = params.combine(group_by_expr_rendered.bind_parameters) + params = params.combine(group_by_expr_rendered.bind_parameter_set) if first: first = False group_by_section_lines.append("GROUP BY") @@ -226,7 +226,7 @@ def _render_group_by_section(self, group_by_columns: Sequence[SqlSelectColumn]) def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRenderResult: # noqa: D102 # Keep track of all execution parameters for all expressions - combined_params = SqlBindParameters() + combined_params = SqlBindParameterSet() # Render description section description_section = "\n".join([f"-- {x}" for x in node.description.split("\n") if x]) @@ -253,7 +253,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRe where_section = None if node.where: where_render_result = self.EXPR_RENDERER.render_sql_expr(node.where) - combined_params = combined_params.combine(where_render_result.bind_parameters) + combined_params = combined_params.combine(where_render_result.bind_parameter_set) where_section = f"WHERE {where_render_result.sql}" # Render "ORDER BY" section @@ -263,7 +263,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRe for order_by in node.order_bys: order_by_render_result = self.EXPR_RENDERER.render_sql_expr(order_by.expr) order_by_items.append(order_by_render_result.sql + (" DESC" if order_by.desc else "")) - combined_params = combined_params.combine(order_by_render_result.bind_parameters) + combined_params = combined_params.combine(order_by_render_result.bind_parameter_set) order_by_section = "ORDER BY " + ", ".join(order_by_items) @@ -298,19 +298,19 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRe return SqlPlanRenderResult( sql="\n".join(sections_to_render), - bind_parameters=combined_params, + bind_parameter_set=combined_params, ) def visit_table_node(self, node: SqlTableNode) -> SqlPlanRenderResult: # noqa: D102 return SqlPlanRenderResult( sql=node.sql_table.sql, - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanRenderResult: # noqa: D102 return SqlPlanRenderResult( sql=node.select_query.rstrip(), - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanRenderResult: # noqa: D102 @@ -330,7 +330,7 @@ def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanRende return SqlPlanRenderResult( sql=sql, - bind_parameters=inner_sql_render_result.bind_parameters, + bind_parameter_set=inner_sql_render_result.bind_parameter_set, ) @property diff --git a/metricflow/sql/render/trino.py b/metricflow/sql/render/trino.py index 5bfda74fd5..4ecc72282b 100644 --- a/metricflow/sql/render/trino.py +++ b/metricflow/sql/render/trino.py @@ -6,7 +6,7 @@ from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -41,7 +41,7 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: return SqlExpressionRenderResult( sql="uuid()", - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ) @override @@ -56,20 +56,20 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE count *= 3 return SqlExpressionRenderResult( sql=f"DATE_ADD('{granularity.value}', -{count}, {arg_rendered.sql})", - bind_parameters=arg_rendered.bind_parameters, + bind_parameter_set=arg_rendered.bind_parameter_set, ) @override def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult: """Render a percentile expression for Trino.""" arg_rendered = self.render_sql_expr(node.order_by_arg) - params = arg_rendered.bind_parameters + params = arg_rendered.bind_parameter_set percentile = node.percentile_args.percentile if node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS: return SqlExpressionRenderResult( sql=f"approx_percentile({arg_rendered.sql}, {percentile})", - bind_parameters=params, + bind_parameter_set=params, ) elif ( node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_DISCRETE @@ -90,10 +90,10 @@ def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderR rendered_start_expr = self.render_sql_expr(node.start_expr) rendered_end_expr = self.render_sql_expr(node.end_expr) - bind_parameters = SqlBindParameters() - bind_parameters = bind_parameters.combine(rendered_column_arg.bind_parameters) - bind_parameters = bind_parameters.combine(rendered_start_expr.bind_parameters) - bind_parameters = bind_parameters.combine(rendered_end_expr.bind_parameters) + bind_parameter_set = SqlBindParameterSet() + bind_parameter_set = bind_parameter_set.combine(rendered_column_arg.bind_parameter_set) + bind_parameter_set = bind_parameter_set.combine(rendered_start_expr.bind_parameter_set) + bind_parameter_set = bind_parameter_set.combine(rendered_end_expr.bind_parameter_set) # Handle timestamp literals differently. if parse(rendered_start_expr.sql): @@ -103,7 +103,7 @@ def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderR return SqlExpressionRenderResult( sql=sql, - bind_parameters=bind_parameters, + bind_parameter_set=bind_parameter_set, ) @override diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 7afdf8c026..775e962a8b 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -17,7 +17,7 @@ from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.visitor import Visitable, VisitorOutputT from typing_extensions import override @@ -47,13 +47,13 @@ def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOu pass @property - def bind_parameters(self) -> SqlBindParameters: + def bind_parameter_set(self) -> SqlBindParameterSet: """Execution parameters when running a query containing this expression. * See: https://docs.sqlalchemy.org/en/14/core/tutorial.html#using-textual-sql * Generally only defined for string expressions. """ - return SqlBindParameters() + return SqlBindParameterSet() @property def as_column_reference_expression(self) -> Optional[SqlColumnReferenceExpression]: @@ -239,7 +239,7 @@ class SqlStringExpression(SqlExpressionNode): Attributes: sql_expr: The SQL in string form. - bind_parameters: See SqlExpressionNode.bind_parameters + bind_parameter_set: See SqlExpressionNode.bind_parameter_set requires_parenthesis: Whether this should be rendered with () if nested in another expression. used_columns: If set, indicates that the expression represented by the string only uses those columns. e.g. sql_expr="a + b", used_columns=["a", "b"]. This may be used by optimizers, and if specified, it must be @@ -247,21 +247,21 @@ class SqlStringExpression(SqlExpressionNode): """ sql_expr: str - bind_parameters: SqlBindParameters = SqlBindParameters() + bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet() requires_parenthesis: bool = True used_columns: Optional[Tuple[str, ...]] = None @staticmethod def create( # noqa: D102 sql_expr: str, - bind_parameters: SqlBindParameters = SqlBindParameters(), + bind_parameter_set: SqlBindParameterSet = SqlBindParameterSet(), requires_parenthesis: bool = True, used_columns: Optional[Tuple[str, ...]] = None, ) -> SqlStringExpression: return SqlStringExpression( parent_nodes=(), sql_expr=sql_expr, - bind_parameters=bind_parameters, + bind_parameter_set=bind_parameter_set, requires_parenthesis=requires_parenthesis, used_columns=used_columns, ) @@ -305,7 +305,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return ( self.sql_expr == other.sql_expr and self.used_columns == other.used_columns - and self.bind_parameters == other.bind_parameters + and self.bind_parameter_set == other.bind_parameter_set ) @property @@ -344,8 +344,8 @@ def requires_parenthesis(self) -> bool: # noqa: D102 return False @property - def bind_parameters(self) -> SqlBindParameters: # noqa: D102 - return SqlBindParameters() + def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102 + return SqlBindParameterSet() def __repr__(self) -> str: # noqa: D105 return f"{self.__class__.__name__}(node_id={self.node_id}, literal_value={self.literal_value})" @@ -1630,8 +1630,8 @@ def requires_parenthesis(self) -> bool: # noqa: D102 return False @property - def bind_parameters(self) -> SqlBindParameters: # noqa: D102 - return SqlBindParameters() + def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102 + return SqlBindParameterSet() def __repr__(self) -> str: # noqa: D105 return f"{self.__class__.__name__}(node_id={self.node_id})" diff --git a/metricflow/validation/data_warehouse_model_validator.py b/metricflow/validation/data_warehouse_model_validator.py index ffc0e89190..58b1055d38 100644 --- a/metricflow/validation/data_warehouse_model_validator.py +++ b/metricflow/validation/data_warehouse_model_validator.py @@ -34,7 +34,7 @@ from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec from metricflow_semantics.specs.measure_spec import MeasureSpec from metricflow_semantics.specs.spec_set import InstanceSpecSet -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver from metricflow.dataflow.builder.source_node import SourceNodeBuilder @@ -85,7 +85,7 @@ def __init__(self, manifest: SemanticManifest) -> None: # noqa: D107 class DataWarehouseValidationTask: """Dataclass for defining a task to be used in the DataWarehouseModelValidator.""" - query_and_params_callable: Callable[[], Tuple[str, SqlBindParameters]] + query_and_params_callable: Callable[[], Tuple[str, SqlBindParameterSet]] error_message: str description: str context: Optional[ValidationContext] = None @@ -123,8 +123,8 @@ def _semantic_model_nodes( @staticmethod def renderize( sql_client: SqlClient, plan_converter: DataflowToSqlQueryPlanConverter, plan_id: str, nodes: FilterElementsNode - ) -> Tuple[str, SqlBindParameters]: - """Generates a sql query plan and returns the rendered sql and bind_parameters.""" + ) -> Tuple[str, SqlBindParameterSet]: + """Generates a sql query plan and returns the rendered sql and bind_parameter_set.""" conversion_result = plan_converter.convert_to_sql_query_plan( sql_engine_type=sql_client.sql_engine_type, dataflow_plan_node=nodes, @@ -132,7 +132,7 @@ def renderize( sql_plan = conversion_result.sql_plan rendered_plan = sql_client.sql_query_plan_renderer.render_sql_query_plan(sql_plan) - return (rendered_plan.sql, rendered_plan.bind_parameters) + return (rendered_plan.sql, rendered_plan.bind_parameter_set) @classmethod def gen_semantic_model_tasks( @@ -148,7 +148,7 @@ def gen_semantic_model_tasks( query_and_params_callable=partial( lambda name=semantic_model.node_relation.relation_name: ( f"SELECT * FROM {name}", - SqlBindParameters(), + SqlBindParameterSet(), ) ), context=SemanticModelContext( @@ -450,11 +450,11 @@ def gen_measure_tasks( @staticmethod def _gen_explain_query_task_query_and_params( mf_engine: MetricFlowEngine, mf_request: MetricFlowQueryRequest - ) -> Tuple[str, SqlBindParameters]: + ) -> Tuple[str, SqlBindParameterSet]: explain_result: MetricFlowExplainResult = mf_engine.explain(mf_request=mf_request) return ( explain_result.rendered_sql_without_descriptions.sql_query, - explain_result.rendered_sql_without_descriptions.bind_parameters, + explain_result.rendered_sql_without_descriptions.bind_parameter_set, ) @classmethod @@ -569,7 +569,7 @@ def run_tasks( break try: (query_string, query_params) = task.query_and_params_callable() - self._sql_client.dry_run(stmt=query_string, sql_bind_parameters=query_params) + self._sql_client.dry_run(stmt=query_string, sql_bind_parameter_set=query_params) except Exception as e: issues.append( ValidationError( diff --git a/tests_metricflow/dataflow/optimizer/test_predicate_pushdown_optimizer.py b/tests_metricflow/dataflow/optimizer/test_predicate_pushdown_optimizer.py index d1851855d3..e9aa84fb0a 100644 --- a/tests_metricflow/dataflow/optimizer/test_predicate_pushdown_optimizer.py +++ b/tests_metricflow/dataflow/optimizer/test_predicate_pushdown_optimizer.py @@ -10,7 +10,7 @@ from metricflow_semantics.specs.linkable_spec_set import LinkableSpecSet from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow_semantics.test_helpers.snapshot_helpers import assert_plan_snapshot_text_equal @@ -78,7 +78,7 @@ def test_branch_state_propagation(branch_state_tracker: PredicatePushdownBranchS where_filter_specs=( WhereFilterSpec( where_sql="x is true", - bind_parameters=SqlBindParameters(), + bind_parameters=SqlBindParameterSet(), linkable_element_unions=(), linkable_spec_set=LinkableSpecSet(), ), @@ -115,13 +115,13 @@ def test_applied_filter_back_propagation(branch_state_tracker: PredicatePushdown base_state = branch_state_tracker.last_pushdown_state where_spec_x_is_true = WhereFilterSpec( where_sql="x is true", - bind_parameters=SqlBindParameters(), + bind_parameters=SqlBindParameterSet(), linkable_element_unions=(), linkable_spec_set=LinkableSpecSet(), ) where_spec_y_is_null = WhereFilterSpec( where_sql="y is null", - bind_parameters=SqlBindParameters(), + bind_parameters=SqlBindParameterSet(), linkable_element_unions=(), linkable_spec_set=LinkableSpecSet(), ) diff --git a/tests_metricflow/execution/test_tasks.py b/tests_metricflow/execution/test_tasks.py index eed19ed646..b6250d4693 100644 --- a/tests_metricflow/execution/test_tasks.py +++ b/tests_metricflow/execution/test_tasks.py @@ -2,7 +2,7 @@ from metricflow_semantics.dag.mf_dag import DagId from metricflow_semantics.random_id import random_id -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration @@ -19,7 +19,7 @@ def test_read_sql_task(sql_client: SqlClient) -> None: # noqa: D103 - task = SelectSqlQueryToDataTableTask.create(sql_client, SqlQuery("SELECT 1 AS foo", SqlBindParameters())) + task = SelectSqlQueryToDataTableTask.create(sql_client, SqlQuery("SELECT 1 AS foo", SqlBindParameterSet())) execution_plan = ExecutionPlan(leaf_tasks=[task], dag_id=DagId.from_str("plan0")) results = SequentialPlanExecutor().execute_plan(execution_plan) @@ -46,7 +46,7 @@ def test_write_table_task( # noqa: D103 sql_client=sql_client, sql_query=SqlQuery( sql_query=f"CREATE TABLE {output_table.sql} AS SELECT 1 AS foo", - bind_parameters=SqlBindParameters(), + bind_parameter_set=SqlBindParameterSet(), ), output_table=output_table, ) diff --git a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py index cc75e995f7..a7c45edc0f 100644 --- a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py +++ b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py @@ -26,7 +26,7 @@ from metricflow_semantics.specs.spec_set import InstanceSpecSet from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow_semantics.test_helpers.metric_time_dimension import MTD_SPEC_DAY @@ -196,7 +196,7 @@ def test_filter_with_where_constraint_node( where_specs=( WhereFilterSpec( where_sql="booking__ds__day = '2020-01-01'", - bind_parameters=SqlBindParameters(), + bind_parameters=SqlBindParameterSet(), linkable_spec_set=LinkableSpecSet( time_dimension_specs=( TimeDimensionSpec( diff --git a/tests_metricflow/sql_clients/test_sql_client.py b/tests_metricflow/sql_clients/test_sql_client.py index b229822c26..068e2a111c 100644 --- a/tests_metricflow/sql_clients/test_sql_client.py +++ b/tests_metricflow/sql_clients/test_sql_client.py @@ -6,7 +6,7 @@ import pytest from dbt_semantic_interfaces.test_utils import as_datetime from metricflow_semantics.random_id import random_id -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration @@ -122,15 +122,15 @@ def test_dry_run_of_bad_query_raises_exception(sql_client: SqlClient) -> None: def test_update_params_with_same_item() -> None: # noqa: D103 - bind_params0 = SqlBindParameters.create_from_dict({"key": "value"}) - bind_params1 = SqlBindParameters.create_from_dict({"key": "value"}) + bind_params0 = SqlBindParameterSet.create_from_dict({"key": "value"}) + bind_params1 = SqlBindParameterSet.create_from_dict({"key": "value"}) bind_params0.combine(bind_params1) def test_update_params_with_same_key_different_values() -> None: # noqa: D103 - bind_params0 = SqlBindParameters.create_from_dict(({"key": "value0"})) - bind_params1 = SqlBindParameters.create_from_dict(({"key": "value1"})) + bind_params0 = SqlBindParameterSet.create_from_dict(({"key": "value0"})) + bind_params1 = SqlBindParameterSet.create_from_dict(({"key": "value1"})) with pytest.raises(RuntimeError): bind_params0.combine(bind_params1) diff --git a/tests_metricflow/validation/test_data_warehouse_tasks.py b/tests_metricflow/validation/test_data_warehouse_tasks.py index 75ff968a2b..093047ed6f 100644 --- a/tests_metricflow/validation/test_data_warehouse_tasks.py +++ b/tests_metricflow/validation/test_data_warehouse_tasks.py @@ -14,7 +14,7 @@ from dbt_semantic_interfaces.test_utils import semantic_model_with_guaranteed_meta from dbt_semantic_interfaces.transformations.semantic_manifest_transformer import PydanticSemanticManifestTransformer from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters +from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow.protocols.sql_client import SqlClient @@ -59,8 +59,8 @@ def test_build_semantic_model_tasks( # noqa: D103 def test_task_runner(sql_client: SqlClient, mf_test_configuration: MetricFlowTestConfiguration) -> None: # noqa: D103 dw_validator = DataWarehouseModelValidator(sql_client=sql_client) - def good_query() -> Tuple[str, SqlBindParameters]: - return ("SELECT 'foo' AS foo", SqlBindParameters()) + def good_query() -> Tuple[str, SqlBindParameterSet]: + return ("SELECT 'foo' AS foo", SqlBindParameterSet()) tasks = [ DataWarehouseValidationTask( @@ -71,8 +71,8 @@ def good_query() -> Tuple[str, SqlBindParameters]: issues = dw_validator.run_tasks(tasks=tasks) assert len(issues.all_issues) == 0 - def bad_query() -> Tuple[str, SqlBindParameters]: - return ("SELECT (true) AS col1 FROM doesnt_exist", SqlBindParameters()) + def bad_query() -> Tuple[str, SqlBindParameterSet]: + return ("SELECT (true) AS col1 FROM doesnt_exist", SqlBindParameterSet()) err_msg_bad = "Could not access table 'doesnt_exist' in data warehouse" bad_task = DataWarehouseValidationTask(