From 2001dcc8675d19fce8b15f02f63aa47944eec3d6 Mon Sep 17 00:00:00 2001 From: Igor Babuschkin Date: Fri, 19 Aug 2016 13:35:41 +0100 Subject: [PATCH] Make sure that default indices increase continuously when reading in chunks This fixes #24 --- README.md | 3 +++ root_pandas/readwrite.py | 30 +++++++++++++++++++----------- tests/test.py | 17 +++++++++++++++++ 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 5de421d..6a47747 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,9 @@ If the `chunksize` parameter is specified, `read_root` returns an iterator that for df in read_root('bigfile.root', chunksize=100000): # process df here ``` +If `bigfile.root` doesn't contain an index, the default indices of the +individual `DataFrame` chunks will still increase continuously, as if they were +parts of a single large `DataFrame`. You can also combine any of the above options at the same time. diff --git a/root_pandas/readwrite.py b/root_pandas/readwrite.py index 79f919c..1f44c29 100644 --- a/root_pandas/readwrite.py +++ b/root_pandas/readwrite.py @@ -5,7 +5,7 @@ import numpy as np from numpy.lib.recfunctions import append_fields -from pandas import DataFrame +from pandas import DataFrame, RangeIndex from root_numpy import root2array, list_trees from fnmatch import fnmatch from root_numpy import list_branches @@ -199,11 +199,13 @@ def do_flatten(arr, flatten): # XXX could explicitly clean up the opened TFiles with TChain::Reset def genchunks(): + current_index = 0 for chunk in range(int(ceil(float(n_entries) / chunksize))): arr = root2array(paths, key, all_vars, start=chunk * chunksize, stop=(chunk+1) * chunksize, selection=where, *args, **kwargs) if flatten: arr = do_flatten(arr, flatten) - yield convert_to_dataframe(arr) + yield convert_to_dataframe(arr, start_index=current_index) + current_index += len(arr) return genchunks() arr = root2array(paths, key, all_vars, selection=where, *args, **kwargs) @@ -212,15 +214,17 @@ def genchunks(): return convert_to_dataframe(arr) - -def convert_to_dataframe(array): +def convert_to_dataframe(array, start_index=None): nonscalar_columns = get_nonscalar_columns(array) if nonscalar_columns: warnings.warn("Ignored the following non-scalar branches: {bad_names}" .format(bad_names=", ".join(nonscalar_columns)), UserWarning) indices = list(filter(lambda x: x.startswith('__index__') and x not in nonscalar_columns, array.dtype.names)) if len(indices) == 0: - df = DataFrame.from_records(array, exclude=nonscalar_columns) + index = None + if start_index is not None: + index = RangeIndex(start=start_index, stop=start_index + len(array)) + df = DataFrame.from_records(array, exclude=nonscalar_columns, index=index) elif len(indices) == 1: # We store the index under the __index__* branch, where # * is the name of the index @@ -235,7 +239,7 @@ def convert_to_dataframe(array): return df -def to_root(df, path, key='default', mode='w', *args, **kwargs): +def to_root(df, path, key='default', mode='w', store_index=True, *args, **kwargs): """ Write DataFrame to a ROOT file. @@ -247,6 +251,9 @@ def to_root(df, path, key='default', mode='w', *args, **kwargs): Name of tree that the DataFrame will be saved as mode: string, {'w', 'a'} Mode that the file should be opened in (default: 'w') + store_index: bool (optional, default: True) + Whether the index of the DataFrame should be stored as + an __index__* branch in the tree Notes ----- @@ -270,11 +277,12 @@ def to_root(df, path, key='default', mode='w', *args, **kwargs): from root_numpy import array2root # We don't want to modify the user's DataFrame here, so we make a shallow copy df_ = df.copy(deep=False) - name = df_.index.name - if name is None: - # Handle the case where the index has no name - name = '' - df_['__index__' + name] = df_.index + if store_index: + name = df_.index.name + if name is None: + # Handle the case where the index has no name + name = '' + df_['__index__' + name] = df_.index arr = df_.to_records(index=False) array2root(arr, path, key, mode=mode, *args, **kwargs) diff --git a/tests/test.py b/tests/test.py index cb13f58..1501be9 100644 --- a/tests/test.py +++ b/tests/test.py @@ -82,6 +82,23 @@ def test_chunked_reading(): assert count == 3 os.remove('tmp.root') +# Make sure that the default index counts up properly, +# even if the input is chunked +def test_chunked_reading_consistent_index(): + df = pd.DataFrame({'x': [1,2,3,4,5,6]}) + df.to_root('tmp.root', store_index=False) + + dfs = [] + for df_ in read_root('tmp.root', chunksize=2): + dfs.append(df_) + assert(not df_.empty) + df_reconstructed = pd.concat(dfs) + + assert_frame_equal(df, df_reconstructed) + + os.remove('tmp.root') + + def test_multiple_files(): df = pd.DataFrame({'x': [1,2,3,4,5,6]}) df.to_root('tmp1.root')