Skip to content

Commit

Permalink
Speed up unit conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 28, 2024
1 parent 88bd226 commit 124323f
Showing 1 changed file with 32 additions and 9 deletions.
41 changes: 32 additions & 9 deletions src/tdastro/astro_utils/unit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,27 @@
from astropy import constants as const


def bulk_convert_matrix(mat, in_units, out_units):
"""Bulk convert a numpy matrix from one set of units to another.
Parameters
----------
mat : numpy.ndarray
The numpy matrix of values to convert.
in_units : astropy.units
The current (input) units of the matrix.
out_units : astropy.units
The desired (output) units of the matrix.
Returns
-------
output : numpy.ndarray
The converted numpy matrix of values.
"""
multipler = (1.0 * in_units).to_value(out_units)
return multipler * mat


def flam_to_fnu(flux_flam, wavelengths, *, wave_unit, flam_unit, fnu_unit):
"""
Covert flux from f_lambda unit to f_nu unit
Expand All @@ -27,8 +48,8 @@ def flam_to_fnu(flux_flam, wavelengths, *, wave_unit, flam_unit, fnu_unit):
flux_fnu : `list` or `np.array`
The flux values in fnu units.
"""
flux_flam = np.array(flux_flam) * flam_unit
wavelengths = np.array(wavelengths) * wave_unit
flux_flam = np.array(flux_flam)
wavelengths = np.array(wavelengths)

# Check if we need to reshape wavelengths to match the number
# of rows in flux_flam.
Expand All @@ -40,12 +61,13 @@ def flam_to_fnu(flux_flam, wavelengths, *, wave_unit, flam_unit, fnu_unit):
_ = np.broadcast_shapes(flux_flam.shape, wavelengths.shape)
except ValueError as err:
raise ValueError(
f"Mismatched sizes for flux_flam={flux_flam.shape} " f"and wavelengths={wavelengths.shape}."
f"Mismatched sizes for flux_flam={flux_flam.shape} and wavelengths={wavelengths.shape}."
) from err

# convert flux in flam_unit (e.g. ergs/s/cm^2/A) to fnu_unit (e.g. nJy or ergs/s/cm^2/Hz)
flux_fnu = (flux_flam * (wavelengths**2) / const.c).to_value(fnu_unit)
return flux_fnu
input_units = (flam_unit * wave_unit * wave_unit) / const.c.unit
flux_fnu = flux_flam * (wavelengths**2) / const.c.value
return bulk_convert_matrix(flux_fnu, input_units, fnu_unit)


def fnu_to_flam(flux_fnu, wavelengths, *, wave_unit, flam_unit, fnu_unit):
Expand Down Expand Up @@ -73,8 +95,8 @@ def fnu_to_flam(flux_fnu, wavelengths, *, wave_unit, flam_unit, fnu_unit):
flux_flam : `list` or `np.array`
The flux values in flam units.
"""
flux_fnu = np.array(flux_fnu) * fnu_unit
wavelengths = np.array(wavelengths) * wave_unit
flux_fnu = np.array(flux_fnu)
wavelengths = np.array(wavelengths)

# Check if we need to reshape wavelengths to match the number
# of rows in flux_fnu.
Expand All @@ -90,5 +112,6 @@ def fnu_to_flam(flux_fnu, wavelengths, *, wave_unit, flam_unit, fnu_unit):
) from err

# convert flux in fnu_unit (e.g. nJy or ergs/s/cm^2/Hz) to flam_unit (e.g. ergs/s/cm^2/A)
flux_flam = (flux_fnu * const.c / wavelengths**2).to_value(flam_unit)
return flux_flam
input_units = (fnu_unit * const.c.unit) / (wave_unit * wave_unit)
flux_flam = flux_fnu * const.c.value / wavelengths**2
return bulk_convert_matrix(flux_flam, input_units, flam_unit)

0 comments on commit 124323f

Please sign in to comment.