From 3d63ff17f0925b4982d0b78442699c52970982a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Srokosz?= Date: Thu, 1 Feb 2024 16:10:57 +0100 Subject: [PATCH] Next steps --- mwdb/core/search/__init__.py | 8 +- mwdb/core/search/exceptions.py | 43 ++++-- mwdb/core/search/fields.py | 178 ++++++++++-------------- mwdb/core/search/mappings.py | 76 +++-------- mwdb/core/search/query_builder.py | 35 +++-- mwdb/core/search/search.py | 217 ------------------------------ mwdb/core/search/tree.py | 7 - mwdb/resources/object.py | 19 +-- mwdb/resources/search.py | 10 +- 9 files changed, 161 insertions(+), 432 deletions(-) delete mode 100644 mwdb/core/search/search.py delete mode 100644 mwdb/core/search/tree.py diff --git a/mwdb/core/search/__init__.py b/mwdb/core/search/__init__.py index 5f862dc80..4de3c5138 100644 --- a/mwdb/core/search/__init__.py +++ b/mwdb/core/search/__init__.py @@ -1,7 +1,7 @@ -from .exceptions import SQLQueryBuilderBaseException -from .search import SQLQueryBuilder +from .exceptions import QueryBaseException +from .query_builder import build_query __all__ = [ - "SQLQueryBuilderBaseException", - "SQLQueryBuilder", + "QueryBaseException", + "build_query", ] diff --git a/mwdb/core/search/exceptions.py b/mwdb/core/search/exceptions.py index 635b43aaf..601e956b2 100644 --- a/mwdb/core/search/exceptions.py +++ b/mwdb/core/search/exceptions.py @@ -1,29 +1,54 @@ -class SQLQueryBuilderBaseException(Exception): +from typing import Optional, Tuple, Type + +from luqum.tree import Item + + +class QueryBaseException(Exception): """ Base exception for SQLQueryBuilder """ -class UnsupportedGrammarException(SQLQueryBuilderBaseException): +class QueryParseException(QueryBaseException): """ - Raised when SQLQueryBuilder does not support given Lucene grammar + Raised when Lucene parser is unable to parse a query """ -class FieldNotQueryableException(SQLQueryBuilderBaseException): +class UnsupportedNodeException(QueryBaseException): + def __init__(self, message: str, node: Item): + super().__init__(f"{message} ({node.pos}:{node.pos + node.size - 1})") + + +class UnsupportedNodeTypeException(UnsupportedNodeException): """ - Raised when field does not exists, so it can't be queried, eg. file.unexistent_field + Raised when query visitor does not support given Lucene grammar """ + def __init__(self, node: Item, expected: Optional[Tuple[Type, ...]] = None): + message = f"{node.__class__.__name__} is not supported here" + if expected: + message += f", expected {', '.join(typ.__name__ for typ in expected)}" + super().__init__(message, node) -class MultipleObjectsQueryException(SQLQueryBuilderBaseException): + +class UnsupportedLikeStatement(UnsupportedNodeException): + def __init__(self, node: Item): + super().__init__("Like statements are not supported here", node) + + +class InvalidValueException(QueryBaseException): + def __init__(self, value: str, expected: str): + super().__init__(f"Invalid value format: {value}, expected {expected}") + + +class FieldNotQueryableException(QueryBaseException): """ - Raised when multiple object types are queried, - e.g. file.file_name:something AND static.cfg:something2 + Raised when field does not exists, so it can't be queried, eg. file.unexistent_field """ -class ObjectNotFoundException(SQLQueryBuilderBaseException): +class ObjectNotFoundException(QueryBaseException): """ Raised when object referenced in query condition can't be found """ diff --git a/mwdb/core/search/fields.py b/mwdb/core/search/fields.py index fd6db046d..3ac47961b 100644 --- a/mwdb/core/search/fields.py +++ b/mwdb/core/search/fields.py @@ -1,12 +1,12 @@ import re import uuid from datetime import datetime, timedelta, timezone -from typing import Any, List, Optional, Tuple, Type +from typing import Any, Optional, Tuple, Type from dateutil.relativedelta import relativedelta from flask import g -from luqum.tree import Item, Phrase, Range, Term, Word -from sqlalchemy import String, and_, cast, func, or_ +from luqum.tree import FieldGroup, Item, Phrase, Range, Term, Word +from sqlalchemy import and_, any_, or_ from mwdb.core.capabilities import Capabilities from mwdb.model import ( @@ -25,8 +25,10 @@ from .exceptions import ( FieldNotQueryableException, + InvalidValueException, ObjectNotFoundException, - UnsupportedGrammarException, + UnsupportedLikeStatement, + UnsupportedNodeTypeException, ) from .parse_helpers import ( PathSelector, @@ -36,7 +38,6 @@ range_equals, string_equals, ) -from .tree import Subquery def string_from_node(node: Item, escaped: bool = False) -> str: @@ -46,24 +47,24 @@ def string_from_node(node: Item, escaped: bool = False) -> str: # Remove quotes from the beginning and the end of Phrase return node.value[1:-1] if escaped else node.unescaped_value[1:-1] else: - raise UnsupportedGrammarException(...) + raise UnsupportedNodeTypeException(node) def range_from_node(node: Item) -> Tuple[Optional[str], Optional[str], bool, bool]: if not isinstance(node, Range): - raise UnsupportedGrammarException(...) + raise UnsupportedNodeTypeException(node) low_value = string_from_node(node.low) if low_value == "*": low_value = None elif has_wildcard(low_value): - raise UnsupportedGrammarException(...) + raise UnsupportedLikeStatement(node.low) high_value = string_from_node(node.high) if high_value == "*": high_value = None elif has_wildcard(high_value): - raise UnsupportedGrammarException(...) + raise UnsupportedLikeStatement(node.high) return low_value, high_value, node.include_low, node.include_high @@ -71,29 +72,31 @@ def range_from_node(node: Item) -> Tuple[Optional[str], Optional[str], bool, boo class BaseField: accepts_subpath = False - def __init__(self, column): - self.column = column - - @property - def column_type(self) -> Type[Object]: - return self.column.class_ - def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: raise NotImplementedError def get_condition(self, value: Item, path_selector: PathSelector) -> Any: if not self.accepts_subpath and len(path_selector) > 1: - raise UnsupportedGrammarException(...) + raise FieldNotQueryableException("Subfields are not allowed for this field") return self._get_condition(value, path_selector) -class StringField(BaseField): +class ColumnField(BaseField): + def __init__(self, column): + self.column = column + + @property + def column_type(self) -> Type[Object]: + return self.column.class_ + + +class StringField(ColumnField): def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: string_value = string_from_node(value, escaped=True) return string_equals(self.column, string_value) -class SizeField(BaseField): +class SizeField(ColumnField): def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: units = {"B": 1, "KB": 1024, "MB": 1024**2, "GB": 1024**3} @@ -101,10 +104,10 @@ def parse_size(size) -> int: if re.match(r"^\d+$", size) is not None: return int(size) else: - size = re.match(r"(\d+(?:[.]\d+)?)[ ]?([KMGT]?B)", size.upper()) - if size is None: - raise UnsupportedGrammarException("Invalid size value") - number, unit = size.groups() + size_match = re.match(r"(\d+(?:[.]\d+)?)[ ]?([KMGT]?B)", size.upper()) + if size_match is None: + raise InvalidValueException(size, expected="size") + number, unit = size_match.groups() return int(float(number) * units[unit]) if isinstance(value, Range): @@ -120,7 +123,7 @@ def parse_size(size) -> int: return self.column == target_value -class StringListField(BaseField): +class StringListField(ColumnField): def __init__(self, column, value_column): super().__init__(column) self.value_column = value_column @@ -130,7 +133,7 @@ def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: return self.column.any(string_equals(self.value_column, string_value)) -class UUIDField(BaseField): +class UUIDField(ColumnField): def __init__(self, column, value_column): super().__init__(column) self.value_column = value_column @@ -143,7 +146,7 @@ def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: try: uuid_value = uuid.UUID(string_value) except ValueError: - raise UnsupportedGrammarException("Field accepts only correct UUID values") + raise InvalidValueException(string_value, expected="UUID") return self.column.any(self.value_column == uuid_value) @@ -167,7 +170,7 @@ def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: if user is None: raise ObjectNotFoundException(f"No such user: {value}") - return self.column.any(User.id == user.id) + return Object.followers.any(User.id == user.id) class AttributeField(BaseField): @@ -175,9 +178,7 @@ class AttributeField(BaseField): def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: if len(path_selector) <= 1: - raise UnsupportedGrammarException( - "Missing attribute key (attribute.:)" - ) + raise FieldNotQueryableException("Missing attribute key (attribute.:)") attribute_key, _ = path_selector[1] attribute_definition = AttributeDefinition.query_for_read( @@ -188,7 +189,7 @@ def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: raise ObjectNotFoundException(f"No such attribute: {attribute_key}") if not isinstance(value, (Range, Term)): - raise UnsupportedGrammarException(...) + raise UnsupportedNodeTypeException(value) if ( attribute_definition.hidden @@ -207,7 +208,7 @@ def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: else: string_value = string_from_node(value, escaped=True) jsonpath_condition = jsonpath_string_equals(path_selector[1:], string_value) - return self.column.any( + return Object.attributes.any( and_( Attribute.key == attribute_key, Attribute.value.op("@?")(jsonpath_condition), @@ -229,7 +230,7 @@ def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: # Cfg values in database are escaped, so we need to escape search phrase too string_value = string_value.encode("unicode_escape").decode("utf-8") jsonpath_condition = jsonpath_string_equals(path_selector, string_value) - return self.column.op("@?")(jsonpath_condition) + return Config.cfg.op("@?")(jsonpath_condition) class ShareField(BaseField): @@ -246,7 +247,7 @@ def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: ]: raise ObjectNotFoundException(f"No such group: {string_value}") - return self.column.any(ObjectPermission.group_id == group_id) + return Object.shares.any(ObjectPermission.group_id == group_id) class SharerField(BaseField): @@ -278,7 +279,7 @@ def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: raise ObjectNotFoundException(f"No such user or group: {string_value}") uploader_ids = [u.id for u in uploaders] - return self.column.any( + return Object.shares.any( and_( ObjectPermission.get_shares_filter(include_inherited_uploads=False), ObjectPermission.related_user_id.in_(uploader_ids), @@ -320,7 +321,7 @@ def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: raise ObjectNotFoundException(f"No such user or group: {string_value}") uploader_ids = [u.id for u in uploaders] - return self.column.any( + return Object.shares.any( and_( ObjectPermission.get_uploaders_filter(), ObjectPermission.related_user_id.in_(uploader_ids), @@ -328,7 +329,7 @@ def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: ) -class DatetimeField(BaseField): +class DatetimeField(ColumnField): def _is_relative_time(self, expression_value): pattern = r"^(\d+[yYmWwDdHhMSs])+$" return re.search(pattern, expression_value) @@ -349,7 +350,7 @@ def _get_field_for_unit(self, unit): elif unit in ["S", "s"]: unit = "seconds" else: - raise UnsupportedGrammarException("Invalid date-time format") + return None return unit def _get_border_time(self, expression_value): @@ -358,6 +359,8 @@ def _get_border_time(self, expression_value): delta_dict = {} for value, unit in conditions: field = self._get_field_for_unit(unit) + if field is None: + raise InvalidValueException(expression_value, "date-time") if field not in delta_dict.keys(): delta_dict.update({field: int(value)}) border_time = datetime.now(tz=timezone.utc) - relativedelta(**delta_dict) @@ -376,9 +379,7 @@ def _get_date_range(self, date_string): except ValueError: continue else: - raise FieldNotQueryableException( - f"Unsupported date-time format ({date_string})" - ) + raise InvalidValueException(date_string, "date-time") def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: if isinstance(value, Range): @@ -404,20 +405,24 @@ def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: string_value = string_from_node(value) low_datetime, high_datetime = self._get_date_range(string_value) include_low = include_high = True - return range_equals(self.column, low_datetime, high_datetime, include_low, - include_high) + return range_equals( + self.column, low_datetime, high_datetime, include_low, include_high + ) -class RelationField(BaseField): +class RelationField(ColumnField): accepts_subquery = True - def _get_condition( - self, expression: Expression, subfields: List[Tuple[str, int]] - ) -> Any: - ... + def _get_condition(self, subquery: Item, path_selector: PathSelector) -> Any: + from .query_builder import QueryConditionVisitor + + if not isinstance(subquery, FieldGroup): + raise UnsupportedNodeTypeException(subquery) + condition = QueryConditionVisitor(Object).visit(subquery) + return self.column.any(Object.id.in_(condition)) -class CommentAuthorField(BaseField): +class CommentAuthorField(ColumnField): def __init__(self, column, value_column): super().__init__(column) self.value_column = value_column @@ -436,15 +441,12 @@ class UploadCountField(BaseField): def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: def parse_upload_value(value): try: - value = int(value) - if value <= 0: + int_value = int(value) + if int_value <= 0: raise ValueError except ValueError: - raise UnsupportedGrammarException( - "Field upload_count accepts statements with " - "only correct positive integer values" - ) - return value + raise InvalidValueException(value, "positive integer value") + return int_value if isinstance(value, Range): low, high, include_low, include_high = range_from_node(value) @@ -452,14 +454,16 @@ def parse_upload_value(value): low = parse_upload_value(low) if high is not None: high = parse_upload_value(high) - return range_equals(self.column, low, high, include_low, include_high) + return range_equals( + Object.upload_count, low, high, include_low, include_high + ) else: string_value = string_from_node(value) upload_value = parse_upload_value(string_value) - return self.column == upload_value + return Object.upload_count == upload_value -class MultiField(BaseField): +class MultiField(ColumnField): @staticmethod def get_column(queried_type: Type[Object], value: str): if queried_type is File: @@ -490,54 +494,18 @@ def get_column(queried_type: Type[Object], value: str): f"{queried_type.__name__} is not valid data type" ) - def _get_condition( - self, expression: Expression, subfields: List[Tuple[str, int]] - ) -> Any: - string_column = ["TextBlob._content"] - json_column = ["Config._cfg"] - - value = get_term_value(expression).strip() - values_list = re.split("\\s+", value) - - condition = None - for value in values_list: - column = MultiField.get_column(self.column_type, value) - - if str(column) in string_column: - value = f"%{value}%" - value = add_escaping_for_like_statement(value) - condition = or_(condition, (column.like(value))) - elif str(column) in json_column: - value = f"%{value}%" - value = add_escaping_for_like_statement(value) - condition = or_(condition, (cast(column, String).like(value))) - else: - # hashes values - condition = or_(condition, (column == value)) - - return condition + def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: + """ + TODO I don't know how to reimplement it right now + """ + raise NotImplementedError class FileNameField(BaseField): accepts_wildcards = True - def _get_condition( - self, expression: Expression, subfields: List[Tuple[str, int]] - ) -> Any: - value = get_term_value(expression) - - if expression.has_wildcard(): - sub_query = db.session.query( - File.id.label("f_id"), func.unnest(File.alt_names).label("alt_name") - ).subquery() - value = add_escaping_for_like_statement(value) - file_id_matching = ( - db.session.query(File.id) - .join(sub_query, sub_query.c.f_id == File.id) - .filter(sub_query.c.alt_name.like(value)) - ) - - condition = or_(self.column.like(value), File.id.in_(file_id_matching)) - else: - condition = or_((self.column == value), File.alt_names.any(value)) - return condition + def _get_condition(self, value: Item, path_selector: PathSelector) -> Any: + string_value = string_from_node(value, escaped=True) + name_condition = string_equals(File.file_name, string_value) + alt_names_condition = string_equals(any_(File.alt_names), string_value) + return or_(name_condition, alt_names_condition) diff --git a/mwdb/core/search/mappings.py b/mwdb/core/search/mappings.py index 39fa3fdfd..5bdd4e036 100644 --- a/mwdb/core/search/mappings.py +++ b/mwdb/core/search/mappings.py @@ -1,5 +1,4 @@ -import re -from typing import Dict, List, Tuple, Type +from typing import Dict, Tuple, Type from mwdb.model import ( Comment, @@ -12,7 +11,7 @@ User, ) -from .exceptions import FieldNotQueryableException, MultipleObjectsQueryException +from .exceptions import FieldNotQueryableException from .fields import ( AttributeField, BaseField, @@ -32,6 +31,7 @@ UploaderField, UUIDField, ) +from .parse_helpers import PathSelector, parse_field_path object_mapping: Dict[str, Type[Object]] = { "file": File, @@ -46,21 +46,21 @@ "dhash": StringField(Object.dhash), "tag": StringListField(Object.tags, Tag.tag), "comment": StringListField(Object.comments, Comment.comment), - "meta": AttributeField(Object.attributes), # legacy - "attribute": AttributeField(Object.attributes), - "shared": ShareField(Object.shares), - "sharer": SharerField(Object.shares), - "uploader": UploaderField(Object.shares), + "meta": AttributeField(), # legacy + "attribute": AttributeField(), + "shared": ShareField(), + "sharer": SharerField(), + "uploader": UploaderField(), "upload_time": DatetimeField(Object.upload_time), "parent": RelationField(Object.parents), "child": RelationField(Object.children), - "favorites": FavoritesField(Object.followers), + "favorites": FavoritesField(), "karton": UUIDField(Object.analyses, KartonAnalysis.id), "comment_author": CommentAuthorField(Object.comment_authors, User.login), - "upload_count": UploadCountField(Object.upload_count), + "upload_count": UploadCountField(), }, File.__name__: { - "name": FileNameField(File.file_name), + "name": FileNameField(), "size": SizeField(File.file_size), "type": StringField(File.file_type), "md5": StringField(File.md5), @@ -74,7 +74,7 @@ Config.__name__: { "type": StringField(Config.config_type), "family": StringField(Config.family), - "cfg": ConfigField(Config.cfg), + "cfg": ConfigField(), "multi": MultiField(Config.id), }, TextBlob.__name__: { @@ -89,60 +89,14 @@ } -def parse_field_path(field_path): - """ - Extract subfields from fields path with proper control character handling: - - - \\x - escaped character - - * - array element reference e.g. (array*:2) - - . - field separator - - " - quote for control character escaping - """ - fields = [""] - last_pos = 0 - - for match in re.finditer(r"\\.|[.]|[*]+(?:[.]|$)", field_path): - control_char = match.group(0) - control_char_pos, next_pos = match.span(0) - # Append remaining characters to the last field - fields[-1] = fields[-1] + field_path[last_pos:control_char_pos] - last_pos = next_pos - # Check control character - if control_char[0] == "\\": - # Escaped character - fields[-1] = fields[-1] + control_char[1] - elif control_char == ".": - # End of field - fields.append("") - elif control_char[0] == "*": - # Terminate field as a tuple with count of trailing asterisks - fields[-1] = (fields[-1], control_char.count("*")) - # End of field with trailing asterisks - if control_char[-1] == ".": - fields.append("") - - if len(field_path) > last_pos: - # Last field should not be a tuple at this point. If it is: something went wrong - assert type(fields[-1]) is str - fields[-1] = fields[-1] + field_path[last_pos:] - return [field if type(field) is tuple else (field, 0) for field in fields] - - def get_field_mapper( - queried_type: Type[Object], field_selector: str -) -> Tuple[BaseField, List[str]]: + queried_type: str, field_selector: str +) -> Tuple[BaseField, PathSelector]: field_path = parse_field_path(field_selector) field_name, asterisks = field_path[0] # Map object type selector if field_name in object_mapping: selected_type = object_mapping[field_name] - # Because object type selector determines queried type, we can't use specialized - # fields from different types in the same query - if not issubclass(selected_type, queried_type): - raise MultipleObjectsQueryException( - f"Can't search for objects with type '{selected_type.__name__}' " - f"and '{queried_type.__name__}' in the same query" - ) field_path = field_path[1:] else: selected_type = queried_type @@ -154,6 +108,6 @@ def get_field_mapper( elif field_name in field_mapping[Object.__name__]: field = field_mapping[Object.__name__][field_name] else: - raise FieldNotQueryableException(f"No such field: {field_name}") + raise FieldNotQueryableException(f"No such field {field_name}") return field, field_path diff --git a/mwdb/core/search/query_builder.py b/mwdb/core/search/query_builder.py index aa77125c4..6ebd84bc0 100644 --- a/mwdb/core/search/query_builder.py +++ b/mwdb/core/search/query_builder.py @@ -1,5 +1,7 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional, Type +from luqum.exceptions import ParseError +from luqum.parser import parser from luqum.tree import ( AndOperation, BaseGroup, @@ -12,7 +14,10 @@ from luqum.visitor import TreeVisitor from sqlalchemy import and_, not_, or_ -from .exceptions import UnsupportedGrammarException +from mwdb.model import Object, db + +from .exceptions import QueryParseException, UnsupportedNodeTypeException +from .mappings import get_field_mapper # SQLAlchemy doesn't provide typings Condition = Any @@ -37,10 +42,7 @@ def visit(self, tree, context=None): return self.visit_iter(tree, context=context) def visit_unsupported(self, node: Item, context: SQLQueryBuilderContext): - raise UnsupportedGrammarException( - f"Lucene grammar element {node.__class__.__name__} " - f"is not supported here" - ) + raise UnsupportedNodeTypeException(node) class QueryConditionVisitor(QueryTreeVisitor): @@ -48,10 +50,16 @@ class QueryConditionVisitor(QueryTreeVisitor): Builds sqlalchemy condition from parsed Lucene query """ + def __init__(self, queried_type): + super().__init__() + self.queried_type = queried_type + def visit_search_field( - self, node: SearchField, context: SQLQueryBuilderContext + self, node: SearchField, _: SQLQueryBuilderContext ) -> Condition: - ... + field_mapper, path_selector = get_field_mapper(self.queried_type, node.name) + condition = field_mapper.get_condition(node.expr, path_selector) + return condition def visit_and_operation( self, node: AndOperation, context: SQLQueryBuilderContext @@ -75,3 +83,14 @@ def visit_group( self, node: BaseGroup, context: SQLQueryBuilderContext ) -> Condition: return self.visit(node.expr, context) + + +def build_query(query: str, queried_type: Optional[Type[Object]] = None): + try: + tree = parser.parse(query) + except ParseError as e: + raise QueryParseException(str(e)) from e + queried_type = queried_type or Object + condition_visitor = QueryConditionVisitor(queried_type) + condition = condition_visitor.visit(tree) + return db.session.query(queried_type).filter(condition) diff --git a/mwdb/core/search/search.py b/mwdb/core/search/search.py deleted file mode 100644 index 57f032db1..000000000 --- a/mwdb/core/search/search.py +++ /dev/null @@ -1,217 +0,0 @@ -from typing import Any, List, Optional, Type, TypeVar, Union - -from flask import g -from luqum.parser import parser -from luqum.tree import ( - AndOperation, - BaseGroup, - FieldGroup, - Item, - Not, - OrOperation, - Phrase, - Prohibit, - Range, - SearchField, - Term, - Word, -) -from luqum.utils import LuceneTreeVisitorV2 -from sqlalchemy import and_, not_, or_ -from sqlalchemy.orm import aliased - -from mwdb.model import Object, db - -from .exceptions import FieldNotQueryableException, UnsupportedGrammarException -from .mappings import get_field_mapper -from .tree import Subquery - -T = TypeVar("T", bound=Term) -# SQLAlchemy doesn't provide typings -Condition = Any - - -class SQLQueryBuilderContext: - def __init__(self, queried_type: Optional[Type[Object]] = None): - self.queried_type = queried_type or Object - self.field_mapper = None - - -class SQLQueryBuilder(LuceneTreeVisitorV2): - generic_visitor_method_name = "visit_unsupported" - - # Visitor methods for value nodes - - def visit_term( - self, node: T, parents: List[Item], context: SQLQueryBuilderContext - ) -> Union[T, Range]: - """ - Visitor for Term (Word and Phrase). - - checks if field is already set - - performs wildcard mapping and unescaping - - wildcards are not allowed inside ranges - - Returns mapped node - """ - if context.field_mapper is None: - raise FieldNotQueryableException( - "You have to specify field, check help for more information" - ) - - is_range_term = isinstance(parents[-1], Range) - - if node.has_wildcard() and is_range_term and str(node) != "*": - raise UnsupportedGrammarException( - "Wildcards other than * are not supported in range queries" - ) - - if context.field_mapper.accepts_range and not is_range_term: - if node.value.startswith(">="): - node.value = node.value[2:] - return Range( - low=node, high=Term("*"), include_low=True, include_high=False - ) - elif node.value.startswith(">"): - node.value = node.value[1:] - return Range( - low=node, high=Term("*"), include_low=False, include_high=False - ) - elif node.value.startswith("<="): - node.value = node.value[2:] - return Range( - low=Term("*"), high=node, include_low=False, include_high=True - ) - elif node.value.startswith("<"): - node.value = node.value[1:] - return Range( - low=Term("*"), high=node, include_low=False, include_high=False - ) - - return node - - def visit_word( - self, node: Word, parents: List[Item], context: SQLQueryBuilderContext - ) -> Word: - """ - Visitor for Word. Words are non-enquoted Terms. - """ - return self.visit_term(node, parents, context) - - def visit_phrase( - self, node: Phrase, parents: List[Item], context: SQLQueryBuilderContext - ) -> Phrase: - """ - Visitor for Phrase. Phrases are enquoted Terms. - """ - # Strip the " from start and end - node.value = node.value[1:-1] - return self.visit_term(node, parents, context) - - def visit_range( - self, node: Range, parents: List[Item], context: SQLQueryBuilderContext - ) -> Range: - """ - Visitor for Range - - inclusive [ TO ] - - exclusive { TO } - """ - if not context.field_mapper.accepts_range: - raise UnsupportedGrammarException( - "Range queries are not supported for this type of field" - ) - - node.low = self.visit(node.low, parents + [node], context) - node.high = self.visit(node.high, parents + [node], context) - return node - - # Visitor methods for fields - - def visit_search_field( - self, node: SearchField, parents: List[Item], context: SQLQueryBuilderContext - ) -> Condition: - field_mapper, name_remainder = get_field_mapper(context.queried_type, node.name) - - if field_mapper.column_type is not Object: - context.queried_type = field_mapper.column_type - - context.field_mapper = field_mapper - condition = field_mapper.get_condition( - self.visit(node.expr, parents + [node], context), name_remainder - ) - context.field_mapper = None - - return condition - - # Visitor methods for operators - - def visit_and_operation( - self, node: AndOperation, parents: List[Item], context: SQLQueryBuilderContext - ) -> Condition: - return and_( - *[ - self.visit(child_node, parents + [node], context) - for child_node in node.children - ] - ) - - def visit_or_operation( - self, node: OrOperation, parents: List[Item], context: SQLQueryBuilderContext - ) -> Condition: - return or_( - *[ - self.visit(child_node, parents + [node], context) - for child_node in node.children - ] - ) - - def visit_not( - self, node: Not, parents: List[Item], context: SQLQueryBuilderContext - ) -> Condition: - return not_(self.visit(node.a, parents + [node], context)) - - def visit_prohibit( - self, node: Prohibit, parents: List[Item], context: SQLQueryBuilderContext - ) -> Condition: - return not_(self.visit(node.a, parents + [node], context)) - - # Visitor methods for other elements - - def visit_group( - self, node: BaseGroup, parents: List[Item], context: SQLQueryBuilderContext - ) -> Condition: - return self.visit(node.expr, parents + [node], context) - - def visit_unsupported( - self, node: Item, parents: List[Item], context: SQLQueryBuilderContext - ): - raise UnsupportedGrammarException( - f"Lucene grammar element {node.__class__.__name__} " - f"is not supported in search" - ) - - def visit_field_group( - self, node: FieldGroup, parents: List[Item], context: SQLQueryBuilderContext - ) -> Subquery: - if context.field_mapper.accepts_subquery: - inner_context = SQLQueryBuilderContext() - condition = self.visit(node.expr, parents + [node], inner_context) - # Make aliased entity for inner query - relative = aliased(inner_context.queried_type, flat=True) - subquery = ( - db.session.query(relative.id) - .select_entity_from(relative) # Use aliased entity in subquery - .filter(condition) - .filter(g.auth_user.has_access_to_object(relative.id)) - ) - return Subquery(node.expr, subquery) - else: - raise UnsupportedGrammarException( - "Subqueries are not supported for this type of field" - ) - - # Main function - def build_query(self, query: str, queried_type: Optional[Type[Object]] = None): - context = SQLQueryBuilderContext(queried_type=queried_type) - tree = parser.parse(query) - condition = self.visit(tree, context=context) - return db.session.query(context.queried_type).filter(condition) diff --git a/mwdb/core/search/tree.py b/mwdb/core/search/tree.py deleted file mode 100644 index d4e235631..000000000 --- a/mwdb/core/search/tree.py +++ /dev/null @@ -1,7 +0,0 @@ -from luqum.tree import FieldGroup - - -class Subquery(FieldGroup): - def __init__(self, expr, subquery): - super().__init__(expr) - self.subquery = subquery diff --git a/mwdb/resources/object.py b/mwdb/resources/object.py index b03395ea9..b5bb93472 100644 --- a/mwdb/resources/object.py +++ b/mwdb/resources/object.py @@ -3,14 +3,13 @@ from flask import g, request from flask_restful import Resource -from luqum.parser import ParseError from werkzeug.exceptions import BadRequest, Forbidden, MethodNotAllowed, NotFound from mwdb.core.capabilities import Capabilities from mwdb.core.config import app_config from mwdb.core.plugins import hooks from mwdb.core.rate_limit import rate_limited_resource -from mwdb.core.search import SQLQueryBuilder, SQLQueryBuilderBaseException +from mwdb.core.search import QueryBaseException, build_query from mwdb.model import AttributeDefinition, Object, db from mwdb.model.tag import Tag from mwdb.schema.object import ( @@ -231,12 +230,8 @@ def get(self): query = obj["query"] if query: try: - db_query = SQLQueryBuilder().build_query( - query, queried_type=self.ObjectType - ) - except SQLQueryBuilderBaseException as e: - raise BadRequest(str(e)) - except ParseError as e: + db_query = build_query(query, queried_type=self.ObjectType) + except QueryBaseException as e: raise BadRequest(str(e)) else: db_query = db.session.query(self.ObjectType) @@ -430,12 +425,8 @@ def get(self, type): query = obj["query"] if query: try: - db_query = SQLQueryBuilder().build_query( - query, queried_type=get_type_from_str(type) - ) - except SQLQueryBuilderBaseException as e: - raise BadRequest(str(e)) - except ParseError as e: + db_query = build_query(query, queried_type=get_type_from_str(type)) + except QueryBaseException as e: raise BadRequest(str(e)) else: db_query = db.session.query(get_type_from_str(type)) diff --git a/mwdb/resources/search.py b/mwdb/resources/search.py index 619cb6bad..6076aa880 100644 --- a/mwdb/resources/search.py +++ b/mwdb/resources/search.py @@ -1,10 +1,9 @@ from flask import g, request from flask_restful import Resource -from luqum.parser import ParseError from werkzeug.exceptions import BadRequest from mwdb.core.rate_limit import rate_limited_resource -from mwdb.core.search import SQLQueryBuilder, SQLQueryBuilderBaseException +from mwdb.core.search import QueryBaseException, build_query from mwdb.model import Object from mwdb.schema.object import ObjectListItemResponseSchema from mwdb.schema.search import SearchRequestSchema @@ -58,15 +57,12 @@ def post(self): query = obj["query"] try: result = ( - SQLQueryBuilder() - .build_query(query) + build_query(query) .filter(g.auth_user.has_access_to_object(Object.id)) .order_by(Object.id.desc()) .limit(10000) ).all() - except SQLQueryBuilderBaseException as e: - raise BadRequest(str(e)) - except ParseError as e: + except QueryBaseException as e: raise BadRequest(str(e)) schema = ObjectListItemResponseSchema(many=True)