From 488fe4bdd9663008e2f8eae36feb15546f056db5 Mon Sep 17 00:00:00 2001 From: Luigi Pertoldi Date: Fri, 27 Oct 2023 16:12:33 +0200 Subject: [PATCH] [compression] turn Numba cache on Significant decrease of import time for the lgdo.compression subpackage. One needs to be careful with this feature in general: https://numba.readthedocs.io/en/stable/user/jit.html#cache But I think it's safe enough in our specific case, especially if we can save a lot of time. --- src/lgdo/compression/base.py | 2 ++ src/lgdo/compression/radware.py | 18 +++++++++--------- src/lgdo/compression/varlen.py | 14 +++++++------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/lgdo/compression/base.py b/src/lgdo/compression/base.py index 83cee8db..2588789e 100644 --- a/src/lgdo/compression/base.py +++ b/src/lgdo/compression/base.py @@ -3,6 +3,8 @@ import re from dataclasses import asdict, dataclass +numba_defaults: dict = {"nopython": True, "cache": True} + @dataclass(frozen=True) class WaveformCodec: diff --git a/src/lgdo/compression/radware.py b/src/lgdo/compression/radware.py index 5213c4ff..f8235d7e 100644 --- a/src/lgdo/compression/radware.py +++ b/src/lgdo/compression/radware.py @@ -9,7 +9,7 @@ from numpy.typing import NDArray from .. import types as lgdo -from .base import WaveformCodec +from .base import WaveformCodec, numba_defaults log = logging.getLogger(__name__) @@ -301,7 +301,7 @@ def decode( raise ValueError("unsupported input signal type") -@numba.jit(nopython=True) +@numba.jit(**numba_defaults) def _set_hton_u16(a: NDArray[ubyte], i: int, x: int) -> int: """Store an unsigned 16-bit integer value in an array of unsigned 8-bit integers. @@ -316,7 +316,7 @@ def _set_hton_u16(a: NDArray[ubyte], i: int, x: int) -> int: return x -@numba.jit(nopython=True) +@numba.jit(**numba_defaults) def _get_hton_u16(a: NDArray[ubyte], i: int) -> uint16: """Read unsigned 16-bit integer values from an array of unsigned 8-bit integers. @@ -331,22 +331,22 @@ def _get_hton_u16(a: NDArray[ubyte], i: int) -> uint16: return a[..., i_1].astype("uint16") << 8 | a[..., i_2] -@numba.jit("uint16(uint32)", nopython=True) +@numba.jit("uint16(uint32)", **numba_defaults) def _get_high_u16(x: uint32) -> uint16: return uint16(x >> 16) -@numba.jit("uint32(uint32, uint16)", nopython=True) +@numba.jit("uint32(uint32, uint16)", **numba_defaults) def _set_high_u16(x: uint32, y: uint16) -> uint32: return uint32(x & 0x0000FFFF | (y << 16)) -@numba.jit("uint16(uint32)", nopython=True) +@numba.jit("uint16(uint32)", **numba_defaults) def _get_low_u16(x: uint32) -> uint16: return uint16(x >> 0) -@numba.jit("uint32(uint32, uint16)", nopython=True) +@numba.jit("uint32(uint32, uint16)", **numba_defaults) def _set_low_u16(x: uint32, y: uint16) -> uint32: return uint32(x & 0xFFFF0000 | (y << 0)) @@ -361,7 +361,7 @@ def _set_low_u16(x: uint32, y: uint16) -> uint32: "void( int64[:], byte[:], int32[:], uint32[:], uint16[:])", ], "(n),(m),(),(),(o)", - nopython=True, + **numba_defaults, ) def _radware_sigcompress_encode( sig_in: NDArray, @@ -574,7 +574,7 @@ def _radware_sigcompress_encode( "void(byte[:], int64[:], int32[:], uint32[:], uint16[:])", ], "(n),(m),(),(),(o)", - nopython=True, + **numba_defaults, ) def _radware_sigcompress_decode( sig_in: NDArray[ubyte], diff --git a/src/lgdo/compression/varlen.py b/src/lgdo/compression/varlen.py index 23a4b84f..e3a4846e 100644 --- a/src/lgdo/compression/varlen.py +++ b/src/lgdo/compression/varlen.py @@ -11,7 +11,7 @@ from numpy.typing import NDArray from .. import types as lgdo -from .base import WaveformCodec +from .base import WaveformCodec, numba_defaults log = logging.getLogger(__name__) @@ -266,7 +266,7 @@ def decode( @numba.vectorize( ["uint64(int64)", "uint32(int32)", "uint16(int16)"], - nopython=True, + **numba_defaults, ) def zigzag_encode(x: int | NDArray[int]) -> int | NDArray[int]: """ZigZag-encode [#WikiZZ]_ signed integer numbers.""" @@ -275,14 +275,14 @@ def zigzag_encode(x: int | NDArray[int]) -> int | NDArray[int]: @numba.vectorize( ["int64(uint64)", "int32(uint32)", "int16(uint16)"], - nopython=True, + **numba_defaults, ) def zigzag_decode(x: int | NDArray[int]) -> int | NDArray[int]: """ZigZag-decode [#WikiZZ]_ signed integer numbers.""" return (x >> 1) ^ -(x & 1) -@numba.jit(["uint32(int64, byte[:])"], nopython=True) +@numba.jit(["uint32(int64, byte[:])"], **numba_defaults) def uleb128_encode(x: int, encx: NDArray[ubyte]) -> int: """Compute a variable-length representation of an unsigned integer. @@ -315,7 +315,7 @@ def uleb128_encode(x: int, encx: NDArray[ubyte]) -> int: return i + 1 -@numba.jit(["UniTuple(uint32, 2)(byte[:])"], nopython=True) +@numba.jit(["UniTuple(uint32, 2)(byte[:])"], **numba_defaults) def uleb128_decode(encx: NDArray[ubyte]) -> (int, int): """Decode a variable-length integer into an unsigned integer. @@ -360,7 +360,7 @@ def uleb128_decode(encx: NDArray[ubyte]) -> (int, int): "void(int64[:], byte[:], uint32[:])", ], "(n),(m),()", - nopython=True, + **numba_defaults, ) def uleb128_zigzag_diff_array_encode( sig_in: NDArray[int], sig_out: NDArray[ubyte], nbytes: int @@ -410,7 +410,7 @@ def uleb128_zigzag_diff_array_encode( "void(byte[:], uint32[:], int64[:], uint32[:])", ], "(n),(),(m),()", - nopython=True, + **numba_defaults, ) def uleb128_zigzag_diff_array_decode( sig_in: NDArray[ubyte],