From aa8b94f722982c6eb418f9f32273152039836fb2 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Thu, 15 Aug 2024 10:23:38 -0500 Subject: [PATCH] fix: explicit 'import awkward' needed to write NumPy strings (#1266) --- src/uproot/writing/_cascadetree.py | 19 ++++++++++++++----- .../test_1264_write_NumPy_array_of_strings.py | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+), 5 deletions(-) create mode 100644 tests/test_1264_write_NumPy_array_of_strings.py diff --git a/src/uproot/writing/_cascadetree.py b/src/uproot/writing/_cascadetree.py index 9ecd2e87f..bf5491cec 100644 --- a/src/uproot/writing/_cascadetree.py +++ b/src/uproot/writing/_cascadetree.py @@ -663,11 +663,20 @@ def extend(self, file, sink, data): if datum["counter"] is None: if datum["dtype"] == ">U0": - lengths = numpy.asarray(awkward.num(branch_array.layout)) + awkward = uproot.extras.awkward() + + layout = awkward.to_layout(branch_array) + if isinstance( + layout, + (awkward.contents.ListArray, awkward.contents.RegularArray), + ): + layout = layout.to_ListOffsetArray64() + + lengths = numpy.asarray(awkward.num(layout)) which_big = lengths >= 255 lengths_extension_offsets = numpy.empty( - len(branch_array.layout) + 1, numpy.int64 + len(layout) + 1, numpy.int64 ) lengths_extension_offsets[0] = 0 numpy.cumsum(which_big * 4, out=lengths_extension_offsets[1:]) @@ -685,7 +694,7 @@ def extend(self, file, sink, data): [ lengths.reshape(-1, 1).astype("u1"), lengths_extension, - awkward.without_parameters(branch_array.layout), + awkward.without_parameters(layout), ], axis=1, ) @@ -693,8 +702,8 @@ def extend(self, file, sink, data): big_endian = numpy.asarray(awkward.flatten(leafc_data_awkward)) big_endian_offsets = ( lengths_extension_offsets - + numpy.asarray(branch_array.layout.offsets) - + numpy.arange(len(branch_array.layout.offsets)) + + numpy.asarray(layout.offsets) + + numpy.arange(len(layout.offsets)) ).astype(">i4", copy=True) tofill.append( ( diff --git a/tests/test_1264_write_NumPy_array_of_strings.py b/tests/test_1264_write_NumPy_array_of_strings.py new file mode 100644 index 000000000..0872bf417 --- /dev/null +++ b/tests/test_1264_write_NumPy_array_of_strings.py @@ -0,0 +1,18 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE + +import pytest +import uproot +import os +import numpy as np + + +def test(tmp_path): + newfile = os.path.join(tmp_path, "example.root") + + with uproot.recreate(newfile) as f: + f["t"] = {"x": np.array(["A", "B"]), "y": np.array([1, 2])} + f["t"].extend({"x": np.array(["A", "B"]), "y": np.array([1, 2])}) + + with uproot.open(newfile) as f: + assert f["t"]["x"].array().tolist() == ["A", "B", "A", "B"] + assert f["t"]["y"].array().tolist() == [1, 2, 1, 2]