Skip to content

Commit

Permalink
Merge pull request #556 from IainHammond/master
Browse files Browse the repository at this point in the history
subtract radial profile and recentering multiprocessing
  • Loading branch information
VChristiaens authored Oct 21, 2022
2 parents 20f5b51 + 090ab5b commit aed258b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 15 deletions.
6 changes: 3 additions & 3 deletions vip_hci/fm/fakecomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def normalize_psf(array, fwhm='fit', size=None, threshold=None, mask_core=None,
estimate the FWHM in 2D or 3D PSF arrays.
size : int or None, optional
If int it will correspond to the size of the centered sub-image to be
cropped form the PSF array. The PSF is assumed to be rougly centered wrt
cropped form the PSF array. The PSF is assumed to be roughly centered wrt
the array.
threshold : None or float, optional
Sets to zero values smaller than threshold (in the normalized image).
Expand All @@ -612,13 +612,13 @@ def normalize_psf(array, fwhm='fit', size=None, threshold=None, mask_core=None,
See the documentation of the ``vip_hci.preproc.frame_shift`` function.
interpolation : str, optional
See the documentation of the ``vip_hci.preproc.frame_shift`` function.
force_odd : str, optional
force_odd : bool, optional
If True the resulting array will have odd size (and the PSF will be
placed at its center). If False, and the frame size is even, then the
PSF will be put at the center of an even-sized frame.
correct_outliers: bool, optional
For an input 3D cube (IFS) of PSFs, if the 2D fit fails for one of the
channels, whether to interpolate fwhm value from surrounding channels,
channels, whether to interpolate FWHM value from surrounding channels,
and recalculate flux and normalization.
full_output : bool, optional
If True the flux in a FWHM aperture is returned along with the
Expand Down
21 changes: 14 additions & 7 deletions vip_hci/preproc/recentering.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def frame_shift(array, shift_y, shift_x, imlib='vip-fft',


def cube_shift(cube, shift_y, shift_x, imlib='vip-fft',
interpolation='lanczos4', border_mode='reflect', nproc=1):
interpolation='lanczos4', border_mode='reflect', nproc=None):
""" Shifts the X-Y coordinates of a cube or 3D array by x and y values.
Parameters
Expand Down Expand Up @@ -1094,7 +1094,7 @@ def cube_recenter_dft_upsampling(array, center_fr1=None, negative=False,
fwhm=4, subi_size=None, upsample_factor=100,
imlib='vip-fft', interpolation='lanczos4',
mask=None, border_mode='reflect',
full_output=False, verbose=True, nproc=1,
full_output=False, verbose=True, nproc=None,
save_shifts=False, debug=False, plot=True):
""" Recenters a cube of frames using the DFT upsampling method as proposed
in [GUI08]_ and implemented in the ``register_translation`` function from
Expand Down Expand Up @@ -1218,6 +1218,9 @@ def cube_recenter_dft_upsampling(array, center_fr1=None, negative=False,

# Finding the shifts with DFT upsampling of each frame wrt the first

if nproc is None:
nproc = cpu_count() // 2 # Hyper-threading doubles the # of cores

if nproc == 1:
for i in Progressbar(range(1, n_frames),
desc="frames", verbose=verbose):
Expand Down Expand Up @@ -1250,7 +1253,7 @@ def cube_recenter_dft_upsampling(array, center_fr1=None, negative=False,
x[:] += cx - x1
y[:] += cy - y1
array_rec = cube_shift(array, shift_y=y, shift_x=x, imlib=imlib,
interpolation=interpolation)
interpolation=interpolation, nproc=nproc)
if verbose:
msg = "Shift for first frame X,Y=({:.3f}, {:.3f})"
print(msg.format(x[0], y[0]))
Expand Down Expand Up @@ -1317,7 +1320,7 @@ def _shift_dft(array_rec, array, frnum, upsample_factor, mask, interpolation,


def cube_recenter_2dfit(array, xy=None, fwhm=4, subi_size=5, model='gauss',
nproc=1, imlib='vip-fft', interpolation='lanczos4',
nproc=None, imlib='vip-fft', interpolation='lanczos4',
offset=None, negative=False, threshold=False,
sigfactor=2, fix_neg=False, params_2g=None,
border_mode='reflect', save_shifts=False,
Expand Down Expand Up @@ -1663,7 +1666,7 @@ def cube_recenter_via_speckles(cube_sci, cube_ref=None, alignment_iter=5,
fit_type='gaus', negative=True, crop=True,
subframesize=21, mask=None, imlib='vip-fft',
interpolation='lanczos4', border_mode='reflect',
plot=True, full_output=False):
plot=True, full_output=False, nproc=None):
""" Registers frames based on the median speckle pattern. Optionally centers
based on the position of the vortex null in the median frame. Images are
filtered to isolate speckle spatial frequencies.
Expand Down Expand Up @@ -1722,6 +1725,7 @@ def cube_recenter_via_speckles(cube_sci, cube_ref=None, alignment_iter=5,
full_output: bool, optional
Whether to return more variables, useful for debugging.
Returns
-------
cube_reg_sci : numpy 3d ndarray
Expand All @@ -1744,6 +1748,9 @@ def cube_recenter_via_speckles(cube_sci, cube_ref=None, alignment_iter=5,
n, y, x = cube_sci.shape
check_array(cube_sci, dim=3)

if nproc is None:
nproc = cpu_count()//2

if recenter_median and fit_type not in {'gaus', 'ann'}:
raise TypeError("fit type not recognized. Should be 'ann' or 'gaus'")

Expand Down Expand Up @@ -1861,7 +1868,7 @@ def cube_recenter_via_speckles(cube_sci, cube_ref=None, alignment_iter=5,
subi_size=None, full_output=True,
verbose=False, plot=False,
mask=mask_tmp, imlib=imlib,
interpolation=interpolation)
interpolation=interpolation, nproc=nproc)
_, y_shift, x_shift = res
sqsum_shifts = np.sum(np.sqrt(y_shift ** 2 + x_shift ** 2))
print('Square sum of shift vecs: ' + str(sqsum_shifts))
Expand Down Expand Up @@ -1930,7 +1937,7 @@ def cube_recenter_via_speckles(cube_sci, cube_ref=None, alignment_iter=5,
def _fit_2dannulus(array, fwhm=4, crop=False, cent=None, cropsize=15,
hole_rad=0.5, sampl_cen=0.1, sampl_rad=None, ann_width=0.5,
unc_in=2.):
"""Finds the center the center of a donut-shape signal (e.g. a coronagraphic
"""Finds the center of a donut-shape signal (e.g. a coronagraphic
PSF) by fitting an annulus, using a grid of positions for the center and
radius of the annulus. The best fit is found by maximizing the mean flux
measured in the annular mask. Requires the image to be already roughly
Expand Down
26 changes: 21 additions & 5 deletions vip_hci/stats/im_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from ..var import frame_center
from ..var import frame_center, mask_circle
from ..config.utils_conf import check_array, vip_figsize


def frame_average_radprofile(frame, sep=1, init_rad=None, plot=True):
def frame_average_radprofile(frame, sep=1, init_rad=None, subtr_profile=False,
plot=True):
""" Calculates the average radial profile of an image.
Parameters
Expand All @@ -24,13 +25,21 @@ def frame_average_radprofile(frame, sep=1, init_rad=None, plot=True):
Input image or 2d array.
sep : int, optional
The average radial profile is recorded every ``sep`` pixels.
init_rad : int, optional
Initial radius in pixels from the center of the image to begin
calculating the average radial profile.
subtr_profile : boolean, optional
If True, the average radial profile is subtracted from the frame and
returned as a second output. Inner mask is applied if init_rad is provided.
plot : bool, optional
If True the profile is plotted.
If True, the profile is plotted.
Returns
-------
df : dataframe
Pandas dataframe with the radial profile and distances.
subtr_frame : numpy ndarray, 2d
[subtr_profile=True] Frame with the radial profile subtracted.
Note
----
Expand All @@ -44,7 +53,7 @@ def frame_average_radprofile(frame, sep=1, init_rad=None, plot=True):

if init_rad is None:
init_rad = 1
x, y = np.indices((frame.shape))
x, y = np.indices(frame.shape)
r = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
r = r.astype(int)
tbin = np.bincount(r.ravel(), frame.ravel())
Expand All @@ -66,7 +75,14 @@ def frame_average_radprofile(frame, sep=1, init_rad=None, plot=True):
plt.minorticks_on()
plt.xlim(0)

return df
if subtr_profile:
radprofile_img = radprofile[r]
subtr_frame = frame - radprofile_img
if init_rad > 1:
subtr_frame = mask_circle(subtr_frame, radius=init_rad)
return df, subtr_frame
else:
return df


def frame_histo_stats(image_array, plot=True):
Expand Down

0 comments on commit aed258b

Please sign in to comment.