diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ab6b7e97..a4133ee7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ exclude: | repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.0.287" + rev: "v0.0.292" hooks: - id: ruff # TODO: turning off line length check diff --git a/abacusnbody/analysis/power_spectrum.py b/abacusnbody/analysis/power_spectrum.py index 6add22ca..614aec10 100644 --- a/abacusnbody/analysis/power_spectrum.py +++ b/abacusnbody/analysis/power_spectrum.py @@ -707,7 +707,7 @@ def calc_pk_from_deltak(field_fft, Lbox, k_bin_edges, mu_bin_edges, field2_fft=N return binned_pk, Nmode, binned_poles, N_mode_poles -def get_field(pos, Lbox, nmesh, paste, w=None, d=0., nthread=MAX_THREADS): +def get_field(pos, Lbox, nmesh, paste, w=None, d=0., nthread=MAX_THREADS, dtype=np.float32): r""" Construct real-space 3D field given particle positions. @@ -727,6 +727,8 @@ def get_field(pos, Lbox, nmesh, paste, w=None, d=0., nthread=MAX_THREADS): uniform shift to particle positions. nthread : int, optional Number of numba threads to use + dtype : np.dtype, optional + Data type of the field Returns ------- @@ -736,21 +738,19 @@ def get_field(pos, Lbox, nmesh, paste, w=None, d=0., nthread=MAX_THREADS): # check if weights are requested if w is not None: assert pos.shape[0] == len(w) - pos = pos.astype(np.float32, copy=False) + field = np.zeros((nmesh, nmesh, nmesh), dtype=dtype) paste = paste.upper() if paste == 'TSC': - if d != 0.: - field = tsc_parallel(pos, nmesh, Lbox, weights=w, nthread=nthread, offset=d) - else: - field = tsc_parallel(pos, nmesh, Lbox, weights=w, nthread=nthread) + tsc_parallel(pos, field, Lbox, weights=w, nthread=nthread, offset=d) elif paste == 'CIC': - field = np.zeros((nmesh, nmesh, nmesh), dtype=np.float32) warnings.warn("Note that currently CIC pasting, unlike TSC, supports only a non-parallel implementation.") if d != 0.: cic_serial(pos + d, field, Lbox, weights=w) else: cic_serial(pos, field, Lbox, weights=w) + else: + raise ValueError(f"Unknown pasting method: {paste}") if w is None: # in the zcv code the weights are already normalized, so don't normalize here # TODO assuming normalized weights is fragile # same as passing "Value" to nbodykit (1+delta)(x) V(x) @@ -895,7 +895,7 @@ def get_interlaced_field_fft(pos, Lbox, nmesh, paste, w, nthread=MAX_THREADS, ve return field_fft -def get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthread=MAX_THREADS, verbose=False): +def get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthread=MAX_THREADS, verbose=False, dtype=np.float32): r""" Calculate field from particle positions and return 3D Fourier field. @@ -917,6 +917,10 @@ def get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthrea want to apply interlacing? nthread : int, optional Number of numba threads to use + verbose : bool, optional + Print out debugging info + dtype : np.dtype, optional + Data type of the field Returns ------- @@ -924,19 +928,17 @@ def get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthrea interlaced 3D Fourier field. """ - # get field in real space - field = get_field(pos, Lbox, nmesh, paste, w, nthread=nthread) - if verbose: - print("field, pos", field.dtype, pos.dtype) if interlaced: # get interlaced field field_fft = get_interlaced_field_fft(pos, Lbox, nmesh, paste, w, nthread=nthread) else: # get field in real space - field = get_field(pos, Lbox, nmesh, paste, w, nthread=nthread) + field = get_field(pos, Lbox, nmesh, paste, w, nthread=nthread, dtype=dtype) + if verbose: + print("field, pos", field.dtype, pos.dtype) # get Fourier modes from skewers grid - inv_size = np.float32(1 / field.size) + inv_size = dtype(1 / field.size) field_fft = rfftn(field, overwrite_x=True, workers=nthread) _normalize(field_fft, inv_size, nthread=nthread) @@ -986,11 +988,14 @@ def get_W_compensated(Lbox, nmesh, paste, interlaced): k = (fftfreq(nmesh, d=d) * 2. * np.pi).astype(np.float32) # h/Mpc # apply deconvolution + paste = paste.upper() if interlaced: if paste == 'TSC': p = 3. elif paste == 'CIC': p = 2. + else: + raise ValueError(f"Unknown pasting method {paste}") W = np.sinc(0.5*k/kN)**p # sinc def else: # first order correction of interlacing (aka aliasing) s = np.sin(0.5 * np.pi * k/kN)**2 @@ -1012,6 +1017,7 @@ def calc_field(pos, pos2 = None, w2 = None, nthread = MAX_THREADS, + dtype = np.float32, ): r""" Compute the 3D Fourier field(s) for some array(s) of positions. @@ -1036,6 +1042,8 @@ def calc_field(pos, second set of particle positions, shape (N,3) nthread : int, optional Number of numba threads to use + dtype : np.dtype, optional + Data type of the field Returns ------- @@ -1050,12 +1058,12 @@ def calc_field(pos, W = None # convert to fourier space - field_fft = get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthread=nthread) + field_fft = get_field_fft(pos, Lbox, nmesh, paste, w, W, compensated, interlaced, nthread=nthread, dtype=dtype) # if second field provided if pos2 is not None: # convert to fourier space - field2_fft = get_field_fft(pos2, Lbox, nmesh, paste, w2, W, compensated, interlaced, nthread=nthread) + field2_fft = get_field_fft(pos2, Lbox, nmesh, paste, w2, W, compensated, interlaced, nthread=nthread, dtype=dtype) else: field2_fft = None @@ -1077,6 +1085,7 @@ def calc_power(pos, w2 = None, poles = None, nthread = MAX_THREADS, + dtype = np.float32, ): r""" Compute the 3D power spectrum given particle positions by first painting them on a @@ -1119,6 +1128,8 @@ def calc_power(pos, Default of None gives the monopole. nthread : int, optional Number of numba threads to use + dtype : np.dtype, optional + Data type of the field Returns ------- @@ -1156,7 +1167,7 @@ def calc_power(pos, # calculate Fourier 3D field - field_fft, field2_fft = calc_field(pos, Lbox, paste=paste, nmesh=nmesh, compensated=compensated, interlaced=interlaced, w=w, pos2=pos2, w2=w2, nthread=nthread) + field_fft, field2_fft = calc_field(pos, Lbox, paste=paste, nmesh=nmesh, compensated=compensated, interlaced=interlaced, w=w, pos2=pos2, w2=w2, nthread=nthread, dtype=dtype) poles = np.asarray(poles or [], dtype=np.int64)