diff --git a/pandagg/node/query/_parameter_clause.py b/pandagg/node/query/_parameter_clause.py index b9b7754f..06f40328 100644 --- a/pandagg/node/query/_parameter_clause.py +++ b/pandagg/node/query/_parameter_clause.py @@ -22,7 +22,7 @@ def __init__(self, value): def line_repr(self, depth, **kwargs): return "%s=%s" % (self.KEY, json.dumps(self.body["value"])) - def serialize(self, with_name=True): + def to_dict(self, with_name=True): return {self.KEY: self.body["value"]} @@ -131,6 +131,9 @@ def __init__(self, *args, **kwargs): ) super(ParentParameterClause, self).__init__(_children=children) + def to_dict(self, with_name=True): + return {self.KEY: [n.to_dict() for n in self._children]} + class Filter(ParentParameterClause): KEY = "filter" diff --git a/pandagg/node/query/abstract.py b/pandagg/node/query/abstract.py index 2b52475a..9d437e1c 100644 --- a/pandagg/node/query/abstract.py +++ b/pandagg/node/query/abstract.py @@ -97,7 +97,7 @@ def name(self): def _identifier_prefix(self): return "%s_" % self.KEY - def serialize(self, with_name=True): + def to_dict(self, with_name=True): b = self.body.copy() if with_name and self._named: b["_name"] = self.name @@ -113,9 +113,9 @@ def __str__(self): def __eq__(self, other): if isinstance(other, self.__class__): - return other.serialize() == self.serialize() + return other.to_dict() == self.to_dict() # make sure we still equal to a dict with the same data - return other == self.serialize() + return other == self.to_dict() class LeafQueryClause(QueryClause): diff --git a/pandagg/node/query/compound.py b/pandagg/node/query/compound.py index a3a74139..b2ce4201 100644 --- a/pandagg/node/query/compound.py +++ b/pandagg/node/query/compound.py @@ -12,20 +12,22 @@ class CompoundClause(QueryClause): - """Compound clauses can encapsulate other query clauses. + """Compound clauses can encapsulate other query clauses:: + + { + "" : { + + + } + } Note: the children attribute's only purpose is for initiation with the following syntax: + >>> from pandagg.query import Bool, Term >>> query = Bool( >>> filter=Term(field='some_path', value=3), >>> _name='bool_id', >>> ) - { - "" : { - - - } - } """ DEFAULT_OPERATOR = None @@ -67,6 +69,12 @@ def params(cls, parent_only=False): or not issubclass(cls.get_dsl_class(p, "_param_"), SimpleParameter) } + def to_dict(self, with_name=True): + d = {} + for c in self._children: + d.update(c.to_dict()) + return {self.KEY: d} + class Bool(CompoundClause): DEFAULT_OPERATOR = Must diff --git a/pandagg/node/query/term_level.py b/pandagg/node/query/term_level.py index fa313ce8..93b916e2 100644 --- a/pandagg/node/query/term_level.py +++ b/pandagg/node/query/term_level.py @@ -29,7 +29,7 @@ def __init__(self, values, _name=None): self.values = values super(Ids, self).__init__(_name=_name, values=values) - def serialize(self, with_name=True): + def to_dict(self, with_name=True): b = {"values": self.values} if with_name and self._named: b["_name"] = self.name diff --git a/pandagg/search.py b/pandagg/search.py index 478986ff..052c9e95 100644 --- a/pandagg/search.py +++ b/pandagg/search.py @@ -201,7 +201,8 @@ def __getitem__(self, n): return s def size(self, size): - """Equivalent to:: + """ + Equivalent to:: s = Search().params(size=size) diff --git a/pandagg/tree/aggs.py b/pandagg/tree/aggs.py index a1ef7803..a2627360 100644 --- a/pandagg/tree/aggs.py +++ b/pandagg/tree/aggs.py @@ -24,9 +24,49 @@ @python_2_unicode_compatible class Aggs(Tree): - """Tree combination of aggregation nodes. + r""" + Combination of aggregation clauses. This class provides handful methods to build an aggregation (see + :func:`~pandagg.tree.aggs.Aggs.aggs` and :func:`~pandagg.tree.aggs.Aggs.groupby`), and is used as well + to parse aggregations response in handy formats. - Mapping declaration is optional, but doing so validates aggregation validity. + Mapping declaration is optional, but doing so validates aggregation validity and automatically handles missing + nested clauses. + + All following syntaxes are identical: + + From a dict: + + >>> Aggs({"per_user":{"terms":{"field":"user"}}}) + + Using shortcut declaration: first argument is the aggregation type, other arguments are aggregation body parameters: + + >>> Aggs('terms', name='per_user', field='user') + + Using DSL class: + + >>> from pandagg.aggs import Terms + >>> Aggs(Terms('per_user', field='user')) + + Dict and DSL class syntaxes allow to provide multiple clauses aggregations: + + >>> Aggs({"per_user":{"terms":{"field":"user"}, "aggs": {"avg_age": {"avg": {"field": "age"}}}}}) + + With is similar to: + + >>> from pandagg.aggs import Terms, Avg + >>> Aggs(Terms('per_user', field='user', aggs=Avg('avg_age', field='age'))) + + :Keyword Arguments: + * *mapping* (``dict`` or ``pandagg.tree.mapping.Mapping``) -- + Mapping of requested indice(s). Providing it will validate aggregations validity, and add required nested + clauses if missing. + + * *nested_autocorrect* (``bool``) -- + In case of missing nested clauses in aggregation, if True, automatically add missing nested clauses, else + raise error. + + * remaining kwargs: + Used as body in aggregation """ node_class = AggNode @@ -34,6 +74,7 @@ class Aggs(Tree): def __init__(self, *args, **kwargs): self.mapping = Mapping(kwargs.pop("mapping", None)) + self.nested_autocorrect = kwargs.pop("nested_autocorrect", False) super(Aggs, self).__init__() if args or kwargs: self._fill(*args, **kwargs) @@ -43,15 +84,6 @@ def __nonzero__(self): __bool__ = __nonzero__ - @classmethod - def deserialize(cls, *args, **kwargs): - mapping = kwargs.pop("mapping", None) - if len(args) == 1 and isinstance(args[0], Aggs): - return args[0] - - new = cls(mapping=mapping) - return new._fill(*args, **kwargs) - def _fill(self, *args, **kwargs): if args: node_hierarchy = self.node_class._type_deserializer(*args, **kwargs) @@ -63,7 +95,10 @@ def _fill(self, *args, **kwargs): return self def _clone_init(self, deep=False): - return Aggs(mapping=self.mapping.clone(deep=deep)) + return Aggs( + mapping=self.mapping.clone(deep=deep), + nested_autocorrect=self.nested_autocorrect, + ) def _is_eligible_grouping_node(self, nid): """Return whether node can be used as grouping node.""" @@ -77,7 +112,8 @@ def _is_eligible_grouping_node(self, nid): @property def deepest_linear_bucket_agg(self): - """Return deepest bucket aggregation node (pandagg.nodes.abstract.BucketAggNode) of that aggregation that + """ + Return deepest bucket aggregation node (pandagg.nodes.abstract.BucketAggNode) of that aggregation that neither has siblings, nor has an ancestor with siblings. """ if not self.root or not self._is_eligible_grouping_node(self.root): @@ -101,7 +137,8 @@ def deepest_linear_bucket_agg(self): return last_bucket_agg_name def _validate_aggs_parent_id(self, pid): - """If pid is not None, ensure that pid belongs to tree, and that it refers to a bucket aggregation. + """ + If pid is not None, ensure that pid belongs to tree, and that it refers to a bucket aggregation. Else, if not provided, return deepest bucket aggregation if there is no ambiguity (linear aggregations). KO: non-ambiguous:: @@ -131,7 +168,8 @@ def _validate_aggs_parent_id(self, pid): return leaves[0].identifier def groupby(self, *args, **kwargs): - r"""Arrange passed aggregations in vertical/nested manner, above or below another agg clause. + r""" + Arrange passed aggregations in vertical/nested manner, above or below another agg clause. Given the initial aggregation:: @@ -140,12 +178,12 @@ def groupby(self, *args, **kwargs): If `insert_below` = 'A':: - A──> by──> B + A──> new──> B └──> C If `insert_above` = 'B':: - A──> by──> B + A──> new──> B └──> C `by` argument accepts single occurrence or sequence of following formats: @@ -163,11 +201,10 @@ def groupby(self, *args, **kwargs): └──> C - Accepted declarations for single aggregation: + Accepted all Aggs.__init__ syntaxes - Official DSL like: - - >>> Aggs().groupby('terms', name='per_user_id', field='user_id') + >>> Aggs()\ + >>> .groupby('terms', name='per_user_id', field='user_id') {"terms_on_my_field":{"terms":{"field":"some_field"}}} Passing a dict: @@ -221,19 +258,19 @@ def groupby(self, *args, **kwargs): raise ValueError( "Kwargs not allowed when passing multiple aggregations in args." ) - inserted_aggs = [self.deserialize(arg) for arg in args] + inserted_aggs = [Aggs(arg) for arg in args] # groupby([{}, {}]) elif len(args) == 1 and isinstance(args[0], (list, tuple)): if kwargs: raise ValueError( "Kwargs not allowed when passing multiple aggregations in args." ) - inserted_aggs = [self.deserialize(arg) for arg in args[0]] + inserted_aggs = [Aggs(arg) for arg in args[0]] # groupby({}) # groupby(Terms()) # groupby('terms', name='per_tag', field='tag') else: - inserted_aggs = [self.deserialize(*args, **kwargs)] + inserted_aggs = [Aggs(*args, **kwargs)] if insert_above is not None: parent = new_agg.parent(insert_above, id_only=False) @@ -265,35 +302,52 @@ def groupby(self, *args, **kwargs): return new_agg def aggs(self, *args, **kwargs): - """Arrange passed aggregations in `arg` arguments "horizontally". + r""" + Arrange passed aggregations "horizontally". - Those will be placed under the `insert_below` aggregation clause id if provided, else under the deepest linear - bucket aggregation if there is no ambiguity: + Given the initial aggregation:: + + A──> B + └──> C + + If passing multiple aggregations with `insert_below` = 'A':: + + A──> B + └──> C + └──> new1 + └──> new2 + + Note: those will be placed under the `insert_below` aggregation clause id if provided, else under the deepest + linear bucket aggregation if there is no ambiguity: OK:: - A──> B ─> C ─> arg + A──> B ─> C ─> new KO:: A──> B └──> C - `arg` argument accepts single occurrence or sequence of following formats: + `args` accepts single occurrence or sequence of following formats: * string (for terms agg concise declaration) * regular Elasticsearch dict syntax * AggNode instance (for instance Terms, Filters etc) - :param arg: aggregation(s) clauses to insert "horizontally" - :param insert_below: parent aggregation id under which these aggregations should be placed - :param kwargs: agg body arguments when using "string" syntax for terms aggregation + :Keyword Arguments: + * *insert_below* (``string``) -- + Parent aggregation name under which these aggregations should be placed + + * remaining kwargs: + Used as body in aggregation + :rtype: pandagg.aggs.Aggs """ insert_below = self._validate_aggs_parent_id(kwargs.pop("insert_below", None)) new_agg = self.clone(with_tree=True) - deserialized = self.deserialize(*args, mapping=self.mapping, **kwargs) + deserialized = Aggs(*args, **kwargs) deserialized_root = deserialized.get(deserialized.root) if isinstance(deserialized_root, ShadowRoot): new_agg.merge(deserialized, nid=insert_below) @@ -335,7 +389,7 @@ def applied_nested_path_at_node(self, nid): def _insert_node_below(self, node, parent_id, with_children=True): """If mapping is provided, nested aggregations are automatically applied. """ - if isinstance(node, ShadowRoot): + if isinstance(node, ShadowRoot) and parent_id is not None: for child in node._children or []: super(Aggs, self)._insert_node_below( child, parent_id=parent_id, with_children=with_children @@ -346,7 +400,6 @@ def _insert_node_below(self, node, parent_id, with_children=True): isinstance(node, Nested) or isinstance(node, ReverseNested) or not self.mapping - or parent_id is None or not hasattr(node, "field") ): return super(Aggs, self)._insert_node_below( @@ -357,11 +410,20 @@ def _insert_node_below(self, node, parent_id, with_children=True): # from deepest to highest required_nested_level = self.mapping.nested_at_field(node.field) - current_nested_level = self.applied_nested_path_at_node(parent_id) + + if self.is_empty(): + current_nested_level = None + else: + current_nested_level = self.applied_nested_path_at_node(parent_id) if current_nested_level == required_nested_level: return super(Aggs, self)._insert_node_below( node, parent_id, with_children=with_children ) + if not self.nested_autocorrect: + raise ValueError( + "Invalid %s agg on %s field. Invalid nested: expected %s, current %s." + % (node.KEY, node.field, required_nested_level, current_nested_level) + ) if current_nested_level and ( required_nested_level or "" in current_nested_level ): diff --git a/pandagg/tree/mapping.py b/pandagg/tree/mapping.py index a0ca3ec7..cc6120f3 100644 --- a/pandagg/tree/mapping.py +++ b/pandagg/tree/mapping.py @@ -44,7 +44,7 @@ def __nonzero__(self): __bool__ = __nonzero__ - def serialize(self, from_=None, depth=None): + def to_dict(self, from_=None, depth=None): if self.root is None: return None from_ = self.root if from_ is None else from_ @@ -54,7 +54,7 @@ def serialize(self, from_=None, depth=None): if depth is not None: depth -= 1 for child_node in self.children(node.identifier, id_only=False): - children_queries[child_node.name] = self.serialize( + children_queries[child_node.name] = self.to_dict( from_=child_node.identifier, depth=depth ) serialized_node = node.body diff --git a/pandagg/tree/query.py b/pandagg/tree/query.py index 043803e2..8afc646b 100644 --- a/pandagg/tree/query.py +++ b/pandagg/tree/query.py @@ -41,16 +41,29 @@ @python_2_unicode_compatible class Query(Tree): - """Tree combination of query nodes. + r"""Combination of query clauses. Mapping declaration is optional, but doing so validates query validity and automatically inserts nested clauses when necessary. + + :Keyword Arguments: + * *mapping* (``dict`` or ``pandagg.tree.mapping.Mapping``) -- + Mapping of requested indice(s). Providing it will add validation features, and add required nested + clauses if missing. + + * *nested_autocorrect* (``bool``) -- + In case of missing nested clauses in query, if True, automatically add missing nested clauses, else raise + error. + + * remaining kwargs: + Used as body in query clauses. """ node_class = QueryClause def __init__(self, *args, **kwargs): self.mapping = Mapping(kwargs.pop("mapping", None)) + self.nested_autocorrect = kwargs.pop("nested_autocorrect", False) super(Query, self).__init__() if args or kwargs: self._fill(*args, **kwargs) @@ -61,7 +74,10 @@ def __nonzero__(self): __bool__ = __nonzero__ def _clone_init(self, deep=False): - return Query(mapping=self.mapping.clone(with_tree=True, deep=deep)) + return Query( + mapping=self.mapping.clone(with_tree=True, deep=deep), + nested_autocorrect=self.nested_autocorrect, + ) @classmethod def deserialize(cls, *args, **kwargs): @@ -86,34 +102,60 @@ def _fill(self, *args, **kwargs): def _insert_node_below(self, node, parent_id=None, with_children=True): """Override lighttree.Tree._insert_node_below method to ensure inserted query clause is consistent.""" - if parent_id is None: + if parent_id is not None: + pnode = self.get(parent_id) + if isinstance(pnode, LeafQueryClause): + raise ValueError( + "Cannot add clause under leaf query clause <%s>" % pnode.KEY + ) + if isinstance(pnode, ParentParameterClause): + if isinstance(node, ParameterClause): + raise ValueError( + "Cannot add parameter clause <%s> under another paramter clause <%s>" + % (pnode.KEY, node.KEY) + ) + if isinstance(pnode, CompoundClause): + if ( + not isinstance(node, ParameterClause) + or node.KEY not in pnode.PARAMS_WHITELIST + ): + raise ValueError( + "Expect a parameter clause of type %s under <%s> compound clause, got <%s>" + % (pnode.PARAMS_WHITELIST, pnode.KEY, node.KEY) + ) + + # automatic handling of nested clauses + if isinstance(node, Nested) or not self.mapping or not hasattr(node, "field"): return super(Query, self)._insert_node_below( - node, parent_id=parent_id, with_children=with_children + node=node, parent_id=parent_id, with_children=with_children ) - - pnode = self.get(parent_id) - if isinstance(pnode, LeafQueryClause): + required_nested_level = self.mapping.nested_at_field(node.field) + if self.is_empty(): + current_nested_level = None + else: + current_nested_level = self.applied_nested_path_at_node(parent_id) + if not self.nested_autocorrect: raise ValueError( - "Cannot add clause under leaf query clause <%s>" % pnode.KEY + "Invalid %s query clause on %s field. Invalid nested: expected %s, current %s." + % (node.KEY, node.field, required_nested_level, current_nested_level) ) - if isinstance(pnode, ParentParameterClause): - if isinstance(node, ParameterClause): - raise ValueError( - "Cannot add parameter clause <%s> under another paramter clause <%s>" - % (pnode.KEY, node.KEY) - ) - if isinstance(pnode, CompoundClause): - if ( - not isinstance(node, ParameterClause) - or node.KEY not in pnode.PARAMS_WHITELIST - ): - raise ValueError( - "Expect a parameter clause of type %s under <%s> compound clause, got <%s>" - % (pnode.PARAMS_WHITELIST, pnode.KEY, node.KEY) - ) - super(Query, self)._insert_node_below( - node=node, parent_id=parent_id, with_children=with_children - ) + if current_nested_level == required_nested_level: + return super(Query, self)._insert_node_below( + node=node, parent_id=parent_id, with_children=with_children + ) + # requires nested - apply all required nested fields + for nested_lvl in self.mapping.list_nesteds_at_field(node.field): + if current_nested_level != nested_lvl: + node = Nested(path=nested_lvl, query=node) + super(Query, self)._insert_node_below(node, parent_id, with_children=True) + + def applied_nested_path_at_node(self, nid): + # from current node to root + for id_ in [nid] + self.ancestors(nid): + node = self.get(id_) + if isinstance(node, Nested): + return node.path + return None def to_dict(self, from_=None, with_name=True): """Return None if no query clause. @@ -123,7 +165,7 @@ def to_dict(self, from_=None, with_name=True): from_ = self.root if from_ is None else from_ node = self.get(from_) if isinstance(node, (LeafQueryClause, SimpleParameter)): - return node.serialize(with_name=True) + return node.to_dict(with_name=True) serialized_children = [] should_yield = False for child_node in self.children(node.identifier, id_only=False): diff --git a/pandagg/tree/response.py b/pandagg/tree/response.py index eebfc5db..c613bd34 100644 --- a/pandagg/tree/response.py +++ b/pandagg/tree/response.py @@ -14,12 +14,13 @@ class AggsResponseTree(Tree): - """Tree representation of an ES response. ES response format is determined by the aggregation query. + """Tree representation of an ElasticSearch response. """ def __init__(self, aggs, index): """ - :param aggs: instance of pandagg.agg.Agg from which this ES response originates + :param aggs: instance of pandagg.agg.Agg from which this Elasticsearch response originates. + :param index: indice(s) on which aggregation was computed. """ super(AggsResponseTree, self).__init__() self.__aggs = aggs @@ -30,11 +31,12 @@ def _clone_init(self, deep=False): def parse(self, raw_response): """Build response tree from ElasticSearch aggregation response - :param raw_response: ElasticSearch aggregation response - :return: self Note: if the root aggregation node can generate multiple buckets, a response root is crafted to avoid having multiple roots. + + :param raw_response: ElasticSearch aggregation response + :return: self """ root_node = self.__aggs.get(self.__aggs.root) pid = None @@ -49,6 +51,7 @@ def parse(self, raw_response): def _parse_node_with_children(self, agg_node, raw_response, pid=None): """Recursive method to parse ES raw response. + :param agg_node: current aggregation, pandagg.nodes.AggNode instance :param raw_response: ES response at current level, dict :param pid: parent node identifier @@ -73,6 +76,7 @@ def _parse_node_with_children(self, agg_node, raw_response, pid=None): def bucket_properties(self, bucket, properties=None, end_level=None, depth=None): """Recursive method returning a given bucket's properties in the form of an ordered dictionnary. Travel from current bucket through all ancestors until reaching root. + :param bucket: instance of pandagg.buckets.buckets.Bucket :param properties: OrderedDict accumulator of 'level' -> 'key' :param end_level: optional parameter to specify until which level properties are fetched diff --git a/setup.py b/setup.py index 12f99bea..361759f8 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -__version__ = "0.0.2" +__version__ = "0.0.3" import os @@ -21,7 +21,7 @@ install_requires = [ "future", - "lighttree==0.0.2", + "lighttree==0.0.3", "elasticsearch>=7.1.0,<8.0.0", ] diff --git a/tests/interactive/test_mapping.py b/tests/interactive/test_mapping.py index 914ec209..9b4c3256 100644 --- a/tests/interactive/test_mapping.py +++ b/tests/interactive/test_mapping.py @@ -92,9 +92,7 @@ def test_imapping_init(self): index=index_name, ) for i, m in enumerate((im1, im2, im3)): - self.assertEqual( - m._tree.serialize(), mapping_dict, "failed at m%d" % (i + 1) - ) + self.assertEqual(m._tree.to_dict(), mapping_dict, "failed at m%d" % (i + 1)) self.assertEqual(m._index, index_name) self.assertIs(m._client, client_mock) diff --git a/tests/node/query/test_compound.py b/tests/node/query/test_compound.py index e6dd7440..713f13b1 100644 --- a/tests/node/query/test_compound.py +++ b/tests/node/query/test_compound.py @@ -59,11 +59,24 @@ def test_bool(self): filter=[{"term": {"some_field": 2}}], ) for b in (b1, b2, b3, b4, b5, b6, b7): + self.assertEqual( + b.to_dict(), + { + "bool": { + "boost": 1.2, + "filter": [{"term": {"some_field": {"value": 2}}}], + "should": [ + {"range": {"other": {"gte": 1}}}, + {"term": {"some": {"value": 3}}}, + ], + } + }, + ) self.assertEqual(len(b._children), 3) self.assertEqual(b.line_repr(depth=None), "bool") boost = next((c for c in b._children if isinstance(c, Boost))) - self.assertEqual(boost.serialize(), {"boost": 1.2}) + self.assertEqual(boost.to_dict(), {"boost": 1.2}) self.assertEqual(boost.line_repr(depth=None), "boost=1.2") filter_ = next((c for c in b._children if isinstance(c, Filter))) @@ -111,14 +124,17 @@ def test_boosting(self): positive = next((c for c in b._children if isinstance(c, Positive))) self.assertEqual(len(positive._children), 1) - self.assertEqual(positive.serialize(), {"positive": {}}) + self.assertEqual( + positive.to_dict(), + {"positive": [{"term": {"text": {"value": "apple"}}}]}, + ) self.assertEqual(positive.line_repr(depth=None), "positive") positive_term = positive._children[0] self.assertIsInstance(positive_term, Term) self.assertEqual(positive_term.field, "text") self.assertEqual(positive_term.body, {"text": {"value": "apple"}}) self.assertEqual( - positive_term.serialize(), {"term": {"text": {"value": "apple"}}} + positive_term.to_dict(), {"term": {"text": {"value": "apple"}}} ) self.assertEqual( positive_term.line_repr(depth=None), 'term, field=text, value="apple"' @@ -126,7 +142,14 @@ def test_boosting(self): negative = next((c for c in b._children if isinstance(c, Negative))) self.assertEqual(len(negative._children), 1) - self.assertEqual(negative.serialize(), {"negative": {}}) + self.assertEqual( + negative.to_dict(), + { + "negative": [ + {"term": {"text": {"value": "pie tart fruit crumble tree"}}} + ] + }, + ) self.assertEqual(negative.line_repr(depth=None), "negative") negative_term = negative._children[0] self.assertIsInstance(negative_term, Term) @@ -135,7 +158,7 @@ def test_boosting(self): negative_term.body, {"text": {"value": "pie tart fruit crumble tree"}} ) self.assertEqual( - negative_term.serialize(), + negative_term.to_dict(), {"term": {"text": {"value": "pie tart fruit crumble tree"}}}, ) self.assertEqual( diff --git a/tests/node/query/test_full_text.py b/tests/node/query/test_full_text.py index 7a9cbd21..3fbe0a72 100644 --- a/tests/node/query/test_full_text.py +++ b/tests/node/query/test_full_text.py @@ -79,7 +79,7 @@ def test_interval_clause(self): ) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'intervals, field=some_field, all_of={"intervals": [{"match": {"query": "the"}}, {"any_of": {"intervals": [{"match": {"query": "big"}}, {"match": {"query": "big bad"}}]}}, {"match": {"query": "wolf"}}], "max_gaps": 0, "ordered": true}', @@ -93,7 +93,7 @@ def test_match_clause(self): q2 = Match(message={"query": "this is a test", "operator": "and"}) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'match, field=message, operator="and", query="this is a test"', @@ -103,7 +103,7 @@ def test_match_clause(self): q3 = Match(message="this is a test") self.assertEqual(q3.body, {"message": {"query": "this is a test"}}) self.assertEqual( - q3.serialize(), {"match": {"message": {"query": "this is a test"}}} + q3.to_dict(), {"match": {"message": {"query": "this is a test"}}} ) self.assertEqual( q3.line_repr(depth=None), 'match, field=message, query="this is a test"' @@ -117,7 +117,7 @@ def test_match_bool_prefix_clause(self): q2 = MatchBoolPrefix(message={"query": "quick brown f", "analyzer": "keyword"}) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'match_bool_prefix, field=message, analyzer="keyword", query="quick brown f"', @@ -127,7 +127,7 @@ def test_match_bool_prefix_clause(self): q3 = MatchBoolPrefix(message="quick brown f") self.assertEqual(q3.body, {"message": {"query": "quick brown f"}}) self.assertEqual( - q3.serialize(), + q3.to_dict(), {"match_bool_prefix": {"message": {"query": "quick brown f"}}}, ) self.assertEqual( @@ -145,7 +145,7 @@ def test_match_phrase_clause(self): q2 = MatchPhrase(message={"query": "this is a test", "analyzer": "my_analyzer"}) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'match_phrase, field=message, analyzer="my_analyzer", query="this is a test"', @@ -155,7 +155,7 @@ def test_match_phrase_clause(self): q3 = MatchPhrase(message="this is a test") self.assertEqual(q3.body, {"message": {"query": "this is a test"}}) self.assertEqual( - q3.serialize(), {"match_phrase": {"message": {"query": "this is a test"}}}, + q3.to_dict(), {"match_phrase": {"message": {"query": "this is a test"}}}, ) self.assertEqual( q3.line_repr(depth=None), @@ -174,7 +174,7 @@ def test_match_phrase_prefix_clause(self): ) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'match_phrase_prefix, field=message, analyzer="my_analyzer", query="this is a test"', @@ -192,7 +192,7 @@ def test_multi_match_clause(self): fields=["subject", "message"], query="this is a test", type="best_fields" ) self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), "multi_match, fields=['subject', 'message']" ) @@ -203,7 +203,7 @@ def test_query_string_clause(self): q = QueryString(query="(new york city) OR (big apple)", default_field="content") self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'query_string, default_field="content", query="(new york city) OR (big apple)"', @@ -217,7 +217,7 @@ def test_simple_string_clause(self): query="(new york city) OR (big apple)", default_field="content" ) self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'simple_string, default_field="content", query="(new york city) OR (big apple)"', diff --git a/tests/node/query/test_geo.py b/tests/node/query/test_geo.py index 22e8e907..9fb1578b 100644 --- a/tests/node/query/test_geo.py +++ b/tests/node/query/test_geo.py @@ -28,7 +28,7 @@ def test_geo_bounding_box_clause(self): ) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'geo_bounding_box, field=pin.location, bottom_right={"lat": 40.01, "lon": -71.12}, top_left={"lat": 40.73, "lon": -74.1}', @@ -44,7 +44,7 @@ def test_geo_polygone_clause(self): q2 = GeoPolygone(person__location={"points": [[-70, 40], [-80, 30], [-90, 20]]}) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), "geo_polygon, field=person.location, points=[[-70, 40], [-80, 30], [-90, 20]]", @@ -56,7 +56,7 @@ def test_geo_distance_clause(self): q = GeoDistance(pin__location="drm3btev3e86", distance="12km") self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual(q.line_repr(depth=None), "geo_distance, field=pin.location") def test_geo_shape(self): @@ -87,7 +87,7 @@ def test_geo_shape(self): ) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'geo_shape, field=location, relation="within", shape={"coordinates": [[13.0, 53.0], [14.0, 52.0]], "type": "envelope"}', diff --git a/tests/node/query/test_joining.py b/tests/node/query/test_joining.py index 4dff0318..4dde6221 100644 --- a/tests/node/query/test_joining.py +++ b/tests/node/query/test_joining.py @@ -22,7 +22,10 @@ def test_nested(self): self.assertEqual(n.path, "some_nested_path") q = next((c for c in n._children if isinstance(c, QueryP))) - self.assertEqual(q.serialize(), {"query": {}}) + self.assertEqual( + q.to_dict(), + {"query": [{"term": {"some_nested_path.id": {"value": 2}}}]}, + ) # ensure term query is present self.assertEqual(len(q._children), 1) self.assertIsInstance(q._children[0], Term, i) diff --git a/tests/node/query/test_parameter_clause.py b/tests/node/query/test_parameter_clause.py index 0c8b7598..ba2040ab 100644 --- a/tests/node/query/test_parameter_clause.py +++ b/tests/node/query/test_parameter_clause.py @@ -30,7 +30,7 @@ def test_filter_parameter(self): self.assertEqual(f.line_repr(depth=None), "filter") term = next((c for c in f._children if isinstance(c, Term))) - self.assertEqual(term.serialize(), {"term": {"some_field": {"value": 1}}}) + self.assertEqual(term.to_dict(), {"term": {"some_field": {"value": 1}}}) range_ = next((c for c in f._children if isinstance(c, Range))) - self.assertEqual(range_.serialize(), {"range": {"other_field": {"gte": 2}}}) + self.assertEqual(range_.to_dict(), {"range": {"other_field": {"gte": 2}}}) diff --git a/tests/node/query/test_specialized.py b/tests/node/query/test_specialized.py index fe0e4a05..a4a773e7 100644 --- a/tests/node/query/test_specialized.py +++ b/tests/node/query/test_specialized.py @@ -23,7 +23,7 @@ def test_distance_feature_clause(self): q = DistanceFeature(field="production_date", pivot="7d", origin="now") self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'distance_feature, field="production_date", origin="now", pivot="7d"', @@ -45,7 +45,7 @@ def test_more_like_this_clause(self): max_query_terms=12, ) self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), "more_like_this, fields=['title', 'description']" ) @@ -61,7 +61,7 @@ def test_percolate_clause(self): field="query", document={"message": "A new bonsai tree in the office"} ) self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'percolate, document={"message": "A new bonsai tree in the office"}, field="query"', @@ -73,7 +73,7 @@ def test_rank_feature_clause(self): q = RankFeature(field="url_length", boost=0.1) self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'rank_feature, boost=0.1, field="url_length"' ) @@ -96,7 +96,7 @@ def test_script_clause(self): } ) self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'script, script={"lang": "painless", "params": {"param1": 5}, "source": "doc[\'num1\'].value > params.param1"}', @@ -108,7 +108,7 @@ def test_wrapper_clause(self): q = Wrapper(query="eyJ0ZXJtIiA6IHsgInVzZXIiIDogIktpbWNoeSIgfX0=") self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'wrapper, query="eyJ0ZXJtIiA6IHsgInVzZXIiIDogIktpbWNoeSIgfX0="', diff --git a/tests/node/query/test_term_level.py b/tests/node/query/test_term_level.py index f8e9af07..c9e0f153 100644 --- a/tests/node/query/test_term_level.py +++ b/tests/node/query/test_term_level.py @@ -26,7 +26,7 @@ def test_fuzzy_clause(self): q3 = Fuzzy(user={"value": "ki"}) for q in (q1, q2, q3): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual(q.line_repr(depth=None), 'fuzzy, field=user, value="ki"') def test_exists_clause(self): @@ -35,7 +35,7 @@ def test_exists_clause(self): q = Exists(field="user") self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual(q.line_repr(depth=None), "exists, field=user") def test_ids_clause(self): @@ -44,7 +44,7 @@ def test_ids_clause(self): q = Ids(values=[1, 4, 100]) self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual(q.line_repr(depth=None), "ids, values=[1, 4, 100]") def test_prefix_clause(self): @@ -53,7 +53,7 @@ def test_prefix_clause(self): q = Prefix(field="user", value="ki") self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual(q.line_repr(depth=None), 'prefix, field=user, value="ki"') def test_range_clause(self): @@ -64,7 +64,7 @@ def test_range_clause(self): q2 = Range(age={"gte": 10, "lte": 20, "boost": 2}) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), "range, field=age, boost=2, gte=10, lte=20" ) @@ -98,7 +98,7 @@ def test_regexp_clause(self): ) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual(q.line_repr(depth=None), tag) def test_term_clause(self): @@ -109,7 +109,7 @@ def test_term_clause(self): q2 = Term(user={"value": "Kimchy", "boost": 1}) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'term, field=user, boost=1, value="Kimchy"' ) @@ -117,7 +117,7 @@ def test_term_clause(self): # other format q3 = Term(user="Kimchy") self.assertEqual(q3.body, {"user": {"value": "Kimchy"}}) - self.assertEqual(q3.serialize(), {"term": {"user": {"value": "Kimchy"}}}) + self.assertEqual(q3.to_dict(), {"term": {"user": {"value": "Kimchy"}}}) self.assertEqual(q3.line_repr(depth=None), 'term, field=user, value="Kimchy"') def test_terms_clause(self): @@ -127,7 +127,7 @@ def test_terms_clause(self): q = Terms(user=["kimchy", "elasticsearch"], boost=1) self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'terms, boost=1, user=["kimchy", "elasticsearch"]', ) @@ -154,7 +154,7 @@ def test_terms_set_clause(self): ) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'terms_set, field=programming_languages, minimum_should_match_field="required_matches", terms=["c++", "java", "php"]', @@ -168,7 +168,7 @@ def test_wildcard_clause(self): q2 = Wildcard(user={"value": "ki*y", "boost": 1.0, "rewrite": "constant_score"}) for q in (q1, q2): self.assertEqual(q.body, body) - self.assertEqual(q.serialize(), expected) + self.assertEqual(q.to_dict(), expected) self.assertEqual( q.line_repr(depth=None), 'wildcard, field=user, boost=1.0, rewrite="constant_score", value="ki*y"', diff --git a/tests/tree/test_aggs.py b/tests/tree/test_aggs.py index cb6dd190..85d54fa0 100644 --- a/tests/tree/test_aggs.py +++ b/tests/tree/test_aggs.py @@ -92,41 +92,36 @@ def test_deserialize_nodes_with_subaggs(self): self.assertEqual(a.to_dict(), expected) def test_add_node_with_mapping(self): - with_mapping = Aggs(mapping=MAPPING) + with_mapping = Aggs(mapping=MAPPING, nested_autocorrect=True) self.assertEqual(len(with_mapping.list()), 0) # add regular node - with_mapping.insert_node(Terms("workflow", field="workflow")) - self.assertEqual(len(with_mapping.list()), 1) - - # try to add second root fill fail - with self.assertRaises(MultipleRootError): - with_mapping.insert_node( - Terms("classification_type", field="classification_type") - ) + with_mapping = with_mapping.aggs(Terms("workflow", field="workflow")) + self.assertEqual( + with_mapping.to_dict(), {"workflow": {"terms": {"field": "workflow"}}} + ) # try to add field aggregation on non-existing field will fail with self.assertRaises(AbsentMappingFieldError): - with_mapping.insert_node( - node=Terms("imaginary_agg", field="imaginary_field"), - parent_id="workflow", + with_mapping.aggs( + Terms("imaginary_agg", field="imaginary_field"), + insert_below="workflow", ) self.assertEqual(len(with_mapping.list()), 1) # try to add aggregation on a non-compatible field will fail with self.assertRaises(InvalidOperationMappingFieldError): - with_mapping.insert_node( - node=Avg("average_of_string", field="classification_type"), - parent_id="workflow", + with_mapping.aggs( + Avg("average_of_string", field="classification_type"), + insert_below="workflow", ) self.assertEqual(len(with_mapping.list()), 1) # add field aggregation on field passing through nested will automatically add nested - with_mapping.insert_node( - node=Avg("local_f1_score", field="local_metrics.performance.test.f1_score"), - parent_id="workflow", + with_mapping = with_mapping.aggs( + Avg("local_f1_score", field="local_metrics.performance.test.f1_score"), + insert_below="workflow", ) - self.assertEqual(len(with_mapping.list()), 3) self.assertEqual( with_mapping.to_dict(), { @@ -153,11 +148,9 @@ def test_add_node_with_mapping(self): self.assertEqual(nested_node.path, "local_metrics") # add other agg requiring nested will reuse nested agg as parent - with_mapping.insert_node( - node=Avg( - "local_precision", field="local_metrics.performance.test.precision" - ), - parent_id="workflow", + with_mapping = with_mapping.aggs( + Avg("local_precision", field="local_metrics.performance.test.precision"), + insert_below="workflow", ) self.assertEqual( with_mapping.to_dict(), @@ -188,9 +181,9 @@ def test_add_node_with_mapping(self): # add under a nested parent a field aggregation that requires to be located under root will automatically # add reverse-nested - with_mapping.insert_node( - node=Terms("language_terms", field="language"), - parent_id="nested_below_workflow", + with_mapping = with_mapping.aggs( + Terms("language_terms", field="language"), + insert_below="nested_below_workflow", ) self.assertEqual(len(with_mapping.list()), 6) self.assertEqual( @@ -316,6 +309,7 @@ def test_paste_tree_with_mapping(self): } }, mapping=MAPPING, + nested_autocorrect=True, ) self.assertEqual(to_id_set(initial_agg_2.list()), {"week"}) pasted_agg_2 = Aggs( @@ -460,6 +454,7 @@ def test_interpret_agg_string(self): some_agg = Aggs( {"term_workflow": {"terms": {"field": "workflow", "size": 5}}}, mapping=MAPPING, + nested_autocorrect=True, ) some_agg = some_agg.aggs( "local_metrics.field_class.name", insert_below="term_workflow" @@ -494,6 +489,7 @@ def test_interpret_node(self): some_agg = Aggs( {"term_workflow": {"terms": {"field": "workflow", "size": 5}}}, mapping=MAPPING, + nested_autocorrect=True, ) node = Avg(name="min_local_f1", field="local_metrics.performance.test.f1_score") some_agg = some_agg.aggs(node, insert_below="term_workflow") @@ -598,7 +594,7 @@ def test_init_from_node_hierarchy(self): ) ], ) - agg = Aggs(node_hierarchy, mapping=MAPPING) + agg = Aggs(node_hierarchy, mapping=MAPPING, nested_autocorrect=True) self.assertEqual( agg.to_dict(), { @@ -919,7 +915,7 @@ def test_applied_nested_path_at_node(self): ) ], ) - agg = Aggs(node_hierarchy, mapping=MAPPING) + agg = Aggs(node_hierarchy, mapping=MAPPING, nested_autocorrect=True) self.assertEqual(agg.applied_nested_path_at_node("week"), None) for nid in ( @@ -955,7 +951,7 @@ def test_deepest_linear_agg(self): ) ], ) - agg = Aggs(node_hierarchy, mapping=MAPPING) + agg = Aggs(node_hierarchy, mapping=MAPPING, nested_autocorrect=True) self.assertEqual( agg.deepest_linear_bucket_agg, "local_metrics.field_class.name" ) @@ -981,5 +977,5 @@ def test_deepest_linear_agg(self): ), ], ) - agg2 = Aggs(node_hierarchy_2, mapping=MAPPING) + agg2 = Aggs(node_hierarchy_2, mapping=MAPPING, nested_autocorrect=True) self.assertEqual(agg2.deepest_linear_bucket_agg, "week") diff --git a/tests/tree/test_mapping.py b/tests/tree/test_mapping.py index f4a8f244..a0d54678 100644 --- a/tests/tree/test_mapping.py +++ b/tests/tree/test_mapping.py @@ -87,7 +87,7 @@ def test_deserialization(self): """ for i, m in enumerate((m1, m2,)): self.assertEqual(m.__repr__(), expected_repr, "failed at m%d" % (i + 1)) - self.assertEqual(m.serialize(), mapping_dict, "failed at m%d" % (i + 1)) + self.assertEqual(m.to_dict(), mapping_dict, "failed at m%d" % (i + 1)) def test_parse_tree_from_dict(self): mapping_tree = Mapping(MAPPING) diff --git a/tests/tree/test_query.py b/tests/tree/test_query.py index fa524e57..31a9a811 100644 --- a/tests/tree/test_query.py +++ b/tests/tree/test_query.py @@ -1086,6 +1086,28 @@ def test_multiple_must_below_nested_query(self): ), ) + def test_autonested(self): + q = Query( + mapping={ + "properties": { + "actors": { + "type": "nested", + "properties": {"id": {"type": "keyword"}}, + } + } + }, + nested_autocorrect=True, + ) + self.assertEqual( + q.query("term", actors__id=2).to_dict(), + { + "nested": { + "path": "actors", + "query": {"term": {"actors.id": {"value": 2}}}, + } + }, + ) + def test_query_unnamed_inserts(self): q = (