Skip to content

Commit

Permalink
Remove tags in SQL comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Jan 31, 2024
1 parent 44236bd commit 56ace64
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 285 deletions.
30 changes: 5 additions & 25 deletions metricflow/cli/dbt_connectors/adapter_backed_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.render.trino import TrinoSqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_request.sql_request_attributes import SqlJsonTag, SqlRequestId, SqlRequestTagSet
from metricflow.sql_request.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata
from metricflow.sql_request.sql_request_attributes import SqlRequestId

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,7 +127,6 @@ def query(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
) -> pd.DataFrame:
"""Query statement; result expected to be data which will be returned as a DataFrame.
Expand All @@ -140,19 +138,15 @@ def query(
"""
start = time.time()
request_id = SqlRequestId(f"mf_rid__{random_id()}")
combined_tags = AdapterBackedSqlClient._consolidate_tags(json_tags=extra_tags, request_id=request_id)
statement = SqlStatementCommentMetadata.add_tag_metadata_as_comment(
sql_statement=stmt, combined_tags=combined_tags
)
if sql_bind_parameters.param_dict:
raise SqlBindParametersNotSupportedError(
f"Invalid execute statement - we do not support queries with bind parameters through dbt adapters! "
f"Bind params: {sql_bind_parameters.param_dict}"
)
logger.info(AdapterBackedSqlClient._format_run_query_log_message(statement, sql_bind_parameters))
logger.info(AdapterBackedSqlClient._format_run_query_log_message(stmt, sql_bind_parameters))
with self._adapter.connection_named(f"MetricFlow_request_{request_id}"):
# returns a Tuple[AdapterResponse, agate.Table] but the decorator converts it to Any
result = self._adapter.execute(sql=statement, auto_begin=True, fetch=True)
result = self._adapter.execute(sql=stmt, auto_begin=True, fetch=True)
logger.info(f"Query returned from dbt Adapter with response {result[0]}")

agate_data = result[1]
Expand All @@ -165,15 +159,13 @@ def execute(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
) -> None:
"""Execute a SQL statement. No result will be returned.
Args:
stmt: The SQL query statement to run. This should not produce output.
sql_bind_parameters: The parameter replacement mapping for filling in
concrete values for SQL query parameters.
extra_tags: An object containing JSON serialized tags meant for annotating queries.
"""
if sql_bind_parameters.param_dict:
raise SqlBindParametersNotSupportedError(
Expand All @@ -182,13 +174,9 @@ def execute(
)
start = time.time()
request_id = SqlRequestId(f"mf_rid__{random_id()}")
combined_tags = AdapterBackedSqlClient._consolidate_tags(json_tags=extra_tags, request_id=request_id)
statement = SqlStatementCommentMetadata.add_tag_metadata_as_comment(
sql_statement=stmt, combined_tags=combined_tags
)
logger.info(AdapterBackedSqlClient._format_run_query_log_message(statement, sql_bind_parameters))
logger.info(AdapterBackedSqlClient._format_run_query_log_message(stmt, sql_bind_parameters))
with self._adapter.connection_named(f"MetricFlow_request_{request_id}"):
result = self._adapter.execute(statement, auto_begin=True, fetch=False)
result = self._adapter.execute(stmt, auto_begin=True, fetch=False)
# Calls to execute often involve some amount of DDL so we commit here
self._adapter.commit_if_has_connection()
logger.info(f"Query executed via dbt Adapter with response {result[0]}")
Expand Down Expand Up @@ -270,11 +258,3 @@ def _format_run_query_log_message(statement: str, sql_bind_parameters: SqlBindPa
if len(sql_bind_parameters.param_dict) > 0:
message += f"\n\nwith parameters:\n\n{indent(mf_pformat(sql_bind_parameters.param_dict))}"
return message

@staticmethod
def _consolidate_tags(json_tags: SqlJsonTag, request_id: SqlRequestId) -> CombinedSqlTags:
"""Consolidates json tags and request ID into a single set of tags."""
return CombinedSqlTags(
system_tags=SqlRequestTagSet().add_request_id(request_id=request_id),
extra_tag=json_tags,
)
7 changes: 0 additions & 7 deletions metricflow/execution/execution_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from metricflow.dataflow.sql_table import SqlTable
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_request.sql_request_attributes import SqlJsonTag
from metricflow.visitor import Visitable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -97,13 +96,11 @@ def __init__( # noqa: D
sql_client: SqlClient,
sql_query: str,
bind_parameters: SqlBindParameters,
extra_sql_tags: SqlJsonTag = SqlJsonTag(),
parent_nodes: Optional[List[ExecutionPlanTask]] = None,
) -> None:
self._sql_client = sql_client
self._sql_query = sql_query
self._bind_parameters = bind_parameters
self._extra_sql_tags = extra_sql_tags
super().__init__(task_id=self.create_unique_id(), parent_nodes=parent_nodes or [])

@classmethod
Expand All @@ -128,7 +125,6 @@ def execute(self) -> TaskExecutionResult: # noqa: D
df = self._sql_client.query(
self._sql_query,
sql_bind_parameters=self.bind_parameters,
extra_tags=self._extra_sql_tags,
)

end_time = time.time()
Expand Down Expand Up @@ -156,14 +152,12 @@ def __init__( # noqa: D
sql_query: str,
bind_parameters: SqlBindParameters,
output_table: SqlTable,
extra_sql_tags: SqlJsonTag = SqlJsonTag(),
parent_nodes: Optional[List[ExecutionPlanTask]] = None,
) -> None:
self._sql_client = sql_client
self._sql_query = sql_query
self._output_table = output_table
self._bind_parameters = bind_parameters
self._extra_sql_tags = extra_sql_tags
super().__init__(task_id=self.create_unique_id(), parent_nodes=parent_nodes or [])

@classmethod
Expand Down Expand Up @@ -192,7 +186,6 @@ def execute(self) -> TaskExecutionResult: # noqa: D
self._sql_client.execute(
sql_query.sql_query,
sql_bind_parameters=sql_query.bind_parameters,
extra_tags=self._extra_sql_tags,
)

end_time = time.time()
Expand Down
6 changes: 0 additions & 6 deletions metricflow/plan_conversion/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql_request.sql_request_attributes import SqlJsonTag

logger = logging.getLogger(__name__)

Expand All @@ -34,20 +33,17 @@ def __init__(
sql_plan_converter: DataflowToSqlQueryPlanConverter,
sql_plan_renderer: SqlQueryPlanRenderer,
sql_client: SqlClient,
extra_sql_tags: SqlJsonTag = SqlJsonTag(),
) -> None:
"""Constructor.
Args:
sql_plan_converter: Converts a dataflow plan node to a SQL query plan
sql_plan_renderer: Converts a SQL query plan to SQL text
sql_client: The client to use for running queries.
extra_sql_tags: Tags to supply to the SQL client when running statements.
"""
self._sql_plan_converter = sql_plan_converter
self._sql_plan_renderer = sql_plan_renderer
self._sql_client = sql_client
self._sql_tags = extra_sql_tags

def _build_execution_plan( # noqa: D
self,
Expand All @@ -70,15 +66,13 @@ def _build_execution_plan( # noqa: D
sql_client=self._sql_client,
sql_query=render_result.sql,
bind_parameters=render_result.bind_parameters,
extra_sql_tags=self._sql_tags,
)
else:
leaf_task = SelectSqlQueryToTableTask(
sql_client=self._sql_client,
sql_query=render_result.sql,
bind_parameters=render_result.bind_parameters,
output_table=output_table,
extra_sql_tags=self._sql_tags,
)

return ExecutionPlan(
Expand Down
3 changes: 0 additions & 3 deletions metricflow/protocols/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_request.sql_request_attributes import SqlJsonTag


class SqlEngine(Enum):
Expand Down Expand Up @@ -53,7 +52,6 @@ def query(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
) -> DataFrame:
"""Base query method, upon execution will run a query that returns a pandas DataFrame."""
raise NotImplementedError
Expand All @@ -63,7 +61,6 @@ def execute(
self,
stmt: str,
sql_bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
) -> None:
"""Base execute method."""
raise NotImplementedError
Expand Down
108 changes: 0 additions & 108 deletions metricflow/sql_request/sql_request_attributes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
from __future__ import annotations

import logging
import typing
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum
from operator import itemgetter
from typing import Any, Dict, Optional, Sequence

from dbt_semantic_interfaces.implementations.base import FrozenBaseModel
from pydantic import Field

logger = logging.getLogger(__name__)

Expand All @@ -22,103 +14,3 @@ class SqlRequestId:

def __repr__(self) -> str: # noqa: D
return self.id_str


class SqlRequestTagSet(FrozenBaseModel):
"""Set of tags as a Pydantic model for easy serialization."""

# Using strings to make for cleaner serialized output. Clients should not use this field directly.
tag_dict: typing.OrderedDict[str, str] = Field(default_factory=OrderedDict)

@property
def tags(self) -> Sequence[SqlRequestTag]: # noqa: D
return tuple(SqlRequestTag(key, value) for key, value in self.tag_dict.items())

@staticmethod
def create_from_dict(tag_dict: Dict[SqlRequestTagKey, str]) -> SqlRequestTagSet: # noqa: D
str_tag_dict = {tag_key_enum.value: value for tag_key_enum, value in tag_dict.items()}
sorted_tuples = tuple(sorted(str_tag_dict.items(), key=itemgetter(0, 1)))
return SqlRequestTagSet(tag_dict=OrderedDict(sorted_tuples))

@staticmethod
def create_from_request_id(request_id: SqlRequestId) -> SqlRequestTagSet:
"""Create a tag set that only includes the tag for the request ID."""
tag_dict = OrderedDict()
tag_dict[SqlRequestTagKey.REQUEST_ID_KEY.value] = request_id.id_str
return SqlRequestTagSet(tag_dict=tag_dict)

def add_request_id(self, request_id: SqlRequestId) -> SqlRequestTagSet:
"""Adds the request ID tag to this set."""
tag_dict = OrderedDict()
tag_dict[SqlRequestTagKey.REQUEST_ID_KEY.value] = request_id.id_str
return SqlRequestTagSet.combine((self, SqlRequestTagSet(tag_dict=tag_dict)))

@staticmethod
def combine(tag_sets: Sequence[SqlRequestTagSet]) -> SqlRequestTagSet: # noqa: D
tag_dict: OrderedDict[str, str] = OrderedDict()
for tag_set in tag_sets:
for key, value in tag_set.tag_dict.items():
if key in tag_dict and tag_dict[key] != value:
raise RuntimeError(
f"Can't combine tag sets due to a conflicting value for key: {key}. Conflicting values are "
f"at least: {value} and {tag_dict[key]}"
)
tag_dict[key] = value

return SqlRequestTagSet(tag_dict=tag_dict)

@property
def request_id(self) -> Optional[SqlRequestId]:
"""The value of the request ID tag."""
tag_value = self.tag_dict.get(SqlRequestTagKey.REQUEST_ID_KEY.value)
if tag_value:
return SqlRequestId(tag_value)
return None

def is_subset_of(self, tag_set: SqlRequestTagSet) -> bool: # noqa: D
return self.tag_dict.items() <= tag_set.tag_dict.items()


@dataclass(frozen=True)
class SqlRequestTag:
"""A key / value that can be used ot label requests to the SQL engine."""

key: str
value: str


class SqlRequestTagKey(Enum):
"""Specific tags used by the system."""

REQUEST_ID_KEY = "MF_REQUEST_ID"


MF_SYSTEM_TAGS_KEY = "MF_SYSTEM_TAGS"
MF_EXTRA_TAGS_KEY = "MF_EXTRA_TAGS"

# Helps to reduce the need too have "ignore type" everywhere.
JsonDict = Dict[str, Any] # type: ignore [misc]


class SqlJsonTag:
"""Immutable object that represents a JSON object to be used for tagging SQL requests."""

def __init__(self, json_dict: Optional[JsonDict] = None) -> None: # noqa: D
self._json_dict = OrderedDict(json_dict or {})

@property
def json_dict(self) -> JsonDict: # noqa: D
return OrderedDict(self._json_dict)

def combine(self, other_tag: SqlJsonTag) -> SqlJsonTag: # noqa: D
new_json_dict = OrderedDict(self._json_dict)
for k, v in other_tag._json_dict.items():
if k in new_json_dict:
logger.error(
f"Conflict while combining tags. Conflict key: {k} Conflicting values: {v} and {new_json_dict[k]}"
)
new_json_dict[k] = v
return SqlJsonTag(new_json_dict)

def __repr__(self) -> str: # noqa: D
return f"{self.__class__.__name__}(json_dict={self._json_dict})"
Loading

0 comments on commit 56ace64

Please sign in to comment.