Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX backend #8

Merged
merged 3 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

## Features

* Compute integral transforms
$$G(y) = \int_0^\infty F(x) K(xy) \frac{dx}x$$
* Inverse transform without analytic inversion
* Integral kernels as derivatives
$$G(y) = \int_0^\infty F(x) K'(xy) \frac{dx}x$$
* Transform input array along any axis of `numpy.ndarray`
* Output the matrix form
* 1-to-n transform for multiple kernels (TODO)
$$G(y_1, \cdots, y_n) = \int_0^\infty \frac{dx}x F(x) \prod_{a=1}^n K_a(xy_a)$$
* Easily extensible for other kernels
* Compute integral transforms:
$$G(y) = \int_0^\infty F(x) K(xy) \frac{dx}x;$$
* Inverse transform without analytic inversion;
* Integral kernels as derivatives:
$$G(y) = \int_0^\infty F(x) K'(xy) \frac{dx}x;$$
* Transform input array along any axis of `ndarray`;
* Output the matrix form;
* 1-to-n transform for multiple kernels (TODO):
$$G(y_1, \cdots, y_n) = \int_0^\infty \frac{dx}x F(x) \prod_{a=1}^n K_a(xy_a);$$
* Easily extensible for other kernels;
* Support NumPy and JAX.


## Algorithm
Expand All @@ -30,9 +31,9 @@ input function and the kernel.
`mcfit` implements the FFTLog algorithm.
The idea is to take advantage of the convolution theorem in $\ln x$ and
$\ln y$.
It approximates the input function with truncated Fourier series over
one period of a periodic approximant, and use the exact Fourier
transform of the kernel.
It approximates the input function with a partial sum of the Fourier
series over one period of a periodic approximant, and use the exact
Fourier transform of the kernel.
One can calculate the latter analytically as a Mellin transform.
This algorithm is optimal when the input function is smooth in $\ln x$,
and is ideal for oscillatory kernels with input spanning a wide range in
Expand Down
102 changes: 65 additions & 37 deletions mcfit/mcfit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import numpy as np
import math
import cmath
import warnings

import numpy
try:
import jax
jax.config.update("jax_enable_x64", True)
except ModuleNotFoundError as e:
JAXNotFoundError = e


class mcfit(object):
r"""Compute integral transforms as a multiplicative convolution.
Expand Down Expand Up @@ -45,6 +53,8 @@ class mcfit(object):
Note that :math:`x_{max}` is not included in `x` but bigger than
`x.max()` by one log interval due to the discretization of the periodic
approximant, and likewise for :math:`y_{max}`
backend : str in {'numpy', 'jax'}, optional
Which backend to use.

Attributes
----------
Expand Down Expand Up @@ -108,8 +118,23 @@ class mcfit(object):

"""

def __init__(self, x, MK, q, N=2j, lowring=False, xy=1):
self.x = np.asarray(x)
def __init__(self, x, MK, q, N=2j, lowring=False, xy=1, backend='numpy'):
if backend == 'numpy':
self.np = numpy
#self.jit = lambda fun: fun # TODO maybe use Numba?
elif backend == 'jax':
try:
self.np = jax.numpy
#self.jit = jax.jit # TODO maybe leave it to the user? jax.jit for CPU too
except NameError:
raise JAXNotFoundError
else:
raise ValueError(f"backend {backend} not supported")

#self.__call__ = self.jit(self.__call__)
#self.matrix = self.jit(self.matrix)

self.x = self.np.asarray(x)
self.Nin = len(x)
self.MK = MK
self.q = q
Expand Down Expand Up @@ -143,31 +168,34 @@ def postfac(self, value):

def _setup(self):
if self.Nin < 2:
raise ValueError("input size must not be smaller than 2")
Delta = np.log(self.x[-1] / self.x[0]) / (self.Nin - 1)
x_head = self.x[:10]
if not np.allclose(np.log(x_head[1:] / x_head[:-1]), Delta, rtol=1e-3):
raise ValueError(f"input size {self.Nin} must not be smaller than 2")
Delta = math.log(self.x[-1] / self.x[0]) / (self.Nin - 1)
x_head = self.x[:8]
if not self.np.allclose(self.np.log(x_head[1:] / x_head[:-1]), Delta,
rtol=1e-3):
warnings.warn("input must be log-spaced")

if isinstance(self.N, complex):
folds = int(np.ceil(np.log2(self.Nin * self.N.imag)))
folds = math.ceil(math.log2(self.Nin * self.N.imag))
self.N = 2**folds
if self.N < self.Nin:
raise ValueError("convolution size must not be smaller than input size")
raise ValueError(f"convolution size {self.N} must not be smaller than "
f"the input size {self.Nin}")

if self.lowring and self.N % 2 == 0:
lnxy = Delta / np.pi * np.angle(self.MK(self.q + 1j * np.pi / Delta))
self.xy = np.exp(lnxy)
lnxy = Delta / math.pi * cmath.phase(self.MK(self.q + 1j * math.pi / Delta))
self.xy = math.exp(lnxy)
else:
lnxy = np.log(self.xy)
self.y = np.exp(lnxy - Delta) / self.x[::-1]
lnxy = math.log(self.xy)
self.y = math.exp(lnxy - Delta) / self.x[::-1]

self._x_ = self._pad(self.x, 0, True, False)
self._y_ = self._pad(self.y, 0, True, True)

m = np.arange(0, self.N//2 + 1)
self._u = self.MK(self.q + 2j * np.pi / self.N / Delta * m)
self._u *= np.exp(-2j * np.pi * lnxy / self.N / Delta * m)
m = numpy.arange(0, self.N//2 + 1)
self._u = self.MK(self.q + 2j * math.pi / self.N / Delta * m)
self._u *= numpy.exp(-2j * math.pi * lnxy / self.N / Delta * m)
self._u = self.np.asarray(self._u, dtype=(self.x[0] + 0j).dtype)

# following is unnecessary because hfft ignores the imag at Nyquist anyway
#if not self.lowring and self.N % 2 == 0:
Expand Down Expand Up @@ -205,7 +233,7 @@ def __call__(self, F, axis=-1, extrap=False, keeppads=False, convonly=False):
output function

"""
F = np.asarray(F)
F = self.np.asarray(F)

to_axis = [1] * F.ndim
to_axis[axis] = -1
Expand All @@ -215,9 +243,9 @@ def __call__(self, F, axis=-1, extrap=False, keeppads=False, convonly=False):
f = self._xfac_.reshape(to_axis) * f

# convolution
f = np.fft.rfft(f, axis=axis) # f(x_n) -> f_m
f = self.np.fft.rfft(f, axis=axis) # f(x_n) -> f_m
g = f * self._u.reshape(to_axis) # f_m -> g_m
g = np.fft.hfft(g, n=self.N, axis=axis) / self.N # g_m -> g(y_n)
g = self.np.fft.hfft(g, n=self.N, axis=axis) / self.N # g_m -> g(y_n)

if not keeppads:
G = self._unpad(g, axis, True)
Expand Down Expand Up @@ -288,8 +316,8 @@ def matrix(self, full=False, keeppads=True):
matrix.

"""
v = np.fft.hfft(self._u, n=self.N) / self.N
idx = sum(np.ogrid[0:self.N, -self.N:0])
v = self.np.fft.hfft(self._u, n=self.N) / self.N
idx = sum(self.np.ogrid[0:self.N, -self.N:0])
C = v[idx] # follow scipy.linalg.{circulant,toeplitz,hankel}

if keeppads:
Expand Down Expand Up @@ -352,32 +380,32 @@ def _pad(self, a, axis, extrap, out):

if isinstance(_extrap, bool):
if _extrap:
end = np.take(a, [0], axis=axis)
ratio = np.take(a, [1], axis=axis) / end
exp = np.arange(-_Npad, 0).reshape(to_axis)
end = self.np.take(a, self.np.array([0]), axis=axis)
ratio = self.np.take(a, self.np.array([1]), axis=axis) / end
exp = self.np.arange(-_Npad, 0).reshape(to_axis)
_a = end * ratio ** exp
else:
_a = np.zeros(a.shape[:axis] + (_Npad,) + a.shape[axis+1:])
_a = self.np.zeros(a.shape[:axis] + (_Npad,) + a.shape[axis+1:])
elif _extrap == 'const':
end = np.take(a, [0], axis=axis)
_a = np.repeat(end, _Npad, axis=axis)
end = self.np.take(a, self.np.array([0]), axis=axis)
_a = self.np.repeat(end, _Npad, axis=axis)
else:
raise ValueError("left extrap not supported")
raise ValueError(f"left extrap {_extrap} not supported")
if isinstance(extrap_, bool):
if extrap_:
end = np.take(a, [-1], axis=axis)
ratio = end / np.take(a, [-2], axis=axis)
exp = np.arange(1, Npad_ + 1).reshape(to_axis)
end = self.np.take(a, self.np.array([-1]), axis=axis)
ratio = end / self.np.take(a, self.np.array([-2]), axis=axis)
exp = self.np.arange(1, Npad_ + 1).reshape(to_axis)
a_ = end * ratio ** exp
else:
a_ = np.zeros(a.shape[:axis] + (Npad_,) + a.shape[axis+1:])
a_ = self.np.zeros(a.shape[:axis] + (Npad_,) + a.shape[axis+1:])
elif extrap_ == 'const':
end = np.take(a, [-1], axis=axis)
a_ = np.repeat(end, Npad_, axis=axis)
end = self.np.take(a, self.np.array([-1]), axis=axis)
a_ = self.np.repeat(end, Npad_, axis=axis)
else:
raise ValueError("right extrap not supported")
raise ValueError(f"right extrap {extrap_} not supported")

return np.concatenate((_a, a, a_), axis=axis)
return self.np.concatenate((_a, a, a_), axis=axis)

def _unpad(self, a, axis, out):
"""Undo padding in an array.
Expand All @@ -404,4 +432,4 @@ def _unpad(self, a, axis, out):
else:
_Npad, Npad_ = Npad//2, Npad - Npad//2

return np.take(a, range(_Npad, self.N - Npad_), axis=axis)
return self.np.take(a, self.np.arange(_Npad, self.N - Npad_), axis=axis)