diff --git a/README.md b/README.md index fe07f03..042f474 100644 --- a/README.md +++ b/README.md @@ -3,16 +3,17 @@ ## Features -* Compute integral transforms - $$G(y) = \int_0^\infty F(x) K(xy) \frac{dx}x$$ -* Inverse transform without analytic inversion -* Integral kernels as derivatives - $$G(y) = \int_0^\infty F(x) K'(xy) \frac{dx}x$$ -* Transform input array along any axis of `numpy.ndarray` -* Output the matrix form -* 1-to-n transform for multiple kernels (TODO) - $$G(y_1, \cdots, y_n) = \int_0^\infty \frac{dx}x F(x) \prod_{a=1}^n K_a(xy_a)$$ -* Easily extensible for other kernels +* Compute integral transforms: + $$G(y) = \int_0^\infty F(x) K(xy) \frac{dx}x;$$ +* Inverse transform without analytic inversion; +* Integral kernels as derivatives: + $$G(y) = \int_0^\infty F(x) K'(xy) \frac{dx}x;$$ +* Transform input array along any axis of `ndarray`; +* Output the matrix form; +* 1-to-n transform for multiple kernels (TODO): + $$G(y_1, \cdots, y_n) = \int_0^\infty \frac{dx}x F(x) \prod_{a=1}^n K_a(xy_a);$$ +* Easily extensible for other kernels; +* Support NumPy and JAX. ## Algorithm @@ -30,9 +31,9 @@ input function and the kernel. `mcfit` implements the FFTLog algorithm. The idea is to take advantage of the convolution theorem in $\ln x$ and $\ln y$. -It approximates the input function with truncated Fourier series over -one period of a periodic approximant, and use the exact Fourier -transform of the kernel. +It approximates the input function with a partial sum of the Fourier +series over one period of a periodic approximant, and use the exact +Fourier transform of the kernel. One can calculate the latter analytically as a Mellin transform. This algorithm is optimal when the input function is smooth in $\ln x$, and is ideal for oscillatory kernels with input spanning a wide range in diff --git a/mcfit/mcfit.py b/mcfit/mcfit.py index 8863133..9fa24a5 100644 --- a/mcfit/mcfit.py +++ b/mcfit/mcfit.py @@ -1,6 +1,14 @@ -import numpy as np +import math +import cmath import warnings +import numpy +try: + import jax + jax.config.update("jax_enable_x64", True) +except ModuleNotFoundError as e: + JAXNotFoundError = e + class mcfit(object): r"""Compute integral transforms as a multiplicative convolution. @@ -45,6 +53,8 @@ class mcfit(object): Note that :math:`x_{max}` is not included in `x` but bigger than `x.max()` by one log interval due to the discretization of the periodic approximant, and likewise for :math:`y_{max}` + backend : str in {'numpy', 'jax'}, optional + Which backend to use. Attributes ---------- @@ -108,8 +118,23 @@ class mcfit(object): """ - def __init__(self, x, MK, q, N=2j, lowring=False, xy=1): - self.x = np.asarray(x) + def __init__(self, x, MK, q, N=2j, lowring=False, xy=1, backend='numpy'): + if backend == 'numpy': + self.np = numpy + #self.jit = lambda fun: fun # TODO maybe use Numba? + elif backend == 'jax': + try: + self.np = jax.numpy + #self.jit = jax.jit # TODO maybe leave it to the user? jax.jit for CPU too + except NameError: + raise JAXNotFoundError + else: + raise ValueError(f"backend {backend} not supported") + + #self.__call__ = self.jit(self.__call__) + #self.matrix = self.jit(self.matrix) + + self.x = self.np.asarray(x) self.Nin = len(x) self.MK = MK self.q = q @@ -143,31 +168,34 @@ def postfac(self, value): def _setup(self): if self.Nin < 2: - raise ValueError("input size must not be smaller than 2") - Delta = np.log(self.x[-1] / self.x[0]) / (self.Nin - 1) - x_head = self.x[:10] - if not np.allclose(np.log(x_head[1:] / x_head[:-1]), Delta, rtol=1e-3): + raise ValueError(f"input size {self.Nin} must not be smaller than 2") + Delta = math.log(self.x[-1] / self.x[0]) / (self.Nin - 1) + x_head = self.x[:8] + if not self.np.allclose(self.np.log(x_head[1:] / x_head[:-1]), Delta, + rtol=1e-3): warnings.warn("input must be log-spaced") if isinstance(self.N, complex): - folds = int(np.ceil(np.log2(self.Nin * self.N.imag))) + folds = math.ceil(math.log2(self.Nin * self.N.imag)) self.N = 2**folds if self.N < self.Nin: - raise ValueError("convolution size must not be smaller than input size") + raise ValueError(f"convolution size {self.N} must not be smaller than " + f"the input size {self.Nin}") if self.lowring and self.N % 2 == 0: - lnxy = Delta / np.pi * np.angle(self.MK(self.q + 1j * np.pi / Delta)) - self.xy = np.exp(lnxy) + lnxy = Delta / math.pi * cmath.phase(self.MK(self.q + 1j * math.pi / Delta)) + self.xy = math.exp(lnxy) else: - lnxy = np.log(self.xy) - self.y = np.exp(lnxy - Delta) / self.x[::-1] + lnxy = math.log(self.xy) + self.y = math.exp(lnxy - Delta) / self.x[::-1] self._x_ = self._pad(self.x, 0, True, False) self._y_ = self._pad(self.y, 0, True, True) - m = np.arange(0, self.N//2 + 1) - self._u = self.MK(self.q + 2j * np.pi / self.N / Delta * m) - self._u *= np.exp(-2j * np.pi * lnxy / self.N / Delta * m) + m = numpy.arange(0, self.N//2 + 1) + self._u = self.MK(self.q + 2j * math.pi / self.N / Delta * m) + self._u *= numpy.exp(-2j * math.pi * lnxy / self.N / Delta * m) + self._u = self.np.asarray(self._u, dtype=(self.x[0] + 0j).dtype) # following is unnecessary because hfft ignores the imag at Nyquist anyway #if not self.lowring and self.N % 2 == 0: @@ -205,7 +233,7 @@ def __call__(self, F, axis=-1, extrap=False, keeppads=False, convonly=False): output function """ - F = np.asarray(F) + F = self.np.asarray(F) to_axis = [1] * F.ndim to_axis[axis] = -1 @@ -215,9 +243,9 @@ def __call__(self, F, axis=-1, extrap=False, keeppads=False, convonly=False): f = self._xfac_.reshape(to_axis) * f # convolution - f = np.fft.rfft(f, axis=axis) # f(x_n) -> f_m + f = self.np.fft.rfft(f, axis=axis) # f(x_n) -> f_m g = f * self._u.reshape(to_axis) # f_m -> g_m - g = np.fft.hfft(g, n=self.N, axis=axis) / self.N # g_m -> g(y_n) + g = self.np.fft.hfft(g, n=self.N, axis=axis) / self.N # g_m -> g(y_n) if not keeppads: G = self._unpad(g, axis, True) @@ -288,8 +316,8 @@ def matrix(self, full=False, keeppads=True): matrix. """ - v = np.fft.hfft(self._u, n=self.N) / self.N - idx = sum(np.ogrid[0:self.N, -self.N:0]) + v = self.np.fft.hfft(self._u, n=self.N) / self.N + idx = sum(self.np.ogrid[0:self.N, -self.N:0]) C = v[idx] # follow scipy.linalg.{circulant,toeplitz,hankel} if keeppads: @@ -352,32 +380,32 @@ def _pad(self, a, axis, extrap, out): if isinstance(_extrap, bool): if _extrap: - end = np.take(a, [0], axis=axis) - ratio = np.take(a, [1], axis=axis) / end - exp = np.arange(-_Npad, 0).reshape(to_axis) + end = self.np.take(a, self.np.array([0]), axis=axis) + ratio = self.np.take(a, self.np.array([1]), axis=axis) / end + exp = self.np.arange(-_Npad, 0).reshape(to_axis) _a = end * ratio ** exp else: - _a = np.zeros(a.shape[:axis] + (_Npad,) + a.shape[axis+1:]) + _a = self.np.zeros(a.shape[:axis] + (_Npad,) + a.shape[axis+1:]) elif _extrap == 'const': - end = np.take(a, [0], axis=axis) - _a = np.repeat(end, _Npad, axis=axis) + end = self.np.take(a, self.np.array([0]), axis=axis) + _a = self.np.repeat(end, _Npad, axis=axis) else: - raise ValueError("left extrap not supported") + raise ValueError(f"left extrap {_extrap} not supported") if isinstance(extrap_, bool): if extrap_: - end = np.take(a, [-1], axis=axis) - ratio = end / np.take(a, [-2], axis=axis) - exp = np.arange(1, Npad_ + 1).reshape(to_axis) + end = self.np.take(a, self.np.array([-1]), axis=axis) + ratio = end / self.np.take(a, self.np.array([-2]), axis=axis) + exp = self.np.arange(1, Npad_ + 1).reshape(to_axis) a_ = end * ratio ** exp else: - a_ = np.zeros(a.shape[:axis] + (Npad_,) + a.shape[axis+1:]) + a_ = self.np.zeros(a.shape[:axis] + (Npad_,) + a.shape[axis+1:]) elif extrap_ == 'const': - end = np.take(a, [-1], axis=axis) - a_ = np.repeat(end, Npad_, axis=axis) + end = self.np.take(a, self.np.array([-1]), axis=axis) + a_ = self.np.repeat(end, Npad_, axis=axis) else: - raise ValueError("right extrap not supported") + raise ValueError(f"right extrap {extrap_} not supported") - return np.concatenate((_a, a, a_), axis=axis) + return self.np.concatenate((_a, a, a_), axis=axis) def _unpad(self, a, axis, out): """Undo padding in an array. @@ -404,4 +432,4 @@ def _unpad(self, a, axis, out): else: _Npad, Npad_ = Npad//2, Npad - Npad//2 - return np.take(a, range(_Npad, self.N - Npad_), axis=axis) + return self.np.take(a, self.np.arange(_Npad, self.N - Npad_), axis=axis)