diff --git a/dbt-metricflow/dbt_metricflow/cli/main.py b/dbt-metricflow/dbt_metricflow/cli/main.py index 2b8d3300c6..7c248654bc 100644 --- a/dbt-metricflow/dbt_metricflow/cli/main.py +++ b/dbt-metricflow/dbt_metricflow/cli/main.py @@ -298,9 +298,9 @@ def query( if explain: assert explain_result sql = ( - explain_result.rendered_sql_without_descriptions.sql_query + explain_result.sql_statement.without_descriptions.sql if not show_sql_descriptions - else explain_result.rendered_sql.sql_query + else explain_result.sql_statement.sql ) if show_dataflow_plan: click.echo("🔎 Generated Dataflow Plan + SQL (remove --explain to see data):") diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index d50d792b9a..b4343c527d 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -51,7 +51,7 @@ from metricflow.execution.dataflow_to_execution import ( DataflowToExecutionPlanConverter, ) -from metricflow.execution.execution_plan import ExecutionPlan, SqlQuery +from metricflow.execution.execution_plan import ExecutionPlan, SqlStatement from metricflow.execution.executor import SequentialPlanExecutor from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter from metricflow.protocols.sql_client import SqlClient @@ -177,35 +177,31 @@ class MetricFlowExplainResult: output_table: Optional[SqlTable] = None @property - def rendered_sql(self) -> SqlQuery: + def sql_statement(self) -> SqlStatement: """Return the SQL query that would be run for the given query.""" execution_plan = self.execution_plan if len(execution_plan.tasks) != 1: raise NotImplementedError( - f"Multiple tasks in the execution plan not yet supported. Got tasks: {execution_plan.tasks}" + str( + LazyFormat( + "Multiple tasks in the execution plan not yet supported.", + tasks=[task.task_id for task in execution_plan.tasks], + ) + ) ) - sql_query = execution_plan.tasks[0].sql_query - if not sql_query: + sql_statement = execution_plan.tasks[0].sql_statement + if not sql_statement: raise NotImplementedError( - f"Execution plan tasks without a SQL query not yet supported. Got tasks: {execution_plan.tasks}" + str( + LazyFormat( + "Execution plan tasks without a SQL statement are not yet supported.", + tasks=[task.task_id for task in execution_plan.tasks], + ) + ) ) - return sql_query - - @property - def rendered_sql_without_descriptions(self) -> SqlQuery: - """Return the SQL query without the inline descriptions.""" - sql_query = self.rendered_sql - return SqlQuery( - sql_query="\n".join( - filter( - lambda line: not line.strip().startswith("--"), - sql_query.sql_query.split("\n"), - ) - ), - bind_parameter_set=sql_query.bind_parameter_set, - ) + return sql_statement @property def execution_plan(self) -> ExecutionPlan: # noqa: D102 diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index b5369f7350..1c2aa2fd95 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -36,7 +36,7 @@ ExecutionPlan, SelectSqlQueryToDataTableTask, SelectSqlQueryToTableTask, - SqlQuery, + SqlStatement, ) from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter @@ -91,7 +91,7 @@ def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode leaf_tasks=( SelectSqlQueryToDataTableTask.create( sql_client=self._sql_client, - sql_query=SqlQuery(render_sql_result.sql, render_sql_result.bind_parameter_set), + sql_statement=SqlStatement(render_sql_result.sql, render_sql_result.bind_parameter_set), ), ) ) @@ -109,8 +109,8 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> Conv leaf_tasks=( SelectSqlQueryToTableTask.create( sql_client=self._sql_client, - sql_query=SqlQuery( - sql_query=render_sql_result.sql, + sql_statement=SqlStatement( + sql=render_sql_result.sql, bind_parameter_set=render_sql_result.bind_parameter_set, ), output_table=node.output_sql_table, diff --git a/metricflow/execution/execution_plan.py b/metricflow/execution/execution_plan.py index cae16bad9f..c3e044a6fb 100644 --- a/metricflow/execution/execution_plan.py +++ b/metricflow/execution/execution_plan.py @@ -27,10 +27,10 @@ class ExecutionPlanTask(DagNode["ExecutionPlanTask"], Visitable, ABC): for these nodes as it seems more intuitive. Attributes: - sql_query: If this runs a SQL query, return the associated SQL. + sql_statement: If this runs a SQL query, return the associated SQL. """ - sql_query: Optional[SqlQuery] + sql_statement: Optional[SqlStatement] @abstractmethod def execute(self) -> TaskExecutionResult: @@ -44,13 +44,26 @@ def task_id(self) -> NodeId: @dataclass(frozen=True) -class SqlQuery: - """A SQL query that can be run along with bind parameters.""" +class SqlStatement: + """Encapsulates a SQL statement along with the bind parameters that should be used.""" # This field will be renamed as it is confusing given the class name. - sql_query: str + sql: str bind_parameter_set: SqlBindParameterSet + @property + def without_descriptions(self) -> SqlStatement: + """Return the SQL query without the inline descriptions.""" + return SqlStatement( + sql="\n".join( + filter( + lambda line: not line.strip().startswith("--"), + self.sql.split("\n"), + ) + ), + bind_parameter_set=self.bind_parameter_set, + ) + @dataclass(frozen=True) class TaskExecutionError(Exception): @@ -80,7 +93,7 @@ class SelectSqlQueryToDataTableTask(ExecutionPlanTask): Attributes: sql_client: The SQL client used to run the query. - sql_query: The SQL query to run. + sql_statement: The SQL query to run. parent_nodes: The parent tasks for this execution plan task. """ @@ -90,12 +103,12 @@ class SelectSqlQueryToDataTableTask(ExecutionPlanTask): @staticmethod def create( # noqa: D102 sql_client: SqlClient, - sql_query: SqlQuery, + sql_statement: SqlStatement, parent_nodes: Sequence[ExecutionPlanTask] = (), ) -> SelectSqlQueryToDataTableTask: return SelectSqlQueryToDataTableTask( sql_client=sql_client, - sql_query=sql_query, + sql_statement=sql_statement, parent_nodes=tuple(parent_nodes), ) @@ -109,31 +122,30 @@ def description(self) -> str: # noqa: D102 @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 - sql_query = self.sql_query - assert sql_query is not None, f"{self.sql_query=} should have been set during creation." - return tuple(super().displayed_properties) + (DisplayedProperty(key="sql_query", value=sql_query.sql_query),) + assert self.sql_statement is not None, f"{self.sql_statement=} should have been set during creation." + return tuple(super().displayed_properties) + (DisplayedProperty(key="sql", value=self.sql_statement.sql),) def execute(self) -> TaskExecutionResult: # noqa: D102 start_time = time.time() - sql_query = self.sql_query - assert sql_query is not None, f"{self.sql_query=} should have been set during creation." + sql_statement = self.sql_statement + assert sql_statement is not None, f"{self.sql_statement=} should have been set during creation." df = self.sql_client.query( - sql_query.sql_query, - sql_bind_parameter_set=sql_query.bind_parameter_set, + sql_statement.sql, + sql_bind_parameter_set=sql_statement.bind_parameter_set, ) end_time = time.time() return TaskExecutionResult( start_time=start_time, end_time=end_time, - sql=sql_query.sql_query, - bind_params=sql_query.bind_parameter_set, + sql=sql_statement.sql, + bind_params=sql_statement.bind_parameter_set, df=df, ) def __repr__(self) -> str: # noqa: D105 - return f"{self.__class__.__name__}(sql_query='{self.sql_query}')" + return f"{self.__class__.__name__}(sql_statement={self.sql_statement!r})" @dataclass(frozen=True) @@ -144,7 +156,7 @@ class SelectSqlQueryToTableTask(ExecutionPlanTask): Attributes: sql_client: The SQL client used to run the query. - sql_query: The SQL query to run. + sql_statement: The SQL query to run. output_table: The table where the results will be written. """ @@ -154,13 +166,13 @@ class SelectSqlQueryToTableTask(ExecutionPlanTask): @staticmethod def create( # noqa: D102 sql_client: SqlClient, - sql_query: SqlQuery, + sql_statement: SqlStatement, output_table: SqlTable, parent_nodes: Sequence[ExecutionPlanTask] = (), ) -> SelectSqlQueryToTableTask: return SelectSqlQueryToTableTask( sql_client=sql_client, - sql_query=sql_query, + sql_statement=sql_statement, output_table=output_table, parent_nodes=tuple(parent_nodes), ) @@ -175,31 +187,31 @@ def description(self) -> str: # noqa: D102 @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 - sql_query = self.sql_query - assert sql_query is not None, f"{self.sql_query=} should have been set during creation." + sql_statement = self.sql_statement + assert sql_statement is not None, f"{self.sql_statement=} should have been set during creation." return tuple(super().displayed_properties) + ( - DisplayedProperty(key="sql_query", value=sql_query.sql_query), + DisplayedProperty(key="sql_statement", value=sql_statement.sql), DisplayedProperty(key="output_table", value=self.output_table), - DisplayedProperty(key="bind_parameter_set", value=sql_query.bind_parameter_set), + DisplayedProperty(key="bind_parameter_set", value=sql_statement.bind_parameter_set), ) def execute(self) -> TaskExecutionResult: # noqa: D102 - sql_query = self.sql_query - assert sql_query is not None, f"{self.sql_query=} should have been set during creation." + sql_statement = self.sql_statement + assert sql_statement is not None, f"{self.sql_statement=} should have been set during creation." start_time = time.time() logger.debug(LazyFormat(lambda: f"Dropping table {self.output_table} in case it already exists")) self.sql_client.execute(f"DROP TABLE IF EXISTS {self.output_table.sql}") logger.debug(LazyFormat(lambda: f"Creating table {self.output_table} using a query")) self.sql_client.execute( - sql_query.sql_query, - sql_bind_parameter_set=sql_query.bind_parameter_set, + sql_statement.sql, + sql_bind_parameter_set=sql_statement.bind_parameter_set, ) end_time = time.time() - return TaskExecutionResult(start_time=start_time, end_time=end_time, sql=sql_query.sql_query) + return TaskExecutionResult(start_time=start_time, end_time=end_time, sql=sql_statement.sql) def __repr__(self) -> str: # noqa: D105 - return f"{self.__class__.__name__}(sql_query='{self.sql_query}', output_table={self.output_table})" + return f"{self.__class__.__name__}(sql_statement={self.sql_statement!r}', output_table={self.output_table})" class ExecutionPlan(MetricFlowDag[ExecutionPlanTask]): diff --git a/metricflow/validation/data_warehouse_model_validator.py b/metricflow/validation/data_warehouse_model_validator.py index 58b1055d38..95efadcb3e 100644 --- a/metricflow/validation/data_warehouse_model_validator.py +++ b/metricflow/validation/data_warehouse_model_validator.py @@ -453,8 +453,8 @@ def _gen_explain_query_task_query_and_params( ) -> Tuple[str, SqlBindParameterSet]: explain_result: MetricFlowExplainResult = mf_engine.explain(mf_request=mf_request) return ( - explain_result.rendered_sql_without_descriptions.sql_query, - explain_result.rendered_sql_without_descriptions.bind_parameter_set, + explain_result.sql_statement.without_descriptions.sql, + explain_result.sql_statement.without_descriptions.bind_parameter_set, ) @classmethod diff --git a/tests_metricflow/engine/test_explain.py b/tests_metricflow/engine/test_explain.py index f23f419981..92c5817332 100644 --- a/tests_metricflow/engine/test_explain.py +++ b/tests_metricflow/engine/test_explain.py @@ -21,7 +21,7 @@ def _explain_one_query(mf_engine: MetricFlowEngine) -> str: explain_result: MetricFlowExplainResult = mf_engine.explain( MetricFlowQueryRequest.create_with_random_request_id(saved_query_name="p0_booking") ) - return explain_result.rendered_sql.sql_query + return explain_result.sql_statement.sql def test_concurrent_explain_consistency( @@ -64,7 +64,7 @@ def test_optimization_level( sql_optimization_level=optimization_level, ) ) - results[optimization_level.value] = explain_result.rendered_sql_without_descriptions.sql_query + results[optimization_level.value] = explain_result.sql_statement.without_descriptions.sql assert_str_snapshot_equal( request=request, diff --git a/tests_metricflow/execution/noop_task.py b/tests_metricflow/execution/noop_task.py index 8f65ec049a..ad06af4966 100644 --- a/tests_metricflow/execution/noop_task.py +++ b/tests_metricflow/execution/noop_task.py @@ -35,7 +35,7 @@ def create( # noqa: D102 ) -> NoOpExecutionPlanTask: return NoOpExecutionPlanTask( parent_nodes=tuple(parent_tasks), - sql_query=None, + sql_statement=None, should_error=should_error, ) diff --git a/tests_metricflow/execution/test_tasks.py b/tests_metricflow/execution/test_tasks.py index b6250d4693..a88bcc1478 100644 --- a/tests_metricflow/execution/test_tasks.py +++ b/tests_metricflow/execution/test_tasks.py @@ -11,7 +11,7 @@ ExecutionPlan, SelectSqlQueryToDataTableTask, SelectSqlQueryToTableTask, - SqlQuery, + SqlStatement, ) from metricflow.execution.executor import SequentialPlanExecutor from metricflow.protocols.sql_client import SqlClient, SqlEngine @@ -19,7 +19,7 @@ def test_read_sql_task(sql_client: SqlClient) -> None: # noqa: D103 - task = SelectSqlQueryToDataTableTask.create(sql_client, SqlQuery("SELECT 1 AS foo", SqlBindParameterSet())) + task = SelectSqlQueryToDataTableTask.create(sql_client, SqlStatement("SELECT 1 AS foo", SqlBindParameterSet())) execution_plan = ExecutionPlan(leaf_tasks=[task], dag_id=DagId.from_str("plan0")) results = SequentialPlanExecutor().execute_plan(execution_plan) @@ -44,8 +44,8 @@ def test_write_table_task( # noqa: D103 output_table = SqlTable(schema_name=mf_test_configuration.mf_system_schema, table_name=f"test_table_{random_id()}") task = SelectSqlQueryToTableTask.create( sql_client=sql_client, - sql_query=SqlQuery( - sql_query=f"CREATE TABLE {output_table.sql} AS SELECT 1 AS foo", + sql_statement=SqlStatement( + sql=f"CREATE TABLE {output_table.sql} AS SELECT 1 AS foo", bind_parameter_set=SqlBindParameterSet(), ), output_table=output_table, diff --git a/tests_metricflow/integration/test_rendered_query.py b/tests_metricflow/integration/test_rendered_query.py index ec5d1c6563..7bc004a0d5 100644 --- a/tests_metricflow/integration/test_rendered_query.py +++ b/tests_metricflow/integration/test_rendered_query.py @@ -31,7 +31,7 @@ def test_render_query( # noqa: D103 request=request, mf_test_configuration=mf_test_configuration, snapshot_id="query0", - sql=result.rendered_sql.sql_query, + sql=result.sql_statement.sql, sql_engine=it_helpers.sql_client.sql_engine_type, ) @@ -64,7 +64,7 @@ def test_id_enumeration( # noqa: D103 request=request, mf_test_configuration=mf_test_configuration, snapshot_id="query", - sql=result.rendered_sql.sql_query, + sql=result.sql_statement.sql, sql_engine=sql_client.sql_engine_type, ) @@ -80,6 +80,6 @@ def test_id_enumeration( # noqa: D103 request=request, mf_test_configuration=mf_test_configuration, snapshot_id="query", - sql=result.rendered_sql.sql_query, + sql=result.sql_statement.sql, sql_engine=sql_client.sql_engine_type, ) diff --git a/tests_metricflow/plan_conversion/test_dataflow_to_execution.py b/tests_metricflow/plan_conversion/test_dataflow_to_execution.py index fa67ffd4c9..08db7d27e0 100644 --- a/tests_metricflow/plan_conversion/test_dataflow_to_execution.py +++ b/tests_metricflow/plan_conversion/test_dataflow_to_execution.py @@ -36,6 +36,7 @@ def make_execution_plan_converter( # noqa: D103 @pytest.mark.sql_engine_snapshot +@pytest.mark.duckdb_only def test_joined_plan( # noqa: D103 request: FixtureRequest, mf_test_configuration: MetricFlowTestConfiguration, @@ -75,6 +76,7 @@ def test_joined_plan( # noqa: D103 @pytest.mark.sql_engine_snapshot +@pytest.mark.duckdb_only def test_small_combined_metrics_plan( # noqa: D103 request: FixtureRequest, mf_test_configuration: MetricFlowTestConfiguration, @@ -112,6 +114,7 @@ def test_small_combined_metrics_plan( # noqa: D103 @pytest.mark.sql_engine_snapshot +@pytest.mark.duckdb_only def test_combined_metrics_plan( # noqa: D103 request: FixtureRequest, mf_test_configuration: MetricFlowTestConfiguration, @@ -151,6 +154,7 @@ def test_combined_metrics_plan( # noqa: D103 @pytest.mark.sql_engine_snapshot +@pytest.mark.duckdb_only def test_multihop_joined_plan( request: FixtureRequest, mf_test_configuration: MetricFlowTestConfiguration, diff --git a/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_combined_metrics_plan__ep_0.xml b/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_combined_metrics_plan__ep_0.xml index 2d133f4c66..e6b630593b 100644 --- a/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_combined_metrics_plan__ep_0.xml +++ b/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_combined_metrics_plan__ep_0.xml @@ -5,7 +5,7 @@ test_filename: test_dataflow_to_execution.py - + diff --git a/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_joined_plan__ep_0.xml b/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_joined_plan__ep_0.xml index f13d263688..3744e5eccf 100644 --- a/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_joined_plan__ep_0.xml +++ b/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_joined_plan__ep_0.xml @@ -5,7 +5,7 @@ test_filename: test_dataflow_to_execution.py - + diff --git a/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_multihop_joined_plan__ep_0.xml b/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_multihop_joined_plan__ep_0.xml index 684ff89128..7988cd5e69 100644 --- a/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_multihop_joined_plan__ep_0.xml +++ b/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_multihop_joined_plan__ep_0.xml @@ -7,7 +7,7 @@ docstring: - + diff --git a/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_small_combined_metrics_plan__ep_0.xml b/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_small_combined_metrics_plan__ep_0.xml index 4da40aaf0e..d24a6bae71 100644 --- a/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_small_combined_metrics_plan__ep_0.xml +++ b/tests_metricflow/snapshots/test_dataflow_to_execution.py/ExecutionPlan/DuckDB/test_small_combined_metrics_plan__ep_0.xml @@ -5,7 +5,7 @@ test_filename: test_dataflow_to_execution.py - +