Skip to content
This repository has been archived by the owner on Jan 9, 2023. It is now read-only.

Commit

Permalink
Make sure that default indices increase continuously when reading in …
Browse files Browse the repository at this point in the history
…chunks

This fixes #24
  • Loading branch information
ibab committed Aug 19, 2016
1 parent 38659d1 commit 2001dcc
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ If the `chunksize` parameter is specified, `read_root` returns an iterator that
for df in read_root('bigfile.root', chunksize=100000):
# process df here
```
If `bigfile.root` doesn't contain an index, the default indices of the
individual `DataFrame` chunks will still increase continuously, as if they were
parts of a single large `DataFrame`.

You can also combine any of the above options at the same time.

Expand Down
30 changes: 19 additions & 11 deletions root_pandas/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
from numpy.lib.recfunctions import append_fields
from pandas import DataFrame
from pandas import DataFrame, RangeIndex
from root_numpy import root2array, list_trees
from fnmatch import fnmatch
from root_numpy import list_branches
Expand Down Expand Up @@ -199,11 +199,13 @@ def do_flatten(arr, flatten):
# XXX could explicitly clean up the opened TFiles with TChain::Reset

def genchunks():
current_index = 0
for chunk in range(int(ceil(float(n_entries) / chunksize))):
arr = root2array(paths, key, all_vars, start=chunk * chunksize, stop=(chunk+1) * chunksize, selection=where, *args, **kwargs)
if flatten:
arr = do_flatten(arr, flatten)
yield convert_to_dataframe(arr)
yield convert_to_dataframe(arr, start_index=current_index)
current_index += len(arr)
return genchunks()

arr = root2array(paths, key, all_vars, selection=where, *args, **kwargs)
Expand All @@ -212,15 +214,17 @@ def genchunks():
return convert_to_dataframe(arr)



def convert_to_dataframe(array):
def convert_to_dataframe(array, start_index=None):
nonscalar_columns = get_nonscalar_columns(array)
if nonscalar_columns:
warnings.warn("Ignored the following non-scalar branches: {bad_names}"
.format(bad_names=", ".join(nonscalar_columns)), UserWarning)
indices = list(filter(lambda x: x.startswith('__index__') and x not in nonscalar_columns, array.dtype.names))
if len(indices) == 0:
df = DataFrame.from_records(array, exclude=nonscalar_columns)
index = None
if start_index is not None:
index = RangeIndex(start=start_index, stop=start_index + len(array))
df = DataFrame.from_records(array, exclude=nonscalar_columns, index=index)
elif len(indices) == 1:
# We store the index under the __index__* branch, where
# * is the name of the index
Expand All @@ -235,7 +239,7 @@ def convert_to_dataframe(array):
return df


def to_root(df, path, key='default', mode='w', *args, **kwargs):
def to_root(df, path, key='default', mode='w', store_index=True, *args, **kwargs):
"""
Write DataFrame to a ROOT file.
Expand All @@ -247,6 +251,9 @@ def to_root(df, path, key='default', mode='w', *args, **kwargs):
Name of tree that the DataFrame will be saved as
mode: string, {'w', 'a'}
Mode that the file should be opened in (default: 'w')
store_index: bool (optional, default: True)
Whether the index of the DataFrame should be stored as
an __index__* branch in the tree
Notes
-----
Expand All @@ -270,11 +277,12 @@ def to_root(df, path, key='default', mode='w', *args, **kwargs):
from root_numpy import array2root
# We don't want to modify the user's DataFrame here, so we make a shallow copy
df_ = df.copy(deep=False)
name = df_.index.name
if name is None:
# Handle the case where the index has no name
name = ''
df_['__index__' + name] = df_.index
if store_index:
name = df_.index.name
if name is None:
# Handle the case where the index has no name
name = ''
df_['__index__' + name] = df_.index
arr = df_.to_records(index=False)
array2root(arr, path, key, mode=mode, *args, **kwargs)

Expand Down
17 changes: 17 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ def test_chunked_reading():
assert count == 3
os.remove('tmp.root')

# Make sure that the default index counts up properly,
# even if the input is chunked
def test_chunked_reading_consistent_index():
df = pd.DataFrame({'x': [1,2,3,4,5,6]})
df.to_root('tmp.root', store_index=False)

dfs = []
for df_ in read_root('tmp.root', chunksize=2):
dfs.append(df_)
assert(not df_.empty)
df_reconstructed = pd.concat(dfs)

assert_frame_equal(df, df_reconstructed)

os.remove('tmp.root')


def test_multiple_files():
df = pd.DataFrame({'x': [1,2,3,4,5,6]})
df.to_root('tmp1.root')
Expand Down

1 comment on commit 2001dcc

@alexpearce
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉 thanks a bunch!

Please sign in to comment.