Skip to content

Commit

Permalink
Merge pull request #639 from mperrin/datacube_units_fix
Browse files Browse the repository at this point in the history
Fix an issue with providing astropy Units to calc_datacube
  • Loading branch information
mperrin authored Oct 1, 2024
2 parents 3a83d51 + 1f7e58d commit 4780261
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
8 changes: 6 additions & 2 deletions poppy/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ def calc_datacube(self, wavelengths, progressbar=False, *args, **kwargs):
raise ValueError("Maximum number of wavelengths exceeded. "
"Cannot be more than 10,000.")

def wavelength_as_meters(wavelength):
"""helper function to avoid trying to put a Quantity into a FITS header """
return wavelength.to_value(units.meter) if isinstance(wavelength, units.Quantity) else wavelength

# Set up cube and initialize structure based on PSF at first wavelength
poppy_core._log.info("Starting multiwavelength data cube calculation.")
psf = self.calc_psf(*args, monochromatic=wavelengths[0], **kwargs)
Expand All @@ -333,7 +337,7 @@ def calc_datacube(self, wavelengths, progressbar=False, *args, **kwargs):
for ext in range(len(psf)):
cube[ext].data = np.zeros((nwavelengths, psf[ext].data.shape[0], psf[ext].data.shape[1]))
cube[ext].data[0] = psf[ext].data
cube[ext].header[label_wl(0)] = wavelengths[0]
cube[ext].header[label_wl(0)] = wavelength_as_meters(wavelengths[0])

iterate_wrapper = utils.get_progressbar_wrapper(progressbar, nwaves=nwavelengths)
# iterate rest of wavelengths
Expand All @@ -342,7 +346,7 @@ def calc_datacube(self, wavelengths, progressbar=False, *args, **kwargs):
psf = self.calc_psf(*args, monochromatic=wl, **kwargs)
for ext in range(len(psf)):
cube[ext].data[i] = psf[ext].data
cube[ext].header[label_wl(i)] = wl
cube[ext].header[label_wl(i)] = wavelength_as_meters(wl)
cube[ext].header.add_history("--- Cube Plane {} ---".format(i))
for h in psf[ext].header['HISTORY']:
cube[ext].header.add_history(h)
Expand Down
14 changes: 14 additions & 0 deletions poppy/tests/test_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,17 @@ def test_instrument_calc_datacube():
"Multi-wavelength PSF does not match weighted sum of individual wavelength PSFs"

return psf


def test_instrument_datacube_wavelengths_units():
"""Test input of datacube wavelengths with and without astropy units """
inst = instrument.Instrument()

psf = inst.calc_datacube(WAVELENGTHS_ARRAY, fov_pixels=FOV_PIXELS,
detector_oversample=2, fft_oversample=2, progressbar=True)

wavelengths_quantity = WAVELENGTHS_ARRAY * u.meter
psf2 = inst.calc_datacube(wavelengths_quantity, fov_pixels=FOV_PIXELS,
detector_oversample=2, fft_oversample=2, progressbar=True)

assert np.allclose(psf[0].data, psf2[0].data), "Should get same outputs with/without units"

0 comments on commit 4780261

Please sign in to comment.