Skip to content
This repository has been archived by the owner on Jan 9, 2023. It is now read-only.

Commit

Permalink
Merge pull request #62 from chrisburr/fix-59
Browse files Browse the repository at this point in the history
Improve performance of get_matching_variables
  • Loading branch information
chrisburr authored Mar 5, 2018
2 parents 7eff0fb + fec37dd commit a4d3708
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
25 changes: 16 additions & 9 deletions root_pandas/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from numpy.lib.recfunctions import append_fields
from pandas import DataFrame, RangeIndex
from root_numpy import root2array, list_trees
from fnmatch import fnmatch
import fnmatch
from root_numpy import list_branches
from root_numpy.extern.six import string_types
import itertools
Expand Down Expand Up @@ -59,17 +59,24 @@ def get_nonscalar_columns(array):


def get_matching_variables(branches, patterns, fail=True):
selected = []

for p in patterns:
# Convert branches to a set to make x "in branches" O(1) on average
branches = set(branches)
patterns = set(patterns)
# Find any trivial matches
selected = list(branches.intersection(patterns))
# Any matches that weren't trivial need to be looped over...
for pattern in patterns.difference(selected):
found = False
for b in branches:
if fnmatch(b, p):
# Avoid using fnmatch if the pattern if possible
if re.findall(r'(\*)|(\?)|(\[.*\])|(\[\!.*\])', pattern):
for match in fnmatch.filter(branches, pattern):
found = True
if fnmatch(b, p) and b not in selected:
selected.append(b)
if match not in selected:
selected.append(match)
elif pattern in branches:
raise NotImplementedError('I think this is impossible?')
if not found and fail:
raise ValueError("Pattern '{}' didn't match any branch".format(p))
raise ValueError("Pattern '{}' didn't match any branch".format(pattern))
return selected


Expand Down
12 changes: 12 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,18 @@ def test_nonscalar_columns():
os.remove(path)


def test_get_matching_variables_performance():
"""Performance regression test for #59"""
import random
import string
import root_pandas.readwrite
for n in [10, 100, 1000, 10000]:
branches = [' '.join(random.sample(string.ascii_letters*100, k=100)) for i in range(n)]
patterns = [' '.join(random.sample(string.ascii_letters*100, k=100)) for i in range(n)]
root_pandas.readwrite.get_matching_variables(branches, patterns, fail=False)
root_pandas.readwrite.get_matching_variables(branches, branches, fail=False)


def test_noexpand_prefix():
xs = np.array([1, 2, 3])
df = pd.DataFrame({'x': xs})
Expand Down

0 comments on commit a4d3708

Please sign in to comment.