Skip to content

Commit

Permalink
Merge branch 'main' into power_spec
Browse files Browse the repository at this point in the history
  • Loading branch information
lgarrison committed Oct 9, 2023
2 parents f999025 + 9e30912 commit 45a6dc1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ exclude: |
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.287"
rev: "v0.0.292"
hooks:
- id: ruff
# TODO: turning off line length check
Expand Down
45 changes: 28 additions & 17 deletions abacusnbody/analysis/power_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ def calc_pk_from_deltak(field_fft, Lbox, k_bin_edges, mu_bin_edges, field2_fft=N
return binned_pk, Nmode, binned_poles, N_mode_poles


def get_field(pos, Lbox, nmesh, paste, w=None, d=0., nthread=MAX_THREADS):
def get_field(pos, Lbox, nmesh, paste, w=None, d=0., nthread=MAX_THREADS, dtype=np.float32):
r"""
Construct real-space 3D field given particle positions.
Expand All @@ -727,6 +727,8 @@ def get_field(pos, Lbox, nmesh, paste, w=None, d=0., nthread=MAX_THREADS):
uniform shift to particle positions.
nthread : int, optional
Number of numba threads to use
dtype : np.dtype, optional
Data type of the field
Returns
-------
Expand All @@ -736,21 +738,19 @@ def get_field(pos, Lbox, nmesh, paste, w=None, d=0., nthread=MAX_THREADS):
# check if weights are requested
if w is not None:
assert pos.shape[0] == len(w)
pos = pos.astype(np.float32, copy=False)

field = np.zeros((nmesh, nmesh, nmesh), dtype=dtype)
paste = paste.upper()
if paste == 'TSC':
if d != 0.:
field = tsc_parallel(pos, nmesh, Lbox, weights=w, nthread=nthread, offset=d)
else:
field = tsc_parallel(pos, nmesh, Lbox, weights=w, nthread=nthread)
tsc_parallel(pos, field, Lbox, weights=w, nthread=nthread, offset=d)
elif paste == 'CIC':
field = np.zeros((nmesh, nmesh, nmesh), dtype=np.float32)
warnings.warn("Note that currently CIC pasting, unlike TSC, supports only a non-parallel implementation.")
if d != 0.:
cic_serial(pos + d, field, Lbox, weights=w)
else:
cic_serial(pos, field, Lbox, weights=w)
else:
raise ValueError(f"Unknown pasting method: {paste}")
if w is None: # in the zcv code the weights are already normalized, so don't normalize here
# TODO assuming normalized weights is fragile
# same as passing "Value" to nbodykit (1+delta)(x) V(x)
Expand Down Expand Up @@ -895,7 +895,7 @@ def get_interlaced_field_fft(pos, Lbox, nmesh, paste, w, nthread=MAX_THREADS, ve
return field_fft


def get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthread=MAX_THREADS, verbose=False):
def get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthread=MAX_THREADS, verbose=False, dtype=np.float32):
r"""
Calculate field from particle positions and return 3D Fourier field.
Expand All @@ -917,26 +917,28 @@ def get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthrea
want to apply interlacing?
nthread : int, optional
Number of numba threads to use
verbose : bool, optional
Print out debugging info
dtype : np.dtype, optional
Data type of the field
Returns
-------
field_fft : array_like
interlaced 3D Fourier field.
"""

# get field in real space
field = get_field(pos, Lbox, nmesh, paste, w, nthread=nthread)
if verbose:
print("field, pos", field.dtype, pos.dtype)
if interlaced:
# get interlaced field
field_fft = get_interlaced_field_fft(pos, Lbox, nmesh, paste, w, nthread=nthread)
else:
# get field in real space
field = get_field(pos, Lbox, nmesh, paste, w, nthread=nthread)
field = get_field(pos, Lbox, nmesh, paste, w, nthread=nthread, dtype=dtype)
if verbose:
print("field, pos", field.dtype, pos.dtype)

# get Fourier modes from skewers grid
inv_size = np.float32(1 / field.size)
inv_size = dtype(1 / field.size)
field_fft = rfftn(field, overwrite_x=True, workers=nthread)
_normalize(field_fft, inv_size, nthread=nthread)

Expand Down Expand Up @@ -986,11 +988,14 @@ def get_W_compensated(Lbox, nmesh, paste, interlaced):
k = (fftfreq(nmesh, d=d) * 2. * np.pi).astype(np.float32) # h/Mpc

# apply deconvolution
paste = paste.upper()
if interlaced:
if paste == 'TSC':
p = 3.
elif paste == 'CIC':
p = 2.
else:
raise ValueError(f"Unknown pasting method {paste}")
W = np.sinc(0.5*k/kN)**p # sinc def
else: # first order correction of interlacing (aka aliasing)
s = np.sin(0.5 * np.pi * k/kN)**2
Expand All @@ -1012,6 +1017,7 @@ def calc_field(pos,
pos2 = None,
w2 = None,
nthread = MAX_THREADS,
dtype = np.float32,
):
r"""
Compute the 3D Fourier field(s) for some array(s) of positions.
Expand All @@ -1036,6 +1042,8 @@ def calc_field(pos,
second set of particle positions, shape (N,3)
nthread : int, optional
Number of numba threads to use
dtype : np.dtype, optional
Data type of the field
Returns
-------
Expand All @@ -1050,12 +1058,12 @@ def calc_field(pos,
W = None

# convert to fourier space
field_fft = get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthread=nthread)
field_fft = get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthread=nthread, dtype=dtype)

# if second field provided
if pos2 is not None:
# convert to fourier space
field2_fft = get_field_fft(pos2, Lbox, nmesh, paste, w2, W, compensated, interlaced, nthread=nthread)
field2_fft = get_field_fft(pos2, Lbox, nmesh, paste, w2, W, compensated, interlaced, nthread=nthread, dtype=dtype)
else:
field2_fft = None

Expand All @@ -1077,6 +1085,7 @@ def calc_power(pos,
w2 = None,
poles = None,
nthread = MAX_THREADS,
dtype = np.float32,
):
r"""
Compute the 3D power spectrum given particle positions by first painting them on a
Expand Down Expand Up @@ -1119,6 +1128,8 @@ def calc_power(pos,
Default of None gives the monopole.
nthread : int, optional
Number of numba threads to use
dtype : np.dtype, optional
Data type of the field
Returns
-------
Expand Down Expand Up @@ -1156,7 +1167,7 @@ def calc_power(pos,


# calculate Fourier 3D field
field_fft, field2_fft = calc_field(pos, Lbox, paste=paste, nmesh=nmesh, compensated=compensated, interlaced=interlaced, w=w, pos2=pos2, w2=w2, nthread=nthread)
field_fft, field2_fft = calc_field(pos, Lbox, paste=paste, nmesh=nmesh, compensated=compensated, interlaced=interlaced, w=w, pos2=pos2, w2=w2, nthread=nthread, dtype=dtype)

poles = np.asarray(poles or [], dtype=np.int64)

Expand Down

0 comments on commit 45a6dc1

Please sign in to comment.