Skip to content

Commit

Permalink
Merge pull request #11 from DifferentiableUniverseInitiative/u/EiffL/…
Browse files Browse the repository at this point in the history
…lensing

Adds basic utilities for Born lensing
  • Loading branch information
EiffL authored May 18, 2022
2 parents e93aa07 + ff5fe80 commit 0991789
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 0 deletions.
81 changes: 81 additions & 0 deletions jaxpm/lensing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import jax
import jax.numpy as jnp
import jax_cosmo.constants as constants
import jax_cosmo

from jax.scipy.ndimage import map_coordinates
from jaxpm.utils import gaussian_smoothing
from jaxpm.painting import cic_paint_2d

def density_plane(positions,
box_shape,
center,
width,
plane_resolution,
smoothing_sigma=None):
""" Extacts a density plane from the simulation
"""
nx, ny, nz = box_shape
xy = positions[..., :2]
d = positions[..., 2]

# Apply 2d periodic conditions
xy = jnp.mod(xy, nx)

# Rescaling positions to target grid
xy = xy / nx * plane_resolution

# Selecting only particles that fall inside the volume of interest
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
# Painting density plane
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)

# Apply density normalization
density_plane = density_plane / ((nx / plane_resolution) *
(ny / plane_resolution) * (width))

# Apply Gaussian smoothing if requested
if smoothing_sigma is not None:
density_plane = gaussian_smoothing(density_plane,
smoothing_sigma)

return density_plane


def convergence_Born(cosmo,
density_planes,
dx, dz,
coords,
z_source):
"""
Compute the Born convergence
Args:
cosmo: `Cosmology`, cosmology object.
density_planes: list of tuples (r, a, density_plane), lens planes to use
dx: float, transverse pixel resolution of the density planes [Mpc/h]
dz: float, width of the density planes [Mpc/h]
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
name: `string`, name of the operation.
Returns:
`Tensor` of shape [batch_size, N, Nz], of convergence values.
"""
# Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))

convergence = 0
for r, a, p in density_planes:
# Normalize density planes
density_normalization = dz * r / a
p = (p - p.mean()) * constant_factor * density_normalization

# Interpolate at the density plane coordinates
im = map_coordinates(p,
coords * r / dx - 0.5,
order=1, mode="wrap")

convergence += im * jnp.clip(1. - (r / r_s), 0, 1000).reshape([-1, 1, 1])

return convergence
28 changes: 28 additions & 0 deletions jaxpm/painting.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,34 @@ def cic_read(mesh, positions):
neighboor_coords[...,1],
neighboor_coords[...,3]]*kernel).sum(axis=-1)

def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh
mesh: [nx, ny]
positions: [npart, 2]
weight: [npart]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])

neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[...,jnp.newaxis]

neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))

dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1),
scatter_dims_to_operand_dims=(0, 1))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1,4]),
dnums)
return mesh

def compensate_cic(field):
"""
Compensate for CiC painting
Expand Down
18 changes: 18 additions & 0 deletions jaxpm/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import jax.numpy as jnp
from jax.scipy.stats import norm

__all__ = ['power_spectrum']

Expand Down Expand Up @@ -79,3 +80,20 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2

return kbins, P / norm

def gaussian_smoothing(im, sigma):
"""
im: 2d image
sigma: smoothing scale in px
"""
# Compute k vector
kvec = jnp.stack(jnp.meshgrid(jnp.fft.fftfreq(im.shape[0]),
jnp.fft.fftfreq(im.shape[1])),
axis=-1)
k = jnp.linalg.norm(kvec, axis=-1)
# We compute the value of the filter at frequency k
filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma))
filter /= filter[0,0]

return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real

0 comments on commit 0991789

Please sign in to comment.