Skip to content

Commit

Permalink
Retrieve elements in xml tree with paths
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomatosoup97 committed Mar 25, 2020
1 parent 6e8abce commit c6a1ae2
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 29 deletions.
7 changes: 6 additions & 1 deletion onadata/apps/data_migration/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from functools import reduce
from functools import reduce, partial
from operator import add


mapc = lambda x: partial(map, x)

fst = lambda x: x[0]


def concat_map(f, iterable):
return reduce(add, map(f, iterable), [])

Expand Down
4 changes: 4 additions & 0 deletions onadata/apps/data_migration/surveytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def insert_field_into_group_chain(self, field, group_chain):

parent.append(field)

def update_field_contents(self, field_name, new_text):
field = self.get_field(field_name)
field.text = new_text

def sort(self, xformtree):
"""Sort XML tree fields according to the order provided by XFormTree"""
pattern = xformtree.get_el_by_path(xformtree.DATA_STRUCT_PATH)[0]
Expand Down
31 changes: 28 additions & 3 deletions onadata/apps/data_migration/tests/test_survey_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def test_get_field(self):
self.assertTrue(etree.iselement(self.survey.get_field('name')))
self.assertTrue(etree.iselement(self.survey.get_field('photo')))

def test_get_field(self):
self.assertTrue(etree.iselement(self.survey.get_field('name')))
self.assertTrue(etree.iselement(self.survey.get_field('photo')))

def test_get_field__no_such_field(self):
with self.assertRaises(MissingFieldException):
self.survey.get_field('i_am_sure_no_such_field_exist')
Expand Down Expand Up @@ -217,9 +221,30 @@ def test_find_group__raises_on_no_such_group(self):
with self.assertRaises(MissingFieldException):
self.survey.find_group('certainly_no_such_group_exist')

def test_get_all_elems(self):
self.assertCountEqual(fixtures.GROUPS_FIELDS_AFTER,
[e.tag for e in self.survey.get_all_elems()])
def test_retrieve_all_elems__tags_only(self):
self.assertCountEqual(fixtures.GROUPS_FIELDS_AFTER + ['AlgebraicTypes2'],
[e[0].tag for e in self.survey.retrieve_all_elems(self.survey.root)])

def test_retrieve_all_elems(self):
survey = SurveyTree('''
<AlgebraicTypes>
<functor>
<applicative>
<monad>Either</monad>
</applicative>
<foldable/>
</functor>
</AlgebraicTypes>
''')
result = [
('AlgebraicTypes', []),
('functor', ['AlgebraicTypes']),
('applicative', ['AlgebraicTypes', 'functor']),
('foldable', ['AlgebraicTypes', 'functor']),
('monad', ['AlgebraicTypes', 'functor', 'applicative']),
]
self.assertCountEqual(result,
[(e[0].tag, e[1]) for e in survey.retrieve_all_elems(survey.root)])

def test_insert_field_into_group_chain(self):
survey = SurveyTree('<AlgebraicTypes></AlgebraicTypes>')
Expand Down
58 changes: 33 additions & 25 deletions onadata/apps/data_migration/xmltree.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re
from operator import not_
from functools import partial
from itertools import ifilter
from lxml import etree

from .common import concat_map, compose
from .common import concat_map, compose, fst


class MissingFieldException(Exception):
Expand Down Expand Up @@ -38,15 +39,11 @@ def set_tag(self, field, value):

def get_fields(self):
"""Parse and return list of all fields in form."""
return self.retrieve_leaf_elems(self.root)
return map(fst, self.retrieve_leaf_elems(self.root))

def get_groups(self):
return concat_map(self.retrieve_groups, self.root.getchildren())

def get_all_elems(self):
"""Return a list of both groups and fields"""
return concat_map(self.retrieve_all_elems, self.root.getchildren())

def get_fields_names(self):
"""Return fields as list of string with field names."""
return map(lambda f: f.tag, self.get_fields())
Expand All @@ -57,20 +54,12 @@ def get_groups_names(self):

def _get_matching_elems(self, condition_func):
"""Return elems that match condition"""
return ifilter(condition_func, self.get_all_elems())
return ifilter(condition_func, self.retrieve_all_elems(self.root))

@staticmethod
def is_leaf(element):
return len(element.getchildren()) == 0

@classmethod
def retrieve_leaf_elems(cls, element):
if not cls.is_relevant(element.tag):
return []
if element.getchildren():
return concat_map(cls.retrieve_leaf_elems, element)
return [element]

@staticmethod
def _get_first_element(name):
def get_next_from_iterator(iterator):
Expand All @@ -81,12 +70,16 @@ def get_next_from_iterator(iterator):
"xml tree".format(name))
return get_next_from_iterator

def get_field(self, name):
"""Get field in tree by name."""
cond = lambda f: self.field_tag(f) == name
def get_field_with_path(self, name):
"""Get field along with path in tree by name."""
cond = lambda field: self.field_tag(field[0]) == name
matching_elems = self._get_matching_elems(cond)
return self._get_first_element(name)(matching_elems)

def get_field(self, name):
"""Get field in tree by name."""
return compose(fst, self.get_field_with_path)(name)

@classmethod
def get_child_field(cls, element, name):
"""Get child of element by name"""
Expand All @@ -102,21 +95,32 @@ def children_tags(cls, element):
def get_el_by_path(self, path):
return reduce(self.get_child_field, path, self.root)

@classmethod
def retrieve_elems(cls, path, base_cond, element):
if not cls.is_relevant(element.tag):
return []

new_elem = [(element, path)] if base_cond(element) else []
new_path = path+[cls.clean_tag(element.tag)]
induction_step = partial(cls.retrieve_elems, new_path, base_cond)
return new_elem + concat_map(induction_step, element)

@classmethod
def retrieve_leaf_elems(cls, element):
return cls.retrieve_elems([], cls.is_leaf, element)

@classmethod
def retrieve_leaf_elems_tags(cls, element):
return map(cls.field_tag, cls.retrieve_leaf_elems(element))
return map(compose(cls.field_tag, fst), cls.retrieve_leaf_elems(element))

@classmethod
def retrieve_groups(cls, element):
if not cls.is_relevant(element.tag) or not element.getchildren():
return []
return [element] + concat_map(cls.retrieve_groups, element)
cond = compose(not_, cls.is_leaf)
return map(fst, cls.retrieve_elems([], cond, element))

@classmethod
def retrieve_all_elems(cls, element):
if not cls.is_relevant(element.tag):
return []
return [element] + concat_map(cls.retrieve_all_elems, element)
return cls.retrieve_elems([], lambda _: True, element)

@classmethod
def is_relevant(cls, tag):
Expand Down Expand Up @@ -148,3 +152,7 @@ def field_tag(cls, field):
def are_elements_equal(cls, e1, e2):
to_str = lambda el: re.sub(r"\s+", "", etree.tostring(el))
return to_str(e1) == to_str(e2)

def _get_matching_leafs(self, condition_func):
"""Return elems that match condition"""
return ifilter(condition_func, self.retrieve_leaf_elems(self.root))

0 comments on commit c6a1ae2

Please sign in to comment.