Skip to content

Commit

Permalink
Add customisation to feature extractor in SN Ia. (#350)
Browse files Browse the repository at this point in the history
* Add customisation to feature extractor in SN Ia. Update the test suite for AL loop.

* Update constructor

* Fix test suite
  • Loading branch information
JulienPeloton authored Oct 19, 2023
1 parent cfa509c commit 8d87ab6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 11 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/run_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,8 @@ jobs:
- name: Run test suites
run: |
pip uninstall -y supernnova
pip install git+https://github.com/supernnova/SuperNNova.git@6cfcb72009a734b26caa46eacb54d3e2dde9e2c3#egg=supernnova
pip uninstall -y actsnfink
pip install git+https://github.com/emilleishida/fink_sn_activelearning.git@bf8d4e263e02d42781642f872f7bc030c24792bc#egg=actsnfink
pip install fink-utils==0.13.3
./run_tests.sh
curl -s https://codecov.io/bash | bash
56 changes: 48 additions & 8 deletions fink_science/random_forest_snia/processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2022 AstroLab Software
# Copyright 2019-2023 AstroLab Software
# Author: Julien Peloton
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -70,7 +70,12 @@ def apply_selection_cuts_ztf(
return mask

@pandas_udf(DoubleType(), PandasUDFType.SCALAR)
def rfscore_sigmoid_full(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist, model=None) -> pd.Series:
def rfscore_sigmoid_full(
jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist,
min_rising_points=pd.Series([2]),
min_data_points=pd.Series([4]),
rising_criteria=pd.Series(['ewma']),
model=None) -> pd.Series:
""" Return the probability of an alert to be a SNe Ia using a Random
Forest Classifier (sigmoid fit).
Expand All @@ -88,6 +93,10 @@ def rfscore_sigmoid_full(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist, model=N
Type of object found in Simbad (string)
ndethist: Spark DataFrame Column
Column containing the number of detection by ZTF at 3 sigma (int)
min_rising_points, min_data_points: int
Parameters from fink_sn_activelearning.git
rising_criteria: str
How to compute derivatives: ewma (default), or diff.
model: Spark DataFrame Column, optional
Path to the trained model. Default is None, in which case the default
model `data/models/default-model.obj` is loaded.
Expand Down Expand Up @@ -139,16 +148,26 @@ def rfscore_sigmoid_full(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist, model=N
+----------------+-----+
<BLANKLINE>
# Note that we can also specify a model
# We can also specify fink_sn_activelearning parameters
>>> args = [F.col(i) for i in what_prefix]
>>> args += [F.col('cdsxmatch'), F.col('candidate.ndethist')]
>>> args += [F.lit(model_path_sigmoid)]
>>> args += [F.lit(2), F.lit(4), F.lit('ewma')]
>>> df = df.withColumn('pIa', rfscore_sigmoid_full(*args))
>>> df.filter(df['pIa'] > 0.5).count()
6
>>> df.agg({"pIa": "max"}).collect()[0][0] < 1.0
# We can also specify a different model
>>> args = [F.col(i) for i in what_prefix]
>>> args += [F.col('cdsxmatch'), F.col('candidate.ndethist')]
>>> args += [F.lit(1), F.lit(3), F.lit('diff')]
>>> args += [F.lit(model_path_al_loop)]
>>> df = df.withColumn('pIaAL', rfscore_sigmoid_full(*args))
>>> df.filter(df['pIaAL'] > 0.5).count()
5
>>> df.agg({"pIaAL": "max"}).collect()[0][0] < 1.0
True
"""
mask = apply_selection_cuts_ztf(magpsf, ndethist, cdsxmatch)
Expand All @@ -171,7 +190,12 @@ def rfscore_sigmoid_full(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist, model=N
flag = []
for id in np.unique(pdf['SNID']):
pdf_sub = pdf[pdf['SNID'] == id]
features = get_sigmoid_features_dev(pdf_sub)
features = get_sigmoid_features_dev(
pdf_sub,
min_rising_points=min_rising_points.values[0],
min_data_points=min_data_points.values[0],
rising_criteria=rising_criteria.values[0]
)
if (features[0] == 0) or (features[6] == 0):
flag.append(False)
else:
Expand All @@ -194,7 +218,11 @@ def rfscore_sigmoid_full(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist, model=N
return pd.Series(to_return)

@pandas_udf(StringType(), PandasUDFType.SCALAR)
def extract_features_rf_snia(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist) -> pd.Series:
def extract_features_rf_snia(
jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist,
min_rising_points=pd.Series([2]),
min_data_points=pd.Series([4]),
rising_criteria=pd.Series(['ewma'])) -> pd.Series:
""" Return the features used by the RF classifier.
There are 12 features. Order is:
Expand All @@ -213,6 +241,10 @@ def extract_features_rf_snia(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist) ->
Type of object found in Simbad (string)
ndethist: Spark DataFrame Column
Column containing the number of detection by ZTF at 3 sigma (int)
min_rising_points, min_data_points: int
Parameters from fink_sn_activelearning.git
rising_criteria: str
How to compute derivatives: ewma (default), or diff.
Returns
----------
Expand Down Expand Up @@ -263,7 +295,12 @@ def extract_features_rf_snia(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist) ->
test_features = []
for id in np.unique(pdf['SNID']):
pdf_sub = pdf[pdf['SNID'] == id]
features = get_sigmoid_features_dev(pdf_sub)
features = get_sigmoid_features_dev(
pdf_sub,
min_rising_points=min_rising_points.values[0],
min_data_points=min_data_points.values[0],
rising_criteria=rising_criteria.values[0]
)
test_features.append(features)

to_return_features = np.zeros((len(jd), len(RF_FEATURE_NAMES)), dtype=float)
Expand Down Expand Up @@ -418,5 +455,8 @@ def rfscore_sigmoid_elasticc(
model_path_sigmoid = '{}/data/models/default-model_sigmoid.obj'.format(path)
globs["model_path_sigmoid"] = model_path_sigmoid

model_path_al_loop = '{}/data/models/for_al_loop/model_20231009.pkl'.format(path)
globs["model_path_al_loop"] = model_path_al_loop

# Run the test suite
spark_unit_tests(globs)

0 comments on commit 8d87ab6

Please sign in to comment.