This repository has been archived by the owner on Jun 22, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fitting.py
314 lines (247 loc) · 11 KB
/
fitting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
#!/usr/bin/env python3.7
# -*- coding: UTF-8 -*-
"""The ``fitting`` module runs a series of fits on individual light-curves and
tabulates the results. This includes the ability to fit each bandpass
independently (like in SiFTO) or to fit restframe blue/red band passes os
separate, collective sets.
Function Documentation
----------------------
"""
from copy import deepcopy
import numpy as np
import sncosmo
from astropy.table import Table
from matplotlib import pyplot
from . import utils
DUST = sncosmo.F99Dust()
def create_empty_table(parameters, **kwargs):
"""Create an empty table for storing fit results
Columns:
- obj_id
- band
- source
- pre_max
- post_max
- num_params
- *parameters
- *parameters + _err
- chisq
- ndof
- b_max
- delta_15
- message
Args:
parameters (iter): List of parameter names to add columns for
Any arguments to pass ``astropy.Table``
Returns:
A masked astropy Table
"""
# Specify column names
names = ['obj_id', 'band', 'source', 'pre_max', 'post_max', 'vparams']
names += list(parameters) + [param + '_err' for param in parameters]
names += ['chisq', 'ndof', 'b_max', 'delta_15', 'message']
# Specify column data types
dtype = ['U20', 'U100', 'U100', int, int, 'U100']
dtype += [float for _ in range(2 * len(parameters))]
dtype += [float, float, float, float, 'U10000']
# Unless otherwise specified, we default to returning a masked table
kwargs = deepcopy(kwargs)
kwargs.setdefault('masked', True)
return Table(names=names, dtype=dtype, **kwargs)
def fit_results_to_dict(data, obj_id, band_set, results, fitted_model):
"""Format sncosmo fit results so they can be appended to an astropy table
See the ``create_empty_table`` function for information on the assumed
table format.
Args:
data (Table): The data used in the fit
obj_id (str): The id of the object that was fit
band_set (str): The name of the band set ('all', 'blue', 'red')
results (Result): Fitting results returned by ``sncosmo``
fitted_model (Model): A fitted ``sncosmo`` model
Returns:
Fit results as a dictionary
"""
new_row = {
'obj_id': obj_id,
'band': band_set,
'source': fitted_model.source.name,
'vparams': ','.join(results.vparam_names)
}
# Determine number of points pre and post maximum
t0 = results.parameters[results.param_names.index('t0')]
new_row['pre_max'] = sum(data['time'] < t0)
new_row['post_max'] = sum(data['time'] >= t0)
# Add parameters and their errors
params = {p: v for p, v in zip(results.param_names, results.parameters)}
new_row.update(params)
for param, error in results.errors.items():
new_row[param + '_err'] = error
# Calc chi-squared
chisq, ndof = utils.calc_model_chisq(data, results, fitted_model)
new_row['chisq'] = np.round(chisq, 2)
new_row['ndof'] = ndof
# Determine peak magnitude and decline rate
b_max = fitted_model.source_peakabsmag('bessellb', 'ab')
peak_phase = fitted_model.source.peakphase('bessellb')
b_0 = fitted_model.source.bandmag('bessellb', 'ab', peak_phase)
b_15 = fitted_model.source.bandmag('bessellb', 'ab', peak_phase + 15)
delta_15 = b_15 - b_0
new_row['b_max'] = np.round(b_max, 2)
new_row['delta_15'] = np.round(delta_15, 3)
# Add fitting exit status message. Not all fitting routines include
# this attribute, so we assign a default value of 'NONE'.
message = getattr(results, 'message', 'NONE')
new_row['message'] = message
return new_row
def _plot_lc(data, result, fitted_model, show=True):
"""Plot fit results
Args:
data (Table): The data used in the fit
result (Result): The fit results
fitted_model (Model): Model with params set to fitted values
"""
fig = sncosmo.plot_lc(data, fitted_model, errors=result.errors)
xs, d = utils.calc_model_chisq(data, result, fitted_model)
print(f'chisq / ndof = {xs} / {d} = {xs / d}', flush=True)
if show:
pyplot.show()
return fig
def _create_fit_data_iter(priors_hs, priors_bg, kwargs_hs, kwargs_bg):
"""Create an iterable of data used to run light-curve fits
Args:
priors_hs (dict): Priors to use when fitting hsiao
priors_bg (dict): Priors to use when fitting sn91bg
kwargs_hs (dict): Kwargs to pass ``fit_func`` when fitting salt2
kwargs_bg (dict): Kwargs to pass ``fit_func`` when fitting sn91bg
Returns:
- An iterable of models, vparams, priors, and kwargs for fitting
- A table to store output data
"""
# Set default kwargs and protect against mutation
priors_hs = deepcopy(priors_hs) or dict()
priors_bg = deepcopy(priors_bg) or dict()
kwargs_hs = deepcopy(kwargs_hs) or dict()
kwargs_bg = deepcopy(kwargs_bg) or dict()
# Define models for normal and 91bg SNe with host galaxy dust
dust_kw = dict(effects=[DUST], effect_names=['mw'], effect_frames=['obs'])
bg_source = sncosmo.get_source('sn91bg', version='hsiao_phase')
sn91bg = sncosmo.Model(bg_source, **dust_kw)
hsiao = sncosmo.Model('hsiao_x1', **dust_kw)
# Determine what parameters to vary for each model
# Hsiao does not have a c parameter. We don't vary mwebv
vparams = {'z', 't0', 'amplitude', 'x1', 'c'}
out_data = create_empty_table(vparams.union({'mwebv'}))
if 'z' in priors_bg and 'z' in priors_hs:
vparams -= {'z'}
hsiao_vparams = set(hsiao.param_names).intersection(vparams)
sn91bg_vparams = set(sn91bg.param_names).intersection(vparams)
# Create iterators over the data we need to fit
model_args = zip(
(hsiao, sn91bg), # The models
(hsiao_vparams, sn91bg_vparams), # The parameters to vary
(priors_hs, priors_bg), # The priors
(kwargs_hs, kwargs_bg) # The fitting kwargs
)
return model_args, out_data
def run_band_fits(
obj_id, data, fit_func,
priors_hs=None, priors_bg=None,
kwargs_hs=None, kwargs_bg=None,
show_plots=False):
"""Run light curve fits on a given target using the Hsiao and 91bg model
Fits are run using both the ``hsiao_x1`` and ``sn91bg`` models for all
available bands and then for each band individually.
Varied parameters include ``z``, ``t0``, ``amplitude``, ``x1``, and ``c``.
If the ``z`` is specified in the priors for both models, it is not varied
in any fit. The parameters ``t0`` and ``z`` are not varied in the
individual band fits.
Args:
obj_id (str): Id of the object being fitted
data (Table): Table of photometric data
fit_func (func): Function to use to run fits (eg. ``fit_funcs.fit_lc``)
priors_hs (dict): Priors to use when fitting hsiao
priors_bg (dict): Priors to use when fitting sn91bg
kwargs_hs (dict): Kwargs to pass ``fit_func`` when fitting salt2
kwargs_bg (dict): Kwargs to pass ``fit_func`` when fitting sn91bg
show_plots (bool): Plot and display each individual fit
Returns:
A table with results each model / dataset combination
"""
model_args, out_data = _create_fit_data_iter(
priors_hs, priors_bg, kwargs_hs, kwargs_bg)
# Tabulate fit results for each band
for model, vparams, prior, kwarg in model_args:
model.update(prior)
kwarg['bounds'] = \
{p: v for p, v in kwarg.get('bounds', {}).items() if p in vparams}
# Fit data in all bands
result_all, fit_all = fit_func(data, model, vparams, **kwarg)
new_row = fit_results_to_dict(data, obj_id, 'all', result_all, fit_all)
out_data.add_row(new_row)
if show_plots:
_plot_lc(data, result_all, fit_all)
# Fix t0 and z during individual band fits
band_vparams = deepcopy(vparams) - {'t0', 'z'}
kwarg['bounds'].pop('t0', None)
kwarg['bounds'].pop('z', None)
# Fit data in individual bands
data = data.group_by('band')
for band_name, band_data in zip(data.groups.keys['band'], data.groups):
# Using amplitude from all data fit as initial guess works better
kwarg['guess_amplitude'] = False
result, fit = fit_func(band_data, fit_all, band_vparams, **kwarg)
new_row = fit_results_to_dict(band_data, obj_id, band_name, result, fit)
out_data.add_row(new_row)
if show_plots:
_plot_lc(band_data, result, fit)
return out_data
def run_collective_fits(
obj_id, data, fit_func,
band_names, lambda_eff,
priors_hs=None, priors_bg=None,
kwargs_hs=None, kwargs_bg=None,
show_plots=False):
"""Run light curve fits on a given target using the Hsiao and 91bg model
Args:
obj_id (str): Id of the object being fitted
data (Table): Table of photometric data
fit_func (func): Function to use to run fits (eg. ``fit_funcs.fit_lc``)
band_names (list): Name of bands included in ``data_iter``
lambda_eff (list): Effective wavelength for bands in ``band_names``
priors_hs (dict): Priors to use when fitting hsiao
priors_bg (dict): Priors to use when fitting sn91bg
kwargs_hs (dict): Kwargs to pass ``fit_func`` when fitting salt2
kwargs_bg (dict): Kwargs to pass ``fit_func`` when fitting sn91bg
show_plots (bool): Plot and display each individual fit
Returns:
A table with results each model / dataset combination
"""
model_args, out_data = _create_fit_data_iter(
priors_hs, priors_bg, kwargs_hs, kwargs_bg)
# Tabulate fit results for each band
for model, vparams, prior, kwarg in model_args:
model.update(prior)
kwarg['bounds'] = \
{p: v for p, v in kwarg.get('bounds', {}).items() if p in vparams}
# Fit data in all bands
result_all, fit_all = fit_func(data, model, vparams, **kwarg)
new_row = fit_results_to_dict(data, obj_id, 'all', result_all, fit_all)
out_data.add_row(new_row)
if show_plots:
_plot_lc(data, result_all, fit_all)
# Fix t0 during individual band fits
band_vparams = deepcopy(vparams) - {'t0', 'z'}
kwarg['bounds'].pop('t0', None)
kwarg['bounds'].pop('z', None)
# Get red and blue data
z = fit_all.parameters[fit_all.param_names.index('z')]
blue_data, red_data = utils.split_data(data, band_names, lambda_eff, z)
for band_name, band_data in zip(('blue', 'red'), (blue_data, red_data)):
# Using amplitude from all data fit as initial guess works better
kwarg['guess_amplitude'] = False
result, fit = fit_func(band_data, fit_all, band_vparams, **kwarg)
new_row = fit_results_to_dict(band_data, obj_id, band_name, result, fit)
out_data.add_row(new_row)
if show_plots:
_plot_lc(band_data, result, fit)
return out_data