diff --git a/vip_hci/greedy/ipca_fullfr.py b/vip_hci/greedy/ipca_fullfr.py index e0fe3899..694a3f2e 100644 --- a/vip_hci/greedy/ipca_fullfr.py +++ b/vip_hci/greedy/ipca_fullfr.py @@ -34,6 +34,13 @@ from ..metrics import stim_map, inverse_stim_map from ..var import prepare_matrix, mask_circle, frame_filter_lowpass +try: + from GreeDS import GreeDS + no_greeds = False +except ImportError: + msg = "GreeDS python bindings are missing." + warnings.warn(msg, ImportWarning) + no_greeds = True @dataclass class IPCA_Params(PCA_Params): @@ -91,7 +98,7 @@ def ipca(*all_args: List, **all_kwargs: dict): shortest wavelength in the cube (more thorough approaches can be used to get the scaling factors). This scaling factors are used to re-scale the spectral channels and align the speckles. - mode: str or None, opt {'Pairet18', 'Pairet21'} + mode: str or None, opt {'Pairet18', 'Pairet21','Juillard23'} - If None: runs with provided value of 'n_comp' for 'nit' iterations, and considering threshold 'thr'. - If 'Pairet18': runs for n_comp iterations, with n_comp=1,...,n_comp, @@ -101,6 +108,12 @@ def ipca(*all_args: List, **all_kwargs: dict): n_comp (i.e. outer loop on n_comp, inner loop on nit). thr set to 0 (ignored if provided). 'thr' parameter discarded, always set to 0. - If 'Christiaens21': same as 'Pairet21', but with 'thr' parameter used. + - If 'Juillard23': Exact implementation from Juillard et al. 2023 using + Torch. Parameter conventions are the same as in 'Christiaens21' and + 'Pairet21'. This method has no additional options for significant + signal extraction and works exclusively with Torch, making it faster + but also more prone to propagate noise and disk flux. Installation of + the GreeDS package is required for this option. ncomp : int or tuple/list of int, optional How many PCs are used as a lower-dimensional subspace to project the target frames. @@ -121,7 +134,7 @@ def ipca(*all_args: List, **all_kwargs: dict): total number of iterations - if mode is 'Pairet18': this parameter is ignored. Number of iterations will be ncomp. - - if mode is 'Pairet21' or 'Christiaens21': + - if mode is 'Pairet21', 'Christiaens21' or 'Juillard23': iterations per tested ncomp. strategy: str {'ADI, 'RDI', 'ARDI', 'RADI'}, opt Whether to do iterative ADI only ('ADI'), iterative RDI only ('RDI'), @@ -351,269 +364,344 @@ def _blurring_3d(array, mask_center_sz, fwhm_sz=2): # force full_output pca_params['full_output'] = True pca_params['verbose'] = False # too verbose otherwise + r = algo_params.ncomp + l = algo_params.nit + + if algo_params.mode=="Juillard23": + if no_greeds: + msg = 'GreeDS Python bindings cannot be imported. Install GreeDS (pip install GreeDS) or use a different method.' + raise RuntimeError(msg) + if algo_params.strategy not in ['ADI',"ARDI"]: + msg = 'Juillard23 not compatible with this mode.' + raise RuntimeError(msg) + + if algo_params.strategy == 'ARDI': + ref = algo_params.cube_ref.copy() + else : ref = None + + mask_center_px = algo_params.mask_center_px + pup = mask_center_px if mask_center_px is not None else 0 + + if algo_params.full_output is True : + it_cube, star_estim = GreeDS(algo_params.cube, algo_params.angle_list, refs=ref,r=r, l=l, r_start=1, pup=pup, full_output=1, returnL=True) + else : it_cube = GreeDS(algo_params.cube, algo_params.angle_list, refs=ref,r=r, l=l, r_start=1, pup=pup, full_output=1, returnL=False) + frame = it_cube[-1] + + # Set results matching full outputs + it=len(it_cube)-1 + algo_params.thr = 0 + + stim_cube = it_cube_nd = sig_images = nstim = sig_mask = np.zeros(it_cube.shape) + + if algo_params.full_output is True : + print(algo_params.cube.shape) + print(star_estim.shape) + residuals_cube_ = cube_derotate(algo_params.cube - star_estim[-1], + algo_params.angle_list, + imlib="torch-fft", + nproc=algo_params.nproc) - frame + + residuals_cube = cube_derotate(residuals_cube_, + -algo_params.angle_list, + imlib="torch-fft", + nproc=algo_params.nproc) + + if algo_params.thr_mode == 'STIM': + for it_i in range(len(it_cube)): + residuals_cube__i = cube_derotate(algo_params.cube - star_estim[it_i], + algo_params.angle_list, + imlib="torch-fft", + nproc=algo_params.nproc) - it_cube[it_i] + + residuals_cube_i = cube_derotate(residuals_cube__i, + -algo_params.angle_list, + imlib="torch-fft", + nproc=algo_params.nproc) + + sig_mask_i, nstim_i = _find_significant_signals(residuals_cube_i, + residuals_cube__i, + algo_params.angle_list, + algo_params.thr, + mask=mask_center_px, + r_out=algo_params.r_out) + + sig_mask[it_i] = sig_mask_i.copy() + nstim[it_i] = nstim_i.copy() - # 1. Prepare/format additional parameters depending on chosen options - mask_center_px = algo_params.mask_center_px # None? what's better in pca? - mask_rdi_tmp = None - if algo_params.strategy == 'ADI' and algo_params.cube_ref is None: - ref_cube = None - mask_rdi_tmp = algo_params.mask_rdi - elif algo_params.cube_ref is not None: - if algo_params.strategy == 'ADI': - msg = "WARNING: requested strategy is 'ADI' but reference cube " - msg += "detected! Strategy automatically switched to 'ARDI'." - print(msg) - algo_params.strategy = 'ARDI' - if algo_params.mask_rdi is not None: - if isinstance(algo_params.mask_rdi, (list, tuple)): - mask_rdi_tmp = algo_params.mask_rdi else: - mask_rdi_tmp = algo_params.mask_rdi.copy() - if algo_params.cube_ref is None: - raise ValueError("cube_ref should be provided for RDI or RADI") - if algo_params.strategy == 'ARDI' and algo_params.mask_rdi is None: - ref_cube = np.concatenate((algo_params.cube, - algo_params.cube_ref), axis=0) - else: - ref_cube = algo_params.cube_ref.copy() + sig_mask = np.ones_like(it_cube) + sig_mask[np.where(it_cube < algo_params.thr)] = 0 + nstim = sig_mask.copy() + + sig_images = it_cube.copy() + sig_images[np.where(1-sig_mask)] = 0 + sig_images[np.where(sig_images < 0)] = 0 + stim_cube = nstim.copy() + else: - msg = "strategy not recognized: must be ADI, RDI, ARDI or RADI" - raise ValueError(msg) - - if isinstance(algo_params.ncomp, (float, int)): - ncomp_list = [algo_params.ncomp] - if algo_params.strategy == 'RADI': - ncomp_list.append(algo_params.ncomp) - elif isinstance(algo_params.ncomp, (tuple, list)): - ncomp_list = algo_params.ncomp - if len(algo_params.ncomp) == 1: + + # 1. Prepare/format additional parameters depending on chosen options + mask_center_px = algo_params.mask_center_px # None? what's better in pca? + mask_rdi_tmp = None + if algo_params.strategy == 'ADI' and algo_params.cube_ref is None: + ref_cube = None + mask_rdi_tmp = algo_params.mask_rdi + elif algo_params.cube_ref is not None: + if algo_params.strategy == 'ADI': + msg = "WARNING: requested strategy is 'ADI' but reference cube " + msg += "detected! Strategy automatically switched to 'ARDI'." + print(msg) + algo_params.strategy = 'ARDI' + if algo_params.mask_rdi is not None: + if isinstance(algo_params.mask_rdi, (list, tuple)): + mask_rdi_tmp = algo_params.mask_rdi + else: + mask_rdi_tmp = algo_params.mask_rdi.copy() + if algo_params.cube_ref is None: + raise ValueError("cube_ref should be provided for RDI or RADI") + if algo_params.strategy == 'ARDI' and algo_params.mask_rdi is None: + ref_cube = np.concatenate((algo_params.cube, + algo_params.cube_ref), axis=0) + else: + ref_cube = algo_params.cube_ref.copy() + else: + msg = "strategy not recognized: must be ADI, RDI, ARDI or RADI" + raise ValueError(msg) + + if isinstance(algo_params.ncomp, (float, int)): + ncomp_list = [algo_params.ncomp] if algo_params.strategy == 'RADI': ncomp_list.append(algo_params.ncomp) - elif not len(algo_params.ncomp) == 2: - raise ValueError("Length of npc list cannot be larger than 2") - else: - raise TypeError("ncomp should be float, int, tuple or list") + elif isinstance(algo_params.ncomp, (tuple, list)): + ncomp_list = algo_params.ncomp + if len(algo_params.ncomp) == 1: + if algo_params.strategy == 'RADI': + ncomp_list.append(algo_params.ncomp) + elif not len(algo_params.ncomp) == 2: + raise ValueError("Length of npc list cannot be larger than 2") + else: + raise TypeError("ncomp should be float, int, tuple or list") - ncomp_tmp = ncomp_list[0] - nframes = algo_params.cube.shape[0] - nit_ori = algo_params.nit + ncomp_tmp = ncomp_list[0] + nframes = algo_params.cube.shape[0] + nit_ori = algo_params.nit - if algo_params.mode is not None: - final_ncomp = list(range(1, ncomp_tmp+1, algo_params.ncomp_step)) - if algo_params.mode == 'Pairet18': - algo_params.nit = ncomp_tmp + if algo_params.mode is not None: final_ncomp = list(range(1, ncomp_tmp+1, algo_params.ncomp_step)) - algo_params.thr = 0 - elif algo_params.mode in ['Pairet21', 'Christiaens21']: - final_ncomp = [] - for npc in range(1, ncomp_tmp+1, algo_params.ncomp_step): - for ii in range(algo_params.nit): - final_ncomp.append(npc) - algo_params.nit = len(final_ncomp) - if algo_params.mode == 'Pairet21': + if algo_params.mode == 'Pairet18': + algo_params.nit = ncomp_tmp + final_ncomp = list(range(1, ncomp_tmp+1, algo_params.ncomp_step)) algo_params.thr = 0 - else: - final_ncomp = [ncomp_tmp]*algo_params.nit - - # Scale cube and cube_ref if necessary - cube_tmp = prepare_matrix(algo_params.cube, scaling=algo_params.scaling, - mask_center_px=mask_center_px, mode='fullfr', - verbose=False) - cube_tmp = np.reshape(cube_tmp, algo_params.cube.shape) - if ref_cube is not None: - cube_ref_tmp = prepare_matrix(ref_cube, scaling=algo_params.scaling, - mask_center_px=mask_center_px, - mode='fullfr', verbose=False) - cube_ref_tmp = np.reshape(cube_ref_tmp, ref_cube.shape) - else: - cube_ref_tmp = None - - # 2. Get a first disc estimate, using PCA - pca_params['ncomp'] = final_ncomp[0] - pca_params['cube_ref'] = ref_cube - res = pca(**pca_params, **rot_options) - frame = res[0] - residuals_cube = res[-2] - residuals_cube_ = res[-1] - # smoothing and manual derotation if requested - smooth_ker = algo_params.smooth_ker - if smooth_ker is None or np.isscalar(smooth_ker): - smooth_ker = np.array([smooth_ker]*algo_params.nit, dtype=object) - else: - smooth_ker = np.array(smooth_ker, dtype=object) - # if smooth_ker[0] is not None: - # residuals_cube = _blurring_3d(residuals_cube, None, - # fwhm_sz=smooth_ker[0]) - # residuals_cube_ = cube_derotate(residuals_cube, - # algo_params.angle_list, - # imlib=algo_params.imlib, - # nproc=algo_params.nproc) - # frame = cube_collapse(residuals_cube_, algo_params.collapse) - if smooth_ker[0] is not None: - frame = frame_filter_lowpass(frame, fwhm_size=smooth_ker[0]) - - # 3. Identify significant signals with STIM map - it_cube = np.zeros([algo_params.nit, frame.shape[0], frame.shape[1]]) - it_cube_nd = np.zeros_like(it_cube) - stim_cube = np.zeros_like(it_cube) - sig_images = np.zeros_like(it_cube) - it_cube[0] = frame.copy() - it_cube_nd[0] = frame.copy() - if algo_params.thr_mode == 'STIM': - sig_mask, nstim = _find_significant_signals(residuals_cube, - residuals_cube_, + elif algo_params.mode in ['Pairet21', 'Christiaens21']: + final_ncomp = [] + for npc in range(1, ncomp_tmp+1, algo_params.ncomp_step): + for ii in range(algo_params.nit): + final_ncomp.append(npc) + algo_params.nit = len(final_ncomp) + if algo_params.mode == 'Pairet21': + algo_params.thr = 0 + else: + final_ncomp = [ncomp_tmp]*algo_params.nit + + # Scale cube and cube_ref if necessary + cube_tmp = prepare_matrix(algo_params.cube, scaling=algo_params.scaling, + mask_center_px=mask_center_px, mode='fullfr', + verbose=False) + cube_tmp = np.reshape(cube_tmp, algo_params.cube.shape) + if ref_cube is not None: + cube_ref_tmp = prepare_matrix(ref_cube, scaling=algo_params.scaling, + mask_center_px=mask_center_px, + mode='fullfr', verbose=False) + cube_ref_tmp = np.reshape(cube_ref_tmp, ref_cube.shape) + else: + cube_ref_tmp = None + + # 2. Get a first disc estimate, using PCA + pca_params['ncomp'] = final_ncomp[0] + pca_params['cube_ref'] = ref_cube + res = pca(**pca_params, **rot_options) + frame = res[0] + residuals_cube = res[-2] + residuals_cube_ = res[-1] + # smoothing and manual derotation if requested + smooth_ker = algo_params.smooth_ker + if smooth_ker is None or np.isscalar(smooth_ker): + smooth_ker = np.array([smooth_ker]*algo_params.nit, dtype=object) + else: + smooth_ker = np.array(smooth_ker, dtype=object) + # if smooth_ker[0] is not None: + # residuals_cube = _blurring_3d(residuals_cube, None, + # fwhm_sz=smooth_ker[0]) + # residuals_cube_ = cube_derotate(residuals_cube, + # algo_params.angle_list, + # imlib=algo_params.imlib, + # nproc=algo_params.nproc) + # frame = cube_collapse(residuals_cube_, algo_params.collapse) + if smooth_ker[0] is not None: + frame = frame_filter_lowpass(frame, fwhm_size=smooth_ker[0]) + + # 3. Identify significant signals with STIM map + it_cube = np.zeros([algo_params.nit, frame.shape[0], frame.shape[1]]) + it_cube_nd = np.zeros_like(it_cube) + stim_cube = np.zeros_like(it_cube) + sig_images = np.zeros_like(it_cube) + it_cube[0] = frame.copy() + it_cube_nd[0] = frame.copy() + if algo_params.thr_mode == 'STIM': + sig_mask, nstim = _find_significant_signals(residuals_cube, + residuals_cube_, + algo_params.angle_list, + algo_params.thr, + mask=mask_center_px, + r_out=algo_params.r_out) + else: + sig_mask = np.ones_like(frame) + sig_mask[np.where(frame < algo_params.thr)] = 0 + nstim = sig_mask.copy() + sig_image = frame.copy() + sig_image[np.where(1-sig_mask)] = 0 + sig_image[np.where(sig_image < 0)] = 0 + sig_images[0] = sig_image.copy() + stim_cube[0] = nstim.copy() + mask_rdi_tmp = None # after first iteration do not use it any more + + # 4.Loop, updating the reference cube before projection by subtracting the + # best disc estimate. This is done by providing sig_cube. + cond_skip = False # whether to skip an iteration (e.g. in incremental mode) + for it in Progressbar(range(1, algo_params.nit), desc="Iterating..."): + if not cond_skip: + # Uncomment here (and comment below) to do like IROLL + # if smooth_ker[it] is not None: + # frame = _blurring_2d(frame, None, fwhm_sz=smooth_ker[it]) + # create and rotate sig cube + sig_cube = np.repeat(frame[np.newaxis, :, :], nframes, axis=0) + sig_cube = cube_derotate(sig_cube, -algo_params.angle_list, + imlib=algo_params.imlib, + nproc=algo_params.nproc) + + if algo_params.thr_mode == 'STIM': + # create and rotate binary mask + mask_sig = np.zeros_like(sig_image) + mask_sig[np.where(sig_image > 0)] = 1 + sig_mcube = np.repeat(mask_sig[np.newaxis, :, :], nframes, + axis=0) + sig_mcube = cube_derotate(sig_mcube, -algo_params.angle_list, + imlib='skimage', + interpolation='bilinear', + nproc=algo_params.nproc) + sig_cube[np.where(sig_mcube < 0.5)] = 0 + sig_cube[np.where(sig_cube < 0)] = 0 + else: + sig_cube[np.where(sig_cube < algo_params.thr)] = 0 + + if algo_params.strategy == 'ARDI': + ref_cube = np.concatenate((algo_params.cube-sig_cube, + algo_params.cube_ref), axis=0) + cube_ref_tmp = prepare_matrix(ref_cube, + scaling=algo_params.scaling, + mask_center_px=mask_center_px, + mode='fullfr', verbose=False) + cube_ref_tmp = np.reshape(cube_ref_tmp, ref_cube.shape) + + # Run PCA on original cube + # Update PCA PARAMS + pca_params['cube'] = algo_params.cube + pca_params['cube_ref'] = ref_cube + pca_params['ncomp'] = final_ncomp[it] + pca_params['scaling'] = algo_params.scaling + pca_params['cube_sig'] = sig_cube + pca_params['mask_rdi'] = mask_rdi_tmp + + res = pca(**pca_params, **rot_options) + + frame = res[0] + residuals_cube = res[-2] + it_cube[it] = frame.copy() + + # DON'T otherwise frame is smoothed twice! + # smoothing and manual derotation if requested + if smooth_ker[it] is not None: + residuals_cube = _blurring_3d(residuals_cube, None, + fwhm_sz=smooth_ker[it]) + residuals_cube_ = cube_derotate(residuals_cube, algo_params.angle_list, - algo_params.thr, - mask=mask_center_px, - r_out=algo_params.r_out) - else: - sig_mask = np.ones_like(frame) - sig_mask[np.where(frame < algo_params.thr)] = 0 - nstim = sig_mask.copy() - sig_image = frame.copy() - sig_image[np.where(1-sig_mask)] = 0 - sig_image[np.where(sig_image < 0)] = 0 - sig_images[0] = sig_image.copy() - stim_cube[0] = nstim.copy() - mask_rdi_tmp = None # after first iteration do not use it any more - - # 4.Loop, updating the reference cube before projection by subtracting the - # best disc estimate. This is done by providing sig_cube. - cond_skip = False # whether to skip an iteration (e.g. in incremental mode) - for it in Progressbar(range(1, algo_params.nit), desc="Iterating..."): - if not cond_skip: - # Uncomment here (and comment below) to do like IROLL - # if smooth_ker[it] is not None: - # frame = _blurring_2d(frame, None, fwhm_sz=smooth_ker[it]) - # create and rotate sig cube - sig_cube = np.repeat(frame[np.newaxis, :, :], nframes, axis=0) - sig_cube = cube_derotate(sig_cube, -algo_params.angle_list, - imlib=algo_params.imlib, - nproc=algo_params.nproc) - - if algo_params.thr_mode == 'STIM': - # create and rotate binary mask - mask_sig = np.zeros_like(sig_image) - mask_sig[np.where(sig_image > 0)] = 1 - sig_mcube = np.repeat(mask_sig[np.newaxis, :, :], nframes, - axis=0) - sig_mcube = cube_derotate(sig_mcube, -algo_params.angle_list, - imlib='skimage', - interpolation='bilinear', - nproc=algo_params.nproc) - sig_cube[np.where(sig_mcube < 0.5)] = 0 - sig_cube[np.where(sig_cube < 0)] = 0 - else: - sig_cube[np.where(sig_cube < algo_params.thr)] = 0 - - if algo_params.strategy == 'ARDI': - ref_cube = np.concatenate((algo_params.cube-sig_cube, - algo_params.cube_ref), axis=0) - cube_ref_tmp = prepare_matrix(ref_cube, - scaling=algo_params.scaling, - mask_center_px=mask_center_px, - mode='fullfr', verbose=False) - cube_ref_tmp = np.reshape(cube_ref_tmp, ref_cube.shape) - - # Run PCA on original cube - # Update PCA PARAMS - pca_params['cube'] = algo_params.cube - pca_params['cube_ref'] = ref_cube - pca_params['ncomp'] = final_ncomp[it] - pca_params['scaling'] = algo_params.scaling - pca_params['cube_sig'] = sig_cube - pca_params['mask_rdi'] = mask_rdi_tmp - - res = pca(**pca_params, **rot_options) - - frame = res[0] - residuals_cube = res[-2] - it_cube[it] = frame.copy() - - # DON'T otherwise frame is smoothed twice! - # smoothing and manual derotation if requested - if smooth_ker[it] is not None: - residuals_cube = _blurring_3d(residuals_cube, None, - fwhm_sz=smooth_ker[it]) - residuals_cube_ = cube_derotate(residuals_cube, - algo_params.angle_list, - imlib=algo_params.imlib, - nproc=algo_params.nproc) - frame = cube_collapse(residuals_cube_, algo_params.collapse) - - # Run PCA on disk-empty cube - # Update PCA PARAMS - pca_params['cube'] = cube_tmp-sig_cube - pca_params['cube_ref'] = cube_ref_tmp - pca_params['cube_sig'] = None - pca_params['scaling'] = None - - res_nd = pca(**pca_params, **rot_options) - - residuals_cube_nd = res_nd[-2] - frame_nd = res_nd[0] - - if algo_params.thr_mode == 'STIM': - sig_mask, nstim = _find_significant_signals(residuals_cube_nd, - residuals_cube_, - algo_params.angle_list, - algo_params.thr, - mask=mask_center_px, - r_out=algo_params.r_out) - else: - sig_mask = np.ones_like(frame) - sig_mask[np.where(frame < algo_params.thr)] = 0 - nstim = sig_mask.copy() - inv_sig_mask = np.ones_like(sig_mask) - inv_sig_mask[np.where(sig_mask)] = 0 - if mask_center_px: - inv_sig_mask = mask_circle(inv_sig_mask, mask_center_px, - fillwith=1) - sig_image = frame.copy() - sig_image[np.where(inv_sig_mask)] = 0 - sig_image[np.where(sig_image < 0)] = 0 - - # whether skipped or not: - it_cube[it] = frame.copy() - it_cube_nd[it] = frame_nd.copy() - sig_images[it] = sig_image.copy() - stim_cube[it] = nstim.copy() - - # check if improvement compared to last iteration - if it > 1: - cond1 = np.allclose(sig_image, sig_images[it-1], - rtol=algo_params.rtol, atol=algo_params.atol) - cond2 = np.allclose(sig_image, sig_images[it-2], - rtol=algo_params.rtol, atol=algo_params.atol) - if cond1 or cond2: - # if convergence in incremental mode: skip iterations until the - # next increment in ncomp - cond_mode = algo_params.mode in ['Pairet21', 'Christiaens21'] - cond_it = (it % nit_ori != nit_ori-1) - if cond_mode and cond_it: - cond_skip = True + imlib=algo_params.imlib, + nproc=algo_params.nproc) + frame = cube_collapse(residuals_cube_, algo_params.collapse) + + # Run PCA on disk-empty cube + # Update PCA PARAMS + pca_params['cube'] = cube_tmp-sig_cube + pca_params['cube_ref'] = cube_ref_tmp + pca_params['cube_sig'] = None + pca_params['scaling'] = None + + res_nd = pca(**pca_params, **rot_options) + + residuals_cube_nd = res_nd[-2] + frame_nd = res_nd[0] + + if algo_params.thr_mode == 'STIM': + sig_mask, nstim = _find_significant_signals(residuals_cube_nd, + residuals_cube_, + algo_params.angle_list, + algo_params.thr, + mask=mask_center_px, + r_out=algo_params.r_out) else: - cond_skip = False - if algo_params.strategy in ['ADI', 'RDI', 'ARDI']: - msg = "Convergence criterion met after {} iterations" - condB = algo_params.continue_without_smooth_after_conv - if smooth_ker[it] is not None and condB: - smooth_ker[it+1:] = None - msg2 = "...Smoothing turned off and iterating more" - if algo_params.verbose: - print(msg.format(it)+msg2) - else: - if algo_params.verbose: - print("Final " + msg.format(it)) - break - if algo_params.strategy == 'RADI': - # continue to iterate with ADI - ncomp_tmp = ncomp_list[1] - algo_params.strategy = 'ADI' - ref_cube = None - if algo_params.verbose: - msg = "After {:.0f} iterations, PCA-RDI -> PCA-ADI." - print(msg.format(it)) + sig_mask = np.ones_like(frame) + sig_mask[np.where(frame < algo_params.thr)] = 0 + nstim = sig_mask.copy() + inv_sig_mask = np.ones_like(sig_mask) + inv_sig_mask[np.where(sig_mask)] = 0 + if mask_center_px: + inv_sig_mask = mask_circle(inv_sig_mask, mask_center_px, + fillwith=1) + sig_image = frame.copy() + sig_image[np.where(inv_sig_mask)] = 0 + sig_image[np.where(sig_image < 0)] = 0 + + # whether skipped or not: + it_cube[it] = frame.copy() + it_cube_nd[it] = frame_nd.copy() + sig_images[it] = sig_image.copy() + stim_cube[it] = nstim.copy() + + # check if improvement compared to last iteration + if it > 1: + cond1 = np.allclose(sig_image, sig_images[it-1], + rtol=algo_params.rtol, atol=algo_params.atol) + cond2 = np.allclose(sig_image, sig_images[it-2], + rtol=algo_params.rtol, atol=algo_params.atol) + if cond1 or cond2: + # if convergence in incremental mode: skip iterations until the + # next increment in ncomp + cond_mode = algo_params.mode in ['Pairet21', 'Christiaens21'] + cond_it = (it % nit_ori != nit_ori-1) + if cond_mode and cond_it: + cond_skip = True + else: + cond_skip = False + if algo_params.strategy in ['ADI', 'RDI', 'ARDI']: + msg = "Convergence criterion met after {} iterations" + condB = algo_params.continue_without_smooth_after_conv + if smooth_ker[it] is not None and condB: + smooth_ker[it+1:] = None + msg2 = "...Smoothing turned off and iterating more" + if algo_params.verbose: + print(msg.format(it)+msg2) + else: + if algo_params.verbose: + print("Final " + msg.format(it)) + break + if algo_params.strategy == 'RADI': + # continue to iterate with ADI + ncomp_tmp = ncomp_list[1] + algo_params.strategy = 'ADI' + ref_cube = None + if algo_params.verbose: + msg = "After {:.0f} iterations, PCA-RDI -> PCA-ADI." + print(msg.format(it)) # mask everything last if mask_center_px is not None: diff --git a/vip_hci/preproc/derotation.py b/vip_hci/preproc/derotation.py index ab605e7e..805445ca 100644 --- a/vip_hci/preproc/derotation.py +++ b/vip_hci/preproc/derotation.py @@ -39,6 +39,13 @@ warnings.warn(msg, ImportWarning) no_opencv = True +try: + import torch as torch + no_torch = False +except ImportError: + msg = "Pytorch python bindings are missing" + warnings.warn(msg, ImportWarning) + no_torch = True def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', cxy=None, border_mode='constant', mask_val=np.nan, @@ -51,7 +58,7 @@ def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', Input image, 2d array. angle : float Rotation angle. - imlib : {'opencv', 'skimage', 'vip-fft'}, str optional + imlib : {'opencv', 'skimage', 'vip-fft', 'torch-fft'}, str optional Library used for image transformations. Opencv is faster than skimage or 'vip-fft', but vip-fft slightly better preserves the flux in the image (followed by skimage with a biquintic interpolation). 'vip-fft' @@ -118,7 +125,7 @@ def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', if edge_blend is None: edge_blend = '' - if edge_blend != '' or imlib == 'vip-fft': + if edge_blend != '' or imlib in ['vip-fft', 'torch-fft'] : # fill with nans cy_ori, cx_ori = frame_center(array) y_ori, x_ori = array.shape @@ -139,7 +146,7 @@ def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', stdfunc=np.nanstd) # pad and interpolate, about 1.2x original size - if imlib == 'vip-fft': + if imlib in ['vip-fft', 'torch-fft'] : fac = 1.5 else: fac = 1.1 @@ -219,7 +226,7 @@ def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', cy, cx = frame_center(array_prep) else: cx, cy = cxy - if imlib == 'vip-fft' and (cy, cx) != frame_center(array_prep): + if imlib in ['vip-fft', 'torch-fft'] and (cy, cx) != frame_center(array_prep): msg = "'vip-fft'imlib does not yet allow for custom center to be " msg += " provided " raise ValueError(msg) @@ -300,11 +307,18 @@ def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', M = cv2.getRotationMatrix2D((cx, cy), angle, 1) array_out = cv2.warpAffine(array_prep.astype(np.float32), M, (x, y), flags=intp, borderMode=bormo) + elif imlib == 'torch-fft': + if no_torch: + msg = 'Pytorch bindings cannot be imported. Install torch or' + msg += ' set imlib to skimage' + raise RuntimeError(msg) + + array_out = (tensor_rotate_fft(torch.unsqueeze(torch.from_numpy(array_prep),0), angle)[0]).numpy() else: raise ValueError('Image transformation library not recognized') - if edge_blend != '' or imlib == 'vip-fft': + if edge_blend != '' or imlib in ['vip-fft', 'torch-fft'] : array_out = array_out[y0:y1, x0:x1] # remove padding array_out[mask_ori] = mask_val # mask again original masked values @@ -622,3 +636,93 @@ def _fft_shear(arr, arr_ori, c, ax, pad=0, shift_ini=True): s_x = fftshift(s_x) return s_x + + + +def tensor_rotate_fft(tensor: torch.Tensor, angle: float) -> torch.Tensor: + """ Rotates Tensor using Fourier transform phases: + Rotation = 3 consecutive lin. shears = 3 consecutive FFT phase shifts + See details in Larkin et al. (1997) and Hagelberg et al. (2016). + Note: this is significantly slower than interpolation methods + (e.g. opencv/lanczos4 or ndimage), but preserves the flux better + (by construction it preserves the total power). It is more prone to + large-scale Gibbs artefacts, so make sure no sharp edge is present in + the image to be rotated. + /!\ This is a blindly coded adaptation for Tensor of the vip function rotate_fft + (https://github.com/vortex-exoplanet/VIP/blob/51e1d734dcdbee1fbd0175aa3d0ab62eec83d5fa/vip_hci/preproc/derotation.py#L507) + /!\ This suppose the frame is perfectly centred + ! Warning: if input frame has even dimensions, the center of rotation + will NOT be between the 4 central pixels, instead it will be on the top + right of those 4 pixels. Make sure your images are centered with + respect to that pixel before rotation. + Parameters + ---------- + tensor : torch.Tensor + Input image, 2d array. + angle : float + Rotation angle. + Returns + ------- + array_out : torch.Tensor + Resulting frame. + """ + y_ori, x_ori = tensor.shape[1:] + + while angle < 0: + angle += 360 + while angle > 360: + angle -= 360 + + if angle > 45: + dangle = angle % 90 + if dangle > 45: + dangle = -(90 - dangle) + nangle = int(np.rint(angle / 90)) + tensor_in = torch.rot90(tensor, nangle, [1, 2]) + else: + dangle = angle + tensor_in = tensor.clone() + + if y_ori % 2 or x_ori % 2: + # NO NEED TO SHIFT BY 0.5px: FFT assumes rot. center on cx+0.5, cy+0.5! + tensor_in = tensor_in[:, :-1, :-1] + + a = np.tan(np.deg2rad(dangle) / 2).item() + b = -np.sin(np.deg2rad(dangle)).item() + + y_new, x_new = tensor_in.shape[1:] + arr_xy = torch.from_numpy(np.mgrid[0:y_new, 0:x_new]) + cy, cx = frame_center(tensor[0]) + arr_y = arr_xy[0] - cy + arr_x = arr_xy[1] - cx + + s_x = tensor_fft_shear(tensor_in, arr_x, a, ax=2) + s_xy = tensor_fft_shear(s_x, arr_y, b, ax=1) + s_xyx = tensor_fft_shear(s_xy, arr_x, a, ax=2) + + if y_ori % 2 or x_ori % 2: + # set it back to original dimensions + array_out = torch.zeros([1, s_xyx.shape[1]+1, s_xyx.shape[2]+1]) + array_out[0, :-1, :-1] = torch.real(s_xyx) + else: + array_out = torch.real(s_xyx) + + return array_out + + +def tensor_fft_shear(arr, arr_ori, c, ax): + ax2 = 1 - (ax-1) % 2 + freqs = torch.fft.fftfreq(arr_ori.shape[ax2], dtype=torch.float64) + sh_freqs = torch.fft.fftshift(freqs) + arr_u = torch.tile(sh_freqs, (arr_ori.shape[ax-1], 1)) + if ax == 2: + arr_u = torch.transpose(arr_u, 0, 1) + s_x = torch.fft.fftshift(arr) + s_x = torch.fft.fft(s_x, dim=ax) + s_x = torch.fft.fftshift(s_x) + s_x = torch.exp(-2j * torch.pi * c * arr_u * arr_ori) * s_x + s_x = torch.fft.fftshift(s_x) + s_x = torch.fft.ifft(s_x, dim=ax) + s_x = torch.fft.fftshift(s_x) + + return s_x