Skip to content

Commit

Permalink
Parse multi-argument templates beyond just std::map. (#52)
Browse files Browse the repository at this point in the history
* Parse multi-argument templates beyond just std::map.

* Include new debugging tools.

* Add stub for #37 so that at least there's a good error message.

* Fix black & flake8.
  • Loading branch information
jpivarski authored Jul 17, 2020
1 parent 0ec248c commit cce6b79
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 8 deletions.
1 change: 1 addition & 0 deletions uproot4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from uproot4.interpretation.strings import AsStrings
from uproot4.interpretation.objects import AsObjects
from uproot4.interpretation.objects import AsStridedObjects
from uproot4.interpretation.grouped import AsGrouped
from uproot4.containers import AsString
from uproot4.containers import AsPointer
from uproot4.containers import AsArray
Expand Down
24 changes: 16 additions & 8 deletions uproot4/behaviors/TBranch.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,14 @@ def _regularize_expressions(
filter_name=filter_name,
filter_typename=filter_typename,
filter_branch=filter_branch,
full_paths=True,
full_paths=False,
):
if not isinstance(
branch.interpretation,
uproot4.interpretation.identify.UnknownInterpretation,
(
uproot4.interpretation.identify.UnknownInterpretation,
uproot4.interpretation.grouped.AsGrouped,
),
):
_regularize_branchname(
hasbranches,
Expand Down Expand Up @@ -1470,6 +1473,16 @@ def entries_to_ranges_or_baskets(self, entry_start, entry_stop):
start = stop
return out

def debug_array(self, entry, dtype=numpy.dtype("u1"), skip_bytes=0):
dtype = numpy.dtype(dtype)
interpretation = uproot4.interpretation.jagged.AsJagged(
uproot4.interpretation.numerical.AsDtype("u1")
)
out = self.array(
interpretation, entry_start=entry, entry_stop=entry + 1, library="np"
)[0][skip_bytes:]
return out[: (len(out) // dtype.itemsize) * dtype.itemsize].view(dtype)

def debug(
self,
entry,
Expand All @@ -1479,12 +1492,7 @@ def debug(
offset=0,
stream=sys.stdout,
):
interpretation = uproot4.interpretation.jagged.AsJagged(
uproot4.interpretation.numerical.AsDtype("u1")
)
data = self.array(
interpretation, entry_start=entry, entry_stop=entry + 1, library="np"
)[0]
data = self.debug_array(entry)
chunk = uproot4.source.chunk.Chunk.wrap(self._file.source, data)
if skip_bytes is None:
cursor = uproot4.source.cursor.Cursor(0)
Expand Down
101 changes: 101 additions & 0 deletions uproot4/interpretation/grouped.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot4/blob/master/LICENSE

from __future__ import absolute_import

import uproot4.interpretation
import uproot4.extras


class Group(object):
pass


class AsGrouped(uproot4.interpretation.Interpretation):
def __init__(self, branch, subbranches, typename=None):
self._branch = branch
self._subbranches = subbranches
self._typename = typename

@property
def branch(self):
return self._branch

@property
def subbranches(self):
return self._subbranches

def __repr__(self):
return "AsGroup({0}, {1})".format(self._branch, self._subbranches)

def __eq__(self, other):
return (
isinstance(other, AsGrouped)
and self._branch == other._branch
and self._subbranches == other._subbranches
)

@property
def cache_key(self):
return "{0}({1},[{2}])".format(
type(self).__name__,
self._branch.name,
",".join(
"{0}:{1}".format(repr(x), y.cache_key)
for x, y in self._subbranches.items()
),
)

@property
def typename(self):
if self._typename is not None:
return self._typename
else:
return "(group of {0})".format(
", ".join(
"{0}:{1}".format(x, y.typename)
for x, y in self._subbranches.items()
)
)

def awkward_form(self, file, index_format="i64", header=False, tobject_header=True):
awkward1 = uproot4.extras.awkward1()

names = []
fields = []
for x, y in self._subbranches.items():
names.append(x)
fields.append(y.awkward_form(file, index_format, header, tobject_header))

return awkward1.forms.RecordForm(fields, names)

def basket_array(self, data, byte_offsets, basket, branch, context, cursor_offset):
raise ValueError(
"""grouping branches like {0} should not be read directly; instead read the subbranches:
{1}
in file {2}
in object {3}""".format(
repr(self._branch.name),
", ".join(repr(x) for x in self._subbranches),
self._branch.file.file_path,
self._branch.object_path,
)
)

def final_array(
self, basket_arrays, entry_start, entry_stop, entry_offsets, library, branch
):
raise ValueError(
"""grouping branches like {0} should not be read directly; instead read the subbranches:
{1}
in file {2}
in object {3}""".format(
repr(self._branch.name),
", ".join(repr(x) for x in self._subbranches),
self._branch.file.file_path,
self._branch.object_path,
)
)
17 changes: 17 additions & 0 deletions uproot4/interpretation/identify.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import uproot4.interpretation.numerical
import uproot4.interpretation.strings
import uproot4.interpretation.objects
import uproot4.interpretation.grouped
import uproot4.containers
import uproot4.streamers
import uproot4._util
Expand Down Expand Up @@ -629,6 +630,10 @@ def _parse_node(tokens, i, typename, file, quote, header, inner_header):
i, keys = _parse_node(
tokens, i + 2, typename, file, quote, inner_header, inner_header
)
while tokens[i].group(0) == ",":
i, keys = _parse_node(
tokens, i + 1, typename, file, quote, inner_header, inner_header
)
_parse_expect(">", tokens, i, typename, file)
stop = tokens[i].span()[1]

Expand Down Expand Up @@ -931,6 +936,18 @@ def _float16_or_double32(branch, context, leaf, is_float16, dims):


def interpretation_of(branch, context, simplify=True):
if len(branch.branches) != 0:
if branch.top_level and branch.has_member("fClassName"):
typename = branch.member("fClassName")
elif branch.streamer is not None:
typename = branch.streamer.typename
else:
typename = None
subbranches = dict((x.name, x.interpretation) for x in branch.branches)
return uproot4.interpretation.grouped.AsGrouped(
branch, subbranches, typename=typename
)

if branch.classname == "TBranchObject":
if branch.top_level and branch.has_member("fClassName"):
model_cls = parse_typename(
Expand Down

0 comments on commit cce6b79

Please sign in to comment.