diff --git a/poppy/instrument.py b/poppy/instrument.py index 33e57df3..aacb7b1c 100644 --- a/poppy/instrument.py +++ b/poppy/instrument.py @@ -298,7 +298,7 @@ def calc_psf(self, outfile=None, source=None, nlambda=None, monochromatic=None, else: return result - def calc_datacube(self, wavelengths, *args, **kwargs): + def calc_datacube(self, wavelengths, progressbar=False, *args, **kwargs): """Calculate a spectral datacube of PSFs Parameters @@ -306,6 +306,12 @@ def calc_datacube(self, wavelengths, *args, **kwargs): wavelengths : iterable of floats List or ndarray or tuple of floating point wavelengths in meters, such as you would supply in a call to calc_psf via the "monochromatic" option + progressbar : bool + Optionally display a progress bar indicator for status + while iterating over wavelengths. Note, this requires the + optional dependency package 'tqdm', which is not included as + a requirement. + """ # Allow up to 10,000 wavelength slices. The number matters because FITS @@ -329,8 +335,9 @@ def calc_datacube(self, wavelengths, *args, **kwargs): cube[ext].data[0] = psf[ext].data cube[ext].header[label_wl(0)] = wavelengths[0] + iterate_wrapper = utils.get_progressbar_wrapper(progressbar, nwaves=nwavelengths) # iterate rest of wavelengths - for i in range(1, nwavelengths): + for i in iterate_wrapper(range(1, nwavelengths)): wl = wavelengths[i] psf = self.calc_psf(*args, monochromatic=wl, **kwargs) for ext in range(len(psf)): diff --git a/poppy/poppy_core.py b/poppy/poppy_core.py index 41a59a56..847ba72d 100644 --- a/poppy/poppy_core.py +++ b/poppy/poppy_core.py @@ -1579,6 +1579,7 @@ def calc_psf(self, wavelength=1e-6, source=None, normalize='first', display_intermediates=False, + progressbar=False, inwave=None): """Calculate a PSF, either multi-wavelength or monochromatic. @@ -1612,6 +1613,11 @@ def calc_psf(self, wavelength=1e-6, display_intermediates: bool, optional Display intermediate optical planes? Default is False. This option is incompatible with parallel calculations using `multiprocessing`. (If calculating in parallel, it will have no effect.) + progressbar : bool + Optionally display a progress bar indicator for status + while iterating over wavelengths. Note, this requires the + optional dependency package 'tqdm', which is not included as + a requirement. Returns ------- @@ -1725,7 +1731,8 @@ def calc_psf(self, wavelength=1e-6, else: # ######### single-threaded computations (may still use multi cores if FFTW enabled ###### if display: plt.clf() - for wlen, wave_weight in zip(wavelength, normwts): + iterate_wrapper = utils.get_progressbar_wrapper(progressbar, nwaves=len(wavelength)) + for wlen, wave_weight in iterate_wrapper(zip(wavelength, normwts)): mono_psf, mono_intermediate_wfs = self.propagate_mono( wlen, retain_intermediates=retain_intermediates, diff --git a/poppy/tests/test_core.py b/poppy/tests/test_core.py index b9218598..be0c0a01 100644 --- a/poppy/tests/test_core.py +++ b/poppy/tests/test_core.py @@ -178,6 +178,9 @@ def test_multiwavelength_opticalsystem(): assert np.allclose(psf[0].data, output), \ "Multi-wavelength PSF does not match weighted sum of individual wavelength PSFs" + # test that it's also possible to display a progress bar for multi wave calculations + psf = osys.calc_psf(wavelength=wavelengths, weight=weights, progressbar=True) + return psf diff --git a/poppy/tests/test_instrument.py b/poppy/tests/test_instrument.py index 1086087e..931f0c12 100644 --- a/poppy/tests/test_instrument.py +++ b/poppy/tests/test_instrument.py @@ -166,7 +166,7 @@ def test_instrument_calc_datacube(): inst = instrument.Instrument() psf = inst.calc_datacube(WAVELENGTHS_ARRAY, fov_pixels=FOV_PIXELS, - detector_oversample=2, fft_oversample=2) + detector_oversample=2, fft_oversample=2, progressbar=True) assert psf[0].header['NWAVES'] == len(WAVELENGTHS_ARRAY), \ "Number of wavelengths in PSF header does not match number requested" assert len(psf[0].data.shape) == 3, "Incorrect dimensions for output cube" diff --git a/poppy/utils.py b/poppy/utils.py index b54ed45c..7fe3ddc0 100644 --- a/poppy/utils.py +++ b/poppy/utils.py @@ -1855,3 +1855,29 @@ def fftw_load_wisdom(filename=None): "optimization measurements (automatically). ") _loaded_fftw_wisdom = True + +# ################################################################## +# Progress bar (optional convenience) +# + +def get_progressbar_wrapper(progressbar=True, nwaves=None): + """ Utility function to return an iterator that MAY display a progress bar, + or may not, depending + """ + if progressbar: + # this relies on an optional dependency, tqdm + # if it's not present, just don't try to display a progressbar + try: + from tqdm import tqdm + except ImportError: + progressbar = False + + if progressbar: + import functools + # set up an optional progressbar wrapper + iterate_wrapper = functools.partial(tqdm, ncols=80, total=nwaves) + else: + # null wrapper that does nothing, for no progress bar + iterate_wrapper = lambda x: x + + return iterate_wrapper