Skip to content

Commit

Permalink
Merge pull request #643 from VChristiaens/master
Browse files Browse the repository at this point in the history
Multiprocessing support + bug fix for IPCA-ARDI with DI initialization
  • Loading branch information
VChristiaens authored Aug 1, 2024
2 parents f336ce0 + a2ce534 commit 4c16ea8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
16 changes: 15 additions & 1 deletion vip_hci/config/utils_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from inspect import signature, Parameter
from functools import wraps
import multiprocessing
import warnings
from vip_hci import __version__

sep = "―" * 80
Expand Down Expand Up @@ -471,7 +472,20 @@ def pool_map(nproc, fkt, *args, **kwargs):
if not _generator:
res = list(res)
else:
multiprocessing.set_start_method("fork", force=True)
# Check available start methods and pick accordingly (machine-dependent)
avail_methods = multiprocessing.get_all_start_methods()
if 'fork' in avail_methods:
# faster when available
warnings.filterwarnings("error") # allows to catch warning as error
try:
multiprocessing.set_start_method("fork", force=True)
except (DeprecationWarning, OSError):
multiprocessing.set_start_method("spawn", force=True)
elif 'forkserver' in avail_methods:
multiprocessing.set_start_method("forkserver", force=True)
else:
multiprocessing.set_start_method("spawn", force=True)
warnings.resetwarnings() # reset warning behaviour to default
from multiprocessing import Pool

# deactivate multithreading
Expand Down
4 changes: 2 additions & 2 deletions vip_hci/greedy/ipca_fullfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _blurring_3d(array, mask_center_sz, fwhm_sz=2):
class_params, rot_options = separate_kwargs_dict(
initial_kwargs=all_kwargs, parent_class=IPCA_Params
)
# Do the same to separate IROLL and ROLL params
# Do the same to separate IPCA and PCA params
pca_params, ipca_params = separate_kwargs_dict(
initial_kwargs=class_params, parent_class=PCA_Params
)
Expand Down Expand Up @@ -371,7 +371,7 @@ def _blurring_3d(array, mask_center_sz, fwhm_sz=2):
mask_rdi_tmp = algo_params.mask_rdi.copy()
if algo_params.cube_ref is None:
raise ValueError("cube_ref should be provided for RDI or RADI")
if algo_params.strategy == 'ARDI':
if algo_params.strategy == 'ARDI' and algo_params.mask_rdi is None:
ref_cube = np.concatenate((algo_params.cube,
algo_params.cube_ref), axis=0)
else:
Expand Down
4 changes: 2 additions & 2 deletions vip_hci/preproc/recentering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2098,7 +2098,7 @@ def cube_recenter_via_speckles(cube_sci, cube_ref=None, alignment_iter=5,
cum_x_shifts_sci = cum_x_shifts[1:(n + 1)]
cube_reg_sci = cube_shift(cube_sci, cum_y_shifts_sci, cum_x_shifts_sci,
imlib=imlib, interpolation=interpolation,
border_mode=border_mode)
border_mode=border_mode, nproc=nproc)

if plot:
plt.figure(figsize=vip_figsize)
Expand All @@ -2124,7 +2124,7 @@ def cube_recenter_via_speckles(cube_sci, cube_ref=None, alignment_iter=5,
cum_x_shifts_ref = cum_x_shifts[(n + 1):]
cube_reg_ref = cube_shift(cube_ref, cum_y_shifts_ref, cum_x_shifts_ref,
imlib=imlib, interpolation=interpolation,
border_mode=border_mode)
border_mode=border_mode, nproc=nproc)

if ref_star:
if full_output:
Expand Down

0 comments on commit 4c16ea8

Please sign in to comment.