diff --git a/querybook/server/lib/elasticsearch/search_table.py b/querybook/server/lib/elasticsearch/search_table.py index ad1d73160..55ac8ee66 100644 --- a/querybook/server/lib/elasticsearch/search_table.py +++ b/querybook/server/lib/elasticsearch/search_table.py @@ -1,9 +1,14 @@ +from typing import Dict, List, Tuple from lib.elasticsearch.query_utils import ( match_filters, highlight_fields, order_by_fields, combine_keyword_and_filter_query, ) +from lib.elasticsearch.search_utils import ( + ES_CONFIG, + get_matching_objects, +) FILTERS_TO_AND = ["tags", "data_elements"] @@ -173,3 +178,65 @@ def construct_tables_query_by_table_names( } return query + + +def get_column_name_suggestion( + fuzzy_column_name: str, full_table_names: List[str] +) -> Tuple[Dict, int]: + """Given an invalid column name and a list of tables to search from, uses fuzzy search to search + the correctly-spelled column name""" + should_clause = [] + for full_table_name in full_table_names: + schema_name, table_name = full_table_name.split(".") + should_clause.append( + { + "bool": { + "must": [ + {"match": {"name": table_name}}, + {"match": {"schema": schema_name}}, + ] + } + } + ) + + search_query = { + "query": { + "bool": { + "must": { + "match": { + "columns": {"query": fuzzy_column_name, "fuzziness": "AUTO"} + } + }, + "should": should_clause, + "minimum_should_match": 1, + }, + }, + "highlight": {"pre_tags": [""], "post_tags": [""], "fields": {"columns": {}}}, + } + + return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True) + + +def get_table_name_suggestion(fuzzy_table_name: str) -> Tuple[Dict, int]: + """Given an invalid table name use fuzzy search to search the correctly-spelled table name""" + + schema_name, fuzzy_name = None, fuzzy_table_name + fuzzy_table_name_parts = fuzzy_table_name.split(".") + if len(fuzzy_table_name_parts) == 2: + schema_name, fuzzy_name = fuzzy_table_name_parts + + must_clause = [ + { + "match": { + "name": {"query": fuzzy_name, "fuzziness": "AUTO"}, + } + }, + ] + if schema_name: + must_clause.append({"match": {"schema": schema_name}}) + + search_query = { + "query": {"bool": {"must": must_clause}}, + } + + return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True) diff --git a/querybook/server/lib/query_analysis/validation/base_query_validator.py b/querybook/server/lib/query_analysis/validation/base_query_validator.py index d2ed3a3d2..6014c89db 100644 --- a/querybook/server/lib/query_analysis/validation/base_query_validator.py +++ b/querybook/server/lib/query_analysis/validation/base_query_validator.py @@ -67,6 +67,7 @@ def validate( query: str, uid: int, # who is doing the syntax check engine_id: int, # which engine they are checking against + **kwargs, ) -> List[QueryValidationResult]: raise NotImplementedError() diff --git a/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py b/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py index 067f1eef0..16dab444d 100644 --- a/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/base_sqlglot_validator.py @@ -1,5 +1,5 @@ -from abc import ABCMeta, abstractmethod -from typing import List, Tuple +from abc import abstractmethod +from typing import Any, Dict, List, Tuple from sqlglot import Tokenizer from sqlglot.tokens import Token @@ -8,9 +8,13 @@ QueryValidationResultObjectType, QueryValidationSeverity, ) +from lib.query_analysis.validation.base_query_validator import BaseQueryValidator -class BaseSQLGlotValidator(metaclass=ABCMeta): +class BaseSQLGlotValidator(BaseQueryValidator): + def __init__(self, name: str = "", config: Dict[str, Any] = {}): + super(BaseSQLGlotValidator, self).__init__(name, config) + @property @abstractmethod def message(self) -> str: @@ -33,6 +37,12 @@ def _get_query_coordinate_by_index(self, query: str, index: int) -> Tuple[int, i rows = query[: index + 1].splitlines(keepends=False) return len(rows) - 1, len(rows[-1]) - 1 + def _get_query_index_by_coordinate( + self, query: str, start_line: int, start_ch: int + ) -> int: + rows = query.splitlines(keepends=True)[:start_line] + return sum([len(row) for row in rows]) + start_ch + def _get_query_validation_result( self, query: str, @@ -56,7 +66,28 @@ def _get_query_validation_result( ) @abstractmethod - def get_query_validation_results( - self, query: str, raw_tokens: List[Token] = None + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[Token] = None, + **kwargs, ) -> List[QueryValidationResult]: raise NotImplementedError() + + +class BaseSQLGlotDecorator(BaseSQLGlotValidator): + def __init__(self, validator: BaseQueryValidator): + self._validator = validator + + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[Token] = None, + **kwargs, + ): + """Override this method to add suggestions to validation results""" + return self._validator.validate(query, uid, engine_id, **kwargs) diff --git a/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py b/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py new file mode 100644 index 000000000..726e2ce9d --- /dev/null +++ b/querybook/server/lib/query_analysis/validation/validators/metadata_suggesters.py @@ -0,0 +1,139 @@ +from abc import abstractmethod +from itertools import chain +from typing import List, Optional + +from lib.elasticsearch import search_table +from lib.query_analysis.lineage import process_query +from lib.query_analysis.validation.base_query_validator import ( + QueryValidationResult, + QueryValidationSeverity, +) +from lib.query_analysis.validation.validators.base_sqlglot_validator import ( + BaseSQLGlotDecorator, +) +from logic.admin import get_query_engine_by_id + + +class BaseColumnNameSuggester(BaseSQLGlotDecorator): + @property + def severity(self): + return QueryValidationSeverity.WARNING # Unused, severity is not changed + + @property + def message(self): + return "" # Unused, message is not changed + + @abstractmethod + def get_column_name_from_error( + self, validation_result: QueryValidationResult + ) -> Optional[str]: + """Returns invalid column name if the validation result is a column name error, otherwise + returns None""" + raise NotImplementedError() + + def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]: + engine = get_query_engine_by_id(engine_id) + tables_per_statement, _ = process_query(query, language=engine.language) + return list(chain.from_iterable(tables_per_statement)) + + def _search_columns_for_suggestion(self, columns: List[str], suggestion: str): + """Return the case-sensitive column name by searching the table's columns for the suggestion text""" + for col in columns: + if col.lower() == suggestion.lower(): + return col + return suggestion + + def _suggest_column_name_if_needed( + self, + validation_result: QueryValidationResult, + tables_in_query: List[str], + ): + """Takes validation result and tables in query to update validation result to provide column + name suggestion""" + fuzzy_column_name = self.get_column_name_from_error(validation_result) + if not fuzzy_column_name: + return + results, count = search_table.get_column_name_suggestion( + fuzzy_column_name, tables_in_query + ) + if count == 1: # Only suggest column if there's a single match + table_result = results[0] + highlights = table_result.get("highlight", {}).get("columns", []) + if len(highlights) == 1: + column_suggestion = self._search_columns_for_suggestion( + table_result.get("columns"), highlights[0] + ) + validation_result.suggestion = column_suggestion + validation_result.end_line = validation_result.start_line + validation_result.end_ch = ( + validation_result.start_ch + len(fuzzy_column_name) - 1 + ) + + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[QueryValidationResult] = None, + **kwargs, + ) -> List[QueryValidationResult]: + if raw_tokens is None: + raw_tokens = self._tokenize_query(query) + validation_results = self._validator.validate( + query, uid, engine_id, raw_tokens=raw_tokens + ) + tables_in_query = self._get_tables_in_query(query, engine_id) + for result in validation_results: + self._suggest_column_name_if_needed(result, tables_in_query) + return validation_results + + +class BaseTableNameSuggester(BaseSQLGlotDecorator): + @property + def severity(self): + return QueryValidationSeverity.WARNING # Unused, severity is not changed + + @property + def message(self): + return "" # Unused, message is not changed + + @abstractmethod + def get_full_table_name_from_error(self, validation_result: QueryValidationResult): + """Returns invalid table name if the validation result is a table name error, otherwise + returns None""" + raise NotImplementedError() + + def _suggest_table_name_if_needed( + self, validation_result: QueryValidationResult + ) -> Optional[str]: + """Takes validation result and tables in query to update validation result to provide table + name suggestion""" + fuzzy_table_name = self.get_full_table_name_from_error(validation_result) + if not fuzzy_table_name: + return + results, count = search_table.get_table_name_suggestion(fuzzy_table_name) + if count > 0: + table_result = results[0] # Get top match + table_suggestion = f"{table_result['schema']}.{table_result['name']}" + validation_result.suggestion = table_suggestion + validation_result.end_line = validation_result.start_line + validation_result.end_ch = ( + validation_result.start_ch + len(fuzzy_table_name) - 1 + ) + + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[QueryValidationResult] = None, + **kwargs, + ) -> List[QueryValidationResult]: + if raw_tokens is None: + raw_tokens = self._tokenize_query(query) + validation_results = self._validator.validate( + query, uid, engine_id, raw_tokens=raw_tokens + ) + for result in validation_results: + self._suggest_table_name_if_needed(result) + return validation_results diff --git a/querybook/server/lib/query_analysis/validation/validators/presto_explain_validator.py b/querybook/server/lib/query_analysis/validation/validators/presto_explain_validator.py index 3a0736b9f..a307921d5 100644 --- a/querybook/server/lib/query_analysis/validation/validators/presto_explain_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/presto_explain_validator.py @@ -72,6 +72,7 @@ def validate( query: str, uid: int, # who is doing the syntax check engine_id: int, # which engine they are checking against + **kwargs, ) -> List[QueryValidationResult]: validation_errors = [] ( diff --git a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py index 5d726a4ff..6cf87de45 100644 --- a/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py +++ b/querybook/server/lib/query_analysis/validation/validators/presto_optimizing_validator.py @@ -1,7 +1,8 @@ -from typing import List +import re from sqlglot import TokenType, Tokenizer from sqlglot.dialects import Trino from sqlglot.tokens import Token +from typing import List from lib.query_analysis.validation.base_query_validator import ( BaseQueryValidator, @@ -12,17 +13,24 @@ PrestoExplainValidator, ) from lib.query_analysis.validation.validators.base_sqlglot_validator import ( - BaseSQLGlotValidator, + BaseSQLGlotDecorator, +) +from lib.query_analysis.validation.validators.metadata_suggesters import ( + BaseColumnNameSuggester, + BaseTableNameSuggester, ) -class BasePrestoSQLGlotValidator(BaseSQLGlotValidator): +class BasePrestoSQLGlotDecorator(BaseSQLGlotDecorator): + def languages(self): + return ["presto", "trino"] + @property def tokenizer(self) -> Tokenizer: return Trino.Tokenizer() -class UnionAllValidator(BasePrestoSQLGlotValidator): +class UnionAllValidator(BasePrestoSQLGlotDecorator): @property def message(self): return "Using UNION ALL instead of UNION will execute faster" @@ -31,27 +39,34 @@ def message(self): def severity(self) -> str: return QueryValidationSeverity.WARNING - def get_query_validation_results( - self, query: str, raw_tokens: List[Token] = None + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[Token] = None, + **kwargs, ) -> List[QueryValidationResult]: if raw_tokens is None: raw_tokens = self._tokenize_query(query) - validation_errors = [] + validation_results = self._validator.validate( + query, uid, engine_id, raw_tokens=raw_tokens + ) for i, token in enumerate(raw_tokens): if token.token_type == TokenType.UNION: if ( i < len(raw_tokens) - 1 and raw_tokens[i + 1].token_type != TokenType.ALL ): - validation_errors.append( + validation_results.append( self._get_query_validation_result( query, token.start, token.end, "UNION ALL" ) ) - return validation_errors + return validation_results -class ApproxDistinctValidator(BasePrestoSQLGlotValidator): +class ApproxDistinctValidator(BasePrestoSQLGlotDecorator): @property def message(self): return ( @@ -62,13 +77,20 @@ def message(self): def severity(self) -> str: return QueryValidationSeverity.WARNING - def get_query_validation_results( - self, query: str, raw_tokens: List[Token] = None + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[Token] = None, + **kwargs, ) -> List[QueryValidationResult]: if raw_tokens is None: raw_tokens = self._tokenize_query(query) - validation_errors = [] + validation_results = self._validator.validate( + query, uid, engine_id, raw_tokens=raw_tokens + ) for i, token in enumerate(raw_tokens): if ( i < len(raw_tokens) - 2 @@ -77,7 +99,7 @@ def get_query_validation_results( and raw_tokens[i + 1].token_type == TokenType.L_PAREN and raw_tokens[i + 2].token_type == TokenType.DISTINCT ): - validation_errors.append( + validation_results.append( self._get_query_validation_result( query, token.start, @@ -85,10 +107,10 @@ def get_query_validation_results( "APPROX_DISTINCT(", ) ) - return validation_errors + return validation_results -class RegexpLikeValidator(BasePrestoSQLGlotValidator): +class RegexpLikeValidator(BasePrestoSQLGlotDecorator): @property def message(self): return "Combining multiple LIKEs into one REGEXP_LIKE will execute faster" @@ -103,13 +125,20 @@ def _get_regexp_like_suggestion(self, column_name: str, like_strings: List[str]) ] return f"REGEXP_LIKE({column_name}, '{'|'.join(sanitized_like_strings)}')" - def get_query_validation_results( - self, query: str, raw_tokens: List[Token] = None + def validate( + self, + query: str, + uid: int, + engine_id: int, + raw_tokens: List[Token] = None, + **kwargs, ) -> List[QueryValidationResult]: if raw_tokens is None: raw_tokens = self._tokenize_query(query) - validation_errors = [] + validation_results = self._validator.validate( + query, uid, engine_id, raw_tokens=raw_tokens + ) start_column_token = None like_strings = [] @@ -139,7 +168,7 @@ def get_query_validation_results( ): # No "OR" token following the phrase, so we cannot combine additional phrases # Check if there are multiple phrases that can be combined if len(like_strings) > 1: - validation_errors.append( + validation_results.append( self._get_query_validation_result( query, start_column_token.start, @@ -157,7 +186,7 @@ def get_query_validation_results( if ( len(like_strings) > 1 ): # Check if a validation suggestion can be created - validation_errors.append( + validation_results.append( self._get_query_validation_result( query, start_column_token.start, @@ -171,53 +200,49 @@ def get_query_validation_results( like_strings = [] token_idx += 1 - return validation_errors + return validation_results + + +class PrestoColumnNameSuggester(BasePrestoSQLGlotDecorator, BaseColumnNameSuggester): + def get_column_name_from_error(self, validation_result: QueryValidationResult): + regex_result = re.match( + r"line \d+:\d+: Column '(.*)' cannot be resolved", validation_result.message + ) + return regex_result.groups()[0] if regex_result else None + + +class PrestoTableNameSuggester(BasePrestoSQLGlotDecorator, BaseTableNameSuggester): + def get_full_table_name_from_error(self, validation_result: QueryValidationResult): + regex_result = re.match( + r"line \d+:\d+: Table '(.*)' does not exist", validation_result.message + ) + return regex_result.groups()[0] if regex_result else None class PrestoOptimizingValidator(BaseQueryValidator): def languages(self): return ["presto", "trino"] + @property + def tokenizer(self) -> Tokenizer: + return Trino.Tokenizer() + def _get_explain_validator(self): return PrestoExplainValidator("") - def _get_sqlglot_validators(self) -> List[BaseSQLGlotValidator]: - return [ - UnionAllValidator(), - ApproxDistinctValidator(), - RegexpLikeValidator(), - ] - - def _get_sql_glot_validation_results( - self, query: str - ) -> List[QueryValidationResult]: - validation_suggestions = [] - - query_raw_tokens = None - for validator in self._get_sqlglot_validators(): - if query_raw_tokens is None: - query_raw_tokens = validator._tokenize_query(query) - validation_suggestions.extend( - validator.get_query_validation_results( - query, raw_tokens=query_raw_tokens + def _get_decorated_validator(self) -> BaseQueryValidator: + return UnionAllValidator( + ApproxDistinctValidator( + RegexpLikeValidator( + PrestoTableNameSuggester( + PrestoColumnNameSuggester(self._get_explain_validator()) + ) ) ) - - return validation_suggestions - - def _get_presto_explain_validation_results( - self, query: str, uid: int, engine_id: int - ) -> List[QueryValidationResult]: - return self._get_explain_validator().validate(query, uid, engine_id) + ) def validate( - self, - query: str, - uid: int, - engine_id: int, + self, query: str, uid: int, engine_id: int, **kwargs ) -> List[QueryValidationResult]: - validation_results = [ - *self._get_presto_explain_validation_results(query, uid, engine_id), - *self._get_sql_glot_validation_results(query), - ] - return validation_results + validator = self._get_decorated_validator() + return validator.validate(query, uid, engine_id) diff --git a/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py b/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py index 0abb7aa51..93a98bbb6 100644 --- a/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py +++ b/querybook/tests/test_lib/test_query_analysis/test_validation/test_validators/test_presto_optimizing_validator.py @@ -1,5 +1,6 @@ from typing import List from unittest import TestCase +from unittest.mock import patch, MagicMock from lib.query_analysis.validation.base_query_validator import ( QueryValidationResult, @@ -8,6 +9,8 @@ ) from lib.query_analysis.validation.validators.presto_optimizing_validator import ( ApproxDistinctValidator, + PrestoColumnNameSuggester, + PrestoTableNameSuggester, RegexpLikeValidator, UnionAllValidator, PrestoOptimizingValidator, @@ -15,6 +18,11 @@ class BaseValidatorTestCase(TestCase): + def _get_explain_validator_mock(self): + explain_validator_mock = MagicMock() + explain_validator_mock.validate.return_value = [] + return explain_validator_mock + def _verify_query_validation_results( self, validation_results: List[QueryValidationResult], @@ -75,12 +83,12 @@ def _get_approx_distinct_validation_result( class UnionAllValidatorTestCase(BaseValidatorTestCase): def setUp(self): - self._validator = UnionAllValidator() + self._validator = UnionAllValidator(self._get_explain_validator_mock()) def test_basic_union(self): query = "SELECT * FROM a \nUNION SELECT * FROM b" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_union_all_validation_result( 1, @@ -94,7 +102,7 @@ def test_basic_union(self): def test_multiple_unions(self): query = "SELECT * FROM a \nUNION SELECT * FROM b \nUNION SELECT * FROM c" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_union_all_validation_result( 1, @@ -113,26 +121,24 @@ def test_multiple_unions(self): def test_union_all(self): query = "SELECT * FROM a UNION ALL SELECT * FROM b" - self._verify_query_validation_results( - self._validator.get_query_validation_results(query), [] - ) + self._verify_query_validation_results(self._validator.validate(query, 0, 0), []) class ApproxDistinctValidatorTestCase(BaseValidatorTestCase): def setUp(self): - self._validator = ApproxDistinctValidator() + self._validator = ApproxDistinctValidator(self._get_explain_validator_mock()) def test_basic_count_distinct(self): query = "SELECT COUNT(DISTINCT x) FROM a" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [self._get_approx_distinct_validation_result(0, 7, 0, 20)], ) def test_count_not_followed_by_distinct(self): query = "SELECT \nCOUNT * FROM a" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [], ) @@ -141,7 +147,7 @@ def test_multiple_count_distincts(self): "SELECT \nCOUNT(DISTINCT y) FROM a UNION SELECT \nCOUNT(DISTINCT x) FROM b" ) self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_approx_distinct_validation_result(1, 0, 1, 13), self._get_approx_distinct_validation_result(2, 0, 2, 13), @@ -153,7 +159,7 @@ def test_count_distinct_in_where_clause(self): "SELECT \nCOUNT(DISTINCT a), b FROM table_a WHERE \nCOUNT(DISTINCT a) > 10" ) self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_approx_distinct_validation_result(1, 0, 1, 13), self._get_approx_distinct_validation_result(2, 0, 2, 13), @@ -163,12 +169,12 @@ def test_count_distinct_in_where_clause(self): class RegexpLikeValidatorTestCase(BaseValidatorTestCase): def setUp(self): - self._validator = RegexpLikeValidator() + self._validator = RegexpLikeValidator(self._get_explain_validator_mock()) def test_basic_combine_case(self): query = "SELECT * from a WHERE \nx LIKE 'foo' OR x LIKE \n'bar'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_regexp_like_validation_result( 1, 0, 2, 4, "REGEXP_LIKE(x, 'foo|bar')" @@ -179,14 +185,14 @@ def test_basic_combine_case(self): def test_and_clause(self): query = "SELECT * from a WHERE \nx LIKE 'foo%' AND x LIKE \n'%bar'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [], ) def test_more_than_two_phrases(self): query = "SELECT * from a WHERE \nx LIKE 'foo' OR x LIKE 'bar' OR x LIKE \n'baz'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_regexp_like_validation_result( 1, 0, 2, 4, "REGEXP_LIKE(x, 'foo|bar|baz')" @@ -197,7 +203,7 @@ def test_more_than_two_phrases(self): def test_different_column_names(self): query = "SELECT * from a WHERE \nx LIKE 'foo' OR y LIKE 'bar'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [], ) @@ -206,7 +212,7 @@ def test_both_or_and(self): "SELECT * from a WHERE \nx LIKE 'foo' OR x LIKE \n'bar' AND y LIKE 'foo'" ) self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_regexp_like_validation_result( 1, 0, 2, 4, "REGEXP_LIKE(x, 'foo|bar')" @@ -217,7 +223,7 @@ def test_both_or_and(self): def test_multiple_suggestions(self): query = "SELECT * from a WHERE \nx LIKE 'foo' OR x LIKE \n'bar' AND \ny LIKE 'foo' OR y LIKE \n'bar'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [ self._get_regexp_like_validation_result( 1, 0, 2, 4, "REGEXP_LIKE(x, 'foo|bar')" @@ -231,72 +237,288 @@ def test_multiple_suggestions(self): def test_phrase_not_match(self): query = "SELECT * from a WHERE x LIKE 'foo' OR x = 'bar'" self._verify_query_validation_results( - self._validator.get_query_validation_results(query), + self._validator.validate(query, 0, 0), [], ) +class PrestoColumnNameSuggesterTestCase(BaseValidatorTestCase): + def setUp(self): + self._validator = PrestoColumnNameSuggester(MagicMock()) + + def test_get_column_name_from_error(self): + self.assertEqual( + self._validator.get_column_name_from_error( + QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: Column 'happyness' cannot be resolved", + ) + ), + "happyness", + ) + self.assertEqual( + self._validator.get_column_name_from_error( + QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: Table 'world_happiness_rank' does not exist", + ) + ), + None, + ) + + def test_search_columns_for_suggestion(self): + self.assertEqual( + self._validator._search_columns_for_suggestion( + ["HappinessRank", "Country", "Region"], "country" + ), + "Country", + ) + self.assertEqual( + self._validator._search_columns_for_suggestion( + ["HappinessRank, Region"], "country" + ), + "country", + ) + + def _get_new_validation_result_obj(self): + return QueryValidationResult( + 0, + 7, + QueryValidationSeverity.WARNING, + "line 0:1: Column 'happynessrank' cannot be resolved", + ) + + @patch( + "lib.elasticsearch.search_table.get_column_name_suggestion", + ) + def test__get_column_name_suggestion(self, mock_get_column_name_suggestion): + # Test too many tables matched + validation_result = self._get_new_validation_result_obj() + mock_get_column_name_suggestion.return_value = [ + [ + { + "columns": ["HappinessRank"], + "highlight": {"columns": ["happinessrank"]}, + }, + { + "columns": ["HappinessRank"], + "highlight": {"columns": ["happinessrank1"]}, + }, + ], + 2, + ] + self._validator._suggest_column_name_if_needed( + validation_result, + ["main.world_happiness_report"], + ) + self.assertEqual(validation_result.suggestion, None) + + # Test too many columns in a table matched + validation_result = self._get_new_validation_result_obj() + mock_get_column_name_suggestion.return_value = [ + [ + { + "columns": ["HappinessRank", "HappinessRank1"], + "highlight": {"columns": ["happinessrank", "happinessrank1"]}, + }, + ], + 1, + ] + self._validator._suggest_column_name_if_needed( + validation_result, + ["main.world_happiness_report"], + ), + self.assertEqual( + validation_result.suggestion, + None, + ) + + # Test single column matched + validation_result = self._get_new_validation_result_obj() + mock_get_column_name_suggestion.return_value = [ + [ + { + "columns": ["HappinessRank", "HappinessRank1"], + "highlight": {"columns": ["happinessrank"]}, + }, + ], + 1, + ] + self._validator._suggest_column_name_if_needed( + validation_result, + ["main.world_happiness_report"], + ), + self.assertEqual(validation_result.suggestion, "HappinessRank") + + # Test no search results + validation_result = self._get_new_validation_result_obj() + mock_get_column_name_suggestion.return_value = [ + [], + 0, + ] + self._validator._suggest_column_name_if_needed( + validation_result, + ["main.world_happiness_report"], + ), + self.assertEqual( + validation_result.suggestion, + None, + ) + + +class PrestoTableNameSuggesterTestCase(BaseValidatorTestCase): + def setUp(self): + self._validator = PrestoTableNameSuggester(MagicMock()) + + def test_get_full_table_name_from_error(self): + self.assertEquals( + self._validator.get_full_table_name_from_error( + QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: Table 'world_happiness_15' does not exist", + ) + ), + "world_happiness_15", + ) + self.assertEquals( + self._validator.get_full_table_name_from_error( + QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: column 'happiness_rank' cannot be resolved", + ) + ), + None, + ) + + @patch( + "lib.elasticsearch.search_table.get_table_name_suggestion", + ) + def test__suggest_table_name_if_needed_single_hit(self, mock_table_suggestion): + validation_result = QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: Table 'world_happiness_15' does not exist", + ) + mock_table_suggestion.return_value = [ + {"schema": "main", "name": "world_happiness_rank_2015"} + ], 1 + self._validator._suggest_table_name_if_needed(validation_result) + self.assertEquals( + validation_result.suggestion, "main.world_happiness_rank_2015" + ) + + @patch( + "lib.elasticsearch.search_table.get_table_name_suggestion", + ) + def test__suggest_table_name_if_needed_multiple_hits(self, mock_table_suggestion): + validation_result = QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: Table 'world_happiness_15' does not exist", + ) + mock_table_suggestion.return_value = [ + {"schema": "main", "name": "world_happiness_rank_2015"}, + {"schema": "main", "name": "world_happiness_rank_2016"}, + ], 2 + self._validator._suggest_table_name_if_needed(validation_result) + self.assertEquals( + validation_result.suggestion, "main.world_happiness_rank_2015" + ) + + @patch( + "lib.elasticsearch.search_table.get_table_name_suggestion", + ) + def test__suggest_table_name_if_needed_no_hits(self, mock_table_suggestion): + validation_result = QueryValidationResult( + 0, + 0, + QueryValidationSeverity.WARNING, + "line 0:1: Table 'world_happiness_15' does not exist", + ) + mock_table_suggestion.return_value = [], 0 + self._validator._suggest_table_name_if_needed(validation_result) + self.assertEquals(validation_result.suggestion, None) + + class PrestoOptimizingValidatorTestCase(BaseValidatorTestCase): def setUp(self): + super(PrestoOptimizingValidatorTestCase, self).setUp() + patch_validator = patch.object( + PrestoColumnNameSuggester, + "validate", + return_value=[], + ) + patch_validator.start() + self.addCleanup(patch_validator.stop) self._validator = PrestoOptimizingValidator("") def test_union_and_count_distinct(self): query = "SELECT \nCOUNT( DISTINCT x) from a \nUNION select \ncount(distinct y) from b" self._verify_query_validation_results( - self._validator._get_sql_glot_validation_results(query), + self._validator.validate(query, 0, 0), [ - self._get_union_all_validation_result(2, 0, 2, 4), self._get_approx_distinct_validation_result(1, 0, 1, 14), self._get_approx_distinct_validation_result(3, 0, 3, 13), + self._get_union_all_validation_result(2, 0, 2, 4), ], ) def test_union_and_regexp_like(self): query = "SELECT * from a WHERE \nx like 'foo' or x like \n'bar' \nUNION select * from b where y like 'foo' AND x like 'bar'" self._verify_query_validation_results( - self._validator._get_sql_glot_validation_results(query), + self._validator.validate(query, 0, 0), [ - self._get_union_all_validation_result(3, 0, 3, 4), self._get_regexp_like_validation_result( 1, 0, 2, 4, "REGEXP_LIKE(x, 'foo|bar')" ), + self._get_union_all_validation_result(3, 0, 3, 4), ], ) def test_count_distinct_and_regexp_like(self): query = "SELECT \nCOUNT( DISTINCT x) from a WHERE \nx LIKE 'foo' or x like \n'bar' and y like 'foo'" self._verify_query_validation_results( - self._validator._get_sql_glot_validation_results(query), + self._validator.validate(query, 0, 0), [ - self._get_approx_distinct_validation_result(1, 0, 1, 14), self._get_regexp_like_validation_result( 2, 0, 3, 4, "REGEXP_LIKE(x, 'foo|bar')" ), + self._get_approx_distinct_validation_result(1, 0, 1, 14), ], ) def test_all_errors(self): query = "SELECT \nCOUNT( DISTINCT x) from a WHERE \nx LIKE 'foo' or x like \n'bar' and y like 'foo' \nUNION select * from b" self._verify_query_validation_results( - self._validator._get_sql_glot_validation_results(query), + self._validator.validate(query, 0, 0), [ - self._get_union_all_validation_result(4, 0, 4, 4), - self._get_approx_distinct_validation_result(1, 0, 1, 14), self._get_regexp_like_validation_result( 2, 0, 3, 4, "REGEXP_LIKE(x, 'foo|bar')" ), + self._get_approx_distinct_validation_result(1, 0, 1, 14), + self._get_union_all_validation_result(4, 0, 4, 4), ], ) def test_extra_whitespace(self): query = "SELECT \n COUNT( DISTINCT x) from a WHERE \n\t x LIKE 'foo' or x like \n'bar' and y like 'foo' \n UNION select * from b" self._verify_query_validation_results( - self._validator._get_sql_glot_validation_results(query), + self._validator.validate(query, 0, 0), [ - self._get_union_all_validation_result(4, 5, 4, 9), - self._get_approx_distinct_validation_result(1, 2, 1, 16), self._get_regexp_like_validation_result( 2, 3, 3, 4, "REGEXP_LIKE(x, 'foo|bar')" ), + self._get_approx_distinct_validation_result(1, 2, 1, 16), + self._get_union_all_validation_result(4, 5, 4, 9), ], )