Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gap filling #4

Merged
merged 16 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 17 additions & 33 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
[build-system]
build-backend = 'setuptools.build_meta'
requires = [
'setuptools==69.2.0',
'setuptools_scm[toml]>=6.2',
'wheel',
]
requires = ['setuptools==69.2.0', 'setuptools_scm[toml]>=6.2', 'wheel']

[project]
name = 'furax'
authors = [
{name = 'Pierre Chanial', email = '[email protected]'},
]
maintainers = [
{name = 'Pierre Chanial', email = '[email protected]'},
{ name = 'Pierre Chanial', email = '[email protected]' },
{ name = 'Simon Biquard', email = '[email protected]' },
]
maintainers = [{ name = 'Pierre Chanial', email = '[email protected]' }]
description = 'Operators and solvers for high-performance computing.'
readme = 'README.md'
keywords = [
'scientific computing',
]
keywords = ['scientific computing']
classifiers = [
'Programming Language :: Python',
'Programming Language :: Python :: 3',
Expand All @@ -27,7 +20,7 @@ classifiers = [
'Topic :: Scientific/Engineering',
]
requires-python = '>=3.10'
license = {file = 'LICENSE'}
license = { file = 'LICENSE' }
dependencies = [
'jaxtyping',
'healpy>=0.16.6',
Expand All @@ -42,13 +35,7 @@ dependencies = [
dynamic = ['version']

[project.optional-dependencies]
dev = [
'pytest',
'pytest-cov',
'pytest-mock',
'setuptools_scm',
'beartype',
]
dev = ['pytest', 'pytest-cov', 'pytest-mock', 'setuptools_scm', 'beartype']

[project.urls]
homepage = 'https://scipol.in2p3.fr'
Expand All @@ -65,23 +52,14 @@ show_missing = true
skip_covered = true

[[tool.mypy.overrides]]
module = [
'healpy',
'jax_healpy',
'lineax',
'scipy.stats.sampling'
]
module = ['healpy', 'jax_healpy', 'lineax', 'scipy.stats.sampling']
ignore_missing_imports = true

[tool.pytest.ini_options]
# addopts = '-ra --cov=furax --jaxtyping-packages=furax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))'
addopts = '-s -ra --color=yes'
testpaths = [
'tests',
]
markers = [
"slow: mark test as slow.",
]
testpaths = ['tests']
markers = ["slow: mark test as slow."]

#[tool.setuptools]
#packages = ['src/furax']
Expand All @@ -104,7 +82,13 @@ select = [
# flake8-debugger
'T10',
]
ignore = ['E203', 'E731', 'E741', 'F722']
ignore = [
'E203',
'E731',
'E741',
'F722',
'UP037', # conflicts with jaxtyping Array annotations
]

[tool.ruff.lint.per-file-ignores]
"src/furax/_base/core.py" = ['E743']
Expand Down
26 changes: 25 additions & 1 deletion src/furax/detectors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from hashlib import sha1
from itertools import product

import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Float
from jaxtyping import Array, Float, PRNGKeyArray, Shaped, UInt32


class DetectorArray:
Expand All @@ -22,5 +26,25 @@ def __init__(
coords /= length
self.coords = jax.device_put(coords)

# generate fake names for the detectors
# TODO(simon): accept user-defined names
widths = [len(str(s - 1)) for s in self.shape]
indices = [[f'{i:0{width}}' for i in range(dim)] for dim, width in zip(self.shape, widths)]
flat_names = ['DET_' + ''.join(combination) for combination in product(*indices)]
self.names = np.array(flat_names).reshape(self.shape)

def __len__(self) -> int:
return int(np.prod(self.shape))

def split_key(self, key: PRNGKeyArray) -> Shaped[PRNGKeyArray, ' _']:
"""Produces a new pseudo-random key for each detector."""
fold = jax.numpy.vectorize(jax.random.fold_in, signature='(),()->()')
subkeys: Shaped[PRNGKeyArray, '...'] = fold(key, self._ids())
return subkeys

def _ids(self) -> UInt32[Array, '...']:
# vectorized hashing + converting to int + keeping only 7 bytes
name_to_int = np.vectorize(lambda s: int(sha1(s.encode()).hexdigest(), 16) & 0xEFFFFFFF)
# return detectors IDs as unsigned 32-bit integers
ids: UInt32[Array, '...'] = jnp.uint32(name_to_int(self.names))
return ids
Empty file.
135 changes: 135 additions & 0 deletions src/furax/preprocessing/gap_filling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from functools import partial

import equinox
import jax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike, Float, PRNGKeyArray

from furax._base.indices import IndexOperator
from furax.detectors import DetectorArray
from furax.operators.toeplitz import SymmetricBandToeplitzOperator

default_fft_size = SymmetricBandToeplitzOperator._get_default_fft_size

__all__ = [
'GapFillingOperator',
]


class GapFillingOperator(equinox.Module):
"""Class for filling masked time samples with a constrained noise realization.
sbiquard marked this conversation as resolved.
Show resolved Hide resolved

We follow the gap-filling algorithm described in https://doi.org/10.1103/PhysRevD.65.022003,
section II.C, page 6. It assumes that the noise is piece wise stationary and has Gaussian
statistics, described by the covariance matrix ``cov``.

Example:
Gap-filling a single-detector timestream

>>> detectors = DetectorArray(jnp.array(0), jnp.array(0), jnp.array(1))
>>> key = jax.random.key(0)
>>> key, subkey = jax.random.split(key)
>>> nsamples = 10
>>> x = jax.random.normal(subkey, detectors.shape + (nsamples,))
>>> in_structure = jax.ShapeDtypeStruct(x.shape, x.dtype)
>>> mask = jnp.array([1, 1, 1, 0, 0, 1, 1, 1, 1, 1], dtype=bool)
>>> mask_op = IndexOperator(jnp.where(mask), in_structure=in_structure)
>>> cov = SymmetricBandToeplitzOperator(jnp.array([1.0]), in_structure)
>>> gf = GapFillingOperator(cov, mask_op, detectors)
>>> gap_filled_x = gf(key, x)
>>> assert gap_filled_x.shape == x.shape
>>> assert jnp.all(gap_filled_x[mask] == x[mask])

Attributes:
cov: A SymmetricBandToeplitzOperator representing the noise covariance matrix.
mask_op: An IndexOperator for masking the gaps.
detectors: A DetectorArray representing the detectors.
rate: The sampling rate of the data.
"""

cov: SymmetricBandToeplitzOperator
mask_op: IndexOperator
detectors: DetectorArray
rate: float = 1.0

def __call__(self, key: PRNGKeyArray, x: Float[Array, ' *shape']) -> Float[Array, ' *shape']:
"""Performs the gap-filling operation.

Args:
key: A PRNG key generated by the user.
x: The vector to be filled.
"""
real: Float[Array, '...'] = self._generate_realization_for(x, key)
p, u = self.mask_op, self.mask_op.T # pack, unpack operators
incomplete_cov = p @ self.cov @ u
op = self.cov @ u @ incomplete_cov.I @ p
y: Float[Array, '...'] = real + op(x - real)
# copy valid samples from original vector
y = y.at[p.indices].set(p(x))
return y

@staticmethod
def folded_psd(
n_tt: Float[ArrayLike, ' _'], fft_size: int
) -> Float[Array, ' {fft_size // 2 + 1}']:
"""Returns the folded Power Spectral Density of a one-dimensional vector.

Args:
n_tt: The autocorrelation function of the vector.
fft_size: The size of the FFT to use (at least twice the size of ``n_tt``).
"""
n_tt = jnp.asarray(n_tt)
kernel = GapFillingOperator._get_kernel(n_tt, fft_size)
psd = jnp.abs(jnp.fft.rfft(kernel, n=fft_size))
# zero out DC value
psd = psd.at[0].set(0)
return psd

@staticmethod
def _get_kernel(n_tt: Float[Array, ' _'], size: int) -> Float[Array, ' {size}']:
lagmax = n_tt.size - 1
padding_size = size - (2 * lagmax + 1)
if padding_size < 0:
msg = f'The maximum lag ({lagmax}) is too large for the required kernel size ({size}).'
raise ValueError(msg)
kernel = jnp.concatenate((n_tt, jnp.zeros(padding_size), n_tt[-1:0:-1]))
return kernel

def _generate_realization_for(
self, x: Float[Array, ' *shape'], key: PRNGKeyArray
) -> Float[Array, ' *shape']:
@partial(jnp.vectorize, signature='(n),(k),()->(n)')
def func(x, n_tt, subkey): # type: ignore[no-untyped-def]
x_size = x.size
fft_size = default_fft_size(x_size)
npsd = fft_size // 2 + 1
norm = self.rate * float(npsd - 1)

# Get PSD values (size = fft_size // 2 + 1)
psd = self.folded_psd(n_tt, fft_size)
scale = jnp.sqrt(norm * psd)

# Gaussian Re/Im random numbers
rngdata = jax.random.normal(subkey, shape=(fft_size,))

fdata = jnp.empty(npsd, dtype=complex)

# Set DC and Nyquist frequency imaginary parts to zero
fdata = fdata.at[0].set(rngdata[0] + 0.0j)
fdata = fdata.at[-1].set(rngdata[npsd - 1] + 0.0j)

# Repack the other values
fdata = fdata.at[1:-1].set(rngdata[1 : npsd - 1] + 1j * rngdata[-1 : npsd - 1 : -1])

# scale by PSD and inverse FFT
tdata = jnp.fft.irfft(fdata * scale)

# subtract the DC level for the samples we want
offset = (fft_size - x_size) // 2
xi = tdata[offset : offset + x_size]
xi -= jnp.mean(xi)
return xi

subkeys = self.detectors.split_key(key)
real: Float[Array, '...'] = func(x, self.cov.band_values, subkeys)
return real
Empty file added tests/preprocessing/__init__.py
Empty file.
Loading
Loading