Skip to content

Commit

Permalink
Merge pull request #605 from mperrin/calc_datacube_progressbar
Browse files Browse the repository at this point in the history
Add optional progress bar display during PSF and datacube calculations
  • Loading branch information
BradleySappington authored Apr 29, 2024
2 parents 9107ebb + f77dfdd commit c7dab74
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 4 deletions.
11 changes: 9 additions & 2 deletions poppy/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,20 @@ def calc_psf(self, outfile=None, source=None, nlambda=None, monochromatic=None,
else:
return result

def calc_datacube(self, wavelengths, *args, **kwargs):
def calc_datacube(self, wavelengths, progressbar=False, *args, **kwargs):
"""Calculate a spectral datacube of PSFs
Parameters
-----------
wavelengths : iterable of floats
List or ndarray or tuple of floating point wavelengths in meters, such as
you would supply in a call to calc_psf via the "monochromatic" option
progressbar : bool
Optionally display a progress bar indicator for status
while iterating over wavelengths. Note, this requires the
optional dependency package 'tqdm', which is not included as
a requirement.
"""

# Allow up to 10,000 wavelength slices. The number matters because FITS
Expand All @@ -329,8 +335,9 @@ def calc_datacube(self, wavelengths, *args, **kwargs):
cube[ext].data[0] = psf[ext].data
cube[ext].header[label_wl(0)] = wavelengths[0]

iterate_wrapper = utils.get_progressbar_wrapper(progressbar, nwaves=nwavelengths)
# iterate rest of wavelengths
for i in range(1, nwavelengths):
for i in iterate_wrapper(range(1, nwavelengths)):
wl = wavelengths[i]
psf = self.calc_psf(*args, monochromatic=wl, **kwargs)
for ext in range(len(psf)):
Expand Down
9 changes: 8 additions & 1 deletion poppy/poppy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,7 @@ def calc_psf(self, wavelength=1e-6,
source=None,
normalize='first',
display_intermediates=False,
progressbar=False,
inwave=None):
"""Calculate a PSF, either multi-wavelength or monochromatic.
Expand Down Expand Up @@ -1612,6 +1613,11 @@ def calc_psf(self, wavelength=1e-6,
display_intermediates: bool, optional
Display intermediate optical planes? Default is False. This option is incompatible with
parallel calculations using `multiprocessing`. (If calculating in parallel, it will have no effect.)
progressbar : bool
Optionally display a progress bar indicator for status
while iterating over wavelengths. Note, this requires the
optional dependency package 'tqdm', which is not included as
a requirement.
Returns
-------
Expand Down Expand Up @@ -1725,7 +1731,8 @@ def calc_psf(self, wavelength=1e-6,
else: # ######### single-threaded computations (may still use multi cores if FFTW enabled ######
if display:
plt.clf()
for wlen, wave_weight in zip(wavelength, normwts):
iterate_wrapper = utils.get_progressbar_wrapper(progressbar, nwaves=len(wavelength))
for wlen, wave_weight in iterate_wrapper(zip(wavelength, normwts)):
mono_psf, mono_intermediate_wfs = self.propagate_mono(
wlen,
retain_intermediates=retain_intermediates,
Expand Down
3 changes: 3 additions & 0 deletions poppy/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def test_multiwavelength_opticalsystem():
assert np.allclose(psf[0].data, output), \
"Multi-wavelength PSF does not match weighted sum of individual wavelength PSFs"

# test that it's also possible to display a progress bar for multi wave calculations
psf = osys.calc_psf(wavelength=wavelengths, weight=weights, progressbar=True)

return psf


Expand Down
2 changes: 1 addition & 1 deletion poppy/tests/test_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_instrument_calc_datacube():

inst = instrument.Instrument()
psf = inst.calc_datacube(WAVELENGTHS_ARRAY, fov_pixels=FOV_PIXELS,
detector_oversample=2, fft_oversample=2)
detector_oversample=2, fft_oversample=2, progressbar=True)
assert psf[0].header['NWAVES'] == len(WAVELENGTHS_ARRAY), \
"Number of wavelengths in PSF header does not match number requested"
assert len(psf[0].data.shape) == 3, "Incorrect dimensions for output cube"
Expand Down
26 changes: 26 additions & 0 deletions poppy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,3 +1855,29 @@ def fftw_load_wisdom(filename=None):
"optimization measurements (automatically). ")

_loaded_fftw_wisdom = True

# ##################################################################
# Progress bar (optional convenience)
#

def get_progressbar_wrapper(progressbar=True, nwaves=None):
""" Utility function to return an iterator that MAY display a progress bar,
or may not, depending
"""
if progressbar:
# this relies on an optional dependency, tqdm
# if it's not present, just don't try to display a progressbar
try:
from tqdm import tqdm
except ImportError:
progressbar = False

if progressbar:
import functools
# set up an optional progressbar wrapper
iterate_wrapper = functools.partial(tqdm, ncols=80, total=nwaves)
else:
# null wrapper that does nothing, for no progress bar
iterate_wrapper = lambda x: x

return iterate_wrapper

0 comments on commit c7dab74

Please sign in to comment.