Skip to content

Commit

Permalink
Modularized filter, fixed lightcurve fitting, checked all bands can b…
Browse files Browse the repository at this point in the history
…e fitted at once, started top level classifier
  • Loading branch information
kdesoto-astro committed Sep 20, 2023
1 parent fd5651a commit af50ddf
Show file tree
Hide file tree
Showing 9 changed files with 685 additions and 199 deletions.
Binary file added src/elasticc2_training/12900721_2_svi.pdf
Binary file not shown.
Binary file added src/elasticc2_training/12900721_3_svi.pdf
Binary file not shown.
Binary file added src/elasticc2_training/12900721_4_svi.pdf
Binary file not shown.
Binary file added src/elasticc2_training/12900721_5_svi.pdf
Binary file not shown.
Binary file added src/elasticc2_training/12900721_6_svi.pdf
Binary file not shown.
Binary file added src/elasticc2_training/12900721_svi.pdf
Binary file not shown.
582 changes: 521 additions & 61 deletions src/elasticc2_training/elasticc2_train.ipynb

Large diffs are not rendered by default.

201 changes: 63 additions & 138 deletions src/elasticc2_training/elasticc_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,122 +2,24 @@
import numpy as np
from astropy.coordinates import SkyCoord
from scipy.stats import truncnorm
from scipy.optimize import curve_fit
import P4J

import io
import torch
import torch.nn as nn
import torch.nn.functional as F

from jax import random, lax, jit
from numpyro.infer import SVI, Trace_ELBO

from superphot_plus.lightcurve import Lightcurve
from superphot_plus.samplers.dynesty_sampler import DynestySampler
from superphot_plus.surveys.surveys import Survey
from superphot_plus.import_utils import clip_lightcurve_end


def get_meta_features(lc, ra, dec):
"""
Calculate the meta features used in the top level and recurring classifier.
TO ADD:
- MWEBV
- HOST/SOURCE OFFSET (both in alert file)
- HOST GAL MAG
"""
t, f, ferr, b = lc.times, lc.fluxes, lc.flux_errors, lc.bands
lin_slopes = []

for b_opt in range(6):
try:
if len(f[b == b_opt]) >= 2:
lin_slopes.append(
curve_fit(
lambda x, *p: p[0]*x + p[1],
t[b == b_opt],
f[b == b_opt],
[0, 0],
sigma=ferr[b == b_opt]
)[0][0]
)
else:
lin_slopes.append(0.)
except:
lin_slopes.append(0.)

u_r_max = 0
g_r_max = 0
i_r_max = 0
z_r_max = 0
Y_r_max = 0

u_r_mean = 0
g_r_mean = 0
i_r_mean = 0
z_r_mean = 0
Y_r_mean = 0

try:
u_r_max = np.abs(np.max(f[b == 0]) / np.max(f[b == 2]))
u_r_mean = np.abs(np.mean(f[b == 0]) / np.mean(f[b == 2]))
except:
pass

try:
g_r_max = np.abs(np.max(f[b == 1]) / np.max(f[b == 2]))
g_r_mean = np.abs(np.mean(f[b == 1]) / np.mean(f[b == 2]))
except:
pass

try:
i_r_max = np.abs(np.max(f[b == 3]) / np.max(f[b == 2]))
i_r_mean = np.abs(np.mean(f[b == 3]) / np.mean(f[b == 2]))
except:
pass

try:
z_r_max = np.abs(np.max(f[b == 4]) / np.max(f[b == 2]))
z_r_mean = np.abs(np.mean(f[b == 4]) / np.mean(f[b == 2]))
except:
pass

try:
Y_r_max = np.abs(np.max(f[b == 5]) / np.max(f[b == 2]))
Y_r_mean = np.abs(np.mean(f[b == 5]) / np.mean(f[b == 2]))
except:
pass

positive_fraction = len(f[f > 0]) / len(f)
best_period = estimate_period(t, f, ferr)
best_period_long = estimate_period(t, f, ferr, 0.2) # only 5 days or longer
gal_b, gal_l = get_galactic_coordinates(ra, dec)

return np.array([gal_b, gal_l, positive_fraction, best_period, best_period_long, *lin_slopes, u_r_max, u_r_mean, g_r_max, g_r_mean, i_r_max, i_r_mean, z_r_max, z_r_mean, Y_r_max, Y_r_mean])


def get_galactic_coordinates(ra, dec):
"""
Get galactic coordinates corresponding to RA and Dec
"""
coords = SkyCoord(ra,dec, frame='icrs', unit="deg")
g_coords = coords.galactic
return g_coords.b.degree, g_coords.l.degree


def estimate_period(t, f, ferr, fmax=50.):
"""
Use MHAOV to estimate the best period of an assumed periodic signal.
"""
#freqs = np.linspace(0., fmax, num=1e5)
#pgram = lombscargle(t, f, freqs, precenter=True)
#fbest = freqs[np.argmax(pgram)]
my_per = P4J.periodogram(method='MHAOV')
my_per.set_data(t, f - np.mean(f), ferr, 6) # shift to be centered vertically around 0
my_per.frequency_grid_evaluation(fmin=0.0, fmax=fmax, fresolution=1e-3) # frequency sweep parameters
my_per.finetune_best_frequencies(fresolution=1e-4, n_local_optima=1)
fbest, pbest = my_per.get_best_frequencies()
#return 1. / fbest
return 1. / fbest[0]
from top_level_utils import (
get_meta_features,
get_galactic_coordinates,
estimate_period
)


def preprocess_lightcurve(lc_arr, name, *, survey: Survey):
Expand Down Expand Up @@ -281,6 +183,14 @@ def setup(self):

self.sampler = DynestySampler()
self.survey = Survey.LSST()

optimizer = numpyro.optim.Adam(step_size=0.001)
self.svi = SVI(jax_model, jax_guide, optimizer, loss=Trace_ELBO())
self.svi_state = None
self.num_iter = 10_000
self.lax_jit = jit(lax_helper_function, static_argnums=(0, 2))



def add_classification(self, class_id, prob):
"""Helper function to add classification result.
Expand Down Expand Up @@ -339,7 +249,11 @@ class and confidence score.

# apply top-level classifier
meta_features = get_meta_features(lc, ra, dec)
self.recurring_prob, self.nonrecurring_prob = self.top_level_model.classify_from_fit_param(meta_features)
(
self.recurring_prob,
self.nonrecurring_prob
) = self.top_level_model.classify_from_fit_param(meta_features)

self.add_classification(2, self.recurring_prob)
self.add_classification(1, self.nonrecurring_prob)

Expand All @@ -360,49 +274,60 @@ class and confidence score.
else: # non-recurring
#print("starting run nested sampling")
gri_lc = lc.filter_by_band(["g", "r", "i"], in_place=False)
gri_samples = self.sampler.run_single_curve(gri_lc, self.survey.priors)
max_flux_r = gri_lc.find_max_flux(band="r")

gri_samples, red_neg_chisq, self.svi_state = _svi_helper_no_recompile(
gri_lc,
max_flux_r,
self.survey.priors,
self.svi,
self.svi_state,
self.lax_jit,
self.num_iter,
)

if gri_samples is None:
self.distribute_prob_evenly(False)

mean_params = gri_samples.sample_mean()
mean_r = mean_params[7:14]
for aux_b in [0, 4, 5]:
#if len(b[b == aux_b]) == 0:
# eq_samples_aux = np.array([
# max_flux_aux = max_flux
#else:
eq_samples_aux, max_flux_aux = run_nested_sampling(mjd[b == aux_b], f[b == aux_b], ferr[b == aux_b], b[b == aux_b], [aux_b,], None, median_red, max_flux)
mean_params = gri_samples.sample_mean() # TODO: only get included bands
mean_params = np.append(mean_params, np.mean(red_neg_chisq))

"""
ref_band_idx = np.where(
self.priors.ordered_bands == self.priors.reference_band
)[0][0]
mean_r = mean_params[7*ref_band_idx:7*(ref_band_idx+1)]
for aux_b in ["u", "z", "Y"]:
aux_lc = lc.filter_by_band(aux_b, in_place=False)
if eq_samples_aux is None:
for class_id in self.nonrecurring_classes:
self.add_classification(
int(class_id),
probs[1].item() / 15.,
)
aux_priors = self.survey.priors.filter_by_band(
[self.priors.reference_band, aux_b]
)
aux_samples, aux_neg_chisq, self.svi_state = _svi_helper_no_recompile(
aux_lc,
max_flux_r,
aux_priors,
self.svi,
self.svi_state,
self.lax_jit,
self.num_iter,
)
if samples_aux is None:
self.distribute_prob_evenly(False)
if aux_b == 0:
eq_samples = np.hstack((eq_samples_aux, eq_samples))
else:
eq_samples = np.hstack((eq_samples, eq_samples_aux))

if eq_samples is None: # max_flux < 0
for class_id in self.nonrecurring_classes:
self.add_classification(
int(class_id),
probs[1].item() / 15.,
)
#print(int(class_id), probs[1].item() / 15.)


"""
probs_all = []
for eq_s in eq_samples:
post = np.append(eq_s, meta_features)
logL = calc_logL(eq_s, mjd, f, ferr, b)
post = np.append(post, logL)
adjusted_params = adjust_log_dists(post) # converts some params to log space
normed_params = (adjusted_params - self.nonrecurring_means) / self.nonrecurring_stddevs # normalizes by mean and stddev used in training normalization

probs_single = get_predictions(self.nonrecurring_model, torch.Tensor(np.array([normed_params,])), 'cpu').numpy() # uses model to output SN type probabilities

probs_all.append(probs_single)

probs_avg = np.mean(np.array(probs_all), axis=0)
Expand Down
101 changes: 101 additions & 0 deletions src/elasticc2_training/top_level_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""NOTE: THIS WILL BE MOVED INTO SUPERPHOT+ EVENTUALLY"""
import numpy as np
from scipy.optimize import curve_fit
import P4J
from astropy.coordinates import SkyCoord


def get_galactic_coordinates(ra, dec):
"""
Get galactic coordinates corresponding to RA and Dec
"""
coords = SkyCoord(ra,dec, frame='icrs', unit="deg")
g_coords = coords.galactic
return g_coords.b.degree, g_coords.l.degree


def estimate_period(t, f, ferr, fmax=50.):
"""
Use MHAOV to estimate the best period of an assumed periodic signal.
"""
#freqs = np.linspace(0., fmax, num=1e5)
#pgram = lombscargle(t, f, freqs, precenter=True)
#fbest = freqs[np.argmax(pgram)]
my_per = P4J.periodogram(method='MHAOV')
my_per.set_data(t, f - np.mean(f), ferr, 6) # shift to be centered vertically around 0
my_per.frequency_grid_evaluation(fmin=0.0, fmax=fmax, fresolution=1e-3) # frequency sweep parameters
my_per.finetune_best_frequencies(fresolution=1e-4, n_local_optima=1)
fbest, pbest = my_per.get_best_frequencies()
#return 1. / fbest
return 1. / fbest[0]


def get_meta_features(
lc,
priors,

):
"""
Calculate the meta features used in the top level and recurring classifier.
TO ADD:
- MWEBV
- HOST/SOURCE OFFSET (both in alert file)
- HOST GAL MAG
"""
extra_info = lc.extras
ra, dec = extra_info['ra'], extra_info['dec'] #TODO: add coords and mwebv fields
mwebv = extra_info['mwebv']
host_sep = extra_info['host_sep']
host_mag = extra_info['host_mag']

t, f, ferr, b = lc.times, lc.fluxes, lc.flux_errors, lc.bands
lin_slopes = []

for unique_b in priors.ordered_bands:
if len(f[b == unique_b]) >= 2:
lin_slopes.append(
curve_fit(
lambda x, *p: p[0]*x + p[1],
t[b == unique_b],
f[b == unique_b],
[0, 0],
sigma=ferr[unique_b]
)[0][0]
)
else:
lin_slopes.append(0.)


N = len(priors.bands)
max_ratios = np.zeros(N-1)
mean_ratios = np.zeros(N-1)

b_ref = priors.reference_band
max_ref = np.max(f[b == b_ref])
mean_ref = np.mean(f[b == b_ref])

for i, b_i in enumerate(priors.aux_bands):
f_b = f[b == b_i]
if len(f_b) < 2:
continue
max_ratios[i] = np.abs(np.max(f_b) / max_ref)
mean_ratios[i] = np.abs(np.mean(f_b) / mean_ref)

positive_fraction = len(f[f > 0]) / len(f)
best_period = estimate_period(t, f, ferr)
best_period_long = estimate_period(t, f, ferr, 0.2) # only 5 days or longer
gal_b, gal_l = get_galactic_coordinates(ra, dec)

return np.array([
gal_b, gal_l, mwebv, host_sep,
host_mag, positive_fraction,
best_period, best_period_long,
*lin_slopes, *max_ratios, *mean_ratios
])


def train_top_level_model():
"""Generate dataset and train top-level model."""
pass

0 comments on commit af50ddf

Please sign in to comment.