Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor validation decorators #1354

Merged
merged 3 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions querybook/server/lib/elasticsearch/search_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def get_column_name_suggestion(
return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True)


def get_table_name_suggestion(fuzzy_table_name: str) -> Tuple[Dict, int]:
def get_table_name_suggestion(
fuzzy_table_name: str, metastore_id: int
) -> Tuple[Dict, int]:
"""Given an invalid table name use fuzzy search to search the correctly-spelled table name"""

schema_name, fuzzy_name = None, fuzzy_table_name
Expand All @@ -229,7 +231,12 @@ def get_table_name_suggestion(fuzzy_table_name: str) -> Tuple[Dict, int]:
{
"match": {
"name": {"query": fuzzy_name, "fuzziness": "AUTO"},
}
},
},
{
"match": {
"metastore_id": metastore_id,
},
},
]
if schema_name:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Any, Dict, List, Tuple
from typing import List, Tuple
from sqlglot import Tokenizer
from sqlglot.tokens import Token

Expand All @@ -8,13 +8,12 @@
QueryValidationResultObjectType,
QueryValidationSeverity,
)
from lib.query_analysis.validation.base_query_validator import BaseQueryValidator

from lib.query_analysis.validation.decorators.base_validation_decorator import (
BaseValidationDecorator,
)

class BaseSQLGlotValidator(BaseQueryValidator):
def __init__(self, name: str = "", config: Dict[str, Any] = {}):
super(BaseSQLGlotValidator, self).__init__(name, config)

class BaseSQLGlotValidationDecorator(BaseValidationDecorator):
@property
@abstractmethod
def message(self) -> str:
Expand Down Expand Up @@ -65,7 +64,6 @@ def _get_query_validation_result(
suggestion=suggestion,
)

@abstractmethod
def validate(
self,
query: str,
Expand All @@ -74,20 +72,8 @@ def validate(
raw_tokens: List[Token] = None,
**kwargs,
) -> List[QueryValidationResult]:
raise NotImplementedError()


class BaseSQLGlotDecorator(BaseSQLGlotValidator):
def __init__(self, validator: BaseQueryValidator):
self._validator = validator

def validate(
self,
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
**kwargs,
):
"""Override this method to add suggestions to validation results"""
return self._validator.validate(query, uid, engine_id, **kwargs)
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
return super(BaseSQLGlotValidationDecorator, self).validate(
query, uid, engine_id, raw_tokens=raw_tokens, **kwargs
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from abc import ABCMeta, abstractmethod
from typing import List

from lib.query_analysis.validation.base_query_validator import (
QueryValidationResult,
)
from lib.query_analysis.validation.base_query_validator import BaseQueryValidator


class BaseValidationDecorator(metaclass=ABCMeta):
def __init__(self, validator: BaseQueryValidator):
self._validator = validator

@abstractmethod
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
**kwargs,
) -> List[QueryValidationResult]:
raise NotImplementedError()

def validate(
self,
query: str,
uid: int,
engine_id: int,
**kwargs,
) -> List[QueryValidationResult]:
validation_results = self._validator.validate(query, uid, engine_id, **kwargs)
return self.decorate_validation_results(
validation_results, query, uid, engine_id, **kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,14 @@
from lib.query_analysis.lineage import process_query
from lib.query_analysis.validation.base_query_validator import (
QueryValidationResult,
QueryValidationSeverity,
)
from lib.query_analysis.validation.validators.base_sqlglot_validator import (
BaseSQLGlotDecorator,
from lib.query_analysis.validation.decorators.base_sqlglot_validation_decorator import (
BaseValidationDecorator,
)
from logic.admin import get_query_engine_by_id
from logic import admin as admin_logic


class BaseColumnNameSuggester(BaseSQLGlotDecorator):
@property
def severity(self):
return QueryValidationSeverity.WARNING # Unused, severity is not changed

@property
def message(self):
return "" # Unused, message is not changed

class BaseColumnNameSuggester(BaseValidationDecorator):
@abstractmethod
def get_column_name_from_error(
self, validation_result: QueryValidationResult
Expand All @@ -32,7 +23,7 @@ def get_column_name_from_error(
raise NotImplementedError()

def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]:
engine = get_query_engine_by_id(engine_id)
engine = admin_logic.get_query_engine_by_id(engine_id)
tables_per_statement, _ = process_query(query, language=engine.language)
return list(chain.from_iterable(tables_per_statement))

Expand Down Expand Up @@ -69,49 +60,43 @@ def _suggest_column_name_if_needed(
validation_result.start_ch + len(fuzzy_column_name) - 1
)

def validate(
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
raw_tokens: List[QueryValidationResult] = None,
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
tables_in_query = self._get_tables_in_query(query, engine_id)
for result in validation_results:
self._suggest_column_name_if_needed(result, tables_in_query)
return validation_results


class BaseTableNameSuggester(BaseSQLGlotDecorator):
@property
def severity(self):
return QueryValidationSeverity.WARNING # Unused, severity is not changed

@property
def message(self):
return "" # Unused, message is not changed

class BaseTableNameSuggester(BaseValidationDecorator):
@abstractmethod
def get_full_table_name_from_error(self, validation_result: QueryValidationResult):
"""Returns invalid table name if the validation result is a table name error, otherwise
returns None"""
raise NotImplementedError()

def _suggest_table_name_if_needed(
self, validation_result: QueryValidationResult
self,
validation_result: QueryValidationResult,
engine_id: int,
) -> Optional[str]:
"""Takes validation result and tables in query to update validation result to provide table
name suggestion"""
fuzzy_table_name = self.get_full_table_name_from_error(validation_result)
if not fuzzy_table_name:
return
results, count = search_table.get_table_name_suggestion(fuzzy_table_name)
metastore_id = admin_logic.get_query_metastore_id_by_engine_id(engine_id)
if metastore_id is None:
return
results, count = search_table.get_table_name_suggestion(
kgopal492 marked this conversation as resolved.
Show resolved Hide resolved
fuzzy_table_name, metastore_id
)
if count > 0:
table_result = results[0] # Get top match
table_suggestion = f"{table_result['schema']}.{table_result['name']}"
Expand All @@ -121,19 +106,14 @@ def _suggest_table_name_if_needed(
validation_result.start_ch + len(fuzzy_table_name) - 1
)

def validate(
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
raw_tokens: List[QueryValidationResult] = None,
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
for result in validation_results:
self._suggest_table_name_if_needed(result)
self._suggest_table_name_if_needed(result, engine_id)
return validation_results
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,16 @@
from lib.query_analysis.validation.validators.presto_explain_validator import (
PrestoExplainValidator,
)
from lib.query_analysis.validation.validators.base_sqlglot_validator import (
BaseSQLGlotDecorator,
from lib.query_analysis.validation.decorators.base_sqlglot_validation_decorator import (
BaseSQLGlotValidationDecorator,
)
from lib.query_analysis.validation.validators.metadata_suggesters import (
from lib.query_analysis.validation.decorators.metadata_decorators import (
BaseColumnNameSuggester,
BaseTableNameSuggester,
)


class BasePrestoSQLGlotDecorator(BaseSQLGlotDecorator):
def languages(self):
return ["presto", "trino"]

class BasePrestoSQLGlotDecorator(BaseSQLGlotValidationDecorator):
@property
def tokenizer(self) -> Tokenizer:
return Trino.Tokenizer()
Expand All @@ -39,19 +36,15 @@ def message(self):
def severity(self) -> str:
return QueryValidationSeverity.WARNING

def validate(
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
raw_tokens: List[Token] = [],
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
for i, token in enumerate(raw_tokens):
if token.token_type == TokenType.UNION:
if (
Expand All @@ -77,20 +70,15 @@ def message(self):
def severity(self) -> str:
return QueryValidationSeverity.WARNING

def validate(
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
raw_tokens: List[Token] = [],
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)

validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
for i, token in enumerate(raw_tokens):
if (
i < len(raw_tokens) - 2
Expand Down Expand Up @@ -125,21 +113,15 @@ def _get_regexp_like_suggestion(self, column_name: str, like_strings: List[str])
]
return f"REGEXP_LIKE({column_name}, '{'|'.join(sanitized_like_strings)}')"

def validate(
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
raw_tokens: List[Token] = [],
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)

validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)

start_column_token = None
like_strings = []
token_idx = 0
Expand Down Expand Up @@ -203,15 +185,15 @@ def validate(
return validation_results


class PrestoColumnNameSuggester(BasePrestoSQLGlotDecorator, BaseColumnNameSuggester):
class PrestoColumnNameSuggester(BaseColumnNameSuggester):
def get_column_name_from_error(self, validation_result: QueryValidationResult):
regex_result = re.match(
r"line \d+:\d+: Column '(.*)' cannot be resolved", validation_result.message
)
return regex_result.groups()[0] if regex_result else None


class PrestoTableNameSuggester(BasePrestoSQLGlotDecorator, BaseTableNameSuggester):
class PrestoTableNameSuggester(BaseTableNameSuggester):
def get_full_table_name_from_error(self, validation_result: QueryValidationResult):
regex_result = re.match(
r"line \d+:\d+: Table '(.*)' does not exist", validation_result.message
Expand Down
6 changes: 6 additions & 0 deletions querybook/server/logic/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ def get_query_metastore_by_name(name, session=None):
return session.query(QueryMetastore).filter(QueryMetastore.name == name).first()


@with_session
def get_query_metastore_id_by_engine_id(engine_id: int, session=None):
query_engine = get_query_engine_by_id(engine_id, session=session)
return query_engine.metastore_id if query_engine else None


@with_session
def get_all_query_metastore(session=None):
return session.query(QueryMetastore).all()
Expand Down
Loading
Loading