Skip to content

Commit

Permalink
query autonested
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardbinet committed May 10, 2020
1 parent 2b3a6d1 commit 2373536
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 77 deletions.
45 changes: 26 additions & 19 deletions pandagg/tree/aggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

@python_2_unicode_compatible
class Aggs(Tree):
"""
r"""
Combination of aggregation clauses. This class provides handful methods to build an aggregation (see
:func:`~pandagg.tree.aggs.Aggs.aggs` and :func:`~pandagg.tree.aggs.Aggs.groupby`), and is used as well
to parse aggregations response in handy formats.
Expand Down Expand Up @@ -61,6 +61,10 @@ class Aggs(Tree):
Mapping of requested indice(s). Providing it will validate aggregations validity, and add required nested
clauses if missing.
* *nested_autocorrect* (``bool``) --
In case of missing nested clauses in aggregation, if True, automatically add missing nested clauses, else
raise error.
* remaining kwargs:
Used as body in aggregation
"""
Expand All @@ -70,6 +74,7 @@ class Aggs(Tree):

def __init__(self, *args, **kwargs):
self.mapping = Mapping(kwargs.pop("mapping", None))
self.nested_autocorrect = kwargs.pop("nested_autocorrect", False)
super(Aggs, self).__init__()
if args or kwargs:
self._fill(*args, **kwargs)
Expand All @@ -79,15 +84,6 @@ def __nonzero__(self):

__bool__ = __nonzero__

@classmethod
def deserialize(cls, *args, **kwargs):
mapping = kwargs.pop("mapping", None)
if len(args) == 1 and isinstance(args[0], Aggs):
return args[0]

new = cls(mapping=mapping)
return new._fill(*args, **kwargs)

def _fill(self, *args, **kwargs):
if args:
node_hierarchy = self.node_class._type_deserializer(*args, **kwargs)
Expand All @@ -99,7 +95,10 @@ def _fill(self, *args, **kwargs):
return self

def _clone_init(self, deep=False):
return Aggs(mapping=self.mapping.clone(deep=deep))
return Aggs(
mapping=self.mapping.clone(deep=deep),
nested_autocorrect=self.nested_autocorrect,
)

def _is_eligible_grouping_node(self, nid):
"""Return whether node can be used as grouping node."""
Expand Down Expand Up @@ -259,19 +258,19 @@ def groupby(self, *args, **kwargs):
raise ValueError(
"Kwargs not allowed when passing multiple aggregations in args."
)
inserted_aggs = [self.deserialize(arg) for arg in args]
inserted_aggs = [Aggs(arg) for arg in args]
# groupby([{}, {}])
elif len(args) == 1 and isinstance(args[0], (list, tuple)):
if kwargs:
raise ValueError(
"Kwargs not allowed when passing multiple aggregations in args."
)
inserted_aggs = [self.deserialize(arg) for arg in args[0]]
inserted_aggs = [Aggs(arg) for arg in args[0]]
# groupby({})
# groupby(Terms())
# groupby('terms', name='per_tag', field='tag')
else:
inserted_aggs = [self.deserialize(*args, **kwargs)]
inserted_aggs = [Aggs(*args, **kwargs)]

if insert_above is not None:
parent = new_agg.parent(insert_above, id_only=False)
Expand Down Expand Up @@ -303,7 +302,7 @@ def groupby(self, *args, **kwargs):
return new_agg

def aggs(self, *args, **kwargs):
"""
r"""
Arrange passed aggregations "horizontally".
Given the initial aggregation::
Expand Down Expand Up @@ -348,7 +347,7 @@ def aggs(self, *args, **kwargs):
"""
insert_below = self._validate_aggs_parent_id(kwargs.pop("insert_below", None))
new_agg = self.clone(with_tree=True)
deserialized = self.deserialize(*args, mapping=self.mapping, **kwargs)
deserialized = Aggs(*args, **kwargs)
deserialized_root = deserialized.get(deserialized.root)
if isinstance(deserialized_root, ShadowRoot):
new_agg.merge(deserialized, nid=insert_below)
Expand Down Expand Up @@ -390,7 +389,7 @@ def applied_nested_path_at_node(self, nid):
def _insert_node_below(self, node, parent_id, with_children=True):
"""If mapping is provided, nested aggregations are automatically applied.
"""
if isinstance(node, ShadowRoot):
if isinstance(node, ShadowRoot) and parent_id is not None:
for child in node._children or []:
super(Aggs, self)._insert_node_below(
child, parent_id=parent_id, with_children=with_children
Expand All @@ -401,7 +400,6 @@ def _insert_node_below(self, node, parent_id, with_children=True):
isinstance(node, Nested)
or isinstance(node, ReverseNested)
or not self.mapping
or parent_id is None
or not hasattr(node, "field")
):
return super(Aggs, self)._insert_node_below(
Expand All @@ -412,11 +410,20 @@ def _insert_node_below(self, node, parent_id, with_children=True):

# from deepest to highest
required_nested_level = self.mapping.nested_at_field(node.field)
current_nested_level = self.applied_nested_path_at_node(parent_id)

if self.is_empty():
current_nested_level = None
else:
current_nested_level = self.applied_nested_path_at_node(parent_id)
if current_nested_level == required_nested_level:
return super(Aggs, self)._insert_node_below(
node, parent_id, with_children=with_children
)
if not self.nested_autocorrect:
raise ValueError(
"Invalid %s agg on %s field. Invalid nested: expected %s, current %s."
% (node.KEY, node.field, required_nested_level, current_nested_level)
)
if current_nested_level and (
required_nested_level or "" in current_nested_level
):
Expand Down
94 changes: 68 additions & 26 deletions pandagg/tree/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,29 @@

@python_2_unicode_compatible
class Query(Tree):
"""Tree combination of query nodes.
r"""Combination of query clauses.
Mapping declaration is optional, but doing so validates query validity and automatically inserts nested clauses
when necessary.
:Keyword Arguments:
* *mapping* (``dict`` or ``pandagg.tree.mapping.Mapping``) --
Mapping of requested indice(s). Providing it will add validation features, and add required nested
clauses if missing.
* *nested_autocorrect* (``bool``) --
In case of missing nested clauses in query, if True, automatically add missing nested clauses, else raise
error.
* remaining kwargs:
Used as body in query clauses.
"""

node_class = QueryClause

def __init__(self, *args, **kwargs):
self.mapping = Mapping(kwargs.pop("mapping", None))
self.nested_autocorrect = kwargs.pop("nested_autocorrect", False)
super(Query, self).__init__()
if args or kwargs:
self._fill(*args, **kwargs)
Expand All @@ -61,7 +74,10 @@ def __nonzero__(self):
__bool__ = __nonzero__

def _clone_init(self, deep=False):
return Query(mapping=self.mapping.clone(with_tree=True, deep=deep))
return Query(
mapping=self.mapping.clone(with_tree=True, deep=deep),
nested_autocorrect=self.nested_autocorrect,
)

@classmethod
def deserialize(cls, *args, **kwargs):
Expand All @@ -86,34 +102,60 @@ def _fill(self, *args, **kwargs):

def _insert_node_below(self, node, parent_id=None, with_children=True):
"""Override lighttree.Tree._insert_node_below method to ensure inserted query clause is consistent."""
if parent_id is None:
if parent_id is not None:
pnode = self.get(parent_id)
if isinstance(pnode, LeafQueryClause):
raise ValueError(
"Cannot add clause under leaf query clause <%s>" % pnode.KEY
)
if isinstance(pnode, ParentParameterClause):
if isinstance(node, ParameterClause):
raise ValueError(
"Cannot add parameter clause <%s> under another paramter clause <%s>"
% (pnode.KEY, node.KEY)
)
if isinstance(pnode, CompoundClause):
if (
not isinstance(node, ParameterClause)
or node.KEY not in pnode.PARAMS_WHITELIST
):
raise ValueError(
"Expect a parameter clause of type %s under <%s> compound clause, got <%s>"
% (pnode.PARAMS_WHITELIST, pnode.KEY, node.KEY)
)

# automatic handling of nested clauses
if isinstance(node, Nested) or not self.mapping or not hasattr(node, "field"):
return super(Query, self)._insert_node_below(
node, parent_id=parent_id, with_children=with_children
node=node, parent_id=parent_id, with_children=with_children
)

pnode = self.get(parent_id)
if isinstance(pnode, LeafQueryClause):
required_nested_level = self.mapping.nested_at_field(node.field)
if self.is_empty():
current_nested_level = None
else:
current_nested_level = self.applied_nested_path_at_node(parent_id)
if not self.nested_autocorrect:
raise ValueError(
"Cannot add clause under leaf query clause <%s>" % pnode.KEY
"Invalid %s query clause on %s field. Invalid nested: expected %s, current %s."
% (node.KEY, node.field, required_nested_level, current_nested_level)
)
if isinstance(pnode, ParentParameterClause):
if isinstance(node, ParameterClause):
raise ValueError(
"Cannot add parameter clause <%s> under another paramter clause <%s>"
% (pnode.KEY, node.KEY)
)
if isinstance(pnode, CompoundClause):
if (
not isinstance(node, ParameterClause)
or node.KEY not in pnode.PARAMS_WHITELIST
):
raise ValueError(
"Expect a parameter clause of type %s under <%s> compound clause, got <%s>"
% (pnode.PARAMS_WHITELIST, pnode.KEY, node.KEY)
)
super(Query, self)._insert_node_below(
node=node, parent_id=parent_id, with_children=with_children
)
if current_nested_level == required_nested_level:
return super(Query, self)._insert_node_below(
node=node, parent_id=parent_id, with_children=with_children
)
# requires nested - apply all required nested fields
for nested_lvl in self.mapping.list_nesteds_at_field(node.field):
if current_nested_level != nested_lvl:
node = Nested(path=nested_lvl, query=node)
super(Query, self)._insert_node_below(node, parent_id, with_children=True)

def applied_nested_path_at_node(self, nid):
# from current node to root
for id_ in [nid] + self.ancestors(nid):
node = self.get(id_)
if isinstance(node, Nested):
return node.path
return None

def to_dict(self, from_=None, with_name=True):
"""Return None if no query clause.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

install_requires = [
"future",
"lighttree==0.0.2",
"lighttree==0.0.3",
"elasticsearch>=7.1.0,<8.0.0",
]

Expand Down
Loading

0 comments on commit 2373536

Please sign in to comment.