diff --git a/src/tdastro/astro_utils/unit_utils.py b/src/tdastro/astro_utils/unit_utils.py index 37dcfa9..35acb72 100644 --- a/src/tdastro/astro_utils/unit_utils.py +++ b/src/tdastro/astro_utils/unit_utils.py @@ -2,6 +2,27 @@ from astropy import constants as const +def bulk_convert_matrix(mat, in_units, out_units): + """Bulk convert a numpy matrix from one set of units to another. + + Parameters + ---------- + mat : numpy.ndarray + The numpy matrix of values to convert. + in_units : astropy.units + The current (input) units of the matrix. + out_units : astropy.units + The desired (output) units of the matrix. + + Returns + ------- + output : numpy.ndarray + The converted numpy matrix of values. + """ + multipler = (1.0 * in_units).to_value(out_units) + return multipler * mat + + def flam_to_fnu(flux_flam, wavelengths, *, wave_unit, flam_unit, fnu_unit): """ Covert flux from f_lambda unit to f_nu unit @@ -27,8 +48,8 @@ def flam_to_fnu(flux_flam, wavelengths, *, wave_unit, flam_unit, fnu_unit): flux_fnu : `list` or `np.array` The flux values in fnu units. """ - flux_flam = np.array(flux_flam) * flam_unit - wavelengths = np.array(wavelengths) * wave_unit + flux_flam = np.asarray(flux_flam) + wavelengths = np.asarray(wavelengths) # Check if we need to reshape wavelengths to match the number # of rows in flux_flam. @@ -40,12 +61,13 @@ def flam_to_fnu(flux_flam, wavelengths, *, wave_unit, flam_unit, fnu_unit): _ = np.broadcast_shapes(flux_flam.shape, wavelengths.shape) except ValueError as err: raise ValueError( - f"Mismatched sizes for flux_flam={flux_flam.shape} " f"and wavelengths={wavelengths.shape}." + f"Mismatched sizes for flux_flam={flux_flam.shape} and wavelengths={wavelengths.shape}." ) from err # convert flux in flam_unit (e.g. ergs/s/cm^2/A) to fnu_unit (e.g. nJy or ergs/s/cm^2/Hz) - flux_fnu = (flux_flam * (wavelengths**2) / const.c).to_value(fnu_unit) - return flux_fnu + input_units = (flam_unit * wave_unit * wave_unit) / const.c.unit + flux_fnu = flux_flam * (wavelengths**2) / const.c.value + return bulk_convert_matrix(flux_fnu, input_units, fnu_unit) def fnu_to_flam(flux_fnu, wavelengths, *, wave_unit, flam_unit, fnu_unit): @@ -73,8 +95,8 @@ def fnu_to_flam(flux_fnu, wavelengths, *, wave_unit, flam_unit, fnu_unit): flux_flam : `list` or `np.array` The flux values in flam units. """ - flux_fnu = np.array(flux_fnu) * fnu_unit - wavelengths = np.array(wavelengths) * wave_unit + flux_fnu = np.asarray(flux_fnu) + wavelengths = np.asarray(wavelengths) # Check if we need to reshape wavelengths to match the number # of rows in flux_fnu. @@ -90,5 +112,6 @@ def fnu_to_flam(flux_fnu, wavelengths, *, wave_unit, flam_unit, fnu_unit): ) from err # convert flux in fnu_unit (e.g. nJy or ergs/s/cm^2/Hz) to flam_unit (e.g. ergs/s/cm^2/A) - flux_flam = (flux_fnu * const.c / wavelengths**2).to_value(flam_unit) - return flux_flam + input_units = (fnu_unit * const.c.unit) / (wave_unit * wave_unit) + flux_flam = flux_fnu * const.c.value / wavelengths**2 + return bulk_convert_matrix(flux_flam, input_units, flam_unit)