Skip to content

Commit

Permalink
Merge pull request #22 from legend-exp/compression
Browse files Browse the repository at this point in the history
Turn Numba cache on for compression routines
  • Loading branch information
gipert authored Oct 27, 2023
2 parents 74048fd + 488fe4b commit f71afff
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
2 changes: 2 additions & 0 deletions src/lgdo/compression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import re
from dataclasses import asdict, dataclass

numba_defaults: dict = {"nopython": True, "cache": True}


@dataclass(frozen=True)
class WaveformCodec:
Expand Down
18 changes: 9 additions & 9 deletions src/lgdo/compression/radware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numpy.typing import NDArray

from .. import types as lgdo
from .base import WaveformCodec
from .base import WaveformCodec, numba_defaults

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -301,7 +301,7 @@ def decode(
raise ValueError("unsupported input signal type")


@numba.jit(nopython=True)
@numba.jit(**numba_defaults)
def _set_hton_u16(a: NDArray[ubyte], i: int, x: int) -> int:
"""Store an unsigned 16-bit integer value in an array of unsigned 8-bit integers.
Expand All @@ -316,7 +316,7 @@ def _set_hton_u16(a: NDArray[ubyte], i: int, x: int) -> int:
return x


@numba.jit(nopython=True)
@numba.jit(**numba_defaults)
def _get_hton_u16(a: NDArray[ubyte], i: int) -> uint16:
"""Read unsigned 16-bit integer values from an array of unsigned 8-bit integers.
Expand All @@ -331,22 +331,22 @@ def _get_hton_u16(a: NDArray[ubyte], i: int) -> uint16:
return a[..., i_1].astype("uint16") << 8 | a[..., i_2]


@numba.jit("uint16(uint32)", nopython=True)
@numba.jit("uint16(uint32)", **numba_defaults)
def _get_high_u16(x: uint32) -> uint16:
return uint16(x >> 16)


@numba.jit("uint32(uint32, uint16)", nopython=True)
@numba.jit("uint32(uint32, uint16)", **numba_defaults)
def _set_high_u16(x: uint32, y: uint16) -> uint32:
return uint32(x & 0x0000FFFF | (y << 16))


@numba.jit("uint16(uint32)", nopython=True)
@numba.jit("uint16(uint32)", **numba_defaults)
def _get_low_u16(x: uint32) -> uint16:
return uint16(x >> 0)


@numba.jit("uint32(uint32, uint16)", nopython=True)
@numba.jit("uint32(uint32, uint16)", **numba_defaults)
def _set_low_u16(x: uint32, y: uint16) -> uint32:
return uint32(x & 0xFFFF0000 | (y << 0))

Expand All @@ -361,7 +361,7 @@ def _set_low_u16(x: uint32, y: uint16) -> uint32:
"void( int64[:], byte[:], int32[:], uint32[:], uint16[:])",
],
"(n),(m),(),(),(o)",
nopython=True,
**numba_defaults,
)
def _radware_sigcompress_encode(
sig_in: NDArray,
Expand Down Expand Up @@ -574,7 +574,7 @@ def _radware_sigcompress_encode(
"void(byte[:], int64[:], int32[:], uint32[:], uint16[:])",
],
"(n),(m),(),(),(o)",
nopython=True,
**numba_defaults,
)
def _radware_sigcompress_decode(
sig_in: NDArray[ubyte],
Expand Down
14 changes: 7 additions & 7 deletions src/lgdo/compression/varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numpy.typing import NDArray

from .. import types as lgdo
from .base import WaveformCodec
from .base import WaveformCodec, numba_defaults

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -266,7 +266,7 @@ def decode(

@numba.vectorize(
["uint64(int64)", "uint32(int32)", "uint16(int16)"],
nopython=True,
**numba_defaults,
)
def zigzag_encode(x: int | NDArray[int]) -> int | NDArray[int]:
"""ZigZag-encode [#WikiZZ]_ signed integer numbers."""
Expand All @@ -275,14 +275,14 @@ def zigzag_encode(x: int | NDArray[int]) -> int | NDArray[int]:

@numba.vectorize(
["int64(uint64)", "int32(uint32)", "int16(uint16)"],
nopython=True,
**numba_defaults,
)
def zigzag_decode(x: int | NDArray[int]) -> int | NDArray[int]:
"""ZigZag-decode [#WikiZZ]_ signed integer numbers."""
return (x >> 1) ^ -(x & 1)


@numba.jit(["uint32(int64, byte[:])"], nopython=True)
@numba.jit(["uint32(int64, byte[:])"], **numba_defaults)
def uleb128_encode(x: int, encx: NDArray[ubyte]) -> int:
"""Compute a variable-length representation of an unsigned integer.
Expand Down Expand Up @@ -315,7 +315,7 @@ def uleb128_encode(x: int, encx: NDArray[ubyte]) -> int:
return i + 1


@numba.jit(["UniTuple(uint32, 2)(byte[:])"], nopython=True)
@numba.jit(["UniTuple(uint32, 2)(byte[:])"], **numba_defaults)
def uleb128_decode(encx: NDArray[ubyte]) -> (int, int):
"""Decode a variable-length integer into an unsigned integer.
Expand Down Expand Up @@ -360,7 +360,7 @@ def uleb128_decode(encx: NDArray[ubyte]) -> (int, int):
"void(int64[:], byte[:], uint32[:])",
],
"(n),(m),()",
nopython=True,
**numba_defaults,
)
def uleb128_zigzag_diff_array_encode(
sig_in: NDArray[int], sig_out: NDArray[ubyte], nbytes: int
Expand Down Expand Up @@ -410,7 +410,7 @@ def uleb128_zigzag_diff_array_encode(
"void(byte[:], uint32[:], int64[:], uint32[:])",
],
"(n),(),(m),()",
nopython=True,
**numba_defaults,
)
def uleb128_zigzag_diff_array_decode(
sig_in: NDArray[ubyte],
Expand Down

0 comments on commit f71afff

Please sign in to comment.