diff --git a/src/lgdo/lh5/_serializers/read/array.py b/src/lgdo/lh5/_serializers/read/array.py index 49f71600..138f0b49 100644 --- a/src/lgdo/lh5/_serializers/read/array.py +++ b/src/lgdo/lh5/_serializers/read/array.py @@ -9,26 +9,26 @@ log = logging.getLogger(__name__) -def _h5_read_array_generic(type_, name, h5f, **kwargs): - nda, attrs, n_rows_to_read = _h5_read_ndarray(name, h5f, **kwargs) +def _h5_read_array_generic(type_, h5d, **kwargs): + nda, attrs, n_rows_to_read = _h5_read_ndarray(h5d, **kwargs) obj_buf = kwargs["obj_buf"] if obj_buf is None: return type_(nda=nda, attrs=attrs), n_rows_to_read - utils.check_obj_buf_attrs(obj_buf.attrs, attrs, h5f, name) + utils.check_obj_buf_attrs(obj_buf.attrs, attrs, h5d) return obj_buf, n_rows_to_read -def _h5_read_array(name, h5f, **kwargs): - return _h5_read_array_generic(Array, name, h5f, **kwargs) +def _h5_read_array(h5d, **kwargs): + return _h5_read_array_generic(Array, h5d, **kwargs) -def _h5_read_fixedsize_array(name, h5f, **kwargs): - return _h5_read_array_generic(FixedSizeArray, name, h5f, **kwargs) +def _h5_read_fixedsize_array(h5d, **kwargs): + return _h5_read_array_generic(FixedSizeArray, h5d, **kwargs) -def _h5_read_array_of_equalsized_arrays(name, h5f, **kwargs): - return _h5_read_array_generic(ArrayOfEqualSizedArrays, name, h5f, **kwargs) +def _h5_read_array_of_equalsized_arrays(h5d, **kwargs): + return _h5_read_array_generic(ArrayOfEqualSizedArrays, h5d, **kwargs) diff --git a/src/lgdo/lh5/_serializers/read/composite.py b/src/lgdo/lh5/_serializers/read/composite.py index f9a84b2c..6c0e9ba3 100644 --- a/src/lgdo/lh5/_serializers/read/composite.py +++ b/src/lgdo/lh5/_serializers/read/composite.py @@ -40,8 +40,7 @@ def _h5_read_lgdo( - name, - h5f, + h5o, start_row=0, n_rows=sys.maxsize, idx=None, @@ -52,11 +51,11 @@ def _h5_read_lgdo( decompress=True, ): # Handle list-of-files recursively - if not isinstance(h5f, (str, h5py.File)): - lh5_file = list(h5f) + if not isinstance(h5o, (h5py.Group, h5py.Dataset)): + lh5_objs = list(h5o) n_rows_read = 0 - for i, _h5f in enumerate(lh5_file): + for i, _h5o in enumerate(lh5_objs): if isinstance(idx, list) and len(idx) > 0 and not np.isscalar(idx[0]): # a list of lists: must be one per file idx_i = idx[i] @@ -65,7 +64,7 @@ def _h5_read_lgdo( if not (isinstance(idx, tuple) and len(idx) == 1): idx = (idx,) # idx is a long continuous array - n_rows_i = read_n_rows(name, _h5f) + n_rows_i = read_n_rows(_h5o) # find the length of the subset of idx that contains indices # that are less than n_rows_i n_rows_to_read_i = bisect.bisect_left(idx[0], n_rows_i) @@ -77,8 +76,7 @@ def _h5_read_lgdo( n_rows_i = n_rows - n_rows_read obj_buf, n_rows_read_i = _h5_read_lgdo( - name, - _h5f, + _h5o, start_row=start_row, n_rows=n_rows_i, idx=idx_i, @@ -97,11 +95,8 @@ def _h5_read_lgdo( return obj_buf, n_rows_read - if not isinstance(h5f, h5py.File): - h5f = h5py.File(h5f, mode="r") - log.debug( - f"reading {h5f.filename}:{name}[{start_row}:{n_rows}], decompress = {decompress}, " + f"reading {h5o.file.filename}:{h5o.name}[{start_row}:{n_rows}], decompress = {decompress}, " + (f" with field mask {field_mask}" if field_mask else "") ) @@ -110,15 +105,14 @@ def _h5_read_lgdo( idx = (idx,) try: - lgdotype = dtypeutils.datatype(h5f[name].attrs["datatype"]) + lgdotype = dtypeutils.datatype(h5o.attrs["datatype"]) except KeyError as e: msg = "dataset not in file or missing 'datatype' attribute" - raise LH5DecodeError(msg, h5f, name) from e + raise LH5DecodeError(msg, h5o) from e if lgdotype is Scalar: return _h5_read_scalar( - name, - h5f, + h5o, obj_buf=obj_buf, ) @@ -138,8 +132,7 @@ def _h5_read_lgdo( if lgdotype is Struct: return _h5_read_struct( - name, - h5f, + h5o, start_row=start_row, n_rows=n_rows, idx=idx, @@ -164,8 +157,7 @@ def _h5_read_lgdo( if lgdotype is Table: return _h5_read_table( - name, - h5f, + h5o, start_row=start_row, n_rows=n_rows, idx=idx, @@ -178,8 +170,7 @@ def _h5_read_lgdo( if lgdotype is ArrayOfEncodedEqualSizedArrays: return _h5_read_array_of_encoded_equalsized_arrays( - name, - h5f, + h5o, start_row=start_row, n_rows=n_rows, idx=idx, @@ -191,8 +182,7 @@ def _h5_read_lgdo( if lgdotype is VectorOfEncodedVectors: return _h5_read_vector_of_encoded_vectors( - name, - h5f, + h5o, start_row=start_row, n_rows=n_rows, idx=idx, @@ -204,8 +194,7 @@ def _h5_read_lgdo( if lgdotype is VectorOfVectors: return _h5_read_vector_of_vectors( - name, - h5f, + h5o, start_row=start_row, n_rows=n_rows, idx=idx, @@ -216,8 +205,7 @@ def _h5_read_lgdo( if lgdotype is FixedSizeArray: return _h5_read_fixedsize_array( - name, - h5f, + h5o, start_row=start_row, n_rows=n_rows, idx=idx, @@ -228,8 +216,7 @@ def _h5_read_lgdo( if lgdotype is ArrayOfEqualSizedArrays: return _h5_read_array_of_equalsized_arrays( - name, - h5f, + h5o, start_row=start_row, n_rows=n_rows, idx=idx, @@ -240,8 +227,7 @@ def _h5_read_lgdo( if lgdotype is Array: return _h5_read_array( - name, - h5f, + h5o, start_row=start_row, n_rows=n_rows, idx=idx, @@ -251,12 +237,11 @@ def _h5_read_lgdo( ) msg = f"no rule to decode {lgdotype.__name__} from LH5" - raise LH5DecodeError(msg, h5f, name) + raise LH5DecodeError(msg, h5o) def _h5_read_struct( - name, - h5f, + h5g, start_row=0, n_rows=sys.maxsize, idx=None, @@ -269,7 +254,7 @@ def _h5_read_struct( # table... Maybe should emit a warning? Or allow them to be # dicts keyed by field name? - attrs = dict(h5f[name].attrs) + attrs = dict(h5g.attrs) # determine fields to be read out all_fields = dtypeutils.get_struct_fields(attrs["datatype"]) @@ -288,8 +273,7 @@ def _h5_read_struct( # support for integer keys field_key = int(field) if attrs.get("int_keys") else str(field) obj_dict[field_key], _ = _h5_read_lgdo( - f"{name}/{field}", - h5f, + h5g[field], start_row=start_row, n_rows=n_rows, idx=idx, @@ -301,8 +285,7 @@ def _h5_read_struct( def _h5_read_table( - name, - h5f, + h5g, start_row=0, n_rows=sys.maxsize, idx=None, @@ -314,9 +297,9 @@ def _h5_read_table( ): if obj_buf is not None and not isinstance(obj_buf, Table): msg = "provided object buffer is not a Table" - raise LH5DecodeError(msg, h5f, name) + raise LH5DecodeError(msg, h5g) - attrs = dict(h5f[name].attrs) + attrs = dict(h5g.attrs) # determine fields to be read out all_fields = dtypeutils.get_struct_fields(attrs["datatype"]) @@ -337,13 +320,12 @@ def _h5_read_table( if obj_buf is not None: if not isinstance(obj_buf, Table) or field not in obj_buf: msg = "provided object buffer is not a Table or columns are missing" - raise LH5DecodeError(msg, h5f, name) + raise LH5DecodeError(msg, h5g) fld_buf = obj_buf[field] col_dict[field], n_rows_read = _h5_read_lgdo( - f"{name}/{field}", - h5f, + h5g[field], start_row=start_row, n_rows=n_rows, idx=idx, @@ -363,12 +345,12 @@ def _h5_read_table( n_rows_read = rows_read[0] else: n_rows_read = 0 - log.warning(f"Table '{name}' has no fields specified by {field_mask=}") + log.warning(f"Table '{h5g.name}' has no fields specified by {field_mask=}") for n in rows_read[1:]: if n != n_rows_read: log.warning( - f"Table '{name}' got strange n_rows_read = {n}, " + f"Table '{h5g.name}' got strange n_rows_read = {n}, " "{n_rows_read} was expected ({rows_read})" ) @@ -400,6 +382,6 @@ def _h5_read_table( obj_buf.loc = obj_buf_start + n_rows_read # check attributes - utils.check_obj_buf_attrs(obj_buf.attrs, attrs, h5f, name) + utils.check_obj_buf_attrs(obj_buf.attrs, attrs, h5g) return obj_buf, n_rows_read diff --git a/src/lgdo/lh5/_serializers/read/encoded.py b/src/lgdo/lh5/_serializers/read/encoded.py index 0e63ebc8..23876c91 100644 --- a/src/lgdo/lh5/_serializers/read/encoded.py +++ b/src/lgdo/lh5/_serializers/read/encoded.py @@ -19,25 +19,22 @@ def _h5_read_array_of_encoded_equalsized_arrays( - name, - h5f, + h5g, **kwargs, ): - return _h5_read_encoded_array(ArrayOfEncodedEqualSizedArrays, name, h5f, **kwargs) + return _h5_read_encoded_array(ArrayOfEncodedEqualSizedArrays, h5g, **kwargs) def _h5_read_vector_of_encoded_vectors( - name, - h5f, + h5g, **kwargs, ): - return _h5_read_encoded_array(VectorOfEncodedVectors, name, h5f, **kwargs) + return _h5_read_encoded_array(VectorOfEncodedVectors, h5g, **kwargs) def _h5_read_encoded_array( lgdotype, - name, - h5f, + h5g, start_row=0, n_rows=sys.maxsize, idx=None, @@ -48,11 +45,11 @@ def _h5_read_encoded_array( ): if lgdotype not in (ArrayOfEncodedEqualSizedArrays, VectorOfEncodedVectors): msg = f"unsupported read of encoded type {lgdotype.__name__}" - raise LH5DecodeError(msg, h5f, name) + raise LH5DecodeError(msg, h5g) if not decompress and obj_buf is not None and not isinstance(obj_buf, lgdotype): msg = f"object buffer is not a {lgdotype.__name__}" - raise LH5DecodeError(msg, h5f, name) + raise LH5DecodeError(msg, h5g) # read out decoded_size, either a Scalar or an Array decoded_size_buf = encoded_data_buf = None @@ -62,8 +59,7 @@ def _h5_read_encoded_array( if lgdotype is VectorOfEncodedVectors: decoded_size, _ = _h5_read_array( - f"{name}/decoded_size", - h5f, + h5g["decoded_size"], start_row=start_row, n_rows=n_rows, idx=idx, @@ -74,15 +70,13 @@ def _h5_read_encoded_array( else: decoded_size, _ = _h5_read_scalar( - f"{name}/decoded_size", - h5f, + h5g["decoded_size"], obj_buf=None if decompress else decoded_size_buf, ) # read out encoded_data, a VectorOfVectors encoded_data, n_rows_read = _h5_read_vector_of_vectors( - f"{name}/encoded_data", - h5f, + h5g["encoded_data"], start_row=start_row, n_rows=n_rows, idx=idx, @@ -99,7 +93,7 @@ def _h5_read_encoded_array( rawdata = lgdotype( encoded_data=encoded_data, decoded_size=decoded_size, - attrs=h5f[name].attrs, + attrs=dict(h5g.attrs), ) # already return if no decompression is requested diff --git a/src/lgdo/lh5/_serializers/read/ndarray.py b/src/lgdo/lh5/_serializers/read/ndarray.py index 46571a32..3d626f70 100644 --- a/src/lgdo/lh5/_serializers/read/ndarray.py +++ b/src/lgdo/lh5/_serializers/read/ndarray.py @@ -14,8 +14,7 @@ def _h5_read_ndarray( - name, - h5f, + h5d, start_row=0, n_rows=sys.maxsize, idx=None, @@ -25,16 +24,16 @@ def _h5_read_ndarray( ): if obj_buf is not None and not isinstance(obj_buf, Array): msg = "object buffer is not an Array" - raise LH5DecodeError(msg, h5f, name) + raise LH5DecodeError(msg, h5d) # compute the number of rows to read # we culled idx above for start_row and n_rows, now we have to apply # the constraint of the length of the dataset try: - ds_n_rows = h5f[name].shape[0] + ds_n_rows = h5d.shape[0] except AttributeError as e: msg = "does not seem to be an HDF5 dataset" - raise LH5DecodeError(msg, h5f, name) from e + raise LH5DecodeError(msg, h5d) from e if idx is not None: if len(idx[0]) > 0 and idx[0][-1] >= ds_n_rows: @@ -78,23 +77,23 @@ def _h5_read_ndarray( # this is required to make the read of multiple files faster # until a better solution found. if change_idx_to_slice or idx is None or use_h5idx: - h5f[name].read_direct(obj_buf.nda, source_sel, dest_sel) + h5d.read_direct(obj_buf.nda, source_sel, dest_sel) else: # it is faster to read the whole object and then do fancy indexing - obj_buf.nda[dest_sel] = h5f[name][...][source_sel] + obj_buf.nda[dest_sel] = h5d[...][source_sel] nda = obj_buf.nda elif n_rows == 0: - tmp_shape = (0,) + h5f[name].shape[1:] - nda = np.empty(tmp_shape, h5f[name].dtype) + tmp_shape = (0,) + h5d.shape[1:] + nda = np.empty(tmp_shape, h5d.dtype) elif change_idx_to_slice or idx is None or use_h5idx: - nda = h5f[name][source_sel] + nda = h5d[source_sel] else: # it is faster to read the whole object and then do fancy indexing - nda = h5f[name][...][source_sel] + nda = h5d[...][source_sel] # Finally, set attributes and return objects - attrs = h5f[name].attrs + attrs = dict(h5d.attrs) # special handling for bools # (c and Julia store as uint8 so cast to bool) diff --git a/src/lgdo/lh5/_serializers/read/scalar.py b/src/lgdo/lh5/_serializers/read/scalar.py index 1eb02e43..db5ed5fe 100644 --- a/src/lgdo/lh5/_serializers/read/scalar.py +++ b/src/lgdo/lh5/_serializers/read/scalar.py @@ -11,24 +11,24 @@ def _h5_read_scalar( - name, - h5f, + h5d, obj_buf=None, ): - value = h5f[name][()] + value = h5d[()] + attrs = dict(h5d.attrs) # special handling for bools # (c and Julia store as uint8 so cast to bool) - if h5f[name].attrs["datatype"] == "bool": + if attrs["datatype"] == "bool": value = np.bool_(value) if obj_buf is not None: if not isinstance(obj_buf, Scalar): msg = "object buffer a Scalar" - raise LH5DecodeError(msg, h5f, name) + raise LH5DecodeError(msg, h5d) obj_buf.value = value - obj_buf.attrs.update(h5f[name].attrs) + obj_buf.attrs.update(attrs) return obj_buf, 1 - return Scalar(value=value, attrs=h5f[name].attrs), 1 + return Scalar(value=value, attrs=attrs), 1 diff --git a/src/lgdo/lh5/_serializers/read/utils.py b/src/lgdo/lh5/_serializers/read/utils.py index 923b2256..7f675fd7 100644 --- a/src/lgdo/lh5/_serializers/read/utils.py +++ b/src/lgdo/lh5/_serializers/read/utils.py @@ -3,10 +3,10 @@ from ...exceptions import LH5DecodeError -def check_obj_buf_attrs(attrs, new_attrs, file, name): +def check_obj_buf_attrs(attrs, new_attrs, obj): if set(attrs.keys()) != set(new_attrs.keys()): msg = ( f"existing buffer and new data chunk have different attributes: " - f"obj_buf.attrs={attrs} != {file.filename}[{name}].attrs={new_attrs}" + f"obj_buf.attrs={attrs} != {obj.file.filename}[{obj.name}].attrs={new_attrs}" ) - raise LH5DecodeError(msg, file, name) + raise LH5DecodeError(msg, obj) diff --git a/src/lgdo/lh5/_serializers/read/vector_of_vectors.py b/src/lgdo/lh5/_serializers/read/vector_of_vectors.py index 1a699527..4c4c7655 100644 --- a/src/lgdo/lh5/_serializers/read/vector_of_vectors.py +++ b/src/lgdo/lh5/_serializers/read/vector_of_vectors.py @@ -20,8 +20,7 @@ def _h5_read_vector_of_vectors( - name, - h5f, + h5g, start_row=0, n_rows=sys.maxsize, idx=None, @@ -31,13 +30,12 @@ def _h5_read_vector_of_vectors( ): if obj_buf is not None and not isinstance(obj_buf, VectorOfVectors): msg = "object buffer is not a VectorOfVectors" - raise LH5DecodeError(msg, h5f, name) + raise LH5DecodeError(msg, h5g) # read out cumulative_length cumulen_buf = None if obj_buf is None else obj_buf.cumulative_length cumulative_length, n_rows_read = _h5_read_array( - f"{name}/cumulative_length", - h5f, + h5g["cumulative_length"], start_row=start_row, n_rows=n_rows, idx=idx, @@ -63,8 +61,7 @@ def _h5_read_vector_of_vectors( fd_start = 0 # this variable avoids an ndarray append fd_starts, fds_n_rows_read = _h5_read_array( - f"{name}/cumulative_length", - h5f, + h5g["cumulative_length"], start_row=start_row, n_rows=n_rows, idx=idx2, @@ -101,7 +98,7 @@ def _h5_read_vector_of_vectors( # need to read out the cumulen sample -before- the first sample # read above in order to get the starting row of the first # vector to read out in flattened_data - fd_start = h5f[f"{name}/cumulative_length"][start_row - 1] + fd_start = h5g["cumulative_length"][start_row - 1] # check limits for values that will be used subsequently if this_cumulen_nda[-1] < fd_start: @@ -115,7 +112,7 @@ def _h5_read_vector_of_vectors( f"cumulative_length non-increasing between entries " f"{start_row} and {start_row+n_rows_read}" ) - raise LH5DecodeError(msg, h5f, name) + raise LH5DecodeError(msg, h5g) # determine the number of rows for the flattened_data readout fd_n_rows = this_cumulen_nda[-1] if n_rows_read > 0 else 0 @@ -147,18 +144,17 @@ def _h5_read_vector_of_vectors( fd_buf.resize(fdb_size) # now read - lgdotype = dtypeutils.datatype(h5f[f"{name}/flattened_data"].attrs["datatype"]) + lgdotype = dtypeutils.datatype(h5g["flattened_data"].attrs["datatype"]) if lgdotype is Array: _func = _h5_read_array elif lgdotype is VectorOfVectors: _func = _h5_read_vector_of_vectors else: msg = "type {lgdotype.__name__} is not supported" - raise LH5DecodeError(msg, h5f, f"{name}/flattened_data") + raise LH5DecodeError(msg, h5g, "flattened_data") flattened_data, _ = _func( - f"{name}/flattened_data", - h5f, + h5g["flattened_data"], start_row=fd_start, n_rows=fd_n_rows, idx=fd_idx, @@ -180,7 +176,7 @@ def _h5_read_vector_of_vectors( VectorOfVectors( flattened_data=flattened_data, cumulative_length=cumulative_length, - attrs=h5f[name].attrs, + attrs=dict(h5g.attrs), ), n_rows_read, ) diff --git a/src/lgdo/lh5/core.py b/src/lgdo/lh5/core.py index c8a2888d..a98081ca 100644 --- a/src/lgdo/lh5/core.py +++ b/src/lgdo/lh5/core.py @@ -107,9 +107,20 @@ def read( `n_rows_read` will be``1``. For tables it is redundant with ``table.loc``. If `obj_buf` is ``None``, only `object` is returned. """ + if isinstance(lh5_file, h5py.File): + lh5_obj = lh5_file[name] + elif isinstance(lh5_file, str): + lh5_file = h5py.File(lh5_file, mode="r") + lh5_obj = lh5_file[name] + else: + lh5_obj = [] + for h5f in lh5_file: + if isinstance(lh5_file, str): + h5f = h5py.File(h5f, mode="r") # noqa: PLW2901 + lh5_obj += h5f[name] + obj, n_rows_read = _serializers._h5_read_lgdo( - name, - lh5_file, + lh5_obj, start_row=start_row, n_rows=n_rows, idx=idx, diff --git a/src/lgdo/lh5/exceptions.py b/src/lgdo/lh5/exceptions.py index ba64e213..fc290691 100644 --- a/src/lgdo/lh5/exceptions.py +++ b/src/lgdo/lh5/exceptions.py @@ -4,11 +4,11 @@ class LH5DecodeError(Exception): - def __init__(self, message: str, file: str, obj: str) -> None: + def __init__(self, message: str, obj: h5py.Dataset | h5py.Group) -> None: super().__init__(message) - self.file = file.filename if isinstance(file, h5py.File) else file - self.obj = obj + self.file = obj.file.filename + self.obj = obj.name def __str__(self) -> str: return ( diff --git a/src/lgdo/lh5/store.py b/src/lgdo/lh5/store.py index 2be17860..b4c2e873 100644 --- a/src/lgdo/lh5/store.py +++ b/src/lgdo/lh5/store.py @@ -144,13 +144,12 @@ def read( """ # grab files from store if not isinstance(lh5_file, (str, h5py.File)): - lh5_file = [self.gimme_file(f, "r") for f in list(lh5_file)] + lh5_obj = [self.gimme_file(f, "r")[name] for f in list(lh5_file)] else: - lh5_file = self.gimme_file(lh5_file, "r") + lh5_obj = self.gimme_file(lh5_file, "r")[name] return _serializers._h5_read_lgdo( - name, - lh5_file, + lh5_obj, start_row=start_row, n_rows=n_rows, idx=idx, diff --git a/src/lgdo/lh5/utils.py b/src/lgdo/lh5/utils.py index cf1fed04..73c192e5 100644 --- a/src/lgdo/lh5/utils.py +++ b/src/lgdo/lh5/utils.py @@ -29,7 +29,7 @@ def get_buffer( Sets size to `size` if object has a size. """ obj, n_rows = _serializers._h5_read_lgdo( - name, lh5_file, n_rows=0, field_mask=field_mask + lh5_file[name], n_rows=0, field_mask=field_mask ) if hasattr(obj, "resize") and size is not None: