From 1f7e58d15dd0f702740c1656b884795f728fffa3 Mon Sep 17 00:00:00 2001 From: Marshall Perrin Date: Fri, 27 Sep 2024 08:52:38 -0400 Subject: [PATCH] Fix an issue with providing astropy Units to calc_datacube --- poppy/instrument.py | 8 ++++++-- poppy/tests/test_instrument.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/poppy/instrument.py b/poppy/instrument.py index aacb7b1c..1975d9b0 100644 --- a/poppy/instrument.py +++ b/poppy/instrument.py @@ -325,6 +325,10 @@ def calc_datacube(self, wavelengths, progressbar=False, *args, **kwargs): raise ValueError("Maximum number of wavelengths exceeded. " "Cannot be more than 10,000.") + def wavelength_as_meters(wavelength): + """helper function to avoid trying to put a Quantity into a FITS header """ + return wavelength.to_value(units.meter) if isinstance(wavelength, units.Quantity) else wavelength + # Set up cube and initialize structure based on PSF at first wavelength poppy_core._log.info("Starting multiwavelength data cube calculation.") psf = self.calc_psf(*args, monochromatic=wavelengths[0], **kwargs) @@ -333,7 +337,7 @@ def calc_datacube(self, wavelengths, progressbar=False, *args, **kwargs): for ext in range(len(psf)): cube[ext].data = np.zeros((nwavelengths, psf[ext].data.shape[0], psf[ext].data.shape[1])) cube[ext].data[0] = psf[ext].data - cube[ext].header[label_wl(0)] = wavelengths[0] + cube[ext].header[label_wl(0)] = wavelength_as_meters(wavelengths[0]) iterate_wrapper = utils.get_progressbar_wrapper(progressbar, nwaves=nwavelengths) # iterate rest of wavelengths @@ -342,7 +346,7 @@ def calc_datacube(self, wavelengths, progressbar=False, *args, **kwargs): psf = self.calc_psf(*args, monochromatic=wl, **kwargs) for ext in range(len(psf)): cube[ext].data[i] = psf[ext].data - cube[ext].header[label_wl(i)] = wl + cube[ext].header[label_wl(i)] = wavelength_as_meters(wl) cube[ext].header.add_history("--- Cube Plane {} ---".format(i)) for h in psf[ext].header['HISTORY']: cube[ext].header.add_history(h) diff --git a/poppy/tests/test_instrument.py b/poppy/tests/test_instrument.py index 931f0c12..320d6815 100644 --- a/poppy/tests/test_instrument.py +++ b/poppy/tests/test_instrument.py @@ -184,3 +184,17 @@ def test_instrument_calc_datacube(): "Multi-wavelength PSF does not match weighted sum of individual wavelength PSFs" return psf + + +def test_instrument_datacube_wavelengths_units(): + """Test input of datacube wavelengths with and without astropy units """ + inst = instrument.Instrument() + + psf = inst.calc_datacube(WAVELENGTHS_ARRAY, fov_pixels=FOV_PIXELS, + detector_oversample=2, fft_oversample=2, progressbar=True) + + wavelengths_quantity = WAVELENGTHS_ARRAY * u.meter + psf2 = inst.calc_datacube(wavelengths_quantity, fov_pixels=FOV_PIXELS, + detector_oversample=2, fft_oversample=2, progressbar=True) + + assert np.allclose(psf[0].data, psf2[0].data), "Should get same outputs with/without units"