Skip to content

Commit

Permalink
Replace file/name with hdf5 group/dataset when decoding (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
iguinn authored Jul 8, 2024
1 parent d050bf7 commit 5fbe2f2
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 120 deletions.
18 changes: 9 additions & 9 deletions src/lgdo/lh5/_serializers/read/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,26 @@
log = logging.getLogger(__name__)


def _h5_read_array_generic(type_, name, h5f, **kwargs):
nda, attrs, n_rows_to_read = _h5_read_ndarray(name, h5f, **kwargs)
def _h5_read_array_generic(type_, h5d, **kwargs):
nda, attrs, n_rows_to_read = _h5_read_ndarray(h5d, **kwargs)

obj_buf = kwargs["obj_buf"]

if obj_buf is None:
return type_(nda=nda, attrs=attrs), n_rows_to_read

utils.check_obj_buf_attrs(obj_buf.attrs, attrs, h5f, name)
utils.check_obj_buf_attrs(obj_buf.attrs, attrs, h5d)

return obj_buf, n_rows_to_read


def _h5_read_array(name, h5f, **kwargs):
return _h5_read_array_generic(Array, name, h5f, **kwargs)
def _h5_read_array(h5d, **kwargs):
return _h5_read_array_generic(Array, h5d, **kwargs)


def _h5_read_fixedsize_array(name, h5f, **kwargs):
return _h5_read_array_generic(FixedSizeArray, name, h5f, **kwargs)
def _h5_read_fixedsize_array(h5d, **kwargs):
return _h5_read_array_generic(FixedSizeArray, h5d, **kwargs)


def _h5_read_array_of_equalsized_arrays(name, h5f, **kwargs):
return _h5_read_array_generic(ArrayOfEqualSizedArrays, name, h5f, **kwargs)
def _h5_read_array_of_equalsized_arrays(h5d, **kwargs):
return _h5_read_array_generic(ArrayOfEqualSizedArrays, h5d, **kwargs)
78 changes: 30 additions & 48 deletions src/lgdo/lh5/_serializers/read/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@


def _h5_read_lgdo(
name,
h5f,
h5o,
start_row=0,
n_rows=sys.maxsize,
idx=None,
Expand All @@ -52,11 +51,11 @@ def _h5_read_lgdo(
decompress=True,
):
# Handle list-of-files recursively
if not isinstance(h5f, (str, h5py.File)):
lh5_file = list(h5f)
if not isinstance(h5o, (h5py.Group, h5py.Dataset)):
lh5_objs = list(h5o)
n_rows_read = 0

for i, _h5f in enumerate(lh5_file):
for i, _h5o in enumerate(lh5_objs):
if isinstance(idx, list) and len(idx) > 0 and not np.isscalar(idx[0]):
# a list of lists: must be one per file
idx_i = idx[i]
Expand All @@ -65,7 +64,7 @@ def _h5_read_lgdo(
if not (isinstance(idx, tuple) and len(idx) == 1):
idx = (idx,)
# idx is a long continuous array
n_rows_i = read_n_rows(name, _h5f)
n_rows_i = read_n_rows(_h5o)
# find the length of the subset of idx that contains indices
# that are less than n_rows_i
n_rows_to_read_i = bisect.bisect_left(idx[0], n_rows_i)
Expand All @@ -77,8 +76,7 @@ def _h5_read_lgdo(
n_rows_i = n_rows - n_rows_read

obj_buf, n_rows_read_i = _h5_read_lgdo(
name,
_h5f,
_h5o,
start_row=start_row,
n_rows=n_rows_i,
idx=idx_i,
Expand All @@ -97,11 +95,8 @@ def _h5_read_lgdo(

return obj_buf, n_rows_read

if not isinstance(h5f, h5py.File):
h5f = h5py.File(h5f, mode="r")

log.debug(
f"reading {h5f.filename}:{name}[{start_row}:{n_rows}], decompress = {decompress}, "
f"reading {h5o.file.filename}:{h5o.name}[{start_row}:{n_rows}], decompress = {decompress}, "
+ (f" with field mask {field_mask}" if field_mask else "")
)

Expand All @@ -110,15 +105,14 @@ def _h5_read_lgdo(
idx = (idx,)

try:
lgdotype = dtypeutils.datatype(h5f[name].attrs["datatype"])
lgdotype = dtypeutils.datatype(h5o.attrs["datatype"])
except KeyError as e:
msg = "dataset not in file or missing 'datatype' attribute"
raise LH5DecodeError(msg, h5f, name) from e
raise LH5DecodeError(msg, h5o) from e

if lgdotype is Scalar:
return _h5_read_scalar(
name,
h5f,
h5o,
obj_buf=obj_buf,
)

Expand All @@ -138,8 +132,7 @@ def _h5_read_lgdo(

if lgdotype is Struct:
return _h5_read_struct(
name,
h5f,
h5o,
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -164,8 +157,7 @@ def _h5_read_lgdo(

if lgdotype is Table:
return _h5_read_table(
name,
h5f,
h5o,
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -178,8 +170,7 @@ def _h5_read_lgdo(

if lgdotype is ArrayOfEncodedEqualSizedArrays:
return _h5_read_array_of_encoded_equalsized_arrays(
name,
h5f,
h5o,
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -191,8 +182,7 @@ def _h5_read_lgdo(

if lgdotype is VectorOfEncodedVectors:
return _h5_read_vector_of_encoded_vectors(
name,
h5f,
h5o,
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -204,8 +194,7 @@ def _h5_read_lgdo(

if lgdotype is VectorOfVectors:
return _h5_read_vector_of_vectors(
name,
h5f,
h5o,
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -216,8 +205,7 @@ def _h5_read_lgdo(

if lgdotype is FixedSizeArray:
return _h5_read_fixedsize_array(
name,
h5f,
h5o,
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -228,8 +216,7 @@ def _h5_read_lgdo(

if lgdotype is ArrayOfEqualSizedArrays:
return _h5_read_array_of_equalsized_arrays(
name,
h5f,
h5o,
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -240,8 +227,7 @@ def _h5_read_lgdo(

if lgdotype is Array:
return _h5_read_array(
name,
h5f,
h5o,
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -251,12 +237,11 @@ def _h5_read_lgdo(
)

msg = f"no rule to decode {lgdotype.__name__} from LH5"
raise LH5DecodeError(msg, h5f, name)
raise LH5DecodeError(msg, h5o)


def _h5_read_struct(
name,
h5f,
h5g,
start_row=0,
n_rows=sys.maxsize,
idx=None,
Expand All @@ -269,7 +254,7 @@ def _h5_read_struct(
# table... Maybe should emit a warning? Or allow them to be
# dicts keyed by field name?

attrs = dict(h5f[name].attrs)
attrs = dict(h5g.attrs)

# determine fields to be read out
all_fields = dtypeutils.get_struct_fields(attrs["datatype"])
Expand All @@ -288,8 +273,7 @@ def _h5_read_struct(
# support for integer keys
field_key = int(field) if attrs.get("int_keys") else str(field)
obj_dict[field_key], _ = _h5_read_lgdo(
f"{name}/{field}",
h5f,
h5g[field],
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -301,8 +285,7 @@ def _h5_read_struct(


def _h5_read_table(
name,
h5f,
h5g,
start_row=0,
n_rows=sys.maxsize,
idx=None,
Expand All @@ -314,9 +297,9 @@ def _h5_read_table(
):
if obj_buf is not None and not isinstance(obj_buf, Table):
msg = "provided object buffer is not a Table"
raise LH5DecodeError(msg, h5f, name)
raise LH5DecodeError(msg, h5g)

attrs = dict(h5f[name].attrs)
attrs = dict(h5g.attrs)

# determine fields to be read out
all_fields = dtypeutils.get_struct_fields(attrs["datatype"])
Expand All @@ -337,13 +320,12 @@ def _h5_read_table(
if obj_buf is not None:
if not isinstance(obj_buf, Table) or field not in obj_buf:
msg = "provided object buffer is not a Table or columns are missing"
raise LH5DecodeError(msg, h5f, name)
raise LH5DecodeError(msg, h5g)

fld_buf = obj_buf[field]

col_dict[field], n_rows_read = _h5_read_lgdo(
f"{name}/{field}",
h5f,
h5g[field],
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -363,12 +345,12 @@ def _h5_read_table(
n_rows_read = rows_read[0]
else:
n_rows_read = 0
log.warning(f"Table '{name}' has no fields specified by {field_mask=}")
log.warning(f"Table '{h5g.name}' has no fields specified by {field_mask=}")

for n in rows_read[1:]:
if n != n_rows_read:
log.warning(
f"Table '{name}' got strange n_rows_read = {n}, "
f"Table '{h5g.name}' got strange n_rows_read = {n}, "
"{n_rows_read} was expected ({rows_read})"
)

Expand Down Expand Up @@ -400,6 +382,6 @@ def _h5_read_table(
obj_buf.loc = obj_buf_start + n_rows_read

# check attributes
utils.check_obj_buf_attrs(obj_buf.attrs, attrs, h5f, name)
utils.check_obj_buf_attrs(obj_buf.attrs, attrs, h5g)

return obj_buf, n_rows_read
28 changes: 11 additions & 17 deletions src/lgdo/lh5/_serializers/read/encoded.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,22 @@


def _h5_read_array_of_encoded_equalsized_arrays(
name,
h5f,
h5g,
**kwargs,
):
return _h5_read_encoded_array(ArrayOfEncodedEqualSizedArrays, name, h5f, **kwargs)
return _h5_read_encoded_array(ArrayOfEncodedEqualSizedArrays, h5g, **kwargs)


def _h5_read_vector_of_encoded_vectors(
name,
h5f,
h5g,
**kwargs,
):
return _h5_read_encoded_array(VectorOfEncodedVectors, name, h5f, **kwargs)
return _h5_read_encoded_array(VectorOfEncodedVectors, h5g, **kwargs)


def _h5_read_encoded_array(
lgdotype,
name,
h5f,
h5g,
start_row=0,
n_rows=sys.maxsize,
idx=None,
Expand All @@ -48,11 +45,11 @@ def _h5_read_encoded_array(
):
if lgdotype not in (ArrayOfEncodedEqualSizedArrays, VectorOfEncodedVectors):
msg = f"unsupported read of encoded type {lgdotype.__name__}"
raise LH5DecodeError(msg, h5f, name)
raise LH5DecodeError(msg, h5g)

if not decompress and obj_buf is not None and not isinstance(obj_buf, lgdotype):
msg = f"object buffer is not a {lgdotype.__name__}"
raise LH5DecodeError(msg, h5f, name)
raise LH5DecodeError(msg, h5g)

# read out decoded_size, either a Scalar or an Array
decoded_size_buf = encoded_data_buf = None
Expand All @@ -62,8 +59,7 @@ def _h5_read_encoded_array(

if lgdotype is VectorOfEncodedVectors:
decoded_size, _ = _h5_read_array(
f"{name}/decoded_size",
h5f,
h5g["decoded_size"],
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -74,15 +70,13 @@ def _h5_read_encoded_array(

else:
decoded_size, _ = _h5_read_scalar(
f"{name}/decoded_size",
h5f,
h5g["decoded_size"],
obj_buf=None if decompress else decoded_size_buf,
)

# read out encoded_data, a VectorOfVectors
encoded_data, n_rows_read = _h5_read_vector_of_vectors(
f"{name}/encoded_data",
h5f,
h5g["encoded_data"],
start_row=start_row,
n_rows=n_rows,
idx=idx,
Expand All @@ -99,7 +93,7 @@ def _h5_read_encoded_array(
rawdata = lgdotype(
encoded_data=encoded_data,
decoded_size=decoded_size,
attrs=h5f[name].attrs,
attrs=dict(h5g.attrs),
)

# already return if no decompression is requested
Expand Down
Loading

0 comments on commit 5fbe2f2

Please sign in to comment.