diff --git a/tests/test_0023-more-interpretations-1.py b/tests/test_0023-more-interpretations-1.py index 70e9cd65d..b4dc52084 100644 --- a/tests/test_0023-more-interpretations-1.py +++ b/tests/test_0023-more-interpretations-1.py @@ -49,24 +49,6 @@ def test_strings1(): assert result.tolist() == ["hey-{0}".format(i) for i in range(30)] -@pytest.mark.skip(reason="FIXME: implement strings specified by a TStreamer") -def test_strings2(): - with uproot4.open( - skhep_testdata.data_path("uproot-small-evnt-tree-fullsplit.root") - )["tree/Str"] as branch: - result = branch.array(library="np") - assert result.tolist() == ["evt-{0:03d}".format(i) for i in range(100)] - - -@pytest.mark.skip(reason="FIXME: implement std::string") -def test_strings3(): - with uproot4.open( - skhep_testdata.data_path("uproot-small-evnt-tree-fullsplit.root") - )["tree/StdStr"] as branch: - result = branch.array(library="np") - assert result.tolist() == ["std-{0:03d}".format(i) for i in range(100)] - - @pytest.mark.skip(reason="FIXME: implement std::vector") def test_strings4(): with uproot4.open( diff --git a/tests/test_0028-fallback-to-read-streamer.py b/tests/test_0028-fallback-to-read-streamer.py index b89b2b5c4..1e90368b5 100644 --- a/tests/test_0028-fallback-to-read-streamer.py +++ b/tests/test_0028-fallback-to-read-streamer.py @@ -10,11 +10,11 @@ def test_fallback_reading(): - with uproot4.open( - skhep_testdata.data_path("uproot-small-evnt-tree-fullsplit.root") - ) as f: - f["tree:evt/P3/P3.Py"] - assert f.file._streamers is None + # with uproot4.open( + # skhep_testdata.data_path("uproot-small-evnt-tree-fullsplit.root") + # ) as f: + # f["tree:evt/P3/P3.Py"] + # assert f.file._streamers is None with uproot4.open(skhep_testdata.data_path("uproot-demo-double32.root")) as f: f["T/fD64"] diff --git a/tests/test_0029-more-string-types.py b/tests/test_0029-more-string-types.py new file mode 100644 index 000000000..9223c4916 --- /dev/null +++ b/tests/test_0029-more-string-types.py @@ -0,0 +1,173 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/uproot4/blob/master/LICENSE + +from __future__ import absolute_import + +import sys +import json + +import numpy +import pytest +import skhep_testdata + +import uproot4 +from uproot4.stl_containers import parse_typename +from uproot4.stl_containers import AsString +from uproot4.stl_containers import AsVector +from uproot4.stl_containers import AsSet +from uproot4.stl_containers import AsMap + + +def test_parse_typename(): + assert parse_typename("TTree") is uproot4.classes["TTree"] + assert parse_typename("string") == AsString() + assert parse_typename("std::string") == AsString() + assert parse_typename("std :: string") == AsString() + assert parse_typename("char*") == AsString(is_stl=False) + assert parse_typename("char *") == AsString(is_stl=False) + assert parse_typename("TString") == AsString(is_stl=False) + assert parse_typename("vector") == AsVector(uproot4.classes["TTree"]) + assert parse_typename("vector") == AsVector(">i4") + assert parse_typename("vector") == AsVector("?") + assert parse_typename("vector") == AsVector(AsString()) + assert parse_typename("vector < string >") == AsVector(AsString()) + assert parse_typename("std::vector") == AsVector(AsString()) + assert parse_typename("vector>") == AsVector(AsVector(">i4")) + assert parse_typename("vector>") == AsVector(AsVector(AsString())) + assert parse_typename("vector>") == AsVector( + AsVector(AsString(is_stl=False)) + ) + assert parse_typename("set") == AsSet(">u2") + assert parse_typename("std::set") == AsSet(">u2") + assert parse_typename("set") == AsSet(AsString()) + assert parse_typename("set>") == AsSet(AsVector(AsString())) + assert parse_typename("set >") == AsSet(AsVector(AsString())) + assert parse_typename("map") == AsMap(">i4", ">f8") + assert parse_typename("map") == AsMap(AsString(), ">f8") + assert parse_typename("map") == AsMap(">i4", AsString()) + assert parse_typename("map") == AsMap(AsString(), AsString()) + assert parse_typename("map") == AsMap(AsString(), AsString()) + assert parse_typename("map< string,string >") == AsMap(AsString(), AsString()) + assert parse_typename("map>") == AsMap( + AsString(), AsVector(">i4") + ) + assert parse_typename("map, string>") == AsMap( + AsVector(">i4"), AsString() + ) + assert parse_typename("map, set>") == AsMap( + AsVector(">i4"), AsSet(">f4") + ) + assert parse_typename("map, set>>") == AsMap( + AsVector(">i4"), AsSet(AsSet(">f4")) + ) + + with pytest.raises(ValueError): + parse_typename("string <") + + with pytest.raises(ValueError): + parse_typename("vector <") + + with pytest.raises(ValueError): + parse_typename("map>") + + with pytest.raises(ValueError): + parse_typename("map>") + + +def test_strings1(): + with uproot4.open( + skhep_testdata.data_path("uproot-small-evnt-tree-fullsplit.root") + )["tree"] as tree: + result = tree["Beg"].array(library="np") + assert result.tolist() == ["beg-{0:03d}".format(i) for i in range(100)] + + result = tree["End"].array(library="np") + assert result.tolist() == ["end-{0:03d}".format(i) for i in range(100)] + + +def test_map_string_string_in_object(): + with uproot4.open(skhep_testdata.data_path("uproot-issue431.root")) as f: + head = f["Head"] + assert head.member("map") == { + "DAQ": "394", + "PDF": "4 58", + "XSecFile": "", + "can": "0 1027 888.4", + "can_user": "0.00 1027.00 888.40", + "coord_origin": "0 0 0", + "cut_in": "0 0 0 0", + "cut_nu": "100 1e+08 -1 1", + "cut_primary": "0 0 0 0", + "cut_seamuon": "0 0 0 0", + "decay": "doesnt happen", + "detector": "NOT", + "drawing": "Volume", + "end_event": "", + "genhencut": "2000 0", + "genvol": "0 1027 888.4 2.649e+09 100000", + "kcut": "2", + "livetime": "0 0", + "model": "1 2 0 1 12", + "muon_desc_file": "", + "ngen": "0.1000E+06", + "norma": "0 0", + "nuflux": "0 3 0 0.500E+00 0.000E+00 0.100E+01 0.300E+01", + "physics": "GENHEN 7.2-220514 181116 1138", + "seed": "GENHEN 3 305765867 0 0", + "simul": "JSirene 11012 11/17/18 07", + "sourcemode": "diffuse", + "spectrum": "-1.4", + "start_run": "1", + "target": "isoscalar", + "usedetfile": "false", + "xlat_user": "0.63297", + "xparam": "OFF", + "zed_user": "0.00 3450.00", + } + + +@pytest.mark.skip( + reason="FIXME: test works, but the file is not in scikit-hep-testdata yet" +) +def test_map_long_int_in_object(): + with uproot4.open( + "/home/pivarski/irishep/scikit-hep-testdata/src/skhep_testdata/data/uproot-issue283.root" + ) as f: + print(f["config/detector"]) + + # raise Exception + + +# has STL vectors at top-level: +# +# python -c 'import uproot; t = uproot.open("/home/pivarski/irishep/scikit-hep-testdata/src/skhep_testdata/data/uproot-issue38a.root")["ntupler/tree"]; print("\n".join(str((x._fName, getattr(x, "_fStreamerType", None), getattr(x, "_fClassName", None), getattr(x, "_fType", None), x.interpretation)) for x in t.allvalues()))' + +# has STL map as described here: +# +# https://github.com/scikit-hep/uproot/issues/468#issuecomment-646325842 +# +# python -c 'import uproot; t = uproot.open("/home/pivarski/irishep/scikit-hep-testdata/src/skhep_testdata/data/uproot-issue468.root")["Geant4Data/Geant4Data./Geant4Data.particles"]; print(t.array(uproot.asdebug)[0][:1000])' + +# def test_strings1(): +# with uproot4.open( +# skhep_testdata.data_path("uproot-issue31.root") +# )["T/name"] as branch: +# result = branch.array(library="np") +# assert result.tolist() == ["one", "two", "three", "four", "five"] + + +@pytest.mark.skip(reason="FIXME: implement strings specified by a TStreamer") +def test_strings2(): + with uproot4.open( + skhep_testdata.data_path("uproot-small-evnt-tree-fullsplit.root") + )["tree/Str"] as branch: + result = branch.array(library="np") + assert result.tolist() == ["evt-{0:03d}".format(i) for i in range(100)] + + +@pytest.mark.skip(reason="FIXME: implement std::string") +def test_strings3(): + with uproot4.open( + skhep_testdata.data_path("uproot-small-evnt-tree-fullsplit.root") + )["tree/StdStr"] as branch: + result = branch.array(library="np") + assert result.tolist() == ["std-{0:03d}".format(i) for i in range(100)] diff --git a/uproot4/__init__.py b/uproot4/__init__.py index 3ef1fe544..ae8a166fb 100644 --- a/uproot4/__init__.py +++ b/uproot4/__init__.py @@ -23,6 +23,8 @@ decompression_executor = ThreadPoolExecutor() interpretation_executor = TrivialExecutor() +from uproot4.deserialization import DeserializationError + from uproot4.reading import open from uproot4.reading import ReadOnlyFile from uproot4.reading import ReadOnlyDirectory @@ -33,6 +35,10 @@ from uproot4.model import has_class_named from uproot4.model import class_named +from uproot4.stl_containers import STLVector +from uproot4.stl_containers import STLSet +from uproot4.stl_containers import STLMap + import uproot4.interpretation import uproot4.interpretation.library @@ -94,12 +100,14 @@ def behavior_of(classname): class KeyInFileError(KeyError): - def __init__(self, key, file_path, cycle=None, because="", object_path=None): + __slots__ = ["key", "because", "cycle", "file_path", "object_path"] + + def __init__(self, key, because="", cycle=None, file_path=None, object_path=None): super(KeyInFileError, self).__init__(key) self.key = key - self.file_path = file_path - self.cycle = cycle self.because = because + self.cycle = cycle + self.file_path = file_path self.object_path = object_path def __str__(self): @@ -108,25 +116,25 @@ def __str__(self): else: because = " because " + self.because - if self.object_path is None: - object_path = "" - else: - object_path = "\nin object {0}".format(self.object_path) + in_file = "" + if self.file_path is not None: + in_file = "\nin file {0}".format(self.file_path) + + in_object = "" + if self.object_path is not None: + in_object = "\nin object {0}".format(self.object_path) if self.cycle == "any": - return """not found: {0} (with any cycle number){1} -in file {2}{3}""".format( - repr(self.key), because, self.file_path, object_path + return """not found: {0} (with any cycle number){1}{2}{3}""".format( + repr(self.key), because, in_file, in_object ) elif self.cycle is None: - return """not found: {0}{1} -in file {2}{3}""".format( - repr(self.key), because, self.file_path, object_path + return """not found: {0}{1}{2}{3}""".format( + repr(self.key), because, in_file, in_object ) else: - return """not found: {0} with cycle {1}{2} -in file {3}{4}""".format( - repr(self.key), self.cycle, because, self.file_path, object_path + return """not found: {0} with cycle {1}{2}{3}{4}""".format( + repr(self.key), self.cycle, because, in_file, in_object ) diff --git a/uproot4/behaviors/TBranch.py b/uproot4/behaviors/TBranch.py index bd2ddb895..11822a9e7 100644 --- a/uproot4/behaviors/TBranch.py +++ b/uproot4/behaviors/TBranch.py @@ -525,7 +525,9 @@ def __getitem__(self, where): return v else: raise uproot4.KeyInFileError( - original_where, self._file.file_path, object_path=self.object_path + original_where, + file_path=self._file.file_path, + object_path=self.object_path, ) elif recursive: @@ -535,7 +537,9 @@ def __getitem__(self, where): return got else: raise uproot4.KeyInFileError( - original_where, self._file.file_path, object_path=self.object_path + original_where, + file_path=self._file.file_path, + object_path=self.object_path, ) else: @@ -545,7 +549,9 @@ def __getitem__(self, where): return branch else: raise uproot4.KeyInFileError( - original_where, self._file.file_path, object_path=self.object_path + original_where, + file_path=self._file.file_path, + object_path=self.object_path, ) def iteritems( diff --git a/uproot4/compute/python.py b/uproot4/compute/python.py index 4451ea3a5..57ce868e8 100644 --- a/uproot4/compute/python.py +++ b/uproot4/compute/python.py @@ -155,7 +155,9 @@ def _expression_to_function( node.body[0].value, keys, aliases, functions, getter ) except KeyError as err: - raise uproot4.KeyInFileError(err.args[0], file_path, object_path=object_path) + raise uproot4.KeyInFileError( + err.args[0], file_path=file_path, object_path=object_path + ) function = ast.parse("lambda: None").body[0].value function.body = expr diff --git a/uproot4/deserialization.py b/uproot4/deserialization.py index 7e18c293a..44c2a988f 100644 --- a/uproot4/deserialization.py +++ b/uproot4/deserialization.py @@ -3,6 +3,7 @@ from __future__ import absolute_import import struct +import sys import numpy @@ -14,7 +15,7 @@ scope = { "struct": struct, "numpy": numpy, - "VersionedModel": uproot4.model.VersionedModel, + "uproot4": uproot4, } @@ -52,10 +53,12 @@ def c(name, version=None): class DeserializationError(Exception): - __slots__ = ["message", "context", "file_path"] + __slots__ = ["message", "chunk", "cursor", "context", "file_path"] - def __init__(self, message, context, file_path): + def __init__(self, message, chunk, cursor, context, file_path): self.message = message + self.chunk = chunk + self.cursor = cursor self.context = context self.file_path = file_path @@ -64,12 +67,13 @@ def __str__(self): indent = " " for obj in self.context.get("breadcrumbs", ()): lines.append( - "{0}{1} version {2} as {3}.{4}".format( + "{0}{1} version {2} as {3}.{4} ({5} bytes)".format( indent, obj.classname, obj.instance_version, type(obj).__module__, type(obj).__name__, + obj.num_bytes, ) ) indent = indent + " " @@ -99,6 +103,34 @@ def __str__(self): "\n".join(lines), self.message, self.file_path, in_parent ) + def debug( + self, skip_bytes=0, limit_bytes=None, dtype=None, offset=0, stream=sys.stdout + ): + cursor = self.cursor.copy() + cursor.skip(skip_bytes) + cursor.debug( + self.chunk, + context=self.context, + limit_bytes=limit_bytes, + dtype=dtype, + offset=offset, + stream=stream, + ) + + def array(self, dtype, skip_bytes=0, limit_bytes=None): + dtype = numpy.dtype(dtype) + cursor = self.cursor.copy() + cursor.skip(skip_bytes) + out = self.chunk.remainder(cursor.index, cursor, self.context)[:limit_bytes] + return out[: (len(out) // dtype.itemsize) * dtype.itemsize].view(dtype) + + @property + def partial_object(self): + if "breadcrumbs" in self.context: + return self.context["breadcrumbs"][-1] + else: + return None + _numbytes_version_1 = struct.Struct(">IH") _numbytes_version_2 = struct.Struct(">H") @@ -133,20 +165,21 @@ def numbytes_check(start_cursor, stop_cursor, num_bytes, classname, context, fil ) -_map_string_string_format1 = struct.Struct(">I") - - -def map_string_string(chunk, cursor, context): - cursor.skip(12) - size = cursor.field(chunk, _map_string_string_format1, context) - cursor.skip(6) - keys = [cursor.string(chunk, context) for i in range(size)] - cursor.skip(6) - values = [cursor.string(chunk, context) for i in range(size)] - return dict(zip(keys, values)) - - -scope["map_string_string"] = map_string_string +# _map_string_string_format1 = struct.Struct(">I") +# def map_long_int(chunk, cursor, context): +# cursor.skip(12) +# size = cursor.field(chunk, _map_string_string_format1, context) +# keys = cursor.array(chunk, size, numpy.dtype(">i8"), context) +# values = cursor.array(chunk, size, numpy.dtype(">i4"), context) +# return dict(zip(keys, values)) +# scope["map_long_int"] = map_long_int + +# def set_long(chunk, cursor, context): +# cursor.skip(6) +# size = cursor.field(chunk, _map_string_string_format1, context) +# values = cursor.array(chunk, size, numpy.dtype(">i8"), context) +# return set(values) +# scope["set_long"] = set_long _read_object_any_format1 = struct.Struct(">I") diff --git a/uproot4/interpretation/identify.py b/uproot4/interpretation/identify.py index 21902315d..2f1f32de1 100644 --- a/uproot4/interpretation/identify.py +++ b/uproot4/interpretation/identify.py @@ -410,6 +410,12 @@ def interpretation_of(branch, context): ) except NotNumerical: + if ( + branch.has_member("fStreamerType") + and branch.member("fStreamerType") == uproot4.const.kTString + ): + return uproot4.interpretation.strings.AsStrings(size_1to5_bytes=True) + if len(branch.member("fLeaves")) != 1: raise UnknownInterpretation( "more or less than one TLeaf ({0}) in a non-numerical TBranch".format( diff --git a/uproot4/model.py b/uproot4/model.py index 6755b78c3..4cbee096a 100644 --- a/uproot4/model.py +++ b/uproot4/model.py @@ -45,6 +45,8 @@ def bootstrap_classes(): class Model(object): + class_streamer = None + @classmethod def read(cls, chunk, cursor, context, file, parent): self = cls.__new__(cls) @@ -201,11 +203,16 @@ def member(self, name, bases=True, recursive_bases=True): if name in base._members: return base._members[name] - if self._file is None: - in_file = "" - else: - in_file = "\nin file {0}".format(self._file.file_path) - raise KeyError("C++ member {0} not found{1}".format(repr(name), in_file)) + raise uproot4.KeyInFileError( + name, + """{0}.{1} has only the following members: + + {2} +""".format( + type(self).__module__, type(self).__name__, "\n ".join(self._members) + ), + file_path=self._file.file_path, + ) def tojson(self): out = {"_typename": self.classname} @@ -335,7 +342,7 @@ def new_class(cls, file, version): if streamer is not None: versioned_cls = streamer.new_class(file) - versioned_cls.streamer = streamer + versioned_cls.class_streamer = streamer cls.known_versions[streamer.class_version] = versioned_cls return versioned_cls diff --git a/uproot4/models/TBasket.py b/uproot4/models/TBasket.py index 25391ffcb..2ee063bad 100644 --- a/uproot4/models/TBasket.py +++ b/uproot4/models/TBasket.py @@ -81,7 +81,12 @@ def read_members(self, chunk, cursor, context): uncompressed = uproot4.compression.decompress( chunk, cursor, {}, self.compressed_bytes, self.uncompressed_bytes, ) - self._raw_data = uncompressed.get(0, self.uncompressed_bytes, context) + self._raw_data = uncompressed.get( + 0, + self.uncompressed_bytes, + uproot4.source.cursor.Cursor(0), + context, + ) else: self._raw_data = cursor.bytes( chunk, self.uncompressed_bytes, context, copy_if_memmap=True diff --git a/uproot4/models/TNamed.py b/uproot4/models/TNamed.py index f0ee259ed..b3fcc9704 100644 --- a/uproot4/models/TNamed.py +++ b/uproot4/models/TNamed.py @@ -17,5 +17,13 @@ def read_members(self, chunk, cursor, context): self._members["fName"] = cursor.string(chunk, context) self._members["fTitle"] = cursor.string(chunk, context) + def __repr__(self): + title = "" + if self._members["fTitle"] != "": + title = " title=" + repr(self._members["fTitle"]) + return "".format( + repr(self._members["fName"]), title, id(self) + ) + uproot4.classes["TNamed"] = Model_TNamed diff --git a/uproot4/models/TObject.py b/uproot4/models/TObject.py index fd08da27e..35573fb3d 100644 --- a/uproot4/models/TObject.py +++ b/uproot4/models/TObject.py @@ -32,5 +32,10 @@ def read_members(self, chunk, cursor, context): cursor.skip(2) self._members["fBits"] = int(self._members["fBits"]) + def __repr__(self): + return "".format( + self._members.get("fUniqueID"), self._members.get("fBits"), id(self) + ) + uproot4.classes["TObject"] = Model_TObject diff --git a/uproot4/reading.py b/uproot4/reading.py index 9df473f03..3ace81a85 100644 --- a/uproot4/reading.py +++ b/uproot4/reading.py @@ -266,7 +266,7 @@ def custom_classes(self, value): else: raise TypeError("custom_classes must be None or a MutableMapping") - def remove_class(self, classname): + def remove_class_definition(self, classname): if self._custom_classes is None: self._custom_classes = dict(uproot4.classes) if classname in self._custom_classes: @@ -308,14 +308,6 @@ def root_directory(self): self, ) - def is_custom_class(self, classname): - if self._custom_classes is None: - return False - else: - mine = self._custom_classes.get(classname) - theirs = uproot4.classes.get(classname) - return mine is not None and mine is not theirs - @property def streamers(self): import uproot4.streamers @@ -792,7 +784,12 @@ def get_uncompressed_chunk_cursor(self): else: uncompressed_chunk = uproot4.source.chunk.Chunk.wrap( chunk.source, - chunk.get(data_start, data_stop, {"breadcrumbs": (), "TKey": self}), + chunk.get( + data_start, + data_stop, + self.data_cursor, + {"breadcrumbs": (), "TKey": self}, + ), ) return uncompressed_chunk, cursor @@ -840,9 +837,12 @@ def get(self): except uproot4.deserialization.DeserializationError: breadcrumbs = context.get("breadcrumbs") + if breadcrumbs is None or all( breadcrumb_cls.classname in uproot4.model.bootstrap_classnames - or self._file.is_custom_class(breadcrumb_cls.classname) + or isinstance(breadcrumb_cls, uproot4.stl_containers.AsSTLContainer) + or getattr(breadcrumb_cls.class_streamer, "file_uuid", None) + == self._file.uuid for breadcrumb_cls in breadcrumbs ): # we're already using the most specialized versions of each class @@ -853,11 +853,12 @@ def get(self): breadcrumb_cls.classname not in uproot4.model.bootstrap_classnames ): - self._file.remove_class(breadcrumb_cls.classname) + self._file.remove_class_definition(breadcrumb_cls.classname) cursor = start_cursor cls = self._file.class_named(self._fClassName) context = {"breadcrumbs": (), "TKey": self} + out = cls.read(chunk, cursor, context, self._file, self) if self._file.object_cache is not None: @@ -1303,9 +1304,9 @@ def __getitem__(self, where): else: raise uproot4.KeyInFileError( where, - self._file.file_path, - because=repr(head) + repr(head) + " is not a TDirectory, TTree, or TBranch", + file_path=self._file.file_path, ) else: step = step[item] @@ -1316,9 +1317,8 @@ def __getitem__(self, where): else: raise uproot4.KeyInFileError( where, - self._file.file_path, - because=repr(item) - + " is not a TDirectory, TTree, or TBranch", + repr(item) + " is not a TDirectory, TTree, or TBranch", + file_path=self._file.file_path, ) return step @@ -1351,8 +1351,8 @@ def key(self, where): else: raise uproot4.KeyInFileError( where, - self._file.file_path, - because=repr(item) + " is not a TDirectory", + repr(item) + " is not a TDirectory", + file_path=self._file.file_path, ) return step.key(items[-1]) @@ -1376,6 +1376,10 @@ def key(self, where): if last is not None: return last elif cycle is None: - raise uproot4.KeyInFileError(item, self._file.file_path, cycle="any") + raise uproot4.KeyInFileError( + item, cycle="any", file_path=self._file.file_path + ) else: - raise uproot4.KeyInFileError(item, self._file.file_path, cycle=cycle) + raise uproot4.KeyInFileError( + item, cycle=cycle, file_path=self._file.file_path + ) diff --git a/uproot4/source/chunk.py b/uproot4/source/chunk.py index 1e5475276..91fba35a7 100644 --- a/uproot4/source/chunk.py +++ b/uproot4/source/chunk.py @@ -340,13 +340,14 @@ def raw_data(self): self.wait() return self._raw_data - def get(self, start, stop, context): + def get(self, start, stop, cursor, context): """ Args: start (int): Starting byte position to extract (inclusive, global in Source). stop (int): Stopping byte position to extract (exclusive, global in Source). + cursor (Cursor): The Cursor that is currently reading this Chunk. context (dict): Information about the current state of deserialization. Returns a subinterval of the `raw_data` using global coordinates as a @@ -369,6 +370,8 @@ def get(self, start, stop, context): outside expected range {2}:{3} for this Chunk""".format( start, stop, self._start, self._stop ), + self, + cursor.copy(), context, self._source.file_path, ) @@ -376,7 +379,7 @@ def get(self, start, stop, context): else: raise RefineChunk(start, stop, self._start, self._stop) - def remainder(self, start, context): + def remainder(self, start, cursor, context): """ Args: start (int): Starting byte position to extract (inclusive, global @@ -402,6 +405,8 @@ def remainder(self, start, context): outside expected range {1}:{2} for this Chunk""".format( start, self._start, self._stop ), + self, + cursor.copy(), context, self._source.file_path, ) diff --git a/uproot4/source/cursor.py b/uproot4/source/cursor.py index 674134be1..6f2620783 100644 --- a/uproot4/source/cursor.py +++ b/uproot4/source/cursor.py @@ -15,6 +15,12 @@ import uproot4.deserialization +_printable_characters = ( + "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLM" + "NOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ " +) + + class Cursor(object): """ Represents a position in a ROOT file, which may be held for later reference @@ -156,7 +162,7 @@ def fields(self, chunk, format, context, move=True): stop = start + format.size if move: self._index = stop - return format.unpack(chunk.get(start, stop, context)) + return format.unpack(chunk.get(start, stop, self, context)) def field(self, chunk, format, context, move=True): """ @@ -169,7 +175,7 @@ def field(self, chunk, format, context, move=True): stop = start + format.size if move: self._index = stop - return format.unpack(chunk.get(start, stop, context))[0] + return format.unpack(chunk.get(start, stop, self, context))[0] def bytes(self, chunk, length, context, move=True, copy_if_memmap=False): """ @@ -184,7 +190,7 @@ def bytes(self, chunk, length, context, move=True, copy_if_memmap=False): stop = start + length if move: self._index = stop - out = chunk.get(start, stop, context) + out = chunk.get(start, stop, self, context) if copy_if_memmap: step = out while getattr(step, "base", None) is not None: @@ -204,7 +210,7 @@ def array(self, chunk, length, dtype, context, move=True): stop = start + length * dtype.itemsize if move: self._index = stop - return numpy.frombuffer(chunk.get(start, stop, context), dtype=dtype) + return numpy.frombuffer(chunk.get(start, stop, self, context), dtype=dtype) _u1 = numpy.dtype("u1") _i4 = numpy.dtype(">i4") @@ -218,17 +224,17 @@ def bytestring(self, chunk, context, move=True): """ start = self._index stop = start + 1 - length = chunk.get(start, stop, context)[0] + length = chunk.get(start, stop, self, context)[0] if length == 255: start = stop stop = start + 4 - length_data = chunk.get(start, stop, context) + length_data = chunk.get(start, stop, self, context) length = numpy.frombuffer(length_data, dtype=self._u1).view(self._i4)[0] start = stop stop = start + length if move: self._index = stop - return chunk.get(start, stop, context).tostring() + return chunk.get(start, stop, self, context).tostring() def string(self, chunk, context, move=True): """ @@ -255,7 +261,7 @@ def classname(self, chunk, context, move=True): If `move` is False, only peek: don't update the index. """ - remainder = chunk.remainder(self._index, context) + remainder = chunk.remainder(self._index, self, context) local_stop = 0 char = None while char != 0: @@ -278,11 +284,6 @@ def classname(self, chunk, context, move=True): else: return out.decode(errors="surrogateescape") - _printable = ( - "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLM" - "NOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ " - ) - def debug( self, chunk, @@ -321,7 +322,7 @@ def debug( --- --- --- C J --- --- C --- --- --- { { 101.0 202.0 303.0 """ - data = chunk.remainder(self._index, context) + data = chunk.remainder(self._index, self, context) if limit_bytes is not None: data = data[:limit_bytes] @@ -365,7 +366,9 @@ def debug( stream.write( prefix + u" ".join( - u"{0:>3s}".format(chr(x)) if chr(x) in self._printable else u"---" + u"{0:>3s}".format(chr(x)) + if chr(x) in _printable_characters + else u"---" for x in line_data ) + u"\n" diff --git a/uproot4/stl_containers.py b/uproot4/stl_containers.py new file mode 100644 index 000000000..aabfe5f0e --- /dev/null +++ b/uproot4/stl_containers.py @@ -0,0 +1,807 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/uproot4/blob/master/LICENSE + +from __future__ import absolute_import + +import re +import types +import struct + +try: + from collections.abc import Sequence + from collections.abc import Set + from collections.abc import Mapping + from collections.abc import KeysView + from collections.abc import ValuesView +except ImportError: + from collections import Sequence + from collections import Set + from collections import Mapping + + KeysView = None + ValuesView = None + +import numpy + +import uproot4._util +import uproot4.model +import uproot4.deserialization + + +_stl_container_size = struct.Struct(">I") +_stl_primitive_types = { + numpy.dtype("?"): "bool", + numpy.dtype("i1"): "int8_t", + numpy.dtype("u1"): "uint8_t", + numpy.dtype("i2"): "int16_t", + numpy.dtype(">i2"): "int16_t", + numpy.dtype("u2"): "unt16_t", + numpy.dtype(">u2"): "unt16_t", + numpy.dtype("i4"): "int32_t", + numpy.dtype(">i4"): "int32_t", + numpy.dtype("u4"): "unt32_t", + numpy.dtype(">u4"): "unt32_t", + numpy.dtype("i8"): "int64_t", + numpy.dtype(">i8"): "int64_t", + numpy.dtype("u8"): "unt64_t", + numpy.dtype(">u8"): "unt64_t", + numpy.dtype("f4"): "float", + numpy.dtype(">f4"): "float", + numpy.dtype("f8"): "double", + numpy.dtype(">f8"): "double", +} +_stl_object_type = numpy.dtype(numpy.object) + + +_tokenize_typename_pattern = re.compile( + r"(\b([A-Za-z_][A-Za-z_0-9]*)(\s*::\s*[A-Za-z_][A-Za-z_0-9]*)*\b(\s*\*)*|<|>|,)" +) + +_simplify_token_1 = re.compile(r"\s*\*") +_simplify_token_2 = re.compile(r"\s*::\s*") + + +def _simplify_token(token): + return _simplify_token_2.sub("::", _simplify_token_1.sub("*", token.group(0))) + + +def _parse_error(pos, typename, file): + in_file = "" + if file is not None: + in_file = "\nin file {0}".format(file.file_path) + raise ValueError( + """invalid C++ type name syntax at char {0} + + {1} +{2}{3}""".format( + pos, typename, "-" * (4 + pos) + "^", in_file + ) + ) + + +def _parse_expect(what, tokens, i, typename, file): + if i >= len(tokens): + _parse_error(len(typename), typename, file) + + if what is not None and tokens[i].group(0) != what: + _parse_error(tokens[i].start() + 1, typename, file) + + +def _parse_maybe_quote(quoted, quote): + if quote: + return quoted + else: + return eval(quoted) + + +def _parse_node(tokens, i, typename, file, quote): + _parse_expect(None, tokens, i, typename, file) + + has2 = i + 1 < len(tokens) + + if tokens[i].group(0) == ",": + _parse_error(tokens[i].start() + 1, typename, file) + + elif tokens[i].group(0) == "Bool_t": + return i + 1, _parse_maybe_quote('numpy.dtype("?")', quote) + elif tokens[i].group(0) == "bool": + return i + 1, _parse_maybe_quote('numpy.dtype("?")', quote) + + elif tokens[i].group(0) == "Char_t": + return i + 1, _parse_maybe_quote('numpy.dtype("i1")', quote) + elif tokens[i].group(0) == "char": + return i + 1, _parse_maybe_quote('numpy.dtype("i1")', quote) + elif tokens[i].group(0) == "UChar_t": + return i + 1, _parse_maybe_quote('numpy.dtype("u1")', quote) + elif has2 and tokens[i].group(0) == "unsigned" and tokens[i + 1].group(0) == "char": + return i + 2, _parse_maybe_quote('numpy.dtype("u1")', quote) + + elif tokens[i].group(0) == "Short_t": + return i + 1, _parse_maybe_quote('numpy.dtype(">i2")', quote) + elif tokens[i].group(0) == "short": + return i + 1, _parse_maybe_quote('numpy.dtype(">i2")', quote) + elif tokens[i].group(0) == "UShort_t": + return i + 1, _parse_maybe_quote('numpy.dtype(">u2")', quote) + elif ( + has2 and tokens[i].group(0) == "unsigned" and tokens[i + 1].group(0) == "short" + ): + return i + 2, _parse_maybe_quote('numpy.dtype(">u2")', quote) + + elif tokens[i].group(0) == "Int_t": + return i + 1, _parse_maybe_quote('numpy.dtype(">i4")', quote) + elif tokens[i].group(0) == "int": + return i + 1, _parse_maybe_quote('numpy.dtype(">i4")', quote) + elif tokens[i].group(0) == "UInt_t": + return i + 1, _parse_maybe_quote('numpy.dtype(">u4")', quote) + elif has2 and tokens[i].group(0) == "unsigned" and tokens[i + 1].group(0) == "int": + return i + 2, _parse_maybe_quote('numpy.dtype(">u4")', quote) + + elif tokens[i].group(0) == "Long_t": + return i + 1, _parse_maybe_quote('numpy.dtype(">i8")', quote) + elif tokens[i].group(0) == "Long64_t": + return i + 1, _parse_maybe_quote('numpy.dtype(">i8")', quote) + elif tokens[i].group(0) == "long": + return i + 1, _parse_maybe_quote('numpy.dtype(">i8")', quote) + elif tokens[i].group(0) == "ULong_t": + return i + 1, _parse_maybe_quote('numpy.dtype(">u8")', quote) + elif tokens[i].group(0) == "ULong64_t": + return i + 1, _parse_maybe_quote('numpy.dtype(">u8")', quote) + elif has2 and tokens[i].group(0) == "unsigned" and tokens[i + 1].group(0) == "long": + return i + 2, _parse_maybe_quote('numpy.dtype(">u8")', quote) + + elif tokens[i].group(0) == "Float_t": + return i + 1, _parse_maybe_quote('numpy.dtype(">f4")', quote) + elif tokens[i].group(0) == "float": + return i + 1, _parse_maybe_quote('numpy.dtype(">f4")', quote) + + elif tokens[i].group(0) == "Double_t": + return i + 1, _parse_maybe_quote('numpy.dtype(">f8")', quote) + elif tokens[i].group(0) == "double": + return i + 1, _parse_maybe_quote('numpy.dtype(">f8")', quote) + + elif tokens[i].group(0) == "string" or _simplify_token(tokens[i]) == "std::string": + return i + 1, _parse_maybe_quote("uproot4.stl_containers.AsString()", quote) + elif tokens[i].group(0) == "TString": + return ( + i + 1, + _parse_maybe_quote("uproot4.stl_containers.AsString(is_stl=False)", quote), + ) + elif _simplify_token(tokens[i]) == "char*": + return ( + i + 1, + _parse_maybe_quote("uproot4.stl_containers.AsString(is_stl=False)", quote), + ) + elif ( + has2 + and tokens[i].group(0) == "const" + and _simplify_token(tokens[i + 1]) == "char*" + ): + return ( + i + 2, + _parse_maybe_quote("uproot4.stl_containers.AsString(is_stl=False)", quote), + ) + + elif tokens[i].group(0) == "vector" or _simplify_token(tokens[i]) == "std::vector": + _parse_expect("<", tokens, i + 1, typename, file) + i, values = _parse_node(tokens, i + 2, typename, file, quote) + _parse_expect(">", tokens, i, typename, file) + if quote: + return i + 1, "uproot4.stl_containers.AsVector({0})".format(values) + else: + return i + 1, AsVector(values) + + elif tokens[i].group(0) == "set" or _simplify_token(tokens[i]) == "std::set": + _parse_expect("<", tokens, i + 1, typename, file) + i, keys = _parse_node(tokens, i + 2, typename, file, quote) + _parse_expect(">", tokens, i, typename, file) + if quote: + return i + 1, "uproot4.stl_containers.AsSet({0})".format(keys) + else: + return i + 1, AsSet(keys) + + elif tokens[i].group(0) == "map" or _simplify_token(tokens[i]) == "std::map": + _parse_expect("<", tokens, i + 1, typename, file) + i, keys = _parse_node(tokens, i + 2, typename, file, quote) + _parse_expect(",", tokens, i, typename, file) + i, values = _parse_node(tokens, i + 1, typename, file, quote) + _parse_expect(">", tokens, i, typename, file) + if quote: + return i + 1, "uproot4.stl_containers.AsMap({0}, {1})".format(keys, values) + else: + return i + 1, AsMap(keys, values) + + else: + start, stop = tokens[i].span() + + if has2 and tokens[i + 1].group(0) == "<": + i, keys = _parse_node(tokens, i + 1, typename, file, quote) + _parse_expect(">", tokens, i + 1, typename, file) + stop = tokens[i + 1].span()[1] + i += 1 + + classname = typename[start:stop] + + if quote: + return "c({0})".format(repr(classname)) + elif file is None: + cls = uproot4.classes[classname] + else: + cls = file.class_named(classname) + + return i + 1, cls + + +def parse_typename(typename, file=None, quote=False): + tokens = list(_tokenize_typename_pattern.finditer(typename)) + i, out = _parse_node(tokens, 0, typename, file, quote) + + if i < len(tokens): + _parse_error(tokens[i].start(), typename, file) + + return out + + +def _read_nested(model, length, chunk, cursor, context, file, parent): + if isinstance(model, numpy.dtype): + return cursor.array(chunk, length, model, context) + + elif isinstance(model, AsSTLContainer): + return model.read(chunk, cursor, context, file, parent, multiplicity=length) + + else: + values = numpy.empty(length, dtype=_stl_object_type) + for i in range(length): + values[i] = model.read(chunk, cursor, context, file, parent) + return values + + +def _tostring(value): + if uproot4._util.isstr(value): + return repr(value) + else: + return str(value) + + +def _str_with_ellipsis(tostring, length, lbracket, rbracket, limit): + leftlen = len(lbracket) + rightlen = len(rbracket) + left, right, i, j, done = [], [], 0, length - 1, False + + while True: + if i > j: + done = True + break + x = tostring(i) + ("" if i == length - 1 else ", ") + i += 1 + dotslen = 0 if i > j else 5 + if leftlen + rightlen + len(x) + dotslen > limit: + break + left.append(x) + leftlen += len(x) + + if i > j: + done = True + break + y = tostring(j) + ("" if j == length - 1 else ", ") + j -= 1 + dotslen = 0 if i > j else 5 + if leftlen + rightlen + len(y) + dotslen > limit: + break + right.insert(0, y) + rightlen += len(y) + + if length == 0: + return lbracket + rbracket + elif done: + return lbracket + "".join(left) + "".join(right) + rbracket + elif len(left) == 0 and len(right) == 0: + return lbracket + "{0}, ...".format(tostring(0)) + rbracket + elif len(right) == 0: + return lbracket + "".join(left) + "..." + rbracket + else: + return lbracket + "".join(left) + "..., " + "".join(right) + rbracket + + +class AsSTLContainer(object): + @property + def classname(self): + raise AssertionError + + def read(self, chunk, cursor, context, file, parent, multiplicity=None): + raise AssertionError + + def __eq__(self, other): + raise AssertionError + + def __ne__(self, other): + return not self == other + + +class STLContainer(object): + pass + + +class AsString(AsSTLContainer): + def __init__(self, is_stl=True): + self._is_stl = is_stl + + def __hash__(self): + return hash((AsString, self._is_stl)) + + @property + def is_stl(self): + return self._is_stl + + def __repr__(self): + is_stl = "" + if not self._is_stl: + is_stl = "is_stl=False" + return "AsString({0})".format(is_stl) + + @property + def classname(self): + if self._is_stl: + return "std::string" + else: + return "const char*" + + def read(self, chunk, cursor, context, file, parent, multiplicity=None): + if self._is_stl: + start_cursor = cursor.copy() + num_bytes, instance_version = uproot4.deserialization.numbytes_version( + chunk, cursor, context + ) + + if multiplicity is None: + out = cursor.string(chunk, context) + else: + out = numpy.empty(multiplicity, dtype=_stl_object_type) + for i in range(multiplicity): + out[i] = cursor.string(chunk, context) + + if self._is_stl: + uproot4.deserialization.numbytes_check( + start_cursor, cursor, num_bytes, self.classname, context, file.file_path + ) + + return out + + def __eq__(self, other): + return isinstance(other, AsString) and self.is_stl == other.is_stl + + +class AsVector(AsSTLContainer): + def __init__(self, values): + if isinstance(values, AsSTLContainer): + self._values = values + elif isinstance(values, type) and issubclass(values, uproot4.model.Model): + self._values = values + else: + self._values = numpy.dtype(values) + + def __hash__(self): + return hash((AsVector, self._values)) + + @property + def values(self): + return self._values + + def __repr__(self): + return "AsVector({0})".format(repr(self._values)) + + @property + def classname(self): + values = _stl_primitive_types.get(self._values) + if values is None: + values = self._values.classname + return "std::vector<{0}>".format(values) + + def read(self, chunk, cursor, context, file, parent, multiplicity=None): + start_cursor = cursor.copy() + num_bytes, instance_version = uproot4.deserialization.numbytes_version( + chunk, cursor, context + ) + + length = cursor.field(chunk, _stl_container_size, context) + + if multiplicity is None: + values = _read_nested( + self._values, length, chunk, cursor, context, file, parent + ) + out = STLVector(values) + + else: + out = numpy.empty(multiplicity, dtype=_stl_object_type) + for i in range(multiplicity): + values = _read_nested( + self._values, length, chunk, cursor, context, file, parent + ) + out[i] = STLVector(values) + + uproot4.deserialization.numbytes_check( + start_cursor, cursor, num_bytes, self.classname, context, file.file_path, + ) + + return out + + def __eq__(self, other): + return isinstance(other, AsVector) and self.values == other.values + + +class STLVector(STLContainer, Sequence): + def __init__(self, values): + if isinstance(values, types.GeneratorType): + values = numpy.asarray(list(values)) + elif isinstance(values, Set): + values = numpy.asarray(list(values)) + elif isinstance(values, (list, tuple)): + values = numpy.asarray(values) + + self._values = values + + def __str__(self, limit=85): + def tostring(i): + return _tostring(self._values[i]) + + return _str_with_ellipsis(tostring, len(self), "[", "]", limit) + + def __repr__(self, limit=85): + return "".format( + self.__str__(limit=limit - 30), id(self) + ) + + def __getitem__(self, where): + return self._values[where] + + def __len__(self): + return len(self._values) + + def __contains__(self, what): + return what in self._values + + def __iter__(self): + return iter(self._values) + + def __reversed__(self): + return STLVector(self._values[::-1]) + + def __eq__(self, other): + if isinstance(other, STLVector): + return self._values == other._values + elif isinstance(other, Sequence): + return self._values == other + else: + return False + + def __ne__(self, other): + return not self == other + + +class AsSet(AsSTLContainer): + def __init__(self, keys): + if isinstance(keys, AsSTLContainer): + self._keys = keys + elif isinstance(keys, type) and issubclass(keys, uproot4.model.Model): + self._keys = keys + else: + self._keys = numpy.dtype(keys) + + def __hash__(self): + return hash((AsSet, self._keys)) + + @property + def keys(self): + return self._keys + + def __repr__(self): + return "AsSet({0})".format(repr(self._keys)) + + @property + def classname(self): + keys = _stl_primitive_types.get(self._keys) + if keys is None: + keys = self._keys.classname + return "std::set<{0}>".format(keys) + + def read(self, chunk, cursor, context, file, parent, multiplicity=None): + start_cursor = cursor.copy() + num_bytes, instance_version = uproot4.deserialization.numbytes_version( + chunk, cursor, context + ) + + length = cursor.field(chunk, _stl_container_size, context) + + if multiplicity is None: + keys = _read_nested( + self._keys, length, chunk, cursor, context, file, parent + ) + out = STLSet(keys) + + else: + out = numpy.empty(multiplicity, dtype=_stl_object_type) + for i in range(multiplicity): + keys = _read_nested( + self._keys, length, chunk, cursor, context, file, parent + ) + out[i] = STLSet(keys) + + uproot4.deserialization.numbytes_check( + start_cursor, cursor, num_bytes, self.classname, context, file.file_path, + ) + + return out + + def __eq__(self, other): + return isinstance(other, AsSet) and self.keys == other.keys + + +class STLSet(STLContainer, Set): + def __init__(self, keys): + if isinstance(keys, types.GeneratorType): + keys = numpy.asarray(list(keys)) + elif isinstance(keys, Set): + keys = numpy.asarray(list(keys)) + else: + keys = numpy.asarray(keys) + + self._keys = numpy.sort(keys) + + def __str__(self, limit=85): + def tostring(i): + return _tostring(self._keys[i]) + + return _str_with_ellipsis(tostring, len(self), "{", "}", limit) + + def __repr__(self, limit=85): + return "".format( + self.__str__(limit=limit - 30), id(self) + ) + + def __len__(self): + return len(self._keys) + + def __iter__(self): + return iter(self._keys) + + def __contains__(self, where): + where = numpy.asarray(where) + index = numpy.searchsorted(self._keys.astype(where.dtype), where, side="left") + + if uproot4._util.isint(index): + if index < len(self._keys) and self._keys[index] == where: + return True + else: + return False + + else: + return False + + def __eq__(self, other): + if isinstance(other, Set): + if not isinstance(other, STLSet): + other = STLSet(other) + else: + return False + + if len(self._keys) != len(other._keys): + return False + + keys_same = self._keys == other._keys + if isinstance(keys_same, bool): + return keys_same + else: + return numpy.all(keys_same) + + def __ne__(self, other): + return not self == other + + +class AsMap(AsSTLContainer): + def __init__(self, keys, values): + if isinstance(keys, AsSTLContainer): + self._keys = keys + else: + self._keys = numpy.dtype(keys) + + if isinstance(values, AsSTLContainer): + self._values = values + elif isinstance(values, type) and issubclass(values, uproot4.model.Model): + self._values = values + else: + self._values = numpy.dtype(values) + + def __hash__(self): + return hash((AsMap, self._keys, self._values)) + + @property + def keys(self): + return self._keys + + @property + def values(self): + return self._values + + def __repr__(self): + return "AsMap({0}, {1})".format(repr(self._keys), repr(self._values)) + + @property + def classname(self): + keys = _stl_primitive_types.get(self._keys) + if keys is None: + keys = self._keys.classname + values = _stl_primitive_types.get(self._values) + if values is None: + values = self._values.classname + return "std::map<{0}, {1}>".format(keys, values) + + def read(self, chunk, cursor, context, file, parent, multiplicity=None): + start_cursor = cursor.copy() + num_bytes, instance_version = uproot4.deserialization.numbytes_version( + chunk, cursor, context + ) + + cursor.skip(6) + + length = cursor.field(chunk, _stl_container_size, context) + + if multiplicity is None: + keys = _read_nested( + self._keys, length, chunk, cursor, context, file, parent + ) + values = _read_nested( + self._values, length, chunk, cursor, context, file, parent + ) + out = STLMap(keys, values) + + else: + out = numpy.empty(multiplicity, dtype=_stl_object_type) + for i in range(multiplicity): + keys = _read_nested( + self._keys, length, chunk, cursor, context, file, parent + ) + values = _read_nested( + self._values, length, chunk, cursor, context, file, parent + ) + out[i] = STLMap(keys, values) + + uproot4.deserialization.numbytes_check( + start_cursor, cursor, num_bytes, self.classname, context, file.file_path, + ) + + return out + + def __eq__(self, other): + return ( + isinstance(other, AsMap) + and self.keys == other.keys + and self.values == other.values + ) + + +class STLMap(STLContainer, Mapping): + @classmethod + def from_mapping(cls, mapping): + return STLMap(mapping.keys(), mapping.values()) + + def __init__(self, keys, values): + if KeysView is not None and isinstance(keys, KeysView): + keys = numpy.asarray(list(keys)) + elif isinstance(keys, types.GeneratorType): + keys = numpy.asarray(list(keys)) + elif isinstance(keys, Set): + keys = numpy.asarray(list(keys)) + else: + keys = numpy.asarray(keys) + + if ValuesView is not None and isinstance(values, ValuesView): + values = numpy.asarray(list(values)) + elif isinstance(values, types.GeneratorType): + values = numpy.asarray(list(values)) + + if len(keys) != len(values): + raise ValueError("number of keys must be equal to the number of values") + + index = numpy.argsort(keys) + + self._keys = keys[index] + try: + self._values = values[index] + except Exception: + self._values = numpy.asarray(values)[index] + + def __str__(self, limit=85): + def tostring(i): + return _tostring(self._keys[i]) + ": " + _tostring(self._values[i]) + + return _str_with_ellipsis(tostring, len(self), "{", "}", limit) + + def __repr__(self, limit=85): + return "".format( + self.__str__(limit=limit - 30), id(self) + ) + + def __getitem__(self, where): + where = numpy.asarray(where) + index = numpy.searchsorted(self._keys.astype(where.dtype), where, side="left") + + if uproot4._util.isint(index): + if index < len(self._keys) and self._keys[index] == where: + return self._values[index] + else: + raise KeyError(where) + + elif len(self._keys) == 0: + values = numpy.empty(len(index)) + return numpy.ma.MaskedArray(values, True) + + else: + index[index >= len(self._keys)] = 0 + mask = self._keys[index] != where + return numpy.ma.MaskedArray(self._values[index], mask) + + def get(self, where, default=None): + where = numpy.asarray(where) + index = numpy.searchsorted(self._keys.astype(where.dtype), where, side="left") + + if uproot4._util.isint(index): + if index < len(self._keys) and self._keys[index] == where: + return self._values[index] + else: + return default + + elif len(self._keys) == 0: + return numpy.array([default])[numpy.zeros(len(index), numpy.int32)] + + else: + index[index >= len(self._keys)] = 0 + matches = self._keys[index] == where + values = self._values[index] + defaults = numpy.array([default])[numpy.zeros(len(index), numpy.int32)] + return numpy.where(matches, values, defaults) + + def __len__(self): + return len(self._keys) + + def __iter__(self): + return iter(self._keys) + + def __contains__(self, where): + where = numpy.asarray(where) + index = numpy.searchsorted(self._keys.astype(where.dtype), where, side="left") + + if uproot4._util.isint(index): + if index < len(self._keys) and self._keys[index] == where: + return True + else: + return False + + else: + return False + + def keys(self): + return self._keys + + def values(self): + return self._values + + def items(self): + return numpy.transpose(numpy.vstack([self._keys, self._values])) + + def __eq__(self, other): + if isinstance(other, Mapping): + if not isinstance(other, STLMap): + other = STLMap(other.keys(), other.values()) + else: + return False + + if len(self._keys) != len(other._keys): + return False + + keys_same = self._keys == other._keys + values_same = self._values == other._values + if isinstance(keys_same, bool) and isinstance(values_same, bool): + return keys_same and values_same + else: + return numpy.logical_and(keys_same, values_same).all() + + def __ne__(self, other): + return not self == other diff --git a/uproot4/streamers.py b/uproot4/streamers.py index 078f3f24d..2ea1f416d 100644 --- a/uproot4/streamers.py +++ b/uproot4/streamers.py @@ -8,6 +8,7 @@ import numpy +import uproot4._util import uproot4.model import uproot4.const import uproot4.deserialization @@ -74,9 +75,9 @@ def _ftype_to_dtype(fType): elif fType in (uproot4.const.kBits, uproot4.const.kUInt, uproot4.const.kCounter): return "numpy.dtype('>u4')" elif fType == uproot4.const.kLong: - return "numpy.dtype(numpy.long).newbyteorder('>')" + return "numpy.dtype('>i8')" elif fType == uproot4.const.kULong: - return "numpy.dtype('>u' + repr(numpy.dtype(numpy.long).itemsize))" + return "numpy.dtype('>u8')" elif fType == uproot4.const.kLong64: return "numpy.dtype('>i8')" elif fType == uproot4.const.kULong64: @@ -105,9 +106,9 @@ def _ftype_to_struct(fType): elif fType in (uproot4.const.kBits, uproot4.const.kUInt, uproot4.const.kCounter): return "I" elif fType == uproot4.const.kLong: - return "l" + return "q" elif fType == uproot4.const.kULong: - return "L" + return "Q" elif fType == uproot4.const.kLong64: return "q" elif fType == uproot4.const.kULong64: @@ -144,6 +145,7 @@ def read_members(self, chunk, cursor, context): def postprocess(self, chunk, cursor, context): # prevent circular dependencies and long-lived references to files + self._file_uuid = self._file.uuid self._file = None self._parent = None return self @@ -165,6 +167,10 @@ def class_version(self): def elements(self): return self._members["fElements"] + @property + def file_uuid(self): + return self._file_uuid + def _dependencies(self, streamers, out): out.append((self.name, self.class_version)) for element in self.elements: @@ -189,6 +195,9 @@ def show(self, stream=sys.stdout): def new_class(self, file): class_code = self.class_code() + + print(class_code) + class_name = uproot4.model.classname_encode(self.name, self.class_version) classes = uproot4.model.maybe_custom_classes(file.custom_classes) return uproot4.deserialization.compile_class( @@ -200,6 +209,7 @@ def class_code(self): fields = [] formats = [] dtypes = [] + stl_containers = [] base_names_versions = [] member_names = [] class_flags = {} @@ -219,6 +229,7 @@ def class_code(self): fields, formats, dtypes, + stl_containers, base_names_versions, member_names, class_flags, @@ -234,12 +245,18 @@ def class_code(self): read_members.append("") class_data = [] + for i, format in enumerate(formats): class_data.append( " _format{0} = struct.Struct('>{1}')".format(i, "".join(format)) ) + for i, dt in enumerate(dtypes): class_data.append(" _dtype{0} = {1}".format(i, dt)) + + for i, stl in enumerate(stl_containers): + class_data.append(" _stl_container{0} = {1}".format(i, stl)) + class_data.append( " base_names_versions = [{0}]".format( ", ".join( @@ -248,19 +265,22 @@ def class_code(self): ) ) ) + class_data.append( " member_names = [{0}]".format(", ".join(repr(x) for x in member_names)) ) + class_data.append( " class_flags = {{{0}}}".format( ", ".join(repr(k) + ": " + repr(v) for k, v in class_flags.items()) ) ) + class_data.append(" hooks = {}") return "\n".join( [ - "class {0}(VersionedModel):".format( + "class {0}(uproot4.model.VersionedModel):".format( uproot4.model.classname_encode(self.name, self.class_version) ) ] @@ -321,6 +341,7 @@ def read_members(self, chunk, cursor, context): def postprocess(self, chunk, cursor, context): # prevent circular dependencies and long-lived references to files + self._file_uuid = self._file.uuid self._file = None self._parent = None return self @@ -345,6 +366,10 @@ def array_length(self): def fType(self): return self.member("fType") + @property + def file_uuid(self): + return self._file_uuid + def _dependencies(self, streamers, out): pass @@ -373,6 +398,7 @@ def class_code( fields, formats, dtypes, + stl_containers, base_names_versions, member_names, class_flags, @@ -428,6 +454,7 @@ def class_code( fields, formats, dtypes, + stl_containers, base_names_versions, member_names, class_flags, @@ -470,6 +497,7 @@ def class_code( fields, formats, dtypes, + stl_containers, base_names_versions, member_names, class_flags, @@ -573,6 +601,7 @@ def class_code( fields, formats, dtypes, + stl_containers, base_names_versions, member_names, class_flags, @@ -659,6 +688,7 @@ def class_code( fields, formats, dtypes, + stl_containers, base_names_versions, member_names, class_flags, @@ -712,88 +742,6 @@ def stl_type(self): def fCtype(self): return self._members["fCtype"] - @property - def is_string(self): - return self.stl_type == uproot4.const.kSTLstring or self.typename == "string" - - @property - def is_vector_dtype(self): - return self.vector_dtype is not None - - @property - def vector_dtype(self): - if self.stl_type == uproot4.const.kSTLvector: - if self.fCtype == uproot4.const.kBool: - return "numpy.dtype('?')" - elif self.fCtype == uproot4.const.kChar: - return "numpy.dtype('i1')" - elif self.fCtype == uproot4.const.kShort: - return "numpy.dtype('>i2')" - elif self.fCtype == uproot4.const.kInt: - return "numpy.dtype('>i4')" - elif self.fCtype == uproot4.const.kLong: - return "numpy.dtype(numpy.long).newbyteorder('>')" - elif self.fCtype == uproot4.const.kLong64: - return "numpy.dtype('>i8')" - elif self.fCtype == uproot4.const.kUChar: - return "numpy.dtype('u1')" - elif self.fCtype == uproot4.const.kUShort: - return "numpy.dtype('>u2')" - elif self.fCtype == uproot4.const.kUInt: - return "numpy.dtype('>u4')" - elif self.fCtype == uproot4.const.kULong: - return "numpy.dtype('>u' + repr(numpy.dtype(numpy.long).itemsize))" - elif self.fCtype == uproot4.const.kULong64: - return "numpy.dtype('>u8')" - elif self.fCtype == uproot4.const.kFloat: - return "numpy.dtype('>f4')" - elif self.fCtype == uproot4.const.kDouble: - return "numpy.dtype('>f8')" - - if self.typename == "vector" or self.typename == "vector": - return "numpy.dtype('?')" - elif self.typename == "vector" or self.typename == "vector": - return "numpy.dtype('i1')" - elif self.typename == "vector" or self.typename == "vector": - return "numpy.dtype('>i2')" - elif self.typename == "vector" or self.typename == "vector": - return "numpy.dtype('>i4')" - elif self.typename == "vector" or self.typename == "vector": - return "numpy.dtype(numpy.long).newbyteorder('>')" - elif self.typename == "vector": - return "numpy.dtype('>i8')" - elif ( - self.typename == "vector" - or self.typename == "vector" - ): - return "numpy.dtype('u1')" - elif ( - self.typename == "vector" - or self.typename == "vector" - ): - return "numpy.dtype('>u2')" - elif ( - self.typename == "vector" or self.typename == "vector" - ): - return "numpy.dtype('>u4')" - elif ( - self.typename == "vector" - or self.typename == "vector" - ): - return "numpy.dtype('>u' + repr(numpy.dtype(numpy.long).itemsize))" - elif self.typename == "vector": - return "numpy.dtype('>u8')" - elif self.typename == "vector" or self.typename == "vector": - return "numpy.dtype('>f4')" - elif self.typename == "vector" or self.typename == "vector": - return "numpy.dtype('>f8')" - else: - return None - - @property - def is_map_string_string(self): - return self.typename == "map" - def class_code( self, streamerinfo, @@ -803,51 +751,26 @@ def class_code( fields, formats, dtypes, + stl_containers, base_names_versions, member_names, class_flags, ): - if self.is_string: - read_members.append(" cursor.skip(6)") - read_members.append( - " self._members[{0}] = cursor.string(chunk, context)".format( - repr(self.name) - ) - ) - - elif self.is_vector_dtype: - read_members.append(" cursor.skip(6)") - read_members.append( - " tmp = cursor.field(chunk, self._format{0}, context)".format( - len(formats) - ) - ) - read_members.append( - " self._members[{0}] = cursor.array(chunk, tmp, " - "self._dtype{1}, context)".format(repr(self.name), len(dtypes)) - ) - formats.append(["i"]) - dtypes.append(self.vector_dtype) - - elif self.is_map_string_string: - read_members.append( - " self._members[{0}] = map_string_string(chunk, cursor, context)" - ) + stl_container = uproot4.stl_containers.parse_typename(self.typename, quote=True) + read_members.append( + " self._members[{0}] = self._stl_container{1}.read(" + "chunk, cursor, context, self._file, self._parent, multiplicity=1)" + "".format(repr(self.name), len(stl_containers)) + ) + stl_containers.append(stl_container) - else: - read_members.append( - " raise NotImplementedError('class members defined by " - "{0} with type {1}')".format(type(self).__name__, self.typename) - ) member_names.append(self.name) -class Model_TStreamerSTLstring(Model_TStreamerElement): +class Model_TStreamerSTLstring(Model_TStreamerSTL): def read_members(self, chunk, cursor, context): self._bases.append( - Model_TStreamerElement.read( - chunk, cursor, context, self._file, self._parent - ) + Model_TStreamerSTL.read(chunk, cursor, context, self._file, self._parent) ) def class_code( @@ -859,6 +782,7 @@ def class_code( fields, formats, dtypes, + stl_containers, base_names_versions, member_names, class_flags, @@ -887,6 +811,7 @@ def class_code( fields, formats, dtypes, + stl_containers, base_names_versions, member_names, class_flags, @@ -947,6 +872,7 @@ def class_code( fields, formats, dtypes, + stl_containers, base_names_versions, member_names, class_flags,