diff --git a/src/lgdo/compression/generic.py b/src/lgdo/compression/generic.py index 07e038e8..82dbb81b 100644 --- a/src/lgdo/compression/generic.py +++ b/src/lgdo/compression/generic.py @@ -40,6 +40,7 @@ def encode( def decode( obj: lgdo.VectorOfEncodedVectors | lgdo.ArrayOfEncodedEqualSizedArrays, + out_buf: lgdo.ArrayOfEqualSizedArrays = None, ) -> lgdo.VectorOfVectors | lgdo.ArrayOfEqualsizedArrays: """Decode encoded LGDOs. @@ -51,6 +52,9 @@ def decode( ---------- obj LGDO array type. + out_buf + pre-allocated LGDO for the decoded signals. See documentation of + wrapped encoders for limitations. """ if "codec" not in obj.attrs: raise RuntimeError( @@ -61,9 +65,11 @@ def decode( log.debug(f"decoding {repr(obj)} with {codec}") if _is_codec(codec, radware.RadwareSigcompress): - return radware.decode(obj, shift=int(obj.attrs.get("codec_shift", 0))) + return radware.decode( + obj, sig_out=out_buf, shift=int(obj.attrs.get("codec_shift", 0)) + ) elif _is_codec(codec, varlen.ULEB128ZigZagDiff): - return varlen.decode(obj) + return varlen.decode(obj, sig_out=out_buf) else: raise ValueError(f"'{codec}' not supported") diff --git a/src/lgdo/compression/radware.py b/src/lgdo/compression/radware.py index f8235d7e..f7f9bca0 100644 --- a/src/lgdo/compression/radware.py +++ b/src/lgdo/compression/radware.py @@ -120,7 +120,7 @@ def encode( return sig_out, nbytes elif isinstance(sig_in, lgdo.VectorOfVectors): - if sig_out: + if sig_out is not None: log.warning( "a pre-allocated VectorOfEncodedVectors was given " "to hold an encoded ArrayOfEqualSizedArrays. " @@ -143,7 +143,7 @@ def encode( return sig_out elif isinstance(sig_in, lgdo.ArrayOfEqualSizedArrays): - if sig_out: + if sig_out is not None: log.warning( "a pre-allocated ArrayOfEncodedEqualSizedArrays was given " "to hold an encoded ArrayOfEqualSizedArrays. " @@ -243,7 +243,7 @@ def decode( return sig_out, siglen elif isinstance(sig_in, lgdo.ArrayOfEncodedEqualSizedArrays): - if not sig_out: + if sig_out is None: # initialize output structure with decoded_size sig_out = lgdo.ArrayOfEqualSizedArrays( dims=(1, 1), diff --git a/src/lgdo/compression/varlen.py b/src/lgdo/compression/varlen.py index e3a4846e..f273e038 100644 --- a/src/lgdo/compression/varlen.py +++ b/src/lgdo/compression/varlen.py @@ -94,7 +94,7 @@ def encode( return sig_out, nbytes elif isinstance(sig_in, lgdo.VectorOfVectors): - if sig_out: + if sig_out is not None: log.warning( "a pre-allocated VectorOfEncodedVectors was given " "to hold an encoded ArrayOfEqualSizedArrays. " @@ -208,7 +208,7 @@ def decode( return sig_out, siglen elif isinstance(sig_in, lgdo.ArrayOfEncodedEqualSizedArrays): - if not sig_out: + if sig_out is None: # initialize output structure with decoded_size sig_out = lgdo.ArrayOfEqualSizedArrays( dims=(1, 1), diff --git a/tests/compression/test_radware_sigcompress.py b/tests/compression/test_radware_sigcompress.py index b466e390..ac634c36 100644 --- a/tests/compression/test_radware_sigcompress.py +++ b/tests/compression/test_radware_sigcompress.py @@ -107,8 +107,9 @@ def test_wrapper(wftable): enc_wfs = np.zeros(s[:-1] + (2 * s[-1],), dtype="ubyte") enclen = np.empty(s[0], dtype="uint32") + _shift = np.full(s[0], shift, dtype="int32") - _radware_sigcompress_encode(wfs, enc_wfs, shift, enclen, _mask) + _radware_sigcompress_encode(wfs, enc_wfs, _shift, enclen, _mask) # test if the wrapper gives the same result w_enc_wfs, w_enclen = encode(wfs, shift=shift)