Skip to content

Commit

Permalink
feat(ingest): utilities for query logs (datahub-project#10036)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Mar 13, 2024
1 parent 4535f2a commit b0163c4
Show file tree
Hide file tree
Showing 10 changed files with 583 additions and 291 deletions.
2 changes: 1 addition & 1 deletion metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
sqlglot_lib = {
# Using an Acryl fork of sqlglot.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:hsheth?expand=1
"acryl-sqlglot==22.3.1.dev3",
"acryl-sqlglot==22.4.1.dev4",
}

classification_lib = {
Expand Down
29 changes: 29 additions & 0 deletions metadata-ingestion/src/datahub/cli/check_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import dataclasses
import json
import logging
import pathlib
import pprint
import shutil
import tempfile
Expand All @@ -17,6 +20,7 @@
from datahub.ingestion.source.source_registry import source_registry
from datahub.ingestion.transformer.transform_registry import transform_registry
from datahub.telemetry import telemetry
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -339,3 +343,28 @@ def test_path_spec(config: str, input: str, path_spec_key: str) -> None:
f"Failed to validate pattern {pattern_dicts} in path {path_spec_key}"
)
raise e


@check.command()
@click.argument("query-log-file", type=click.Path(exists=True, dir_okay=False))
@click.option("--output", type=click.Path())
def extract_sql_agg_log(query_log_file: str, output: Optional[str]) -> None:
"""Convert a sqlite db generated by the SqlParsingAggregator into a JSON."""

from datahub.sql_parsing.sql_parsing_aggregator import LoggedQuery

assert dataclasses.is_dataclass(LoggedQuery)

shared_connection = ConnectionWrapper(pathlib.Path(query_log_file))
query_log = FileBackedList[LoggedQuery](
shared_connection=shared_connection, tablename="stored_queries"
)
logger.info(f"Extracting {len(query_log)} queries from {query_log_file}")
queries = [dataclasses.asdict(query) for query in query_log]

if output:
with open(output, "w") as f:
json.dump(queries, f, indent=2)
logger.info(f"Extracted {len(queries)} queries to {output}")
else:
click.echo(json.dumps(queries, indent=2))
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import contextlib
import dataclasses
import enum
import itertools
import json
import logging
import os
import pathlib
import tempfile
import uuid
Expand All @@ -15,6 +17,7 @@
from datahub.emitter.mce_builder import get_sys_time, make_ts_millis
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.sql_parsing_builder import compute_upstream_fields
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.report import Report
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.graph.client import DataHubGraph
Expand Down Expand Up @@ -53,16 +56,30 @@
QueryId = str
UrnStr = str

_DEFAULT_USER_URN = CorpUserUrn("_ingestion")
_MISSING_SESSION_ID = "__MISSING_SESSION_ID"


class QueryLogSetting(enum.Enum):
DISABLED = "DISABLED"
STORE_ALL = "STORE_ALL"
STORE_FAILED = "STORE_FAILED"


_DEFAULT_USER_URN = CorpUserUrn("_ingestion")
_MISSING_SESSION_ID = "__MISSING_SESSION_ID"
_DEFAULT_QUERY_LOG_SETTING = QueryLogSetting[
os.getenv("DATAHUB_SQL_AGG_QUERY_LOG") or QueryLogSetting.DISABLED.name
]


@dataclasses.dataclass
class LoggedQuery:
query: str
session_id: Optional[str]
timestamp: Optional[datetime]
user: Optional[UrnStr]
default_db: Optional[str]
default_schema: Optional[str]


@dataclasses.dataclass
class ViewDefinition:
view_definition: str
Expand Down Expand Up @@ -170,7 +187,7 @@ def compute_stats(self) -> None:
return super().compute_stats()


class SqlParsingAggregator:
class SqlParsingAggregator(Closeable):
def __init__(
self,
*,
Expand All @@ -185,7 +202,7 @@ def __init__(
usage_config: Optional[BaseUsageConfig] = None,
is_temp_table: Optional[Callable[[UrnStr], bool]] = None,
format_queries: bool = True,
query_log: QueryLogSetting = QueryLogSetting.DISABLED,
query_log: QueryLogSetting = _DEFAULT_QUERY_LOG_SETTING,
) -> None:
self.platform = DataPlatformUrn(platform)
self.platform_instance = platform_instance
Expand All @@ -210,13 +227,18 @@ def __init__(
self.format_queries = format_queries
self.query_log = query_log

# The exit stack helps ensure that we close all the resources we open.
self._exit_stack = contextlib.ExitStack()

# Set up the schema resolver.
self._schema_resolver: SchemaResolver
if graph is None:
self._schema_resolver = SchemaResolver(
platform=self.platform.platform_name,
platform_instance=self.platform_instance,
env=self.env,
self._schema_resolver = self._exit_stack.enter_context(
SchemaResolver(
platform=self.platform.platform_name,
platform_instance=self.platform_instance,
env=self.env,
)
)
else:
self._schema_resolver = None # type: ignore
Expand All @@ -235,44 +257,54 @@ def __init__(

# By providing a filename explicitly here, we also ensure that the file
# is not automatically deleted on exit.
self._shared_connection = ConnectionWrapper(filename=query_log_path)
self._shared_connection = self._exit_stack.enter_context(
ConnectionWrapper(filename=query_log_path)
)

# Stores the logged queries.
self._logged_queries = FileBackedList[str](
self._logged_queries = FileBackedList[LoggedQuery](
shared_connection=self._shared_connection, tablename="stored_queries"
)
self._exit_stack.push(self._logged_queries)

# Map of query_id -> QueryMetadata
self._query_map = FileBackedDict[QueryMetadata](
shared_connection=self._shared_connection, tablename="query_map"
)
self._exit_stack.push(self._query_map)

# Map of downstream urn -> { query ids }
self._lineage_map = FileBackedDict[OrderedSet[QueryId]](
shared_connection=self._shared_connection, tablename="lineage_map"
)
self._exit_stack.push(self._lineage_map)

# Map of view urn -> view definition
self._view_definitions = FileBackedDict[ViewDefinition](
shared_connection=self._shared_connection, tablename="view_definitions"
)
self._exit_stack.push(self._view_definitions)

# Map of session ID -> {temp table name -> query id}
# Needs to use the query_map to find the info about the query.
# This assumes that a temp table is created at most once per session.
self._temp_lineage_map = FileBackedDict[Dict[UrnStr, QueryId]](
shared_connection=self._shared_connection, tablename="temp_lineage_map"
)
self._exit_stack.push(self._temp_lineage_map)

# Map of query ID -> schema fields, only for query IDs that generate temp tables.
self._inferred_temp_schemas = FileBackedDict[List[models.SchemaFieldClass]](
shared_connection=self._shared_connection, tablename="inferred_temp_schemas"
shared_connection=self._shared_connection,
tablename="inferred_temp_schemas",
)
self._exit_stack.push(self._inferred_temp_schemas)

# Map of table renames, from original UrnStr to new UrnStr.
self._table_renames = FileBackedDict[UrnStr](
shared_connection=self._shared_connection, tablename="table_renames"
)
self._exit_stack.push(self._table_renames)

# Usage aggregator. This will only be initialized if usage statistics are enabled.
# TODO: Replace with FileBackedDict.
Expand All @@ -281,6 +313,9 @@ def __init__(
assert self.usage_config is not None
self._usage_aggregator = UsageAggregator(config=self.usage_config)

def close(self) -> None:
self._exit_stack.close()

@property
def _need_schemas(self) -> bool:
return self.generate_lineage or self.generate_usage_statistics
Expand Down Expand Up @@ -499,6 +534,9 @@ def add_observed_query(
default_db=default_db,
default_schema=default_schema,
schema_resolver=schema_resolver,
session_id=session_id,
timestamp=query_timestamp,
user=user,
)
if parsed.debug_info.error:
self.report.observed_query_parse_failures.append(
Expand Down Expand Up @@ -700,6 +738,9 @@ def _run_sql_parser(
default_db: Optional[str],
default_schema: Optional[str],
schema_resolver: SchemaResolverInterface,
session_id: str = _MISSING_SESSION_ID,
timestamp: Optional[datetime] = None,
user: Optional[CorpUserUrn] = None,
) -> SqlParsingResult:
parsed = sqlglot_lineage(
query,
Expand All @@ -712,7 +753,15 @@ def _run_sql_parser(
if self.query_log == QueryLogSetting.STORE_ALL or (
self.query_log == QueryLogSetting.STORE_FAILED and parsed.debug_info.error
):
self._logged_queries.append(query)
query_log_entry = LoggedQuery(
query=query,
session_id=session_id if session_id != _MISSING_SESSION_ID else None,
timestamp=timestamp,
user=user.urn() if user else None,
default_db=default_db,
default_schema=default_schema,
)
self._logged_queries.append(query_log_entry)

# Also add some extra logging.
if parsed.debug_info.error:
Expand Down
12 changes: 8 additions & 4 deletions metadata-ingestion/src/datahub/testing/compare_metadata_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,13 @@ def assert_metadata_files_equal(
# We have to "normalize" the golden file by reading and writing it back out.
# This will clean up nulls, double serialization, and other formatting issues.
with tempfile.NamedTemporaryFile() as temp:
golden_metadata = read_metadata_file(pathlib.Path(golden_path))
write_metadata_file(pathlib.Path(temp.name), golden_metadata)
golden = load_json_file(temp.name)
try:
golden_metadata = read_metadata_file(pathlib.Path(golden_path))
write_metadata_file(pathlib.Path(temp.name), golden_metadata)
golden = load_json_file(temp.name)
except (ValueError, AssertionError) as e:
logger.info(f"Error reformatting golden file as MCP/MCEs: {e}")
golden = load_json_file(golden_path)

diff = diff_metadata_json(output, golden, ignore_paths, ignore_order=ignore_order)
if diff and update_golden:
Expand Down Expand Up @@ -107,7 +111,7 @@ def diff_metadata_json(
# if ignore_order is False, always use DeepDiff
except CannotCompareMCPs as e:
logger.info(f"{e}, falling back to MCE diff")
except AssertionError as e:
except (AssertionError, ValueError) as e:
logger.warning(f"Reverting to old diff method: {e}")
logger.debug("Error with new diff method", exc_info=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def executemany(
def close(self) -> None:
for obj in self._dependent_objects:
obj.close()
self._dependent_objects.clear()
with self.conn_lock:
self.conn.close()
if self._temp_directory:
Expand Down Expand Up @@ -440,7 +441,7 @@ def __del__(self) -> None:
self.close()


class FileBackedList(Generic[_VT]):
class FileBackedList(Generic[_VT], Closeable):
"""An append-only, list-like object that stores its contents in a SQLite database."""

_len: int = field(default=0)
Expand All @@ -456,7 +457,6 @@ def __init__(
cache_max_size: Optional[int] = None,
cache_eviction_batch_size: Optional[int] = None,
) -> None:
self._len = 0
self._dict = FileBackedDict[_VT](
shared_connection=shared_connection,
tablename=tablename,
Expand All @@ -468,6 +468,12 @@ def __init__(
or _DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE,
)

if shared_connection:
shared_connection._dependent_objects.append(self)

# In case we're reusing an existing list, we need to run a query to get the length.
self._len = len(self._dict)

@property
def tablename(self) -> str:
return self._dict.tablename
Expand Down
Loading

0 comments on commit b0163c4

Please sign in to comment.