Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional progress bar display during PSF and datacube calculations #605

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions poppy/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,20 @@ def calc_psf(self, outfile=None, source=None, nlambda=None, monochromatic=None,
else:
return result

def calc_datacube(self, wavelengths, *args, **kwargs):
def calc_datacube(self, wavelengths, progressbar=False, *args, **kwargs):
"""Calculate a spectral datacube of PSFs

Parameters
-----------
wavelengths : iterable of floats
List or ndarray or tuple of floating point wavelengths in meters, such as
you would supply in a call to calc_psf via the "monochromatic" option
progressbar : bool
Optionally display a progress bar indicator for status
while iterating over wavelengths. Note, this requires the
optional dependency package 'tqdm', which is not included as
a requirement.

"""

# Allow up to 10,000 wavelength slices. The number matters because FITS
Expand All @@ -329,8 +335,9 @@ def calc_datacube(self, wavelengths, *args, **kwargs):
cube[ext].data[0] = psf[ext].data
cube[ext].header[label_wl(0)] = wavelengths[0]

iterate_wrapper = utils.get_progressbar_wrapper(progressbar, nwaves=nwavelengths)
# iterate rest of wavelengths
for i in range(1, nwavelengths):
for i in iterate_wrapper(range(1, nwavelengths)):
wl = wavelengths[i]
psf = self.calc_psf(*args, monochromatic=wl, **kwargs)
for ext in range(len(psf)):
Expand Down
9 changes: 8 additions & 1 deletion poppy/poppy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,7 @@ def calc_psf(self, wavelength=1e-6,
source=None,
normalize='first',
display_intermediates=False,
progressbar=False,
inwave=None):
"""Calculate a PSF, either multi-wavelength or monochromatic.

Expand Down Expand Up @@ -1612,6 +1613,11 @@ def calc_psf(self, wavelength=1e-6,
display_intermediates: bool, optional
Display intermediate optical planes? Default is False. This option is incompatible with
parallel calculations using `multiprocessing`. (If calculating in parallel, it will have no effect.)
progressbar : bool
Optionally display a progress bar indicator for status
while iterating over wavelengths. Note, this requires the
optional dependency package 'tqdm', which is not included as
a requirement.

Returns
-------
Expand Down Expand Up @@ -1725,7 +1731,8 @@ def calc_psf(self, wavelength=1e-6,
else: # ######### single-threaded computations (may still use multi cores if FFTW enabled ######
if display:
plt.clf()
for wlen, wave_weight in zip(wavelength, normwts):
iterate_wrapper = utils.get_progressbar_wrapper(progressbar, nwaves=len(wavelength))
for wlen, wave_weight in iterate_wrapper(zip(wavelength, normwts)):
mono_psf, mono_intermediate_wfs = self.propagate_mono(
wlen,
retain_intermediates=retain_intermediates,
Expand Down
3 changes: 3 additions & 0 deletions poppy/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def test_multiwavelength_opticalsystem():
assert np.allclose(psf[0].data, output), \
"Multi-wavelength PSF does not match weighted sum of individual wavelength PSFs"

# test that it's also possible to display a progress bar for multi wave calculations
psf = osys.calc_psf(wavelength=wavelengths, weight=weights, progressbar=True)

return psf


Expand Down
2 changes: 1 addition & 1 deletion poppy/tests/test_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_instrument_calc_datacube():

inst = instrument.Instrument()
psf = inst.calc_datacube(WAVELENGTHS_ARRAY, fov_pixels=FOV_PIXELS,
detector_oversample=2, fft_oversample=2)
detector_oversample=2, fft_oversample=2, progressbar=True)
assert psf[0].header['NWAVES'] == len(WAVELENGTHS_ARRAY), \
"Number of wavelengths in PSF header does not match number requested"
assert len(psf[0].data.shape) == 3, "Incorrect dimensions for output cube"
Expand Down
26 changes: 26 additions & 0 deletions poppy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,3 +1855,29 @@ def fftw_load_wisdom(filename=None):
"optimization measurements (automatically). ")

_loaded_fftw_wisdom = True

# ##################################################################
# Progress bar (optional convenience)
#

def get_progressbar_wrapper(progressbar=True, nwaves=None):
""" Utility function to return an iterator that MAY display a progress bar,
or may not, depending
"""
if progressbar:
# this relies on an optional dependency, tqdm
# if it's not present, just don't try to display a progressbar
try:
from tqdm import tqdm
except ImportError:
progressbar = False

if progressbar:
import functools
# set up an optional progressbar wrapper
iterate_wrapper = functools.partial(tqdm, ncols=80, total=nwaves)
else:
# null wrapper that does nothing, for no progress bar
iterate_wrapper = lambda x: x

return iterate_wrapper
Loading