diff --git a/pyproject.toml b/pyproject.toml index 8915bd1..1c48126 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,24 +1,17 @@ [build-system] build-backend = 'setuptools.build_meta' -requires = [ - 'setuptools==69.2.0', - 'setuptools_scm[toml]>=6.2', - 'wheel', -] +requires = ['setuptools==69.2.0', 'setuptools_scm[toml]>=6.2', 'wheel'] [project] name = 'furax' authors = [ - {name = 'Pierre Chanial', email = 'chanial@apc.in2p3.fr'}, -] -maintainers = [ - {name = 'Pierre Chanial', email = 'chanial@apc.in2p3.fr'}, + { name = 'Pierre Chanial', email = 'chanial@apc.in2p3.fr' }, + { name = 'Simon Biquard', email = 'biquard@apc.in2p3.fr' }, ] +maintainers = [{ name = 'Pierre Chanial', email = 'chanial@apc.in2p3.fr' }] description = 'Operators and solvers for high-performance computing.' readme = 'README.md' -keywords = [ - 'scientific computing', -] +keywords = ['scientific computing'] classifiers = [ 'Programming Language :: Python', 'Programming Language :: Python :: 3', @@ -27,7 +20,7 @@ classifiers = [ 'Topic :: Scientific/Engineering', ] requires-python = '>=3.10' -license = {file = 'LICENSE'} +license = { file = 'LICENSE' } dependencies = [ 'jaxtyping', 'healpy>=0.16.6', @@ -42,13 +35,7 @@ dependencies = [ dynamic = ['version'] [project.optional-dependencies] -dev = [ - 'pytest', - 'pytest-cov', - 'pytest-mock', - 'setuptools_scm', - 'beartype', -] +dev = ['pytest', 'pytest-cov', 'pytest-mock', 'setuptools_scm', 'beartype'] [project.urls] homepage = 'https://scipol.in2p3.fr' @@ -65,23 +52,14 @@ show_missing = true skip_covered = true [[tool.mypy.overrides]] -module = [ - 'healpy', - 'jax_healpy', - 'lineax', - 'scipy.stats.sampling' -] +module = ['healpy', 'jax_healpy', 'lineax', 'scipy.stats.sampling'] ignore_missing_imports = true [tool.pytest.ini_options] # addopts = '-ra --cov=furax --jaxtyping-packages=furax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))' addopts = '-s -ra --color=yes' -testpaths = [ - 'tests', -] -markers = [ - "slow: mark test as slow.", -] +testpaths = ['tests'] +markers = ["slow: mark test as slow."] #[tool.setuptools] #packages = ['src/furax'] @@ -104,7 +82,13 @@ select = [ # flake8-debugger 'T10', ] -ignore = ['E203', 'E731', 'E741', 'F722'] +ignore = [ + 'E203', + 'E731', + 'E741', + 'F722', + 'UP037', # conflicts with jaxtyping Array annotations +] [tool.ruff.lint.per-file-ignores] "src/furax/_base/core.py" = ['E743'] diff --git a/src/furax/detectors.py b/src/furax/detectors.py index 06ba3c9..f2cc93a 100644 --- a/src/furax/detectors.py +++ b/src/furax/detectors.py @@ -1,6 +1,10 @@ +from hashlib import sha1 +from itertools import product + import jax +import jax.numpy as jnp import numpy as np -from jaxtyping import Float +from jaxtyping import Array, Float, PRNGKeyArray, Shaped, UInt32 class DetectorArray: @@ -22,5 +26,25 @@ def __init__( coords /= length self.coords = jax.device_put(coords) + # generate fake names for the detectors + # TODO(simon): accept user-defined names + widths = [len(str(s - 1)) for s in self.shape] + indices = [[f'{i:0{width}}' for i in range(dim)] for dim, width in zip(self.shape, widths)] + flat_names = ['DET_' + ''.join(combination) for combination in product(*indices)] + self.names = np.array(flat_names).reshape(self.shape) + def __len__(self) -> int: return int(np.prod(self.shape)) + + def split_key(self, key: PRNGKeyArray) -> Shaped[PRNGKeyArray, ' _']: + """Produces a new pseudo-random key for each detector.""" + fold = jax.numpy.vectorize(jax.random.fold_in, signature='(),()->()') + subkeys: Shaped[PRNGKeyArray, '...'] = fold(key, self._ids()) + return subkeys + + def _ids(self) -> UInt32[Array, '...']: + # vectorized hashing + converting to int + keeping only 7 bytes + name_to_int = np.vectorize(lambda s: int(sha1(s.encode()).hexdigest(), 16) & 0xEFFFFFFF) + # return detectors IDs as unsigned 32-bit integers + ids: UInt32[Array, '...'] = jnp.uint32(name_to_int(self.names)) + return ids diff --git a/src/furax/preprocessing/__init__.py b/src/furax/preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/furax/preprocessing/gap_filling.py b/src/furax/preprocessing/gap_filling.py new file mode 100644 index 0000000..e3fc7c8 --- /dev/null +++ b/src/furax/preprocessing/gap_filling.py @@ -0,0 +1,135 @@ +from functools import partial + +import equinox +import jax +import jax.numpy as jnp +from jaxtyping import Array, ArrayLike, Float, PRNGKeyArray + +from furax._base.indices import IndexOperator +from furax.detectors import DetectorArray +from furax.operators.toeplitz import SymmetricBandToeplitzOperator + +default_fft_size = SymmetricBandToeplitzOperator._get_default_fft_size + +__all__ = [ + 'GapFillingOperator', +] + + +class GapFillingOperator(equinox.Module): + """Class for filling masked time samples with a constrained noise realization. + + We follow the gap-filling algorithm described in https://doi.org/10.1103/PhysRevD.65.022003, + section II.C, page 6. It assumes that the noise is piece wise stationary and has Gaussian + statistics, described by the covariance matrix ``cov``. + + Example: + Gap-filling a single-detector timestream + + >>> detectors = DetectorArray(jnp.array(0), jnp.array(0), jnp.array(1)) + >>> key = jax.random.key(0) + >>> key, subkey = jax.random.split(key) + >>> nsamples = 10 + >>> x = jax.random.normal(subkey, detectors.shape + (nsamples,)) + >>> in_structure = jax.ShapeDtypeStruct(x.shape, x.dtype) + >>> mask = jnp.array([1, 1, 1, 0, 0, 1, 1, 1, 1, 1], dtype=bool) + >>> mask_op = IndexOperator(jnp.where(mask), in_structure=in_structure) + >>> cov = SymmetricBandToeplitzOperator(jnp.array([1.0]), in_structure) + >>> gf = GapFillingOperator(cov, mask_op, detectors) + >>> gap_filled_x = gf(key, x) + >>> assert gap_filled_x.shape == x.shape + >>> assert jnp.all(gap_filled_x[mask] == x[mask]) + + Attributes: + cov: A SymmetricBandToeplitzOperator representing the noise covariance matrix. + mask_op: An IndexOperator for masking the gaps. + detectors: A DetectorArray representing the detectors. + rate: The sampling rate of the data. + """ + + cov: SymmetricBandToeplitzOperator + mask_op: IndexOperator + detectors: DetectorArray + rate: float = 1.0 + + def __call__(self, key: PRNGKeyArray, x: Float[Array, ' *shape']) -> Float[Array, ' *shape']: + """Performs the gap-filling operation. + + Args: + key: A PRNG key generated by the user. + x: The vector to be filled. + """ + real: Float[Array, '...'] = self._generate_realization_for(x, key) + p, u = self.mask_op, self.mask_op.T # pack, unpack operators + incomplete_cov = p @ self.cov @ u + op = self.cov @ u @ incomplete_cov.I @ p + y: Float[Array, '...'] = real + op(x - real) + # copy valid samples from original vector + y = y.at[p.indices].set(p(x)) + return y + + @staticmethod + def folded_psd( + n_tt: Float[ArrayLike, ' _'], fft_size: int + ) -> Float[Array, ' {fft_size // 2 + 1}']: + """Returns the folded Power Spectral Density of a one-dimensional vector. + + Args: + n_tt: The autocorrelation function of the vector. + fft_size: The size of the FFT to use (at least twice the size of ``n_tt``). + """ + n_tt = jnp.asarray(n_tt) + kernel = GapFillingOperator._get_kernel(n_tt, fft_size) + psd = jnp.abs(jnp.fft.rfft(kernel, n=fft_size)) + # zero out DC value + psd = psd.at[0].set(0) + return psd + + @staticmethod + def _get_kernel(n_tt: Float[Array, ' _'], size: int) -> Float[Array, ' {size}']: + lagmax = n_tt.size - 1 + padding_size = size - (2 * lagmax + 1) + if padding_size < 0: + msg = f'The maximum lag ({lagmax}) is too large for the required kernel size ({size}).' + raise ValueError(msg) + kernel = jnp.concatenate((n_tt, jnp.zeros(padding_size), n_tt[-1:0:-1])) + return kernel + + def _generate_realization_for( + self, x: Float[Array, ' *shape'], key: PRNGKeyArray + ) -> Float[Array, ' *shape']: + @partial(jnp.vectorize, signature='(n),(k),()->(n)') + def func(x, n_tt, subkey): # type: ignore[no-untyped-def] + x_size = x.size + fft_size = default_fft_size(x_size) + npsd = fft_size // 2 + 1 + norm = self.rate * float(npsd - 1) + + # Get PSD values (size = fft_size // 2 + 1) + psd = self.folded_psd(n_tt, fft_size) + scale = jnp.sqrt(norm * psd) + + # Gaussian Re/Im random numbers + rngdata = jax.random.normal(subkey, shape=(fft_size,)) + + fdata = jnp.empty(npsd, dtype=complex) + + # Set DC and Nyquist frequency imaginary parts to zero + fdata = fdata.at[0].set(rngdata[0] + 0.0j) + fdata = fdata.at[-1].set(rngdata[npsd - 1] + 0.0j) + + # Repack the other values + fdata = fdata.at[1:-1].set(rngdata[1 : npsd - 1] + 1j * rngdata[-1 : npsd - 1 : -1]) + + # scale by PSD and inverse FFT + tdata = jnp.fft.irfft(fdata * scale) + + # subtract the DC level for the samples we want + offset = (fft_size - x_size) // 2 + xi = tdata[offset : offset + x_size] + xi -= jnp.mean(xi) + return xi + + subkeys = self.detectors.split_key(key) + real: Float[Array, '...'] = func(x, self.cov.band_values, subkeys) + return real diff --git a/tests/preprocessing/__init__.py b/tests/preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/preprocessing/test_gap_filling.py b/tests/preprocessing/test_gap_filling.py new file mode 100644 index 0000000..ec499f4 --- /dev/null +++ b/tests/preprocessing/test_gap_filling.py @@ -0,0 +1,134 @@ +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from furax._base.indices import IndexOperator +from furax.detectors import DetectorArray +from furax.operators.toeplitz import SymmetricBandToeplitzOperator +from furax.preprocessing.gap_filling import GapFillingOperator + + +class FakeDetectorArray(DetectorArray): + def __init__(self, num: int | tuple[int, ...]) -> None: + super().__init__(np.zeros(num), np.zeros(num), 1.0) + + +@pytest.mark.parametrize( + 'n_tt, fft_size, expected_kernel', + [ + ([1], 1, [1]), + ([1], 2, [1, 0]), + ([1], 4, [1, 0, 0, 0]), + ([1, 2], 4, [1, 2, 0, 2]), + ([3, 2, 1], 8, [3, 2, 1, 0, 0, 0, 1, 2]), + ], +) +def test_get_kernel(n_tt: list[int], fft_size: int, expected_kernel: list[int]): + n_tt = jnp.array(n_tt) + expected_kernel = np.array(expected_kernel) + actual_kernel = GapFillingOperator._get_kernel(n_tt, fft_size) + assert_allclose(actual_kernel, expected_kernel) + + +@pytest.mark.parametrize('n_tt, fft_size', [([1, 2], 1), ([1, 2, 3], 4)]) +def test_get_kernel_fail_lagmax(n_tt: list[int], fft_size: int): + # This test should fail because the maximum lag is too large for the required fft_size + n_tt = jnp.array(n_tt) + with pytest.raises(ValueError): + _ = GapFillingOperator._get_kernel(n_tt, fft_size) + + +@pytest.mark.parametrize('do_jit', [False, True]) +@pytest.mark.parametrize('x_shape', [(1,), (10,), (1, 100), (2, 10), (2, 100), (1, 2, 100)]) +def test_generate_realization_shape(x_shape: tuple[int, ...], do_jit: bool): + x = jnp.zeros(x_shape, dtype=float) + key = jax.random.key(31415926539) + structure = jax.ShapeDtypeStruct(x.shape, x.dtype) + cov = SymmetricBandToeplitzOperator(jnp.array([1.0]), structure) + indices = jnp.where(jnp.ones_like(x, dtype=bool)) + mask_op = IndexOperator(indices, in_structure=structure) + dets = FakeDetectorArray(x_shape[:-1]) + op = GapFillingOperator(cov, mask_op, dets) + if do_jit: + # avoid error: TypeError: unhashable type: 'jaxlib.xla_extension.ArrayImpl' + func = jax.jit(lambda x, k: op._generate_realization_for(x, k)) + else: + func = op._generate_realization_for + real = func(x, key) + assert real.shape == x_shape + + +@pytest.fixture +def dummy_shape(): + shape = (2, 100) + return shape + + +@pytest.fixture +def dummy_x(dummy_shape): + key = jax.random.key(987654321) + x = jax.random.uniform(key, dummy_shape, dtype=float) + return x + + +@pytest.fixture +def dummy_detectors(dummy_shape): + return FakeDetectorArray(dummy_shape[0]) + + +@pytest.fixture +def dummy_mask(dummy_shape): + mask = jnp.ones(dummy_shape, dtype=bool) + samples = dummy_shape[-1] + gap_size = samples // 10 + left, right = (samples - gap_size) // 2, (samples + gap_size) // 2 + mask = mask.at[:, left:right].set(False) + return mask + + +@pytest.fixture +def dummy_mask_op(dummy_x, dummy_mask): + structure = jax.ShapeDtypeStruct(dummy_x.shape, dummy_x.dtype) + indices = jnp.where(dummy_mask) + mask_op = IndexOperator(indices, in_structure=structure) + return mask_op + + +@pytest.fixture +def dummy_cov(dummy_x): + structure = jax.ShapeDtypeStruct(dummy_x.shape, dummy_x.dtype) + cov = SymmetricBandToeplitzOperator(jnp.array([1.0]), structure) + return cov + + +@pytest.fixture +def dummy_gap_filling_operator(dummy_shape, dummy_mask, dummy_detectors): + x = jnp.ones(dummy_shape, dtype=float) + structure = jax.ShapeDtypeStruct(x.shape, x.dtype) + cov = SymmetricBandToeplitzOperator(jnp.array([1.0]), structure) + indices = jnp.where(dummy_mask) + mask_op = IndexOperator(indices, in_structure=structure) + return GapFillingOperator(cov, mask_op, dummy_detectors) + + +@pytest.mark.parametrize( + 'n_tt, fft_size', [([1], 1), ([1], 2), ([1], 4), ([1, 2], 4), ([3, 2, 1], 8)] +) +def test_get_psd_non_negative(n_tt, fft_size): + n_tt = np.array(n_tt) + psd = GapFillingOperator.folded_psd(n_tt, fft_size) + assert np.all(psd >= 0) + + +@pytest.mark.parametrize('do_jit', [False, True]) +def test_valid_samples_and_no_nans(do_jit, dummy_shape, dummy_x, dummy_gap_filling_operator): + op = dummy_gap_filling_operator + if do_jit: + func = jax.jit(lambda k, x: op(k, x)) + else: + func = op + y = func(jax.random.key(1234), dummy_x) + assert_allclose(op.mask_op(dummy_x), op.mask_op(y)) + assert not np.any(np.isnan(y))