Skip to content

Commit

Permalink
cleaning up; shear calculation is faster now; added k_avg (only chang…
Browse files Browse the repository at this point in the history
…es performance by 5-10%)
  • Loading branch information
boryanah committed Nov 12, 2023
1 parent bfdaa7f commit 48c4fb6
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 133 deletions.
126 changes: 43 additions & 83 deletions abacusnbody/analysis/power_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def bin_kmu(n1d, L, kedges, Nmu, weights, poles=np.empty(0, 'i8'), dtype=np.floa
mean power spectrum per k for each Legendre multipole.
counts_poles : ndarray of int
number of modes per k.
weighted_counts_k : ndarray of float
mean wavenumber per (k, mu) wedge.
"""

numba.set_num_threads(nthread)
Expand All @@ -196,6 +198,7 @@ def bin_kmu(n1d, L, kedges, Nmu, weights, poles=np.empty(0, 'i8'), dtype=np.floa
else:
poles = poles.astype(np.int64)
weighted_counts_poles = np.zeros((nthread, len(poles), Nk), dtype=dtype)
weighted_counts_k = np.zeros((nthread, Nk, Nmu), dtype=dtype)

# Loop over all k vectors
for i in numba.prange(n1d):
Expand Down Expand Up @@ -226,6 +229,7 @@ def bin_kmu(n1d, L, kedges, Nmu, weights, poles=np.empty(0, 'i8'), dtype=np.floa

counts[tid, bk, bmu] += 1 if k == 0 else 2
weighted_counts[tid, bk, bmu] += weights[i, j, k] if k == 0 else dtype(2.)*weights[i, j, k]
weighted_counts_k[tid, bk, bmu] += np.sqrt(kmag2)*dk if k == 0 else dtype(2.)*np.sqrt(kmag2)*dk
if Np > 0:
for ip in range(len(poles)):
pole = poles[ip]
Expand All @@ -238,6 +242,7 @@ def bin_kmu(n1d, L, kedges, Nmu, weights, poles=np.empty(0, 'i8'), dtype=np.floa
counts = counts.sum(axis=0)
weighted_counts = weighted_counts.sum(axis=0)
weighted_counts_poles = weighted_counts_poles.sum(axis=0)
weighted_counts_k = weighted_counts_k.sum(axis=0)
counts_poles = counts.sum(axis=1)

for i in range(Nk):
Expand All @@ -247,7 +252,9 @@ def bin_kmu(n1d, L, kedges, Nmu, weights, poles=np.empty(0, 'i8'), dtype=np.floa
for j in range(Nmu):
if counts[i, j] != 0:
weighted_counts[i, j] /= dtype(counts[i, j])
return weighted_counts, counts, weighted_counts_poles, counts_poles
weighted_counts_k[i, j] /= dtype(counts[i, j])
return weighted_counts, counts, weighted_counts_poles, counts_poles, weighted_counts_k


@numba.njit(parallel=True, fastmath=True)
def bin_kppi(n1d, L, kedges, pimax, Npi, weights, dtype=np.float32,
Expand Down Expand Up @@ -376,7 +383,7 @@ def project_3d_to_poles(k_bin_edges, raw_p3d, Lbox, poles):
nmesh = raw_p3d.shape[0]
poles = np.asarray(poles)
raw_p3d = np.asarray(raw_p3d)
binned_p3d, N3d, binned_poles, Npoles = bin_kmu(nmesh, Lbox, k_bin_edges, Nmu=1, weights=raw_p3d, poles=poles)
binned_p3d, N3d, binned_poles, Npoles, k_avg = bin_kmu(nmesh, Lbox, k_bin_edges, Nmu=1, weights=raw_p3d, poles=poles)
binned_poles *= Lbox**3
return binned_poles, Npoles

Expand Down Expand Up @@ -585,7 +592,7 @@ def pk_to_xi(pk_fn, Lbox, r_bins, poles=[0, 2, 4], key='P_k3D_tr_tr'):
# bin into xi_ell(r)
nmesh = Xi.shape[0]
poles = np.asarray(poles)
_, _, binned_poles, Npoles = bin_kmu(nmesh, Lbox, r_bins, Nmu=1, weights=Xi, poles=poles, space='real')
_, _, binned_poles, Npoles, r_avg = bin_kmu(nmesh, Lbox, r_bins, Nmu=1, weights=Xi, poles=poles, space='real')
binned_poles *= nmesh**3
return r_binc, binned_poles, Npoles

Expand Down Expand Up @@ -655,7 +662,7 @@ def get_raw_power(field_fft, field2_fft=None):
raw_p3d = (np.abs(field_fft)**2)
return raw_p3d

@numba.njit(parallel=True, fastmath=True)
@numba.njit(parallel=False, fastmath=True)
def calc_pk_from_deltak(field_fft, Lbox, k_bin_edges, mu_bin_edges, field2_fft=None, poles=np.empty(0, 'i8'), nthread=MAX_THREADS):
r"""
Calculate the power spectrum of a given Fourier field, with binning in (k,mu).
Expand Down Expand Up @@ -689,6 +696,8 @@ def calc_pk_from_deltak(field_fft, Lbox, k_bin_edges, mu_bin_edges, field2_fft=N
mean power spectrum per k for each Legendre multipole.
Npoles : array_like
number of modes per k.
k_avg : array_like
mean wavenumber per (k, mu) wedge.
"""
numba.set_num_threads(nthread)

Expand All @@ -698,13 +707,13 @@ def calc_pk_from_deltak(field_fft, Lbox, k_bin_edges, mu_bin_edges, field2_fft=N
# power spectrum
nmesh = raw_p3d.shape[0]
Nmu = len(mu_bin_edges) - 1
binned_pk, Nmode, binned_poles, N_mode_poles = bin_kmu(nmesh, Lbox, k_bin_edges, Nmu, raw_p3d, poles, nthread=nthread)
pk3d, N3d, binned_poles, Npoles, k_avg = bin_kmu(nmesh, Lbox, k_bin_edges, Nmu, raw_p3d, poles, nthread=nthread)

# quantity above is dimensionless, multiply by box size (in Mpc/h)
binned_pk *= Lbox**3
pk3d *= Lbox**3
if len(poles) > 0:
binned_poles *= Lbox**3
return binned_pk, Nmode, binned_poles, N_mode_poles
return pk3d, N3d, binned_poles, Npoles, k_avg


def get_field(pos, Lbox, nmesh, paste, w=None, d=0., nthread=MAX_THREADS, dtype=np.float32):
Expand Down Expand Up @@ -1006,70 +1015,6 @@ def get_W_compensated(Lbox, nmesh, paste, interlaced):
del s
return W


def calc_field(pos,
Lbox,
paste = 'TSC',
nmesh = 128,
compensated = True,
interlaced = True,
w = None,
pos2 = None,
w2 = None,
nthread = MAX_THREADS,
dtype = np.float32,
):
r"""
Compute the 3D Fourier field(s) for some array(s) of positions.
Parameters
----------
pos : array_like
particle positions, shape (N,3)
Lbox : float
box size of the simulation.
paste : str, optional
particle pasting approach (CIC or TSC). Default is 'TSC'.
nmesh : int, optional
size of the 3d array along x and y dimension. Default is 128.
compensated : bool, optional
want to apply first-order compensated filter? Default is True.
interlaced : bool, optional
want to apply interlacing? Default is True.
w : array_like, optional
weights for each particle.
pos2 : array_like, optional
second set of particle positions, shape (N,3)
nthread : int, optional
Number of numba threads to use
dtype : np.dtype, optional
Data type of the field
Returns
-------
field_fft : array_like
Fourier 3D field.
"""

# get the window function
if compensated:
W = get_W_compensated(Lbox, nmesh, paste, interlaced)
else:
W = None

# convert to fourier space
field_fft = get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthread=nthread, dtype=dtype)

# if second field provided
if pos2 is not None:
# convert to fourier space
field2_fft = get_field_fft(pos2, Lbox, nmesh, paste, w2, W, compensated, interlaced, nthread=nthread, dtype=dtype)
else:
field2_fft = None

return field_fft, field2_fft


def calc_power(pos,
Lbox,
kbins = None,
Expand Down Expand Up @@ -1137,6 +1082,7 @@ def calc_power(pos,
The power spectrum in an astropy Table of length ``nbins_k``.
The columns are:
- ``k_mid``: arithmetic bin centers of the k wavenumbers, shape ``(nbins_k,)``
- ``k_avg``: mean wavenumber per (k, mu) wedge, shape ``(nbins_k,nbins_mu)``
- ``mu_mid``: arithmetic bin centers of the mu angles, shape ``(nbins_k,nbins_mu)``
- ``power``: mean power spectrum per (k, mu) wedge, shape ``(nbins_k,nbins_mu)``
- ``N_mode``: number of modes per (k, mu) wedge, shape ``(nbins_k,nbins_mu)``
Expand Down Expand Up @@ -1166,14 +1112,27 @@ def calc_power(pos,
)


# calculate Fourier 3D field
field_fft, field2_fft = calc_field(pos, Lbox, paste=paste, nmesh=nmesh, compensated=compensated, interlaced=interlaced, w=w, pos2=pos2, w2=w2, nthread=nthread, dtype=dtype)
# get the window function
if compensated:
W = get_W_compensated(Lbox, nmesh, paste, interlaced)
else:
W = None

# convert to fourier space
field_fft = get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthread=nthread, dtype=dtype)

# if second field provided
if pos2 is not None:
# convert to fourier space
field2_fft = get_field_fft(pos2, Lbox, nmesh, paste, w2, W, compensated, interlaced, nthread=nthread, dtype=dtype)
else:
field2_fft = None

poles = np.asarray(poles or [], dtype=np.int64)

# calculate power spectrum
kbins, mubins = get_k_mu_edges(Lbox, k_max, kbins, mubins, logk)
pk, N_mode, binned_poles, N_mode_poles = calc_pk_from_deltak(field_fft, Lbox, kbins, mubins, field2_fft=field2_fft, poles=poles, nthread=nthread)
pk3d, N3d, binned_poles, Npoles, k_avg = calc_pk_from_deltak(field_fft, Lbox, kbins, mubins, field2_fft=field2_fft, poles=poles, nthread=nthread)

# define bin centers
k_binc = (kbins[1:] + kbins[:-1])*.5
Expand All @@ -1183,19 +1142,20 @@ def calc_power(pos,
k_min=kbins[:-1],
k_max=kbins[1:],
k_mid=k_binc,
power=pk,
N_mode=N_mode,
k_avg=k_avg,
power=pk3d,
N_mode=N3d,
)
if return_mubins:
res.update(
mu_min=np.broadcast_to(mubins[:-1], pk.shape),
mu_max=np.broadcast_to(mubins[1:], pk.shape),
mu_mid=np.broadcast_to(mu_binc, pk.shape),
)
if len(poles) > 0:
res.update(
poles=binned_poles.T,
N_mode_poles=N_mode_poles,
N_mode_poles=Npoles,
)
if return_mubins:
res.update(
mu_min=np.broadcast_to(mubins[:-1], pk3d.shape),
mu_max=np.broadcast_to(mubins[1:], pk3d.shape),
mu_mid=np.broadcast_to(mu_binc, pk3d.shape),
)
res = Table(res, meta=meta)

Expand Down
Loading

0 comments on commit 48c4fb6

Please sign in to comment.