diff --git a/analysis/spectra_chisq.py b/analysis/spectra_chisq.py new file mode 100644 index 0000000..cc7c07c --- /dev/null +++ b/analysis/spectra_chisq.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# -*- coding: UTF-8 -*- + +"""Calculate the chi-squared values for modeled spectra""" + +from copy import deepcopy + +import numpy as np +import sncosmo +from astropy.table import Table + +from . import utils + + +# Todo: Write the color_warp_spectrum function +def color_warp_spectrum(wave, flux, flux_err, **kwargs): + """Color warp a spectrum + + Args: + wave (ndarray): An array of wavelengths + flux (ndarray): An array of flux values + flux_err (ndarray): An array of error values for ``flux`` + + Returns: + An array of color warped flux values + An array of error values for the color warped fluxes + """ + + return flux, flux_err + + +def band_limits(band_name, trans_limit): + """Return wavelength range where a band is above a given transmission + + Args: + wave (ndarray): An array of wavelengths + flux (ndarray): An array of flux values + flux_err (ndarray): An array of error values for ``flux`` + band_name (str): Name of an sncosmo registered band + trans_limit (float): The transmission limit + + Returns: + The wavelengths, fluxes, and flux errors inside the bandpass + """ + + if band_name.lower() == 'all': + return -np.inf, np.inf + + band = sncosmo.get_bandpass(band_name) + transmission_limits = band.wave[band.trans >= trans_limit] + return np.min(transmission_limits), np.max(transmission_limits) + + +def band_chisq(wave, flux, flux_err, model_flux, band_start, band_end): + """Calculate the chi-squared for a spectrum within a wavelength range + + Args: + wave (ndarray): An array of wavelengths + flux (ndarray): An array of flux values + flux_err (ndarray): An array of error values for ``flux`` + model_flux (ndarray): An array of model flux values + band_start (float): The starting wavelength for the band + band_end (float): The ending wavelength for the band + + Returns: + A dictionary with the chi_squared value in each band + """ + + if band_start < np.min(wave) or np.max(wave) < band_end: + raise ValueError + + indices = np.where((band_start < wave) & (wave < band_end))[0] + chisq_arr = (flux[indices] - model_flux[indices]) / flux_err[indices] + return np.sum(chisq_arr) + + +def create_empty_output_table(band_names): + """Create an empty astropy table for storing chi-squared results + + Args: + band_names (list): List band names + + Returns: + An astropy table + """ + + names, dtype = ['obj_id', 'source', 'version'], ['U100', 'U100', 'U100'] + names.extend(band_names) + dtype.extend((float for _ in band_names)) + + out_table = Table(names=names, dtype=dtype, masked=True) + return out_table + + +def tabulate_chi_squared(data_release, models, bands, out_path=None): + """Tabulate chi-squared values for spectroscopic observations + + Args: + data_release (module): An sndata data release + models (list): List of sncosmo models + bands (list): A list of band names + out_path (str): Optionally write results to file + + Returns: + An astropy table of chi-squared values + """ + + out_table = create_empty_output_table(bands) + data_iter = data_release.iter_data( + verbose={'desc': 'Targets'}, filter_func=utils.filter_has_csp_data) + + for data_table in data_iter: + obj_id = data_table.meta['obj_id'] + ebv = utils.get_csp_ebv(obj_id) + t0 = utils.get_csp_t0(obj_id) + obs_time, wave, flux = utils.parse_spectra_table(data_table) + flux_err = .1 * flux # Todo: get actual CSP DR1 errors + phase = obs_time - t0 + + for model in models: + model = deepcopy(model) + model.set(extebv=ebv) + + for p, w, f, fe in zip(phase, wave, flux, flux_err): + new_row = [obj_id, model.source.name, model.source.version] + mask = [False, False, False] + + for band in bands: + band_start, band_end = band_limits(band, .1) + model_flux = model.flux(p, w) + + try: + chisq = band_chisq(w, f, fe, model_flux, band_start, band_end) + new_row.append(chisq) + mask.append(False) + + except ValueError: + new_row.append(np.NAN) + mask.append(True) + + out_table.add_row(new_row, mask=mask) + if out_table: + out_table.write(out_path, overwrite=True) + + return out_table diff --git a/analysis/utils.py b/analysis/utils.py index 7a01e81..b9f1020 100644 --- a/analysis/utils.py +++ b/analysis/utils.py @@ -9,6 +9,55 @@ from tqdm import tqdm +class NoCSPData(Exception): + pass + + +def filter_has_csp_data(data_table): + """Return whether an object ID has an available t0 and E(B - V) value + + Args: + data_table (Table): A table from sndata + + Returns: + A boolean + """ + + obj_id = data_table.meta['obj_id'] + try: + get_csp_t0(obj_id) + get_csp_ebv(obj_id) + + except NoCSPData: + return False + + else: + return True + + +@np.vectorize +def convert_to_jd(time): + """Convert MJD and Snoopy dates into JD + + Args: + time (float): Time stamp in JD, MJD, or SNPY format + + Returns: + The time value in JD format + """ + + # Snoopy time format + if time < 53000: + return time + 53000 + 2400000.5 + + # Snoopy time format + elif 53000 < time < 2400000.5: + return time + 2400000.5 + + else: + return time + + def get_csp_t0(obj_id): """Get the t0 value published by CSP DR3 for a given object @@ -19,9 +68,11 @@ def get_csp_t0(obj_id): The published MJD of maximum minus 53000 """ + dr3.download_module_data() params = dr3.load_table(3) + params = params[~params['T(Bmax)'].mask] if obj_id not in params['SN']: - raise ValueError(f'No published t0 for {obj_id}') + raise NoCSPData(f'No published t0 for {obj_id}') return params[params['SN'] == obj_id]['T(Bmax)'][0] @@ -36,9 +87,10 @@ def get_csp_ebv(obj_id): The published E(B - V) value """ + dr1.download_module_data() extinction_table = dr1.load_table(1) if obj_id not in extinction_table['SN']: - raise ValueError(f'No published E(B-V) for {obj_id}') + raise NoCSPData(f'No published E(B-V) for {obj_id}') data_for_target = extinction_table[extinction_table['SN'] == obj_id] return data_for_target['E(B-V)'][0] @@ -99,5 +151,5 @@ def parse_spectra_table(data): wavelength.append(data_for_date['wavelength']) flux.append(data_for_date['flux']) - obs_dates = np.array(obs_dates) - 2400000.5 # Convert from JD to MJD + obs_dates = np.array(obs_dates) return obs_dates, np.array(wavelength), np.array(flux)