diff --git a/root_pandas/readwrite.py b/root_pandas/readwrite.py index dc7873f..ddc6dc0 100644 --- a/root_pandas/readwrite.py +++ b/root_pandas/readwrite.py @@ -27,27 +27,57 @@ NOEXPAND_PREFIX = 'noexpand:' -def expand_braces(orig): - r = r'.*?(\{.+[^\\]\})' - p = re.compile(r) - - s = orig[:] - res = list() - - m = p.search(s) - if m is not None: - sub = m.group(1) - open_brace = s.find(sub) - close_brace = open_brace + len(sub) - 1 - if sub.find(',') != -1: - for pat in sub[1:-1].split(','): - res.extend(expand_braces(s[:open_brace] + pat + s[close_brace+1:])) - else: - res.extend(expand_braces(s[:open_brace] + sub.replace('}', '\\}') + s[close_brace+1:])) - else: - res.append(s.replace('\\}', '}')) +def _getitem(string, depth=0): + """ + Get an item from the string (where item is up to the next ',' or '}' or the + end of the string) + """ + out = [""] + while string: + char = string[0] + if depth and (char == ',' or char == '}'): + return out, string + if char == '{': + groups_string = _getgroup(string[1:], depth+1) + if groups_string is not None: + groups, string = groups_string + out = [a + g for a in out for g in groups] + continue + if char == '\\' and len(string) > 1: + string, char = string[1:], char + string[1] + + out, string = [a + char for a in out], string[1:] + + return out, string + + +def _getgroup(string, depth): + """ + Get a group from the string, where group is a list of all the comma + separated substrings up to the next '}' char or the brace enclosed substring + if there is no comma + """ + out, comma = [], False + while string: + items, string = _getitem(string, depth) - return list(set(res)) + if not string: + break + out += items + + if string[0] == '}': + if comma: + return out, string[1:] + return ['{' + a + '}' for a in out], string[1:] + + if string[0] == ',': + comma, string = True, string[1:] + + return None + + +def expand_braces(orig): + return _getitem(orig, 0)[0] def get_nonscalar_columns(array): diff --git a/tests/test.py b/tests/test.py index 792f672..56b2610 100644 --- a/tests/test.py +++ b/tests/test.py @@ -304,10 +304,10 @@ def test_brace_pattern_in_columns(): assert_frame_equal(df[['var{03}', 'var2', 'var{04}']], reference_df[['var{03}', 'var2', 'var{04}']]) - # # TODO Recursive expansions - # df = read_root('tmp.root', columns=[r'var{0{2,3},1{1,3}}']) - # assert set(df.columns) == {'var02', 'var03', 'var11', 'var13'} - # assert_frame_equal(df[['var02', 'var03', 'var11', 'var13']], - # reference_df[['var02', 'var03', 'var11', 'var13']]) + # Recursive expansions + df = read_root('tmp.root', columns=[r'var{0{2,3},1{1,3}}']) + assert set(df.columns) == {'var02', 'var03', 'var11', 'var13'} + assert_frame_equal(df[['var02', 'var03', 'var11', 'var13']], + reference_df[['var02', 'var03', 'var11', 'var13']]) os.remove('tmp.root')