From 23735368e780aff7d853687dd76c266c81d196a7 Mon Sep 17 00:00:00 2001 From: Leonard Binet Date: Sun, 10 May 2020 22:29:45 +0200 Subject: [PATCH] query autonested --- pandagg/tree/aggs.py | 45 +++++++++++-------- pandagg/tree/query.py | 94 +++++++++++++++++++++++++++++----------- setup.py | 2 +- tests/tree/test_aggs.py | 58 ++++++++++++------------- tests/tree/test_query.py | 22 ++++++++++ 5 files changed, 144 insertions(+), 77 deletions(-) diff --git a/pandagg/tree/aggs.py b/pandagg/tree/aggs.py index fcb3b3a2..a2627360 100644 --- a/pandagg/tree/aggs.py +++ b/pandagg/tree/aggs.py @@ -24,7 +24,7 @@ @python_2_unicode_compatible class Aggs(Tree): - """ + r""" Combination of aggregation clauses. This class provides handful methods to build an aggregation (see :func:`~pandagg.tree.aggs.Aggs.aggs` and :func:`~pandagg.tree.aggs.Aggs.groupby`), and is used as well to parse aggregations response in handy formats. @@ -61,6 +61,10 @@ class Aggs(Tree): Mapping of requested indice(s). Providing it will validate aggregations validity, and add required nested clauses if missing. + * *nested_autocorrect* (``bool``) -- + In case of missing nested clauses in aggregation, if True, automatically add missing nested clauses, else + raise error. + * remaining kwargs: Used as body in aggregation """ @@ -70,6 +74,7 @@ class Aggs(Tree): def __init__(self, *args, **kwargs): self.mapping = Mapping(kwargs.pop("mapping", None)) + self.nested_autocorrect = kwargs.pop("nested_autocorrect", False) super(Aggs, self).__init__() if args or kwargs: self._fill(*args, **kwargs) @@ -79,15 +84,6 @@ def __nonzero__(self): __bool__ = __nonzero__ - @classmethod - def deserialize(cls, *args, **kwargs): - mapping = kwargs.pop("mapping", None) - if len(args) == 1 and isinstance(args[0], Aggs): - return args[0] - - new = cls(mapping=mapping) - return new._fill(*args, **kwargs) - def _fill(self, *args, **kwargs): if args: node_hierarchy = self.node_class._type_deserializer(*args, **kwargs) @@ -99,7 +95,10 @@ def _fill(self, *args, **kwargs): return self def _clone_init(self, deep=False): - return Aggs(mapping=self.mapping.clone(deep=deep)) + return Aggs( + mapping=self.mapping.clone(deep=deep), + nested_autocorrect=self.nested_autocorrect, + ) def _is_eligible_grouping_node(self, nid): """Return whether node can be used as grouping node.""" @@ -259,19 +258,19 @@ def groupby(self, *args, **kwargs): raise ValueError( "Kwargs not allowed when passing multiple aggregations in args." ) - inserted_aggs = [self.deserialize(arg) for arg in args] + inserted_aggs = [Aggs(arg) for arg in args] # groupby([{}, {}]) elif len(args) == 1 and isinstance(args[0], (list, tuple)): if kwargs: raise ValueError( "Kwargs not allowed when passing multiple aggregations in args." ) - inserted_aggs = [self.deserialize(arg) for arg in args[0]] + inserted_aggs = [Aggs(arg) for arg in args[0]] # groupby({}) # groupby(Terms()) # groupby('terms', name='per_tag', field='tag') else: - inserted_aggs = [self.deserialize(*args, **kwargs)] + inserted_aggs = [Aggs(*args, **kwargs)] if insert_above is not None: parent = new_agg.parent(insert_above, id_only=False) @@ -303,7 +302,7 @@ def groupby(self, *args, **kwargs): return new_agg def aggs(self, *args, **kwargs): - """ + r""" Arrange passed aggregations "horizontally". Given the initial aggregation:: @@ -348,7 +347,7 @@ def aggs(self, *args, **kwargs): """ insert_below = self._validate_aggs_parent_id(kwargs.pop("insert_below", None)) new_agg = self.clone(with_tree=True) - deserialized = self.deserialize(*args, mapping=self.mapping, **kwargs) + deserialized = Aggs(*args, **kwargs) deserialized_root = deserialized.get(deserialized.root) if isinstance(deserialized_root, ShadowRoot): new_agg.merge(deserialized, nid=insert_below) @@ -390,7 +389,7 @@ def applied_nested_path_at_node(self, nid): def _insert_node_below(self, node, parent_id, with_children=True): """If mapping is provided, nested aggregations are automatically applied. """ - if isinstance(node, ShadowRoot): + if isinstance(node, ShadowRoot) and parent_id is not None: for child in node._children or []: super(Aggs, self)._insert_node_below( child, parent_id=parent_id, with_children=with_children @@ -401,7 +400,6 @@ def _insert_node_below(self, node, parent_id, with_children=True): isinstance(node, Nested) or isinstance(node, ReverseNested) or not self.mapping - or parent_id is None or not hasattr(node, "field") ): return super(Aggs, self)._insert_node_below( @@ -412,11 +410,20 @@ def _insert_node_below(self, node, parent_id, with_children=True): # from deepest to highest required_nested_level = self.mapping.nested_at_field(node.field) - current_nested_level = self.applied_nested_path_at_node(parent_id) + + if self.is_empty(): + current_nested_level = None + else: + current_nested_level = self.applied_nested_path_at_node(parent_id) if current_nested_level == required_nested_level: return super(Aggs, self)._insert_node_below( node, parent_id, with_children=with_children ) + if not self.nested_autocorrect: + raise ValueError( + "Invalid %s agg on %s field. Invalid nested: expected %s, current %s." + % (node.KEY, node.field, required_nested_level, current_nested_level) + ) if current_nested_level and ( required_nested_level or "" in current_nested_level ): diff --git a/pandagg/tree/query.py b/pandagg/tree/query.py index f16582d1..8afc646b 100644 --- a/pandagg/tree/query.py +++ b/pandagg/tree/query.py @@ -41,16 +41,29 @@ @python_2_unicode_compatible class Query(Tree): - """Tree combination of query nodes. + r"""Combination of query clauses. Mapping declaration is optional, but doing so validates query validity and automatically inserts nested clauses when necessary. + + :Keyword Arguments: + * *mapping* (``dict`` or ``pandagg.tree.mapping.Mapping``) -- + Mapping of requested indice(s). Providing it will add validation features, and add required nested + clauses if missing. + + * *nested_autocorrect* (``bool``) -- + In case of missing nested clauses in query, if True, automatically add missing nested clauses, else raise + error. + + * remaining kwargs: + Used as body in query clauses. """ node_class = QueryClause def __init__(self, *args, **kwargs): self.mapping = Mapping(kwargs.pop("mapping", None)) + self.nested_autocorrect = kwargs.pop("nested_autocorrect", False) super(Query, self).__init__() if args or kwargs: self._fill(*args, **kwargs) @@ -61,7 +74,10 @@ def __nonzero__(self): __bool__ = __nonzero__ def _clone_init(self, deep=False): - return Query(mapping=self.mapping.clone(with_tree=True, deep=deep)) + return Query( + mapping=self.mapping.clone(with_tree=True, deep=deep), + nested_autocorrect=self.nested_autocorrect, + ) @classmethod def deserialize(cls, *args, **kwargs): @@ -86,34 +102,60 @@ def _fill(self, *args, **kwargs): def _insert_node_below(self, node, parent_id=None, with_children=True): """Override lighttree.Tree._insert_node_below method to ensure inserted query clause is consistent.""" - if parent_id is None: + if parent_id is not None: + pnode = self.get(parent_id) + if isinstance(pnode, LeafQueryClause): + raise ValueError( + "Cannot add clause under leaf query clause <%s>" % pnode.KEY + ) + if isinstance(pnode, ParentParameterClause): + if isinstance(node, ParameterClause): + raise ValueError( + "Cannot add parameter clause <%s> under another paramter clause <%s>" + % (pnode.KEY, node.KEY) + ) + if isinstance(pnode, CompoundClause): + if ( + not isinstance(node, ParameterClause) + or node.KEY not in pnode.PARAMS_WHITELIST + ): + raise ValueError( + "Expect a parameter clause of type %s under <%s> compound clause, got <%s>" + % (pnode.PARAMS_WHITELIST, pnode.KEY, node.KEY) + ) + + # automatic handling of nested clauses + if isinstance(node, Nested) or not self.mapping or not hasattr(node, "field"): return super(Query, self)._insert_node_below( - node, parent_id=parent_id, with_children=with_children + node=node, parent_id=parent_id, with_children=with_children ) - - pnode = self.get(parent_id) - if isinstance(pnode, LeafQueryClause): + required_nested_level = self.mapping.nested_at_field(node.field) + if self.is_empty(): + current_nested_level = None + else: + current_nested_level = self.applied_nested_path_at_node(parent_id) + if not self.nested_autocorrect: raise ValueError( - "Cannot add clause under leaf query clause <%s>" % pnode.KEY + "Invalid %s query clause on %s field. Invalid nested: expected %s, current %s." + % (node.KEY, node.field, required_nested_level, current_nested_level) ) - if isinstance(pnode, ParentParameterClause): - if isinstance(node, ParameterClause): - raise ValueError( - "Cannot add parameter clause <%s> under another paramter clause <%s>" - % (pnode.KEY, node.KEY) - ) - if isinstance(pnode, CompoundClause): - if ( - not isinstance(node, ParameterClause) - or node.KEY not in pnode.PARAMS_WHITELIST - ): - raise ValueError( - "Expect a parameter clause of type %s under <%s> compound clause, got <%s>" - % (pnode.PARAMS_WHITELIST, pnode.KEY, node.KEY) - ) - super(Query, self)._insert_node_below( - node=node, parent_id=parent_id, with_children=with_children - ) + if current_nested_level == required_nested_level: + return super(Query, self)._insert_node_below( + node=node, parent_id=parent_id, with_children=with_children + ) + # requires nested - apply all required nested fields + for nested_lvl in self.mapping.list_nesteds_at_field(node.field): + if current_nested_level != nested_lvl: + node = Nested(path=nested_lvl, query=node) + super(Query, self)._insert_node_below(node, parent_id, with_children=True) + + def applied_nested_path_at_node(self, nid): + # from current node to root + for id_ in [nid] + self.ancestors(nid): + node = self.get(id_) + if isinstance(node, Nested): + return node.path + return None def to_dict(self, from_=None, with_name=True): """Return None if no query clause. diff --git a/setup.py b/setup.py index 12f99bea..6d4f44cd 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ install_requires = [ "future", - "lighttree==0.0.2", + "lighttree==0.0.3", "elasticsearch>=7.1.0,<8.0.0", ] diff --git a/tests/tree/test_aggs.py b/tests/tree/test_aggs.py index cb6dd190..85d54fa0 100644 --- a/tests/tree/test_aggs.py +++ b/tests/tree/test_aggs.py @@ -92,41 +92,36 @@ def test_deserialize_nodes_with_subaggs(self): self.assertEqual(a.to_dict(), expected) def test_add_node_with_mapping(self): - with_mapping = Aggs(mapping=MAPPING) + with_mapping = Aggs(mapping=MAPPING, nested_autocorrect=True) self.assertEqual(len(with_mapping.list()), 0) # add regular node - with_mapping.insert_node(Terms("workflow", field="workflow")) - self.assertEqual(len(with_mapping.list()), 1) - - # try to add second root fill fail - with self.assertRaises(MultipleRootError): - with_mapping.insert_node( - Terms("classification_type", field="classification_type") - ) + with_mapping = with_mapping.aggs(Terms("workflow", field="workflow")) + self.assertEqual( + with_mapping.to_dict(), {"workflow": {"terms": {"field": "workflow"}}} + ) # try to add field aggregation on non-existing field will fail with self.assertRaises(AbsentMappingFieldError): - with_mapping.insert_node( - node=Terms("imaginary_agg", field="imaginary_field"), - parent_id="workflow", + with_mapping.aggs( + Terms("imaginary_agg", field="imaginary_field"), + insert_below="workflow", ) self.assertEqual(len(with_mapping.list()), 1) # try to add aggregation on a non-compatible field will fail with self.assertRaises(InvalidOperationMappingFieldError): - with_mapping.insert_node( - node=Avg("average_of_string", field="classification_type"), - parent_id="workflow", + with_mapping.aggs( + Avg("average_of_string", field="classification_type"), + insert_below="workflow", ) self.assertEqual(len(with_mapping.list()), 1) # add field aggregation on field passing through nested will automatically add nested - with_mapping.insert_node( - node=Avg("local_f1_score", field="local_metrics.performance.test.f1_score"), - parent_id="workflow", + with_mapping = with_mapping.aggs( + Avg("local_f1_score", field="local_metrics.performance.test.f1_score"), + insert_below="workflow", ) - self.assertEqual(len(with_mapping.list()), 3) self.assertEqual( with_mapping.to_dict(), { @@ -153,11 +148,9 @@ def test_add_node_with_mapping(self): self.assertEqual(nested_node.path, "local_metrics") # add other agg requiring nested will reuse nested agg as parent - with_mapping.insert_node( - node=Avg( - "local_precision", field="local_metrics.performance.test.precision" - ), - parent_id="workflow", + with_mapping = with_mapping.aggs( + Avg("local_precision", field="local_metrics.performance.test.precision"), + insert_below="workflow", ) self.assertEqual( with_mapping.to_dict(), @@ -188,9 +181,9 @@ def test_add_node_with_mapping(self): # add under a nested parent a field aggregation that requires to be located under root will automatically # add reverse-nested - with_mapping.insert_node( - node=Terms("language_terms", field="language"), - parent_id="nested_below_workflow", + with_mapping = with_mapping.aggs( + Terms("language_terms", field="language"), + insert_below="nested_below_workflow", ) self.assertEqual(len(with_mapping.list()), 6) self.assertEqual( @@ -316,6 +309,7 @@ def test_paste_tree_with_mapping(self): } }, mapping=MAPPING, + nested_autocorrect=True, ) self.assertEqual(to_id_set(initial_agg_2.list()), {"week"}) pasted_agg_2 = Aggs( @@ -460,6 +454,7 @@ def test_interpret_agg_string(self): some_agg = Aggs( {"term_workflow": {"terms": {"field": "workflow", "size": 5}}}, mapping=MAPPING, + nested_autocorrect=True, ) some_agg = some_agg.aggs( "local_metrics.field_class.name", insert_below="term_workflow" @@ -494,6 +489,7 @@ def test_interpret_node(self): some_agg = Aggs( {"term_workflow": {"terms": {"field": "workflow", "size": 5}}}, mapping=MAPPING, + nested_autocorrect=True, ) node = Avg(name="min_local_f1", field="local_metrics.performance.test.f1_score") some_agg = some_agg.aggs(node, insert_below="term_workflow") @@ -598,7 +594,7 @@ def test_init_from_node_hierarchy(self): ) ], ) - agg = Aggs(node_hierarchy, mapping=MAPPING) + agg = Aggs(node_hierarchy, mapping=MAPPING, nested_autocorrect=True) self.assertEqual( agg.to_dict(), { @@ -919,7 +915,7 @@ def test_applied_nested_path_at_node(self): ) ], ) - agg = Aggs(node_hierarchy, mapping=MAPPING) + agg = Aggs(node_hierarchy, mapping=MAPPING, nested_autocorrect=True) self.assertEqual(agg.applied_nested_path_at_node("week"), None) for nid in ( @@ -955,7 +951,7 @@ def test_deepest_linear_agg(self): ) ], ) - agg = Aggs(node_hierarchy, mapping=MAPPING) + agg = Aggs(node_hierarchy, mapping=MAPPING, nested_autocorrect=True) self.assertEqual( agg.deepest_linear_bucket_agg, "local_metrics.field_class.name" ) @@ -981,5 +977,5 @@ def test_deepest_linear_agg(self): ), ], ) - agg2 = Aggs(node_hierarchy_2, mapping=MAPPING) + agg2 = Aggs(node_hierarchy_2, mapping=MAPPING, nested_autocorrect=True) self.assertEqual(agg2.deepest_linear_bucket_agg, "week") diff --git a/tests/tree/test_query.py b/tests/tree/test_query.py index fa524e57..31a9a811 100644 --- a/tests/tree/test_query.py +++ b/tests/tree/test_query.py @@ -1086,6 +1086,28 @@ def test_multiple_must_below_nested_query(self): ), ) + def test_autonested(self): + q = Query( + mapping={ + "properties": { + "actors": { + "type": "nested", + "properties": {"id": {"type": "keyword"}}, + } + } + }, + nested_autocorrect=True, + ) + self.assertEqual( + q.query("term", actors__id=2).to_dict(), + { + "nested": { + "path": "actors", + "query": {"term": {"actors.id": {"value": 2}}}, + } + }, + ) + def test_query_unnamed_inserts(self): q = (