Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retrieve elements in xml tree with paths #41

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion onadata/apps/data_migration/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from functools import reduce
from functools import reduce, partial
from operator import add


mapc = lambda x: partial(map, x)

fst = lambda x: x[0]


def concat_map(f, iterable):
return reduce(add, map(f, iterable), [])

Expand Down
4 changes: 4 additions & 0 deletions onadata/apps/data_migration/surveytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def insert_field_into_group_chain(self, field, group_chain):

parent.append(field)

def update_field_contents(self, field_name, new_text):
field = self.get_field(field_name)
field.text = new_text

def sort(self, xformtree):
"""Sort XML tree fields according to the order provided by XFormTree"""
pattern = xformtree.get_el_by_path(xformtree.DATA_STRUCT_PATH)[0]
Expand Down
31 changes: 28 additions & 3 deletions onadata/apps/data_migration/tests/test_survey_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def test_get_field(self):
self.assertTrue(etree.iselement(self.survey.get_field('name')))
self.assertTrue(etree.iselement(self.survey.get_field('photo')))

def test_get_field(self):
self.assertTrue(etree.iselement(self.survey.get_field('name')))
self.assertTrue(etree.iselement(self.survey.get_field('photo')))

def test_get_field__no_such_field(self):
with self.assertRaises(MissingFieldException):
self.survey.get_field('i_am_sure_no_such_field_exist')
Expand Down Expand Up @@ -217,9 +221,30 @@ def test_find_group__raises_on_no_such_group(self):
with self.assertRaises(MissingFieldException):
self.survey.find_group('certainly_no_such_group_exist')

def test_get_all_elems(self):
self.assertCountEqual(fixtures.GROUPS_FIELDS_AFTER,
[e.tag for e in self.survey.get_all_elems()])
def test_retrieve_all_elems__tags_only(self):
self.assertCountEqual(fixtures.GROUPS_FIELDS_AFTER + ['AlgebraicTypes2'],
[e[0].tag for e in self.survey.retrieve_all_elems(self.survey.root)])

def test_retrieve_all_elems(self):
survey = SurveyTree('''
<AlgebraicTypes>
<functor>
<applicative>
<monad>Either</monad>
</applicative>
<foldable/>
</functor>
</AlgebraicTypes>
''')
result = [
('AlgebraicTypes', []),
('functor', ['AlgebraicTypes']),
('applicative', ['AlgebraicTypes', 'functor']),
('foldable', ['AlgebraicTypes', 'functor']),
('monad', ['AlgebraicTypes', 'functor', 'applicative']),
]
self.assertCountEqual(result,
[(e[0].tag, e[1]) for e in survey.retrieve_all_elems(survey.root)])

def test_insert_field_into_group_chain(self):
survey = SurveyTree('<AlgebraicTypes></AlgebraicTypes>')
Expand Down
58 changes: 33 additions & 25 deletions onadata/apps/data_migration/xmltree.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re
from operator import not_
from functools import partial
from itertools import ifilter
from lxml import etree

from .common import concat_map, compose
from .common import concat_map, compose, fst


class MissingFieldException(Exception):
Expand Down Expand Up @@ -38,15 +39,11 @@ def set_tag(self, field, value):

def get_fields(self):
"""Parse and return list of all fields in form."""
return self.retrieve_leaf_elems(self.root)
return map(fst, self.retrieve_leaf_elems(self.root))

def get_groups(self):
return concat_map(self.retrieve_groups, self.root.getchildren())

def get_all_elems(self):
"""Return a list of both groups and fields"""
return concat_map(self.retrieve_all_elems, self.root.getchildren())

def get_fields_names(self):
"""Return fields as list of string with field names."""
return map(lambda f: f.tag, self.get_fields())
Expand All @@ -57,20 +54,12 @@ def get_groups_names(self):

def _get_matching_elems(self, condition_func):
"""Return elems that match condition"""
return ifilter(condition_func, self.get_all_elems())
return ifilter(condition_func, self.retrieve_all_elems(self.root))

@staticmethod
def is_leaf(element):
return len(element.getchildren()) == 0

@classmethod
def retrieve_leaf_elems(cls, element):
if not cls.is_relevant(element.tag):
return []
if element.getchildren():
return concat_map(cls.retrieve_leaf_elems, element)
return [element]

@staticmethod
def _get_first_element(name):
def get_next_from_iterator(iterator):
Expand All @@ -81,12 +70,16 @@ def get_next_from_iterator(iterator):
"xml tree".format(name))
return get_next_from_iterator

def get_field(self, name):
"""Get field in tree by name."""
cond = lambda f: self.field_tag(f) == name
def get_field_with_path(self, name):
"""Get field along with path in tree by name."""
cond = lambda field: self.field_tag(field[0]) == name
matching_elems = self._get_matching_elems(cond)
return self._get_first_element(name)(matching_elems)

def get_field(self, name):
"""Get field in tree by name."""
return compose(fst, self.get_field_with_path)(name)

@classmethod
def get_child_field(cls, element, name):
"""Get child of element by name"""
Expand All @@ -102,21 +95,32 @@ def children_tags(cls, element):
def get_el_by_path(self, path):
return reduce(self.get_child_field, path, self.root)

@classmethod
def retrieve_elems(cls, path, base_cond, element):
if not cls.is_relevant(element.tag):
return []

new_elem = [(element, path)] if base_cond(element) else []
new_path = path+[cls.clean_tag(element.tag)]
induction_step = partial(cls.retrieve_elems, new_path, base_cond)
return new_elem + concat_map(induction_step, element)

@classmethod
def retrieve_leaf_elems(cls, element):
return cls.retrieve_elems([], cls.is_leaf, element)

@classmethod
def retrieve_leaf_elems_tags(cls, element):
return map(cls.field_tag, cls.retrieve_leaf_elems(element))
return map(compose(cls.field_tag, fst), cls.retrieve_leaf_elems(element))

@classmethod
def retrieve_groups(cls, element):
if not cls.is_relevant(element.tag) or not element.getchildren():
return []
return [element] + concat_map(cls.retrieve_groups, element)
cond = compose(not_, cls.is_leaf)
return map(fst, cls.retrieve_elems([], cond, element))

@classmethod
def retrieve_all_elems(cls, element):
if not cls.is_relevant(element.tag):
return []
return [element] + concat_map(cls.retrieve_all_elems, element)
return cls.retrieve_elems([], lambda _: True, element)

@classmethod
def is_relevant(cls, tag):
Expand Down Expand Up @@ -148,3 +152,7 @@ def field_tag(cls, field):
def are_elements_equal(cls, e1, e2):
to_str = lambda el: re.sub(r"\s+", "", etree.tostring(el))
return to_str(e1) == to_str(e2)

def _get_matching_leafs(self, condition_func):
"""Return elems that match condition"""
return ifilter(condition_func, self.retrieve_leaf_elems(self.root))