Skip to content

Commit

Permalink
Fallback to reading streamer and raise better error messages on true …
Browse files Browse the repository at this point in the history
…failures. (#28)

* Fallback to reading streamer and raise better error messages on true failures.

* Pass context down to Cursor.

* Black and flake8.

* Passed context all the way down to Chunk.

* This is a good error message.

* Fallback for wrong streamers works.
  • Loading branch information
jpivarski authored Jun 18, 2020
1 parent a7d71de commit 53eba54
Show file tree
Hide file tree
Showing 22 changed files with 413 additions and 215 deletions.
4 changes: 1 addition & 3 deletions tests/test_0023-more-interpretations-1.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ def test_double32():
del uproot4.classes["TBranch"]
del uproot4.classes["TBranchElement"]

with uproot4.open(
skhep_testdata.data_path("uproot-demo-double32.root"),
)["T"] as t:
with uproot4.open(skhep_testdata.data_path("uproot-demo-double32.root"))["T"] as t:

print(t["fD64"].interpretation)
print(t["fF32"].interpretation)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_0028-fallback-to-read-streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/uproot4/blob/master/LICENSE

from __future__ import absolute_import

import numpy
import pytest
import skhep_testdata

import uproot4


def test_fallback_reading():
with uproot4.open(
skhep_testdata.data_path("uproot-small-evnt-tree-fullsplit.root")
) as f:
f["tree:evt/P3/P3.Py"]
assert f.file._streamers is None

with uproot4.open(skhep_testdata.data_path("uproot-demo-double32.root")) as f:
f["T/fD64"]
assert f.file._streamers is not None
14 changes: 8 additions & 6 deletions uproot4/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,24 +156,26 @@ def decompress(chunk, cursor, context, compressed_bytes, uncompressed_bytes):
# https://github.com/root-project/root/blob/master/core/lzma/src/ZipLZMA.c#L81
# https://github.com/root-project/root/blob/master/core/lz4/src/ZipLZ4.cxx#L38
algo, method, c1, c2, c3, u1, u2, u3 = cursor.fields(
chunk, _decompress_header_format
chunk, _decompress_header_format, context
)
block_compressed_bytes = c1 + (c2 << 8) + (c3 << 16)
block_uncompressed_bytes = u1 + (u2 << 8) + (u3 << 16)

if algo == b"ZL":
cls = ZLIB
data = cursor.bytes(chunk, block_compressed_bytes)
data = cursor.bytes(chunk, block_compressed_bytes, context)

elif algo == b"XZ":
cls = LZMA
data = cursor.bytes(chunk, block_compressed_bytes)
data = cursor.bytes(chunk, block_compressed_bytes, context)

elif algo == b"L4":
cls = LZ4
block_compressed_bytes -= 8
expected_checksum = cursor.field(chunk, _decompress_checksum_format)
data = cursor.bytes(chunk, block_compressed_bytes)
expected_checksum = cursor.field(
chunk, _decompress_checksum_format, context
)
data = cursor.bytes(chunk, block_compressed_bytes, context)
try:
import xxhash
except ImportError:
Expand All @@ -197,7 +199,7 @@ def decompress(chunk, cursor, context, compressed_bytes, uncompressed_bytes):

elif algo == b"ZS":
cls = ZSTD
data = cursor.bytes(chunk, block_compressed_bytes)
data = cursor.bytes(chunk, block_compressed_bytes, context)

elif algo == b"CS":
raise ValueError(
Expand Down
85 changes: 66 additions & 19 deletions uproot4/deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,61 @@ def c(name, version=None):
return out


class DeserializationError(Exception):
__slots__ = ["message", "context", "file_path"]

def __init__(self, message, context, file_path):
self.message = message
self.context = context
self.file_path = file_path

def __str__(self):
lines = []
indent = " "
for obj in self.context.get("breadcrumbs", ()):
lines.append(
"{0}{1} version {2} as {3}.{4}".format(
indent,
obj.classname,
obj.instance_version,
type(obj).__module__,
type(obj).__name__,
)
)
indent = indent + " "
for v in getattr(obj, "_bases", []):
lines.append("{0}(base): {1}".format(indent, repr(v)))
for k, v in getattr(obj, "_members", {}).items():
lines.append("{0}{1}: {2}".format(indent, k, repr(v)))

in_parent = ""
if "TBranch" in self.context:
in_parent = "\nin TBranch {0}".format(self.context["TBranch"].object_path)
elif "TKey" in self.context:
in_parent = "\nin object {0}".format(self.context["TKey"].object_path)

if len(lines) == 0:
return """{0}
in file {1}{2}""".format(
self.message, self.file_path, in_parent
)
else:
return """while reading
{0}
{1}
in file {2}{3}""".format(
"\n".join(lines), self.message, self.file_path, in_parent
)


_numbytes_version_1 = struct.Struct(">IH")
_numbytes_version_2 = struct.Struct(">H")


def numbytes_version(chunk, cursor, move=True):
num_bytes, version = cursor.fields(chunk, _numbytes_version_1, move=False)
def numbytes_version(chunk, cursor, context, move=True):
num_bytes, version = cursor.fields(chunk, _numbytes_version_1, context, move=False)
num_bytes = numpy.int64(num_bytes)

if num_bytes & uproot4.const.kByteCountMask:
Expand All @@ -66,36 +115,34 @@ def numbytes_version(chunk, cursor, move=True):

else:
num_bytes = None
version = cursor.field(chunk, _numbytes_version_2, move=move)
version = cursor.field(chunk, _numbytes_version_2, context, move=move)

return num_bytes, version


def numbytes_check(start_cursor, stop_cursor, num_bytes, classname, file_path):
def numbytes_check(start_cursor, stop_cursor, num_bytes, classname, context, file_path):
if num_bytes is not None:
observed = stop_cursor.displacement(start_cursor)
if observed != num_bytes:
if file_path is None:
in_file = ""
else:
in_file = "\nin file {0}".format(file_path)
raise ValueError(
"""instance of ROOT class {0} has {1} bytes; expected {2}{3}""".format(
classname, observed, num_bytes, in_file
)
raise uproot4.deserialization.DeserializationError(
"""expected {0} bytes but cursor moved by {1} bytes (through {2})""".format(
num_bytes, observed, classname
),
context,
file_path,
)


_map_string_string_format1 = struct.Struct(">I")


def map_string_string(chunk, cursor):
def map_string_string(chunk, cursor, context):
cursor.skip(12)
size = cursor.field(chunk, _map_string_string_format1)
size = cursor.field(chunk, _map_string_string_format1, context)
cursor.skip(6)
keys = [cursor.string(chunk) for i in range(size)]
keys = [cursor.string(chunk, context) for i in range(size)]
cursor.skip(6)
values = [cursor.string(chunk) for i in range(size)]
values = [cursor.string(chunk, context) for i in range(size)]
return dict(zip(keys, values))


Expand All @@ -111,7 +158,7 @@ def read_object_any(chunk, cursor, context, file, parent, as_class=None):
# https://github.com/root-project/root/blob/c4aa801d24d0b1eeb6c1623fd18160ef2397ee54/io/io/src/TBufferFile.cxx#L2404

beg = cursor.displacement()
bcnt = numpy.int64(cursor.field(chunk, _read_object_any_format1))
bcnt = numpy.int64(cursor.field(chunk, _read_object_any_format1, context))

if (bcnt & uproot4.const.kByteCountMask) == 0 or (
bcnt == uproot4.const.kNewClassTag
Expand All @@ -123,7 +170,7 @@ def read_object_any(chunk, cursor, context, file, parent, as_class=None):
else:
vers = 1
start = cursor.displacement()
tag = numpy.int64(cursor.field(chunk, _read_object_any_format1))
tag = numpy.int64(cursor.field(chunk, _read_object_any_format1, context))
bcnt = int(bcnt)

if tag & uproot4.const.kClassMask == 0:
Expand All @@ -146,7 +193,7 @@ def read_object_any(chunk, cursor, context, file, parent, as_class=None):
elif tag == uproot4.const.kNewClassTag:
# new class and object

classname = cursor.classname(chunk)
classname = cursor.classname(chunk, context)

cls = file.class_named(classname)

Expand Down
68 changes: 57 additions & 11 deletions uproot4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,40 @@
import uproot4._util


bootstrap_classnames = [
"TStreamerInfo",
"TStreamerElement",
"TStreamerArtificial",
"TStreamerBase",
"TStreamerBasicPointer",
"TStreamerBasicType",
"TStreamerLoop",
"TStreamerObject",
"TStreamerObjectAny",
"TStreamerObjectAnyPointer",
"TStreamerObjectPointer",
"TStreamerSTL",
"TStreamerSTLstring",
"TStreamerString",
"TList",
"TObjArray",
"TObjString",
]


def bootstrap_classes():
import uproot4.streamers
import uproot4.models.TList
import uproot4.models.TObjArray
import uproot4.models.TObjString

custom_classes = {}
for classname in bootstrap_classnames:
custom_classes[classname] = uproot4.classes[classname]

return custom_classes


class Model(object):
@classmethod
def read(cls, chunk, cursor, context, file, parent):
Expand All @@ -22,6 +56,9 @@ def read(cls, chunk, cursor, context, file, parent):
self._num_bytes = None
self._instance_version = None

old_breadcrumbs = context.get("breadcrumbs", ())
context["breadcrumbs"] = old_breadcrumbs + (self,)

self.hook_before_read(chunk=chunk, cursor=cursor, context=context)

self.read_numbytes_version(chunk, cursor, context)
Expand All @@ -36,7 +73,11 @@ def read(cls, chunk, cursor, context, file, parent):

self.hook_before_postprocess(chunk=chunk, cursor=cursor, context=context)

return self.postprocess(chunk, cursor, context)
out = self.postprocess(chunk, cursor, context)

context["breadcrumbs"] = old_breadcrumbs

return out

def __repr__(self):
return "<{0} at 0x{1:012x}>".format(
Expand All @@ -49,7 +90,7 @@ def read_numbytes_version(self, chunk, cursor, context):
(
self._num_bytes,
self._instance_version,
) = uproot4.deserialization.numbytes_version(chunk, cursor)
) = uproot4.deserialization.numbytes_version(chunk, cursor, context)

def read_members(self, chunk, cursor, context):
pass
Expand All @@ -61,7 +102,8 @@ def check_numbytes(self, cursor, context):
self._cursor,
cursor,
self._num_bytes,
classname_pretty(self.classname, self.class_version),
self.classname,
context,
getattr(self._file, "file_path"),
)

Expand Down Expand Up @@ -259,7 +301,7 @@ def read(cls, chunk, cursor, context, file, parent):
import uproot4.deserialization

num_bytes, version = uproot4.deserialization.numbytes_version(
chunk, cursor, move=False
chunk, cursor, context, move=False
)

versioned_cls = cls.known_versions.get(version)
Expand Down Expand Up @@ -399,11 +441,15 @@ def classname_pretty(classname, version):
return "{0} (version {1})".format(classname, version)


def has_class_named(classname, version=None, classes=None):
if classes is None:
classes = uproot4.classes
def maybe_custom_classes(custom_classes):
if custom_classes is None:
return uproot4.classes
else:
return custom_classes

cls = classes.get(classname)

def has_class_named(classname, version=None, custom_classes=None):
cls = maybe_custom_classes(custom_classes).get(classname)
if cls is None:
return False

Expand All @@ -413,10 +459,10 @@ def has_class_named(classname, version=None, classes=None):
return True


def class_named(classname, version=None, classes=None):
if classes is None:
def class_named(classname, version=None, custom_classes=None):
if custom_classes is None:
classes = uproot4.classes
where = "the given 'classes' dict"
where = "the 'custom_classes' dict"
else:
where = "uproot4.classes"

Expand Down
2 changes: 1 addition & 1 deletion uproot4/models/RNTuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def read_members(self, chunk, cursor, context):
self._members["fNBytesFooter"],
self._members["fLenFooter"],
self._members["fReserved"],
) = cursor.fields(chunk, _rntuple_format1)
) = cursor.fields(chunk, _rntuple_format1, context)


uproot4.classes[
Expand Down
4 changes: 2 additions & 2 deletions uproot4/models/TArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def read_numbytes_version(self, chunk, cursor, context):
pass

def read_members(self, chunk, cursor, context):
self._members["fN"] = cursor.field(chunk, _tarray_format1)
self._data = cursor.array(chunk, self._members["fN"], self.dtype)
self._members["fN"] = cursor.field(chunk, _tarray_format1, context)
self._data = cursor.array(chunk, self._members["fN"], self.dtype, context)

def __array__(self):
return self._data
Expand Down
Loading

0 comments on commit 53eba54

Please sign in to comment.