Skip to content

Commit

Permalink
Merge pull request #149 from lincc-frameworks/fnu_to_flam
Browse files Browse the repository at this point in the history
Utility function to convert F_ν to F_λ
  • Loading branch information
hombit authored Oct 3, 2024
2 parents dd22ca2 + 0a0fc80 commit 7efe1be
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 9 deletions.
61 changes: 54 additions & 7 deletions src/tdastro/astro_utils/unit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from astropy import constants as const


def flam_to_fnu(flux_flam, wavelengths, wave_unit=None, flam_unit=None, fnu_unit=None):
def flam_to_fnu(flux_flam, wavelengths, *, wave_unit, flam_unit, fnu_unit):
"""
Covert flux from f_lambda unit to f_nu unit
Expand Down Expand Up @@ -30,18 +30,65 @@ def flam_to_fnu(flux_flam, wavelengths, wave_unit=None, flam_unit=None, fnu_unit
flux_flam = np.array(flux_flam) * flam_unit
wavelengths = np.array(wavelengths) * wave_unit

# Check if we need to repeat wavelengths to match the number
# Check if we need to reshape wavelengths to match the number
# of rows in flux_flam.
if len(flux_flam.shape) > 1 and len(wavelengths.shape) == 1:
num_rows = flux_flam.shape[0]
wavelengths = np.tile(wavelengths, (num_rows, 1))
if flux_flam.ndim > 1 and wavelengths.ndim == 1:
wavelengths = wavelengths[None, :]

# Check that the shapes match.
if flux_flam.shape != wavelengths.shape:
try:
_ = np.broadcast_shapes(flux_flam.shape, wavelengths.shape)
except ValueError as err:
raise ValueError(
f"Mismatched sizes for flux_flam={flux_flam.shape} " f"and wavelengths={wavelengths.shape}."
)
) from err

# convert flux in flam_unit (e.g. ergs/s/cm^2/A) to fnu_unit (e.g. nJy or ergs/s/cm^2/Hz)
flux_fnu = (flux_flam * (wavelengths**2) / const.c).to_value(fnu_unit)
return flux_fnu


def fnu_to_flam(flux_fnu, wavelengths, *, wave_unit, flam_unit, fnu_unit):
"""
Covert flux from f_nu unit to f_lambda unit
Parameters
----------
flux_fnu : `list` or `numpy.ndarray`
The flux values in fnu units. This can be a single N-length array
or an M x N matrix.
wavelengths: `list` or `numpy.ndarray`
The wavelength values associated with the input flux values.
This can be a single N-length array or an M x N matrix. If it is an
N-length array, the same wavelength values are used for each flux_fnu.
wave_unit: `astropy.units.Unit`
The unit for the wavelength values.
flam_unit: `astropy.units.Unit`
The unit for the output flux_flam values.
fnu_unit: `astropy.units.Unit`
The unit for the input flux_fnu values.
Returns
-------
flux_flam : `list` or `np.array`
The flux values in flam units.
"""
flux_fnu = np.array(flux_fnu) * fnu_unit
wavelengths = np.array(wavelengths) * wave_unit

# Check if we need to reshape wavelengths to match the number
# of rows in flux_fnu.
if flux_fnu.ndim > 1 and wavelengths.ndim == 1:
wavelengths = wavelengths[None, :]

# Check that the shapes match.
try:
_ = np.broadcast_shapes(flux_fnu.shape, wavelengths.shape)
except ValueError as err:
raise ValueError(
f"Mismatched sizes for flux_fnu={flux_fnu.shape} " f"and wavelengths={wavelengths.shape}."
) from err

# convert flux in fnu_unit (e.g. nJy or ergs/s/cm^2/Hz) to flam_unit (e.g. ergs/s/cm^2/A)
flux_flam = (flux_fnu * const.c / wavelengths**2).to_value(flam_unit)
return flux_flam
27 changes: 25 additions & 2 deletions tests/tdastro/astro_utils/test_unit_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
from astropy import units as u
from tdastro.astro_utils.unit_utils import flam_to_fnu
from tdastro.astro_utils.unit_utils import flam_to_fnu, fnu_to_flam


def test_flam_to_fnu():
Expand Down Expand Up @@ -62,8 +62,31 @@ def test_flam_to_fnu_matrix():
with pytest.raises(ValueError):
_ = flam_to_fnu(
[[c * 1.0e-4, 1.0e5], [1000.0, 2000.0], [c * 1.0e-4, c * 1.0e-4]],
[100.0],
[100.0, 200.0, 300.0],
wave_unit=u.AA,
flam_unit=u.erg / u.second / u.cm**2 / u.AA,
fnu_unit=u.erg / u.second / u.cm**2 / u.Hz,
)


def test_flam_to_fnu_to_flam():
"""Test that flam_to_fnu(fnu_to_flam) is the identity."""
rng = np.random.default_rng(None)
n = 100
waves = rng.uniform(low=100.0, high=1.0e5, size=n)
flam0 = rng.lognormal(mean=0.0, sigma=1.0, size=n)
fnu = flam_to_fnu(
flam0,
waves,
wave_unit=u.AA,
flam_unit=u.erg / u.second / u.cm**2 / u.AA,
fnu_unit=u.nJy,
)
flam1 = fnu_to_flam(
fnu,
waves,
wave_unit=u.AA,
flam_unit=u.erg / u.second / u.cm**2 / u.AA,
fnu_unit=u.nJy,
)
np.testing.assert_allclose(flam0, flam1)

0 comments on commit 7efe1be

Please sign in to comment.