From 69b8fee68ba8e27adc8eb4006f3d04569332714e Mon Sep 17 00:00:00 2001 From: Sand-jrd Date: Wed, 6 Nov 2024 08:15:32 +0100 Subject: [PATCH 1/2] Torch FFT Rotation & Juillard23 IPCA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1-Adding Torch for Faster FFT Rotation Mode Fft rotation using vip-fft mode use numpy library. The use of Torch instead of Numpy consistently improves speed. New imlib='torch-fft'. 2-Adding “Juillard23” Mode to IPCA The “Juillard23” mode has been added to the IPCA mode, enabling users to utilize the implementation from Juillard et al. (2023/2024). This method has no additional options for significant signal extraction and works exclusively with Torch, making it faster but also more prone to propagating noise and disk flux. The mode can return a stim map and residual map if full-output is set to true. This option is only compatible with 'ADI' and 'ARDI'; an error is raised if the options do not match. Other parameters, such as imbil and interpolation, are ignored. This option only uses nit, ncomp, pup_mask_center_px, cube_ref, cube, and angle_list. Installation of the GreeDS package (via pip install GreeDS) is required for this option. In particular, the threshold is always set to 0, and STIM is only computed at the end and is not part of the iterative process. --- vip_hci/greedy/ipca_fullfr.py | 596 +++++++++++++++++++--------------- vip_hci/preproc/derotation.py | 114 ++++++- 2 files changed, 451 insertions(+), 259 deletions(-) diff --git a/vip_hci/greedy/ipca_fullfr.py b/vip_hci/greedy/ipca_fullfr.py index e0fe3899..694a3f2e 100644 --- a/vip_hci/greedy/ipca_fullfr.py +++ b/vip_hci/greedy/ipca_fullfr.py @@ -34,6 +34,13 @@ from ..metrics import stim_map, inverse_stim_map from ..var import prepare_matrix, mask_circle, frame_filter_lowpass +try: + from GreeDS import GreeDS + no_greeds = False +except ImportError: + msg = "GreeDS python bindings are missing." + warnings.warn(msg, ImportWarning) + no_greeds = True @dataclass class IPCA_Params(PCA_Params): @@ -91,7 +98,7 @@ def ipca(*all_args: List, **all_kwargs: dict): shortest wavelength in the cube (more thorough approaches can be used to get the scaling factors). This scaling factors are used to re-scale the spectral channels and align the speckles. - mode: str or None, opt {'Pairet18', 'Pairet21'} + mode: str or None, opt {'Pairet18', 'Pairet21','Juillard23'} - If None: runs with provided value of 'n_comp' for 'nit' iterations, and considering threshold 'thr'. - If 'Pairet18': runs for n_comp iterations, with n_comp=1,...,n_comp, @@ -101,6 +108,12 @@ def ipca(*all_args: List, **all_kwargs: dict): n_comp (i.e. outer loop on n_comp, inner loop on nit). thr set to 0 (ignored if provided). 'thr' parameter discarded, always set to 0. - If 'Christiaens21': same as 'Pairet21', but with 'thr' parameter used. + - If 'Juillard23': Exact implementation from Juillard et al. 2023 using + Torch. Parameter conventions are the same as in 'Christiaens21' and + 'Pairet21'. This method has no additional options for significant + signal extraction and works exclusively with Torch, making it faster + but also more prone to propagate noise and disk flux. Installation of + the GreeDS package is required for this option. ncomp : int or tuple/list of int, optional How many PCs are used as a lower-dimensional subspace to project the target frames. @@ -121,7 +134,7 @@ def ipca(*all_args: List, **all_kwargs: dict): total number of iterations - if mode is 'Pairet18': this parameter is ignored. Number of iterations will be ncomp. - - if mode is 'Pairet21' or 'Christiaens21': + - if mode is 'Pairet21', 'Christiaens21' or 'Juillard23': iterations per tested ncomp. strategy: str {'ADI, 'RDI', 'ARDI', 'RADI'}, opt Whether to do iterative ADI only ('ADI'), iterative RDI only ('RDI'), @@ -351,269 +364,344 @@ def _blurring_3d(array, mask_center_sz, fwhm_sz=2): # force full_output pca_params['full_output'] = True pca_params['verbose'] = False # too verbose otherwise + r = algo_params.ncomp + l = algo_params.nit + + if algo_params.mode=="Juillard23": + if no_greeds: + msg = 'GreeDS Python bindings cannot be imported. Install GreeDS (pip install GreeDS) or use a different method.' + raise RuntimeError(msg) + if algo_params.strategy not in ['ADI',"ARDI"]: + msg = 'Juillard23 not compatible with this mode.' + raise RuntimeError(msg) + + if algo_params.strategy == 'ARDI': + ref = algo_params.cube_ref.copy() + else : ref = None + + mask_center_px = algo_params.mask_center_px + pup = mask_center_px if mask_center_px is not None else 0 + + if algo_params.full_output is True : + it_cube, star_estim = GreeDS(algo_params.cube, algo_params.angle_list, refs=ref,r=r, l=l, r_start=1, pup=pup, full_output=1, returnL=True) + else : it_cube = GreeDS(algo_params.cube, algo_params.angle_list, refs=ref,r=r, l=l, r_start=1, pup=pup, full_output=1, returnL=False) + frame = it_cube[-1] + + # Set results matching full outputs + it=len(it_cube)-1 + algo_params.thr = 0 + + stim_cube = it_cube_nd = sig_images = nstim = sig_mask = np.zeros(it_cube.shape) + + if algo_params.full_output is True : + print(algo_params.cube.shape) + print(star_estim.shape) + residuals_cube_ = cube_derotate(algo_params.cube - star_estim[-1], + algo_params.angle_list, + imlib="torch-fft", + nproc=algo_params.nproc) - frame + + residuals_cube = cube_derotate(residuals_cube_, + -algo_params.angle_list, + imlib="torch-fft", + nproc=algo_params.nproc) + + if algo_params.thr_mode == 'STIM': + for it_i in range(len(it_cube)): + residuals_cube__i = cube_derotate(algo_params.cube - star_estim[it_i], + algo_params.angle_list, + imlib="torch-fft", + nproc=algo_params.nproc) - it_cube[it_i] + + residuals_cube_i = cube_derotate(residuals_cube__i, + -algo_params.angle_list, + imlib="torch-fft", + nproc=algo_params.nproc) + + sig_mask_i, nstim_i = _find_significant_signals(residuals_cube_i, + residuals_cube__i, + algo_params.angle_list, + algo_params.thr, + mask=mask_center_px, + r_out=algo_params.r_out) + + sig_mask[it_i] = sig_mask_i.copy() + nstim[it_i] = nstim_i.copy() - # 1. Prepare/format additional parameters depending on chosen options - mask_center_px = algo_params.mask_center_px # None? what's better in pca? - mask_rdi_tmp = None - if algo_params.strategy == 'ADI' and algo_params.cube_ref is None: - ref_cube = None - mask_rdi_tmp = algo_params.mask_rdi - elif algo_params.cube_ref is not None: - if algo_params.strategy == 'ADI': - msg = "WARNING: requested strategy is 'ADI' but reference cube " - msg += "detected! Strategy automatically switched to 'ARDI'." - print(msg) - algo_params.strategy = 'ARDI' - if algo_params.mask_rdi is not None: - if isinstance(algo_params.mask_rdi, (list, tuple)): - mask_rdi_tmp = algo_params.mask_rdi else: - mask_rdi_tmp = algo_params.mask_rdi.copy() - if algo_params.cube_ref is None: - raise ValueError("cube_ref should be provided for RDI or RADI") - if algo_params.strategy == 'ARDI' and algo_params.mask_rdi is None: - ref_cube = np.concatenate((algo_params.cube, - algo_params.cube_ref), axis=0) - else: - ref_cube = algo_params.cube_ref.copy() + sig_mask = np.ones_like(it_cube) + sig_mask[np.where(it_cube < algo_params.thr)] = 0 + nstim = sig_mask.copy() + + sig_images = it_cube.copy() + sig_images[np.where(1-sig_mask)] = 0 + sig_images[np.where(sig_images < 0)] = 0 + stim_cube = nstim.copy() + else: - msg = "strategy not recognized: must be ADI, RDI, ARDI or RADI" - raise ValueError(msg) - - if isinstance(algo_params.ncomp, (float, int)): - ncomp_list = [algo_params.ncomp] - if algo_params.strategy == 'RADI': - ncomp_list.append(algo_params.ncomp) - elif isinstance(algo_params.ncomp, (tuple, list)): - ncomp_list = algo_params.ncomp - if len(algo_params.ncomp) == 1: + + # 1. Prepare/format additional parameters depending on chosen options + mask_center_px = algo_params.mask_center_px # None? what's better in pca? + mask_rdi_tmp = None + if algo_params.strategy == 'ADI' and algo_params.cube_ref is None: + ref_cube = None + mask_rdi_tmp = algo_params.mask_rdi + elif algo_params.cube_ref is not None: + if algo_params.strategy == 'ADI': + msg = "WARNING: requested strategy is 'ADI' but reference cube " + msg += "detected! Strategy automatically switched to 'ARDI'." + print(msg) + algo_params.strategy = 'ARDI' + if algo_params.mask_rdi is not None: + if isinstance(algo_params.mask_rdi, (list, tuple)): + mask_rdi_tmp = algo_params.mask_rdi + else: + mask_rdi_tmp = algo_params.mask_rdi.copy() + if algo_params.cube_ref is None: + raise ValueError("cube_ref should be provided for RDI or RADI") + if algo_params.strategy == 'ARDI' and algo_params.mask_rdi is None: + ref_cube = np.concatenate((algo_params.cube, + algo_params.cube_ref), axis=0) + else: + ref_cube = algo_params.cube_ref.copy() + else: + msg = "strategy not recognized: must be ADI, RDI, ARDI or RADI" + raise ValueError(msg) + + if isinstance(algo_params.ncomp, (float, int)): + ncomp_list = [algo_params.ncomp] if algo_params.strategy == 'RADI': ncomp_list.append(algo_params.ncomp) - elif not len(algo_params.ncomp) == 2: - raise ValueError("Length of npc list cannot be larger than 2") - else: - raise TypeError("ncomp should be float, int, tuple or list") + elif isinstance(algo_params.ncomp, (tuple, list)): + ncomp_list = algo_params.ncomp + if len(algo_params.ncomp) == 1: + if algo_params.strategy == 'RADI': + ncomp_list.append(algo_params.ncomp) + elif not len(algo_params.ncomp) == 2: + raise ValueError("Length of npc list cannot be larger than 2") + else: + raise TypeError("ncomp should be float, int, tuple or list") - ncomp_tmp = ncomp_list[0] - nframes = algo_params.cube.shape[0] - nit_ori = algo_params.nit + ncomp_tmp = ncomp_list[0] + nframes = algo_params.cube.shape[0] + nit_ori = algo_params.nit - if algo_params.mode is not None: - final_ncomp = list(range(1, ncomp_tmp+1, algo_params.ncomp_step)) - if algo_params.mode == 'Pairet18': - algo_params.nit = ncomp_tmp + if algo_params.mode is not None: final_ncomp = list(range(1, ncomp_tmp+1, algo_params.ncomp_step)) - algo_params.thr = 0 - elif algo_params.mode in ['Pairet21', 'Christiaens21']: - final_ncomp = [] - for npc in range(1, ncomp_tmp+1, algo_params.ncomp_step): - for ii in range(algo_params.nit): - final_ncomp.append(npc) - algo_params.nit = len(final_ncomp) - if algo_params.mode == 'Pairet21': + if algo_params.mode == 'Pairet18': + algo_params.nit = ncomp_tmp + final_ncomp = list(range(1, ncomp_tmp+1, algo_params.ncomp_step)) algo_params.thr = 0 - else: - final_ncomp = [ncomp_tmp]*algo_params.nit - - # Scale cube and cube_ref if necessary - cube_tmp = prepare_matrix(algo_params.cube, scaling=algo_params.scaling, - mask_center_px=mask_center_px, mode='fullfr', - verbose=False) - cube_tmp = np.reshape(cube_tmp, algo_params.cube.shape) - if ref_cube is not None: - cube_ref_tmp = prepare_matrix(ref_cube, scaling=algo_params.scaling, - mask_center_px=mask_center_px, - mode='fullfr', verbose=False) - cube_ref_tmp = np.reshape(cube_ref_tmp, ref_cube.shape) - else: - cube_ref_tmp = None - - # 2. Get a first disc estimate, using PCA - pca_params['ncomp'] = final_ncomp[0] - pca_params['cube_ref'] = ref_cube - res = pca(**pca_params, **rot_options) - frame = res[0] - residuals_cube = res[-2] - residuals_cube_ = res[-1] - # smoothing and manual derotation if requested - smooth_ker = algo_params.smooth_ker - if smooth_ker is None or np.isscalar(smooth_ker): - smooth_ker = np.array([smooth_ker]*algo_params.nit, dtype=object) - else: - smooth_ker = np.array(smooth_ker, dtype=object) - # if smooth_ker[0] is not None: - # residuals_cube = _blurring_3d(residuals_cube, None, - # fwhm_sz=smooth_ker[0]) - # residuals_cube_ = cube_derotate(residuals_cube, - # algo_params.angle_list, - # imlib=algo_params.imlib, - # nproc=algo_params.nproc) - # frame = cube_collapse(residuals_cube_, algo_params.collapse) - if smooth_ker[0] is not None: - frame = frame_filter_lowpass(frame, fwhm_size=smooth_ker[0]) - - # 3. Identify significant signals with STIM map - it_cube = np.zeros([algo_params.nit, frame.shape[0], frame.shape[1]]) - it_cube_nd = np.zeros_like(it_cube) - stim_cube = np.zeros_like(it_cube) - sig_images = np.zeros_like(it_cube) - it_cube[0] = frame.copy() - it_cube_nd[0] = frame.copy() - if algo_params.thr_mode == 'STIM': - sig_mask, nstim = _find_significant_signals(residuals_cube, - residuals_cube_, + elif algo_params.mode in ['Pairet21', 'Christiaens21']: + final_ncomp = [] + for npc in range(1, ncomp_tmp+1, algo_params.ncomp_step): + for ii in range(algo_params.nit): + final_ncomp.append(npc) + algo_params.nit = len(final_ncomp) + if algo_params.mode == 'Pairet21': + algo_params.thr = 0 + else: + final_ncomp = [ncomp_tmp]*algo_params.nit + + # Scale cube and cube_ref if necessary + cube_tmp = prepare_matrix(algo_params.cube, scaling=algo_params.scaling, + mask_center_px=mask_center_px, mode='fullfr', + verbose=False) + cube_tmp = np.reshape(cube_tmp, algo_params.cube.shape) + if ref_cube is not None: + cube_ref_tmp = prepare_matrix(ref_cube, scaling=algo_params.scaling, + mask_center_px=mask_center_px, + mode='fullfr', verbose=False) + cube_ref_tmp = np.reshape(cube_ref_tmp, ref_cube.shape) + else: + cube_ref_tmp = None + + # 2. Get a first disc estimate, using PCA + pca_params['ncomp'] = final_ncomp[0] + pca_params['cube_ref'] = ref_cube + res = pca(**pca_params, **rot_options) + frame = res[0] + residuals_cube = res[-2] + residuals_cube_ = res[-1] + # smoothing and manual derotation if requested + smooth_ker = algo_params.smooth_ker + if smooth_ker is None or np.isscalar(smooth_ker): + smooth_ker = np.array([smooth_ker]*algo_params.nit, dtype=object) + else: + smooth_ker = np.array(smooth_ker, dtype=object) + # if smooth_ker[0] is not None: + # residuals_cube = _blurring_3d(residuals_cube, None, + # fwhm_sz=smooth_ker[0]) + # residuals_cube_ = cube_derotate(residuals_cube, + # algo_params.angle_list, + # imlib=algo_params.imlib, + # nproc=algo_params.nproc) + # frame = cube_collapse(residuals_cube_, algo_params.collapse) + if smooth_ker[0] is not None: + frame = frame_filter_lowpass(frame, fwhm_size=smooth_ker[0]) + + # 3. Identify significant signals with STIM map + it_cube = np.zeros([algo_params.nit, frame.shape[0], frame.shape[1]]) + it_cube_nd = np.zeros_like(it_cube) + stim_cube = np.zeros_like(it_cube) + sig_images = np.zeros_like(it_cube) + it_cube[0] = frame.copy() + it_cube_nd[0] = frame.copy() + if algo_params.thr_mode == 'STIM': + sig_mask, nstim = _find_significant_signals(residuals_cube, + residuals_cube_, + algo_params.angle_list, + algo_params.thr, + mask=mask_center_px, + r_out=algo_params.r_out) + else: + sig_mask = np.ones_like(frame) + sig_mask[np.where(frame < algo_params.thr)] = 0 + nstim = sig_mask.copy() + sig_image = frame.copy() + sig_image[np.where(1-sig_mask)] = 0 + sig_image[np.where(sig_image < 0)] = 0 + sig_images[0] = sig_image.copy() + stim_cube[0] = nstim.copy() + mask_rdi_tmp = None # after first iteration do not use it any more + + # 4.Loop, updating the reference cube before projection by subtracting the + # best disc estimate. This is done by providing sig_cube. + cond_skip = False # whether to skip an iteration (e.g. in incremental mode) + for it in Progressbar(range(1, algo_params.nit), desc="Iterating..."): + if not cond_skip: + # Uncomment here (and comment below) to do like IROLL + # if smooth_ker[it] is not None: + # frame = _blurring_2d(frame, None, fwhm_sz=smooth_ker[it]) + # create and rotate sig cube + sig_cube = np.repeat(frame[np.newaxis, :, :], nframes, axis=0) + sig_cube = cube_derotate(sig_cube, -algo_params.angle_list, + imlib=algo_params.imlib, + nproc=algo_params.nproc) + + if algo_params.thr_mode == 'STIM': + # create and rotate binary mask + mask_sig = np.zeros_like(sig_image) + mask_sig[np.where(sig_image > 0)] = 1 + sig_mcube = np.repeat(mask_sig[np.newaxis, :, :], nframes, + axis=0) + sig_mcube = cube_derotate(sig_mcube, -algo_params.angle_list, + imlib='skimage', + interpolation='bilinear', + nproc=algo_params.nproc) + sig_cube[np.where(sig_mcube < 0.5)] = 0 + sig_cube[np.where(sig_cube < 0)] = 0 + else: + sig_cube[np.where(sig_cube < algo_params.thr)] = 0 + + if algo_params.strategy == 'ARDI': + ref_cube = np.concatenate((algo_params.cube-sig_cube, + algo_params.cube_ref), axis=0) + cube_ref_tmp = prepare_matrix(ref_cube, + scaling=algo_params.scaling, + mask_center_px=mask_center_px, + mode='fullfr', verbose=False) + cube_ref_tmp = np.reshape(cube_ref_tmp, ref_cube.shape) + + # Run PCA on original cube + # Update PCA PARAMS + pca_params['cube'] = algo_params.cube + pca_params['cube_ref'] = ref_cube + pca_params['ncomp'] = final_ncomp[it] + pca_params['scaling'] = algo_params.scaling + pca_params['cube_sig'] = sig_cube + pca_params['mask_rdi'] = mask_rdi_tmp + + res = pca(**pca_params, **rot_options) + + frame = res[0] + residuals_cube = res[-2] + it_cube[it] = frame.copy() + + # DON'T otherwise frame is smoothed twice! + # smoothing and manual derotation if requested + if smooth_ker[it] is not None: + residuals_cube = _blurring_3d(residuals_cube, None, + fwhm_sz=smooth_ker[it]) + residuals_cube_ = cube_derotate(residuals_cube, algo_params.angle_list, - algo_params.thr, - mask=mask_center_px, - r_out=algo_params.r_out) - else: - sig_mask = np.ones_like(frame) - sig_mask[np.where(frame < algo_params.thr)] = 0 - nstim = sig_mask.copy() - sig_image = frame.copy() - sig_image[np.where(1-sig_mask)] = 0 - sig_image[np.where(sig_image < 0)] = 0 - sig_images[0] = sig_image.copy() - stim_cube[0] = nstim.copy() - mask_rdi_tmp = None # after first iteration do not use it any more - - # 4.Loop, updating the reference cube before projection by subtracting the - # best disc estimate. This is done by providing sig_cube. - cond_skip = False # whether to skip an iteration (e.g. in incremental mode) - for it in Progressbar(range(1, algo_params.nit), desc="Iterating..."): - if not cond_skip: - # Uncomment here (and comment below) to do like IROLL - # if smooth_ker[it] is not None: - # frame = _blurring_2d(frame, None, fwhm_sz=smooth_ker[it]) - # create and rotate sig cube - sig_cube = np.repeat(frame[np.newaxis, :, :], nframes, axis=0) - sig_cube = cube_derotate(sig_cube, -algo_params.angle_list, - imlib=algo_params.imlib, - nproc=algo_params.nproc) - - if algo_params.thr_mode == 'STIM': - # create and rotate binary mask - mask_sig = np.zeros_like(sig_image) - mask_sig[np.where(sig_image > 0)] = 1 - sig_mcube = np.repeat(mask_sig[np.newaxis, :, :], nframes, - axis=0) - sig_mcube = cube_derotate(sig_mcube, -algo_params.angle_list, - imlib='skimage', - interpolation='bilinear', - nproc=algo_params.nproc) - sig_cube[np.where(sig_mcube < 0.5)] = 0 - sig_cube[np.where(sig_cube < 0)] = 0 - else: - sig_cube[np.where(sig_cube < algo_params.thr)] = 0 - - if algo_params.strategy == 'ARDI': - ref_cube = np.concatenate((algo_params.cube-sig_cube, - algo_params.cube_ref), axis=0) - cube_ref_tmp = prepare_matrix(ref_cube, - scaling=algo_params.scaling, - mask_center_px=mask_center_px, - mode='fullfr', verbose=False) - cube_ref_tmp = np.reshape(cube_ref_tmp, ref_cube.shape) - - # Run PCA on original cube - # Update PCA PARAMS - pca_params['cube'] = algo_params.cube - pca_params['cube_ref'] = ref_cube - pca_params['ncomp'] = final_ncomp[it] - pca_params['scaling'] = algo_params.scaling - pca_params['cube_sig'] = sig_cube - pca_params['mask_rdi'] = mask_rdi_tmp - - res = pca(**pca_params, **rot_options) - - frame = res[0] - residuals_cube = res[-2] - it_cube[it] = frame.copy() - - # DON'T otherwise frame is smoothed twice! - # smoothing and manual derotation if requested - if smooth_ker[it] is not None: - residuals_cube = _blurring_3d(residuals_cube, None, - fwhm_sz=smooth_ker[it]) - residuals_cube_ = cube_derotate(residuals_cube, - algo_params.angle_list, - imlib=algo_params.imlib, - nproc=algo_params.nproc) - frame = cube_collapse(residuals_cube_, algo_params.collapse) - - # Run PCA on disk-empty cube - # Update PCA PARAMS - pca_params['cube'] = cube_tmp-sig_cube - pca_params['cube_ref'] = cube_ref_tmp - pca_params['cube_sig'] = None - pca_params['scaling'] = None - - res_nd = pca(**pca_params, **rot_options) - - residuals_cube_nd = res_nd[-2] - frame_nd = res_nd[0] - - if algo_params.thr_mode == 'STIM': - sig_mask, nstim = _find_significant_signals(residuals_cube_nd, - residuals_cube_, - algo_params.angle_list, - algo_params.thr, - mask=mask_center_px, - r_out=algo_params.r_out) - else: - sig_mask = np.ones_like(frame) - sig_mask[np.where(frame < algo_params.thr)] = 0 - nstim = sig_mask.copy() - inv_sig_mask = np.ones_like(sig_mask) - inv_sig_mask[np.where(sig_mask)] = 0 - if mask_center_px: - inv_sig_mask = mask_circle(inv_sig_mask, mask_center_px, - fillwith=1) - sig_image = frame.copy() - sig_image[np.where(inv_sig_mask)] = 0 - sig_image[np.where(sig_image < 0)] = 0 - - # whether skipped or not: - it_cube[it] = frame.copy() - it_cube_nd[it] = frame_nd.copy() - sig_images[it] = sig_image.copy() - stim_cube[it] = nstim.copy() - - # check if improvement compared to last iteration - if it > 1: - cond1 = np.allclose(sig_image, sig_images[it-1], - rtol=algo_params.rtol, atol=algo_params.atol) - cond2 = np.allclose(sig_image, sig_images[it-2], - rtol=algo_params.rtol, atol=algo_params.atol) - if cond1 or cond2: - # if convergence in incremental mode: skip iterations until the - # next increment in ncomp - cond_mode = algo_params.mode in ['Pairet21', 'Christiaens21'] - cond_it = (it % nit_ori != nit_ori-1) - if cond_mode and cond_it: - cond_skip = True + imlib=algo_params.imlib, + nproc=algo_params.nproc) + frame = cube_collapse(residuals_cube_, algo_params.collapse) + + # Run PCA on disk-empty cube + # Update PCA PARAMS + pca_params['cube'] = cube_tmp-sig_cube + pca_params['cube_ref'] = cube_ref_tmp + pca_params['cube_sig'] = None + pca_params['scaling'] = None + + res_nd = pca(**pca_params, **rot_options) + + residuals_cube_nd = res_nd[-2] + frame_nd = res_nd[0] + + if algo_params.thr_mode == 'STIM': + sig_mask, nstim = _find_significant_signals(residuals_cube_nd, + residuals_cube_, + algo_params.angle_list, + algo_params.thr, + mask=mask_center_px, + r_out=algo_params.r_out) else: - cond_skip = False - if algo_params.strategy in ['ADI', 'RDI', 'ARDI']: - msg = "Convergence criterion met after {} iterations" - condB = algo_params.continue_without_smooth_after_conv - if smooth_ker[it] is not None and condB: - smooth_ker[it+1:] = None - msg2 = "...Smoothing turned off and iterating more" - if algo_params.verbose: - print(msg.format(it)+msg2) - else: - if algo_params.verbose: - print("Final " + msg.format(it)) - break - if algo_params.strategy == 'RADI': - # continue to iterate with ADI - ncomp_tmp = ncomp_list[1] - algo_params.strategy = 'ADI' - ref_cube = None - if algo_params.verbose: - msg = "After {:.0f} iterations, PCA-RDI -> PCA-ADI." - print(msg.format(it)) + sig_mask = np.ones_like(frame) + sig_mask[np.where(frame < algo_params.thr)] = 0 + nstim = sig_mask.copy() + inv_sig_mask = np.ones_like(sig_mask) + inv_sig_mask[np.where(sig_mask)] = 0 + if mask_center_px: + inv_sig_mask = mask_circle(inv_sig_mask, mask_center_px, + fillwith=1) + sig_image = frame.copy() + sig_image[np.where(inv_sig_mask)] = 0 + sig_image[np.where(sig_image < 0)] = 0 + + # whether skipped or not: + it_cube[it] = frame.copy() + it_cube_nd[it] = frame_nd.copy() + sig_images[it] = sig_image.copy() + stim_cube[it] = nstim.copy() + + # check if improvement compared to last iteration + if it > 1: + cond1 = np.allclose(sig_image, sig_images[it-1], + rtol=algo_params.rtol, atol=algo_params.atol) + cond2 = np.allclose(sig_image, sig_images[it-2], + rtol=algo_params.rtol, atol=algo_params.atol) + if cond1 or cond2: + # if convergence in incremental mode: skip iterations until the + # next increment in ncomp + cond_mode = algo_params.mode in ['Pairet21', 'Christiaens21'] + cond_it = (it % nit_ori != nit_ori-1) + if cond_mode and cond_it: + cond_skip = True + else: + cond_skip = False + if algo_params.strategy in ['ADI', 'RDI', 'ARDI']: + msg = "Convergence criterion met after {} iterations" + condB = algo_params.continue_without_smooth_after_conv + if smooth_ker[it] is not None and condB: + smooth_ker[it+1:] = None + msg2 = "...Smoothing turned off and iterating more" + if algo_params.verbose: + print(msg.format(it)+msg2) + else: + if algo_params.verbose: + print("Final " + msg.format(it)) + break + if algo_params.strategy == 'RADI': + # continue to iterate with ADI + ncomp_tmp = ncomp_list[1] + algo_params.strategy = 'ADI' + ref_cube = None + if algo_params.verbose: + msg = "After {:.0f} iterations, PCA-RDI -> PCA-ADI." + print(msg.format(it)) # mask everything last if mask_center_px is not None: diff --git a/vip_hci/preproc/derotation.py b/vip_hci/preproc/derotation.py index ab605e7e..805445ca 100644 --- a/vip_hci/preproc/derotation.py +++ b/vip_hci/preproc/derotation.py @@ -39,6 +39,13 @@ warnings.warn(msg, ImportWarning) no_opencv = True +try: + import torch as torch + no_torch = False +except ImportError: + msg = "Pytorch python bindings are missing" + warnings.warn(msg, ImportWarning) + no_torch = True def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', cxy=None, border_mode='constant', mask_val=np.nan, @@ -51,7 +58,7 @@ def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', Input image, 2d array. angle : float Rotation angle. - imlib : {'opencv', 'skimage', 'vip-fft'}, str optional + imlib : {'opencv', 'skimage', 'vip-fft', 'torch-fft'}, str optional Library used for image transformations. Opencv is faster than skimage or 'vip-fft', but vip-fft slightly better preserves the flux in the image (followed by skimage with a biquintic interpolation). 'vip-fft' @@ -118,7 +125,7 @@ def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', if edge_blend is None: edge_blend = '' - if edge_blend != '' or imlib == 'vip-fft': + if edge_blend != '' or imlib in ['vip-fft', 'torch-fft'] : # fill with nans cy_ori, cx_ori = frame_center(array) y_ori, x_ori = array.shape @@ -139,7 +146,7 @@ def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', stdfunc=np.nanstd) # pad and interpolate, about 1.2x original size - if imlib == 'vip-fft': + if imlib in ['vip-fft', 'torch-fft'] : fac = 1.5 else: fac = 1.1 @@ -219,7 +226,7 @@ def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', cy, cx = frame_center(array_prep) else: cx, cy = cxy - if imlib == 'vip-fft' and (cy, cx) != frame_center(array_prep): + if imlib in ['vip-fft', 'torch-fft'] and (cy, cx) != frame_center(array_prep): msg = "'vip-fft'imlib does not yet allow for custom center to be " msg += " provided " raise ValueError(msg) @@ -300,11 +307,18 @@ def frame_rotate(array, angle, imlib='vip-fft', interpolation='lanczos4', M = cv2.getRotationMatrix2D((cx, cy), angle, 1) array_out = cv2.warpAffine(array_prep.astype(np.float32), M, (x, y), flags=intp, borderMode=bormo) + elif imlib == 'torch-fft': + if no_torch: + msg = 'Pytorch bindings cannot be imported. Install torch or' + msg += ' set imlib to skimage' + raise RuntimeError(msg) + + array_out = (tensor_rotate_fft(torch.unsqueeze(torch.from_numpy(array_prep),0), angle)[0]).numpy() else: raise ValueError('Image transformation library not recognized') - if edge_blend != '' or imlib == 'vip-fft': + if edge_blend != '' or imlib in ['vip-fft', 'torch-fft'] : array_out = array_out[y0:y1, x0:x1] # remove padding array_out[mask_ori] = mask_val # mask again original masked values @@ -622,3 +636,93 @@ def _fft_shear(arr, arr_ori, c, ax, pad=0, shift_ini=True): s_x = fftshift(s_x) return s_x + + + +def tensor_rotate_fft(tensor: torch.Tensor, angle: float) -> torch.Tensor: + """ Rotates Tensor using Fourier transform phases: + Rotation = 3 consecutive lin. shears = 3 consecutive FFT phase shifts + See details in Larkin et al. (1997) and Hagelberg et al. (2016). + Note: this is significantly slower than interpolation methods + (e.g. opencv/lanczos4 or ndimage), but preserves the flux better + (by construction it preserves the total power). It is more prone to + large-scale Gibbs artefacts, so make sure no sharp edge is present in + the image to be rotated. + /!\ This is a blindly coded adaptation for Tensor of the vip function rotate_fft + (https://github.com/vortex-exoplanet/VIP/blob/51e1d734dcdbee1fbd0175aa3d0ab62eec83d5fa/vip_hci/preproc/derotation.py#L507) + /!\ This suppose the frame is perfectly centred + ! Warning: if input frame has even dimensions, the center of rotation + will NOT be between the 4 central pixels, instead it will be on the top + right of those 4 pixels. Make sure your images are centered with + respect to that pixel before rotation. + Parameters + ---------- + tensor : torch.Tensor + Input image, 2d array. + angle : float + Rotation angle. + Returns + ------- + array_out : torch.Tensor + Resulting frame. + """ + y_ori, x_ori = tensor.shape[1:] + + while angle < 0: + angle += 360 + while angle > 360: + angle -= 360 + + if angle > 45: + dangle = angle % 90 + if dangle > 45: + dangle = -(90 - dangle) + nangle = int(np.rint(angle / 90)) + tensor_in = torch.rot90(tensor, nangle, [1, 2]) + else: + dangle = angle + tensor_in = tensor.clone() + + if y_ori % 2 or x_ori % 2: + # NO NEED TO SHIFT BY 0.5px: FFT assumes rot. center on cx+0.5, cy+0.5! + tensor_in = tensor_in[:, :-1, :-1] + + a = np.tan(np.deg2rad(dangle) / 2).item() + b = -np.sin(np.deg2rad(dangle)).item() + + y_new, x_new = tensor_in.shape[1:] + arr_xy = torch.from_numpy(np.mgrid[0:y_new, 0:x_new]) + cy, cx = frame_center(tensor[0]) + arr_y = arr_xy[0] - cy + arr_x = arr_xy[1] - cx + + s_x = tensor_fft_shear(tensor_in, arr_x, a, ax=2) + s_xy = tensor_fft_shear(s_x, arr_y, b, ax=1) + s_xyx = tensor_fft_shear(s_xy, arr_x, a, ax=2) + + if y_ori % 2 or x_ori % 2: + # set it back to original dimensions + array_out = torch.zeros([1, s_xyx.shape[1]+1, s_xyx.shape[2]+1]) + array_out[0, :-1, :-1] = torch.real(s_xyx) + else: + array_out = torch.real(s_xyx) + + return array_out + + +def tensor_fft_shear(arr, arr_ori, c, ax): + ax2 = 1 - (ax-1) % 2 + freqs = torch.fft.fftfreq(arr_ori.shape[ax2], dtype=torch.float64) + sh_freqs = torch.fft.fftshift(freqs) + arr_u = torch.tile(sh_freqs, (arr_ori.shape[ax-1], 1)) + if ax == 2: + arr_u = torch.transpose(arr_u, 0, 1) + s_x = torch.fft.fftshift(arr) + s_x = torch.fft.fft(s_x, dim=ax) + s_x = torch.fft.fftshift(s_x) + s_x = torch.exp(-2j * torch.pi * c * arr_u * arr_ori) * s_x + s_x = torch.fft.fftshift(s_x) + s_x = torch.fft.ifft(s_x, dim=ax) + s_x = torch.fft.fftshift(s_x) + + return s_x From 0b726d5d216568fe71744e4abd910f2777429ad6 Mon Sep 17 00:00:00 2001 From: Sand-jrd Date: Wed, 6 Nov 2024 08:39:56 +0100 Subject: [PATCH 2/2] Remove type definitions in function to fix check failures --- vip_hci/preproc/derotation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vip_hci/preproc/derotation.py b/vip_hci/preproc/derotation.py index 805445ca..8b5c9607 100644 --- a/vip_hci/preproc/derotation.py +++ b/vip_hci/preproc/derotation.py @@ -639,7 +639,7 @@ def _fft_shear(arr, arr_ori, c, ax, pad=0, shift_ini=True): -def tensor_rotate_fft(tensor: torch.Tensor, angle: float) -> torch.Tensor: +def tensor_rotate_fft(tensor, angle): """ Rotates Tensor using Fourier transform phases: Rotation = 3 consecutive lin. shears = 3 consecutive FFT phase shifts See details in Larkin et al. (1997) and Hagelberg et al. (2016).