Skip to content

Commit

Permalink
Merge pull request #573 from ADACS-Australia/master
Browse files Browse the repository at this point in the history
Incorporating multiprocessing into time consuming functions.
  • Loading branch information
VChristiaens authored Jan 31, 2023
2 parents b5a56ed + 5da8fdc commit e219768
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 44 deletions.
2 changes: 1 addition & 1 deletion vip_hci/metrics/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,4 +561,4 @@ def mask_sources(mask, ap_rad):
rad_arr = dist_matrix(ny, cx=s_coords[1][s], cy=s_coords[0][s])
mask_out[np.where(rad_arr < ap_rad)] = 0

return mask_out
return mask_out
2 changes: 1 addition & 1 deletion vip_hci/metrics/snr_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,4 +573,4 @@ def frame_report(array, fwhm, source_xy=None, verbose=True):
print(msg3.format(stdsnr_i))
print(sep)

return source_xy, obj_flux, snr_centpx, meansnr_pixels
return source_xy, obj_flux, snr_centpx, meansnr_pixels
2 changes: 1 addition & 1 deletion vip_hci/preproc/badframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,4 @@ def cube_detect_badfr_correlation(array, frame_ref, crop_size=30,
if full_output:
return good_index_list, bad_index_list, distances
else:
return good_index_list, bad_index_list
return good_index_list, bad_index_list
223 changes: 188 additions & 35 deletions vip_hci/preproc/badpixremoval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@
from ..config.utils_conf import pool_map, iterable
from .rescaling import find_scal_vector, frame_rescaling
from .cosmetics import frame_pad
from multiprocessing import Process
import multiprocessing
from multiprocessing import set_start_method
try:
from multiprocessing import shared_memory
except ImportError:
print('Failed to import shared_memory from multiprocessing')
try:
print('Trying to import shared_memory directly(for python 3.7)')
import shared_memory
except ModuleNotFoundError:
print('Use shared_memory on python 3.7 to activate')
print('multiprocessing on badpixels using..')
print('pip install shared-memory38')

import warnings
try:
Expand Down Expand Up @@ -160,7 +174,7 @@ def cube_fix_badpix_isolated(array, bpm_mask=None, correct_only=False,
sigma_clip=3, num_neig=5, size=5,
frame_by_frame=False, protect_mask=0, cxy=None,
mad=False, ignore_nan=True, verbose=True,
full_output=False):
full_output=False, nproc=1):
""" Corrects the bad pixels, marked in the bad pixel mask. The bad pixel is
replaced by the median of the adjacent pixels. This function is very fast
but works only with isolated (sparse) pixels.
Expand Down Expand Up @@ -208,7 +222,12 @@ def cube_fix_badpix_isolated(array, bpm_mask=None, correct_only=False,
verbose : bool, optional
If True additional information will be printed.
full_output: bool, {False,True}, optional
Whether to return as well the cube of bad pixel maps.
Whether to return as well the cube of bad pixel maps and the cube of
defined annuli.
nproc: int, optional
This feature is added following ADACS update. Refers to the number of processors
available for calculations. Choosing a number >1 enables multiprocessing for the
correction of frames. This happens only when frame_by_frame=True.
Return
------
Expand Down Expand Up @@ -264,20 +283,73 @@ def cube_fix_badpix_isolated(array, bpm_mask=None, correct_only=False,
if bpm_mask.ndim == 2:
bpm_mask = [bpm_mask]*n_frames
bpm_mask = np.array(bpm_mask)
for i in Progressbar(range(n_frames), desc="processing frames"):
if nproc==1:
for i in Progressbar(range(n_frames), desc="processing frames"):
if bpm_mask is not None:
bpm_mask_tmp = bpm_mask[i]
else:
bpm_mask_tmp = None
res = frame_fix_badpix_isolated(array[i], bpm_mask=bpm_mask_tmp,
sigma_clip=sigma_clip,
num_neig=num_neig, size=size,
protect_mask=protect_mask,
verbose=False, cxy=(cx[i], cy[i]),
ignore_nan=ignore_nan,
full_output=True)
array_out[i] = res[0]
final_bpm[i] = res[1]
else:
print("Processing using ADACS' multiprocessing approach...")
#dummy calling the function to create cached version of the code prior to forking
if bpm_mask is not None:
bpm_mask_tmp = bpm_mask[i]
bpm_mask_dum = bpm_mask[0]
else:
bpm_mask_tmp = None
res = frame_fix_badpix_isolated(array[i], bpm_mask=bpm_mask_tmp,
sigma_clip=sigma_clip,
num_neig=num_neig, size=size,
protect_mask=protect_mask,
verbose=False, cxy=(cx[i], cy[i]),
ignore_nan=ignore_nan,
full_output=True)
array_out[i] = res[0]
final_bpm[i] = res[1]
bpm_mask_dum = None
#point of dummy call
frame_fix_badpix_isolated(array[0], bpm_mask=bpm_mask_dum, sigma_clip=sigma_clip, num_neig=num_neig, size=size, protect_mask=protect_mask, verbose=False, cxy=(cx[0], cy[0]), ignore_nan=ignore_nan, full_output=False)
# multiprocessing included only in the frame-by-frame branch of the if statement above.
# creating shared memory buffer for the cube (array)
shm_array_out = shared_memory.SharedMemory(create=True, size=array.nbytes)
# creating a shared array_out version that is the shm_array_out buffer above.
shared_array_out = np.ndarray(array.shape, dtype=array.dtype, buffer=shm_array_out.buf)
# creating shared memory buffer for the final bad pixel mask cube.
shm_final_bpm = shared_memory.SharedMemory(create=True, size=final_bpm.nbytes)
# creating a shared final_bpm version that is in the shm_final_bpm buffer above.
shared_final_bpm = np.ndarray(final_bpm.shape, dtype=final_bpm.dtype, buffer=shm_final_bpm.buf)

#function that calls frame_fix_badpix_isolated using the similar arguments as in if nproc==1 branch above.
def mp_clean_isolated (j,frame, bpm_mask=None, sigma_clip=3, num_neig=5, size=5, protect_mask=0, verbose=False, cxy=None, ignore_nan=True, full_output=True):
shared_array_out[j], shared_final_bpm[j] = frame_fix_badpix_isolated(frame, bpm_mask=bpm_mask, sigma_clip=sigma_clip, num_neig=num_neig, size=size, protect_mask=protect_mask, verbose=verbose, cxy=cxy, ignore_nan=ignore_nan, full_output=full_output)
#function that unwraps the arguments and passes them to mp_clean_isolated.
global _mp_clean_isolated
def _mp_clean_isolated (args):
pargs=args[0:2]
kwargs=args[2]
mp_clean_isolated(*pargs, **kwargs)

context=multiprocessing.get_context('fork')
pool=context.Pool(processes=nproc, maxtasksperchild=1)

args=[]
for j in range(n_frames):
if bpm_mask is not None:
bpm_mask_tmp = bpm_mask[j]
else:
bpm_mask_tmp = None
dict_kwargs={'bpm_mask' : bpm_mask_tmp, 'sigma_clip': sigma_clip, 'num_neig': num_neig, 'size' : size, 'protect_mask': protect_mask, 'cxy' : (cx[j], cy[j]), 'ignore_nan': ignore_nan}
args.append([j, array[j], dict_kwargs])

try:
pool.map_async(_mp_clean_isolated, args, chunksize=1 ).get(timeout=10_000_000)
finally:
pool.close()
pool.join()
array_out[:]=shared_array_out[:]
final_bpm[:]=shared_final_bpm[:]
shm_array_out.close()
shm_array_out.unlink()
shm_final_bpm.close()
shm_final_bpm.unlink()
count_bp = np.sum(final_bpm)
else:
if bpm_mask is None or not correct_only:
Expand Down Expand Up @@ -589,7 +661,7 @@ def bp_removal_2d(array, cy, cx, fwhm, sig, protect_mask, r_in_std,
def cube_fix_badpix_clump(array, bpm_mask=None, correct_only=False, cy=None,
cx=None, fwhm=4., sig=4., protect_mask=0,
half_res_y=False, min_thr=None, max_nit=15, mad=True,
verbose=True, full_output=False):
verbose=True, full_output=False, nproc=1):
"""
Function to identify and correct clumps of bad pixels. Very fast when a bad
pixel map is provided. If a bad pixel map is not provided, the bad pixel
Expand Down Expand Up @@ -653,6 +725,10 @@ def cube_fix_badpix_clump(array, bpm_mask=None, correct_only=False, cy=None,
full_output: bool, {False,True}, optional
Whether to return as well the cube of bad pixel maps and the cube of
defined annuli.
nproc: int, optional
This feature is added following ADACS update. Refers to the number of processors
available for calculations. Choosing a number >1 enables multiprocessing for the
correction of frames.
Returns
-------
Expand Down Expand Up @@ -803,18 +879,53 @@ def bp_removal_2d(array_corr, cy, cx, fwhm, sig, protect_mask, bpm_mask_ori,
cx = [cx]*n_z
if isinstance(fwhm, (float, int)):
fwhm = [fwhm]*n_z
bpix_map_cumul = np.zeros_like(array_corr)
for i in range(n_z):
if verbose:
print('************Frame # ', i, ' *************')
array_corr[i], bpix_map_cumul[i] = bp_removal_2d(array_corr[i],
cy[i], cx[i],
fwhm[i], sig,
protect_mask,
bpm_mask,
min_thr,
half_res_y, mad,
verbose)
if nproc==1:
bpix_map_cumul = np.zeros_like(array_corr)
for i in range(n_z):
if verbose:
print('************Frame # ', i, ' *************')
array_corr[i], bpix_map_cumul[i] = bp_removal_2d(array_corr[i], cy[i],
cx[i], fwhm[i],
sig, protect_mask,
bpm_mask, min_thr,
half_res_y, mad,
verbose)
else:
msg="Cleaning frames using ADACS' multiprocessing appraoch"
print(msg)
#creating shared memory buffer space for the image cube.
shm_clump= shared_memory.SharedMemory(create=True, size=array_corr.nbytes)
obj_tmp_shared_clump = np.ndarray(array_corr.shape, dtype=array_corr.dtype, buffer=shm_clump.buf)
#creating shared memory buffer space for the bad pixel cube.
shm_clump_bpix= shared_memory.SharedMemory(create=True, size=array_corr.nbytes)
#works with dtype=obj_tmp.dtype but not dtype=int
bpix_map_cumul_shared= np.ndarray(array_corr.shape, dtype=array_corr.dtype, buffer=shm_clump_bpix.buf)
def mp_clump_slow(j, array_corr, cy, cx, fwhm, sig, protect_mask, bpm_mask, min_thr, half_res_y, mad, verbose):
obj_tmp_shared_clump[j], bpix_map_cumul_shared[j] = bp_removal_2d(array_corr, cy, cx, fwhm, sig, protect_mask, bpm_mask, min_thr, half_res_y, mad, verbose)

global _mp_clump_slow

def _mp_clump_slow(args):
mp_clump_slow(*args)

context=multiprocessing.get_context('fork')
pool=context.Pool(processes=nproc, maxtasksperchild=1)
args=[]
for i in range(n_z):
args.append([i,array_corr[i], cy[i], cx[i], fwhm[i], sig, protect_mask, bpm_mask, min_thr, half_res_y, mad, verbose ])
try:
pool.map_async(_mp_clump_slow, args, chunksize=1 ).get(timeout=10_000_000)
finally:
pool.close()
pool.join()
bpix_map_cumul = np.zeros_like(array_corr, dtype=array_corr.dtype)
bpix_map_cumul[:]=bpix_map_cumul_shared[:]
array_corr[:]=obj_tmp_shared_clump[:]
shm_clump.close()
shm_clump.unlink()
shm_clump_bpix.close()
shm_clump_bpix.unlink()

else:
if isinstance(fwhm, (float, int)):
fwhm_round = int(round(fwhm))
Expand All @@ -823,15 +934,57 @@ def bp_removal_2d(array_corr, cy, cx, fwhm, sig, protect_mask, bpm_mask_ori,
fwhm_round = fwhm_round+1-(fwhm_round % 2) # make it odd
neighbor_box = max(3, fwhm_round) # to not replace a companion
nneig = sum(np.arange(3, neighbor_box+2, 2))
for i in range(n_z):
if verbose:
print('************Frame # ', i, ' *************')

if nproc==1:
for i in range(n_z):
if verbose:
print('Using serial approach')
print('************Frame # ', i, ' *************')
if bpm_mask.ndim == 3:
bpm = bpm_mask[i]
else:
bpm = bpm_mask
array_corr[i] = sigma_filter(array_corr[i], bpm, neighbor_box,
nneig, half_res_y, verbose)
else:
msg="Cleaning frames using ADACS' multiprocessing appraoch"
print(msg)
#dummy calling sigma_filter function to create a cached version of the numba function
if bpm_mask.ndim == 3:
bpm = bpm_mask[i]
else:
bpm = bpm_mask
array_corr[i] = sigma_filter(array_corr[i], bpm, neighbor_box,
nneig, half_res_y, verbose)
dummy_bpm = bpm_mask[0]
else:
dummy_bpm = bpm_mask
#Actual dummy call is here.
sigma_filter(array_corr[0], dummy_bpm, neighbor_box, nneig, half_res_y, verbose)
#creating shared memory that each process writes into.
shm_clump = shared_memory.SharedMemory(create=True, size=array_corr.nbytes)
#creating an array that uses shared memory buffer and has the properties of array_corr.
obj_tmp_shared_clump = np.ndarray(array_corr.shape, dtype=array_corr.dtype, buffer=shm_clump.buf)
#function that is called repeatedly by each process.
def mp_clean_clump(j, array_corr, bpm, neighbor_box, nneig, half_res_y, verbose):
obj_tmp_shared_clump[j] = sigma_filter(array_corr, bpm, neighbor_box, nneig, half_res_y, verbose)

global _mp_clean_clump
#function that converts the args into bite-sized pieces for mp_clean_clump.
def _mp_clean_clump(args):
mp_clean_clump(*args)
context=multiprocessing.get_context('fork')
pool=context.Pool(processes=nproc, maxtasksperchild=1)
args=[]
for j in range(n_z):
if bpm_mask.ndim == 3:
bpm = bpm_mask[j]
else:
bpm = bpm_mask
args.append([j,array_corr[j], bpm, neighbor_box, nneig, half_res_y, verbose])
try:
pool.map_async(_mp_clean_clump, args, chunksize=1 ).get(timeout=10_000_000)
finally:
pool.close()
pool.join()
array_corr[:]=obj_tmp_shared_clump[:]
shm_clump.close()
shm_clump.unlink()
bpix_map_cumul = bpm_mask

# make it a binary map
Expand Down
Loading

0 comments on commit e219768

Please sign in to comment.