Skip to content

Commit

Permalink
mapping proper deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardbinet committed Mar 5, 2020
1 parent 83819c1 commit 7bc8aaf
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 41 deletions.
2 changes: 1 addition & 1 deletion pandagg/node/agg/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def deserialize_agg(d):
agg_class.__name__, children_aggs))
if children_aggs:
if isinstance(children_aggs, dict):
children_aggs = [{k: v for k, v in iteritems(children_aggs)}]
children_aggs = [{k: v} for k, v in iteritems(children_aggs)]
elif isinstance(children_aggs, AggNode):
children_aggs = (children_aggs,)
agg_body['aggs'] = children_aggs
Expand Down
47 changes: 36 additions & 11 deletions pandagg/node/mapping/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import copy
import json
from six import python_2_unicode_compatible
from six import python_2_unicode_compatible, iteritems
from builtins import str as text
from pandagg.node._node import Node
from pandagg.utils import PrettyNode
Expand All @@ -14,18 +14,44 @@ class Field(Node):
KEY = NotImplementedError()
DISPLAY_PATTERN = ' %s'

def __init__(self, name, depth, is_subfield=False, **body):
def __init__(self, name, depth=0, is_subfield=False, **body):
# name will be used for dynamic attribute access in tree
self.name = name
# TODO - remove knowledge of depth here -> PR in treelib to update `show` method
self.depth = depth
self.is_subfield = is_subfield

self.fields = body.pop('fields', None)
self.properties = body.pop('properties', None)
# fields and properties can be a Field instance, a sequence of Field instances, or a dict
self.fields = self._atomize(body.pop('fields', None))
self.properties = self._atomize(body.pop('properties', None))
# rest of body
self._body = body
super(Field, self).__init__(data=PrettyNode(pretty=self.tree_repr))

def reset_data(self):
# hack until treelib show issue is fixed
self.data = PrettyNode(pretty=self.tree_repr)

@staticmethod
def _atomize(children):
if children is None:
return []
if isinstance(children, dict):
return [{k: v} for k, v in iteritems(children)]
if isinstance(children, Field):
return [children]
return children

@staticmethod
def _serialize_atomized(children):
d = {}
for child in children:
if isinstance(child, dict):
d.update(child)
if isinstance(child, Field):
d[child.name] = child.body(with_children=True)
return d

@property
def _identifier_prefix(self):
return self.name
Expand All @@ -36,13 +62,12 @@ def deserialize(cls, name, body, depth=0, is_subfield=False):
raise ValueError('Deserialization error for field <%s>: <%s>' % (cls.KEY, body))
return cls(name=name, depth=depth, is_subfield=is_subfield, **body)

@property
def body(self):
def body(self, with_children=False):
b = copy.deepcopy(self._body)
if self.properties:
b['properties'] = self.properties
if self.fields:
b['fields'] = self.fields
if with_children and self.properties:
b['properties'] = self._serialize_atomized(self.properties)
if with_children and self.fields:
b['fields'] = self._serialize_atomized(self.fields)
if self.KEY in ('object', ''):
return b
b['type'] = self.KEY
Expand All @@ -59,5 +84,5 @@ def __str__(self):
return '<Mapping Field %s> of type %s:\n%s' % (
text(self.name),
text(self.KEY),
text(json.dumps(self.body, indent=4))
text(json.dumps(self.body(with_children=True), indent=4))
)
4 changes: 1 addition & 3 deletions pandagg/tree/agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def _insert_from_node(self, agg_node, pid=None):
if isinstance(agg_node, BucketAggNode):
for child_agg_node in agg_node.aggs or []:
self._insert(child_agg_node, pid=agg_node.identifier)
# reset children to None to avoid confusion since this serves only __init__ syntax.
agg_node.aggs = None

def _insert(self, from_, pid=None):
inserted_tree = self.deserialize(from_=from_)
Expand Down Expand Up @@ -324,7 +322,7 @@ def paste(self, nid, new_tree, deep=False):
return super(Agg, self).paste(nid, new_tree, deep)
# validates that mappings are similar
if new_tree.tree_mapping is not None:
if new_tree.tree_mapping.body != self.tree_mapping.body:
if new_tree.tree_mapping.serialize() != self.tree_mapping.serialize():
raise MappingError('Pasted tree has a different mapping.')

# check root node nested position in mapping
Expand Down
79 changes: 66 additions & 13 deletions pandagg/tree/mapping.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,86 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy

from six import iteritems
from treelib.exceptions import NodeIDAbsentError

from pandagg.node.mapping.abstract import Field
from pandagg.node.mapping.deserializer import deserialize_field
from pandagg.exceptions import AbsentMappingFieldError, InvalidOperationMappingFieldError
from pandagg.node.mapping.field_datatypes import Object
from pandagg.tree._tree import Tree


class Mapping(Tree):

node_class = Field

def __init__(self, body=None, identifier=None):
def __init__(self, from_=None, identifier=None, properties=None, dynamic=False):
if from_ is not None and properties is not None:
raise ValueError('Can provide at most one of "from_" and "properties"')
if properties is not None:
from_ = Object(name='', properties=properties, dynamic=dynamic)
super(Mapping, self).__init__(identifier=identifier)
self.body = body
if body:
self.deserialize(name='', body=body)

def deserialize(self, name, body, pid=None, depth=0, is_subfield=False):
if from_ is not None:
self._insert(from_, depth=0)

@classmethod
def deserialize(cls, from_, depth=0):
if isinstance(from_, Mapping):
return from_
if isinstance(from_, Field):
new = Mapping()
new._insert_from_node(field=from_, depth=depth, is_subfield=False)
return new
if isinstance(from_, dict):
from_ = copy.deepcopy(from_)
new = Mapping()
new._insert_from_dict(name='', body=from_, is_subfield=False, depth=depth)
return new
else:
raise ValueError('Unsupported type <%s>.' % type(from_))

def serialize(self):
if self.root is None:
return None
return self[self.root].body(with_children=True)

def _insert_from_dict(self, name, body, is_subfield, depth, pid=None):
node = deserialize_field(name=name, depth=depth, is_subfield=is_subfield, body=body)
self.add_node(node, parent=pid)
depth += 1
for sub_name, sub_body in iteritems(node.properties or {}):
self.deserialize(name=sub_name, body=sub_body, pid=node.identifier, depth=depth)
for sub_name, sub_body in iteritems(node.fields or {}):
self.deserialize(name=sub_name, body=sub_body, pid=node.identifier, depth=depth, is_subfield=True)
self._insert_from_node(node, depth=depth, pid=pid, is_subfield=is_subfield)

def _insert_from_node(self, field, depth, is_subfield, pid=None):
# overriden to allow smooth DSL declaration
field.depth = depth
field.is_subfield = is_subfield
field.reset_data()

self.add_node(field, pid)
for subfield in field.fields or []:
if isinstance(subfield, dict):
name, body = next(iteritems(subfield))
self._insert_from_dict(name=name, body=body, pid=field.identifier, is_subfield=True, depth=depth + 1)
elif isinstance(subfield, Field):
self._insert_from_node(subfield, pid=field.identifier, depth=depth + 1, is_subfield=True)
else:
raise ValueError('Wrong type %s' % type(field))
for subfield in field.properties or []:
if isinstance(subfield, dict):
name, body = next(iteritems(subfield))
self._insert_from_dict(name=name, body=body, pid=field.identifier, is_subfield=False, depth=depth + 1)
elif isinstance(subfield, Field):
self._insert_from_node(subfield, pid=field.identifier, depth=depth + 1, is_subfield=False)
else:
raise ValueError('Wrong type %s' % type(field))

def _insert(self, from_, depth, pid=None):
inserted_tree = self.deserialize(from_=from_, depth=depth)
if self.root is None:
self.merge(nid=pid, new_tree=inserted_tree)
return self
self.paste(nid=pid, new_tree=inserted_tree)
return self

def __getitem__(self, key):
"""Tries to fetch node by identifier, else by succession of names."""
Expand Down Expand Up @@ -65,7 +118,7 @@ def contains(self, nid):
def _clone(self, identifier, with_tree=False, deep=False):
return Mapping(
identifier=identifier,
body=self.body if with_tree else None
from_=self if with_tree else None
)

def show(self, data_property='pretty', **kwargs):
Expand Down
2 changes: 0 additions & 2 deletions pandagg/tree/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ def _insert_from_node(self, query_node, pid=None):
if hasattr(query_node, 'children'):
for child_node in query_node.children or []:
self._insert(child_node, pid=query_node.identifier)
# reset children to None to avoid confusion since this serves only __init__ syntax.
query_node.children = None

def add_node(self, node, pid=None):
if pid is None:
Expand Down
8 changes: 4 additions & 4 deletions tests/base/tree/test_aggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,8 @@ def test_groupby_method(self):

def test_mapping_from_init(self):
agg_from_dict_mapping = Agg(mapping=MAPPING)
agg_from_tree_mapping = Agg(mapping=Mapping(body=MAPPING))
agg_from_obj_mapping = Agg(mapping=IMapping(tree=Mapping(body=MAPPING)))
agg_from_tree_mapping = Agg(mapping=Mapping(from_=MAPPING))
agg_from_obj_mapping = Agg(mapping=IMapping(tree=Mapping(from_=MAPPING)))
self.assertEqual(
agg_from_dict_mapping.tree_mapping.__repr__(),
agg_from_tree_mapping.tree_mapping.__repr__()
Expand All @@ -635,9 +635,9 @@ def test_set_mapping(self):
agg_from_dict_mapping = Agg() \
.set_mapping(mapping=MAPPING)
agg_from_tree_mapping = Agg() \
.set_mapping(mapping=Mapping(body=MAPPING))
.set_mapping(mapping=Mapping(from_=MAPPING))
agg_from_obj_mapping = Agg() \
.set_mapping(mapping=IMapping(tree=Mapping(body=MAPPING), client=None))
.set_mapping(mapping=IMapping(tree=Mapping(from_=MAPPING), client=None))
self.assertEqual(
agg_from_dict_mapping.tree_mapping.__repr__(),
agg_from_tree_mapping.tree_mapping.__repr__()
Expand Down
73 changes: 66 additions & 7 deletions tests/base/tree/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pandagg.exceptions import AbsentMappingFieldError
from pandagg.interactive._field_agg_factory import field_classes_per_name
from pandagg.node.mapping.abstract import Field
from pandagg.node.mapping.field_datatypes import Keyword
from pandagg.node.mapping.field_datatypes import Keyword, Object, Text, Nested, Integer
from pandagg.tree.mapping import Mapping
from pandagg.interactive.mapping import IMapping
from tests.base.mapping_example import MAPPING, EXPECTED_MAPPING_TREE_REPR
Expand Down Expand Up @@ -39,13 +39,72 @@ def test_node_repr(self):
}"""
)

def test_deserialization(self):
mapping_dict = {
"dynamic": False,
"properties": {
"classification_type": {
"type": "keyword",
"fields": {
"raw": {
"type": "text"
}
}
},
"local_metrics": {
"type": "nested",
"dynamic": False,
"properties": {
"dataset": {
"dynamic": False,
"properties": {
"support_test": {
"type": "integer"
},
"support_train": {
"type": "integer"
}
}
}
}
}
}
}

m1 = Mapping(mapping_dict)

m2 = Mapping(dynamic=False, properties={
Keyword('classification_type', fields=[
Text('raw')
]),
Nested('local_metrics', dynamic=False, properties=[
Object('dataset', dynamic=False, properties=[
Integer('support_test'),
Integer('support_train')
])
])
})

expected_repr = """<Mapping>
{Object}
├── classification_type Keyword
│ └── raw ~ Text
└── local_metrics [Nested]
└── dataset {Object}
├── support_test Integer
└── support_train Integer
"""
for i, m in enumerate((m1, m2,)):
self.assertEqual(m.__repr__(), expected_repr, "failed at m%d" % (i + 1))
self.assertEqual(m.serialize(), mapping_dict, "failed at m%d" % (i + 1))

def test_parse_tree_from_dict(self):
mapping_tree = Mapping(body=MAPPING)
mapping_tree = Mapping(from_=MAPPING)

self.assertEqual(mapping_tree.__str__(), EXPECTED_MAPPING_TREE_REPR)

def test_nesteds_applied_at_field(self):
mapping_tree = Mapping(body=MAPPING)
mapping_tree = Mapping(from_=MAPPING)

self.assertEqual(mapping_tree.nested_at_field('classification_type'), None)
self.assertEqual(mapping_tree.list_nesteds_at_field('classification_type'), [])
Expand All @@ -60,7 +119,7 @@ def test_nesteds_applied_at_field(self):
self.assertEqual(mapping_tree.list_nesteds_at_field('local_metrics.dataset.support_test'), ['local_metrics'])

def test_mapping_type_of_field(self):
mapping_tree = Mapping(body=MAPPING)
mapping_tree = Mapping(from_=MAPPING)
with self.assertRaises(AbsentMappingFieldError):
self.assertEqual(mapping_tree.mapping_type_of_field('yolo'), False)

Expand All @@ -70,15 +129,15 @@ def test_mapping_type_of_field(self):
self.assertEqual(mapping_tree.mapping_type_of_field('local_metrics.dataset.support_test'), 'integer')

def test_node_path(self):
mapping_tree = Mapping(body=MAPPING)
mapping_tree = Mapping(from_=MAPPING)
# get node by path syntax
node = mapping_tree['local_metrics.dataset.support_test']
self.assertIsInstance(node, Field)
self.assertEqual(node.name, 'support_test')
self.assertEqual(mapping_tree.node_path(node.identifier), 'local_metrics.dataset.support_test')

def test_mapping_aggregations(self):
mapping_tree = Mapping(body=MAPPING)
mapping_tree = Mapping(from_=MAPPING)
# check that leaves are expanded, based on 'field_name' attribute of nodes
mapping = IMapping(tree=mapping_tree, depth=1)
for field_name in ('classification_type', 'date', 'global_metrics', 'id', 'language', 'local_metrics', 'workflow'):
Expand Down Expand Up @@ -126,7 +185,7 @@ def test_client_bound(self):
}
client_mock.search = Mock(return_value=es_response_mock)

mapping_tree = Mapping(body=MAPPING)
mapping_tree = Mapping(from_=MAPPING)
client_bound_mapping = IMapping(
client=client_mock,
tree=mapping_tree,
Expand Down

0 comments on commit 7bc8aaf

Please sign in to comment.