Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up unit conversions #179

Merged
merged 2 commits into from
Oct 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions src/tdastro/astro_utils/unit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,27 @@
from astropy import constants as const


def bulk_convert_matrix(mat, in_units, out_units):
"""Bulk convert a numpy matrix from one set of units to another.

Parameters
----------
mat : numpy.ndarray
The numpy matrix of values to convert.
in_units : astropy.units
The current (input) units of the matrix.
out_units : astropy.units
The desired (output) units of the matrix.

Returns
-------
output : numpy.ndarray
The converted numpy matrix of values.
"""
multipler = (1.0 * in_units).to_value(out_units)
return multipler * mat


def flam_to_fnu(flux_flam, wavelengths, *, wave_unit, flam_unit, fnu_unit):
"""
Covert flux from f_lambda unit to f_nu unit
Expand All @@ -27,8 +48,8 @@ def flam_to_fnu(flux_flam, wavelengths, *, wave_unit, flam_unit, fnu_unit):
flux_fnu : `list` or `np.array`
The flux values in fnu units.
"""
flux_flam = np.array(flux_flam) * flam_unit
wavelengths = np.array(wavelengths) * wave_unit
flux_flam = np.asarray(flux_flam)
wavelengths = np.asarray(wavelengths)

# Check if we need to reshape wavelengths to match the number
# of rows in flux_flam.
Expand All @@ -40,12 +61,13 @@ def flam_to_fnu(flux_flam, wavelengths, *, wave_unit, flam_unit, fnu_unit):
_ = np.broadcast_shapes(flux_flam.shape, wavelengths.shape)
except ValueError as err:
raise ValueError(
f"Mismatched sizes for flux_flam={flux_flam.shape} " f"and wavelengths={wavelengths.shape}."
f"Mismatched sizes for flux_flam={flux_flam.shape} and wavelengths={wavelengths.shape}."
) from err

# convert flux in flam_unit (e.g. ergs/s/cm^2/A) to fnu_unit (e.g. nJy or ergs/s/cm^2/Hz)
flux_fnu = (flux_flam * (wavelengths**2) / const.c).to_value(fnu_unit)
return flux_fnu
input_units = (flam_unit * wave_unit * wave_unit) / const.c.unit
flux_fnu = flux_flam * (wavelengths**2) / const.c.value
return bulk_convert_matrix(flux_fnu, input_units, fnu_unit)


def fnu_to_flam(flux_fnu, wavelengths, *, wave_unit, flam_unit, fnu_unit):
Expand Down Expand Up @@ -73,8 +95,8 @@ def fnu_to_flam(flux_fnu, wavelengths, *, wave_unit, flam_unit, fnu_unit):
flux_flam : `list` or `np.array`
The flux values in flam units.
"""
flux_fnu = np.array(flux_fnu) * fnu_unit
wavelengths = np.array(wavelengths) * wave_unit
flux_fnu = np.asarray(flux_fnu)
wavelengths = np.asarray(wavelengths)

# Check if we need to reshape wavelengths to match the number
# of rows in flux_fnu.
Expand All @@ -90,5 +112,6 @@ def fnu_to_flam(flux_fnu, wavelengths, *, wave_unit, flam_unit, fnu_unit):
) from err

# convert flux in fnu_unit (e.g. nJy or ergs/s/cm^2/Hz) to flam_unit (e.g. ergs/s/cm^2/A)
flux_flam = (flux_fnu * const.c / wavelengths**2).to_value(flam_unit)
return flux_flam
input_units = (fnu_unit * const.c.unit) / (wave_unit * wave_unit)
flux_flam = flux_fnu * const.c.value / wavelengths**2
return bulk_convert_matrix(flux_flam, input_units, flam_unit)
Loading