Skip to content

Commit

Permalink
Merge pull request #354 from astrolabsoftware/master
Browse files Browse the repository at this point in the history
merge master into branch for AL loop
  • Loading branch information
emilleishida authored Oct 25, 2023
2 parents 4f7c8e2 + 18d7686 commit 86a63f2
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 21 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/run_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ jobs:
- name: Run test suites
run: |
pip uninstall -y supernnova
pip install git+https://github.com/supernnova/SuperNNova.git@6cfcb72009a734b26caa46eacb54d3e2dde9e2c3#egg=supernnova
pip uninstall -y actsnfink
pip install git+https://github.com/emilleishida/fink_sn_activelearning.git@bf8d4e263e02d42781642f872f7bc030c24792bc#egg=actsnfink
Expand Down
2 changes: 1 addition & 1 deletion fink_science/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "4.8.1"
__version__ = "5.1.0"
7 changes: 3 additions & 4 deletions fink_science/agn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
import joblib
import fink_science.agn.kernel as k
import fink_science.agn.feature_extraction as fe
import os
Expand All @@ -27,7 +27,7 @@
def load_classifier(source):
"""
load the random forest classifier trained to recognize the AGN
on binary cases : AGNs vs non-AGNs (pickle format).
on binary cases : AGNs vs non-AGNs (joblib format).
Parameters
----------
Expand Down Expand Up @@ -58,8 +58,7 @@ def load_classifier(source):
elif source == 'ZTF':
model_path = k.CLASSIFIER_ZTF

with open(model_path, "rb") as f:
clf = pickle.load(f)
clf = joblib.load(model_path)

return clf

Expand Down
4 changes: 2 additions & 2 deletions fink_science/agn/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fink_science import __file__

curdir = os.path.dirname(os.path.abspath(__file__))
CLASSIFIER_ELASTICC = curdir + "/data/models/AGN_elasticc_alerts.pkl"
CLASSIFIER_ZTF = curdir + "/data/models/AGN_binary.pkl"
CLASSIFIER_ELASTICC = curdir + "/data/models/AGN_elasticc_alerts.joblib"
CLASSIFIER_ZTF = curdir + "/data/models/AGN_binary.joblib"
MINIMUM_POINTS = 4
MAXFEV = 3000
6 changes: 2 additions & 4 deletions fink_science/agn/unit_examples.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import pandas as pd
import numpy as np
import fink_science.agn.kernel as k
import pickle
import joblib


with open(k.CLASSIFIER_ELASTICC, "rb") as f:
clf_unit = pickle.load(f)
clf_unit = joblib.load(k.CLASSIFIER_ELASTICC)

raw_ztf_unit = pd.DataFrame(
{
Expand Down
Binary file not shown.
Binary file removed fink_science/data/models/AGN_elasticc_alerts.pkl
Binary file not shown.
Binary file modified fink_science/data/models/SLSN_elasticc_alerts.joblib
Binary file not shown.
60 changes: 52 additions & 8 deletions fink_science/random_forest_snia/processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2022 AstroLab Software
# Copyright 2019-2023 AstroLab Software
# Author: Julien Peloton
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -70,7 +70,12 @@ def apply_selection_cuts_ztf(
return mask

@pandas_udf(DoubleType(), PandasUDFType.SCALAR)
def rfscore_sigmoid_full(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist, model=None) -> pd.Series:
def rfscore_sigmoid_full(
jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist,
min_rising_points=pd.Series([2]),
min_data_points=pd.Series([4]),
rising_criteria=pd.Series(['ewma']),
model=None) -> pd.Series:
""" Return the probability of an alert to be a SNe Ia using a Random
Forest Classifier (sigmoid fit).
Expand All @@ -88,6 +93,10 @@ def rfscore_sigmoid_full(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist, model=N
Type of object found in Simbad (string)
ndethist: Spark DataFrame Column
Column containing the number of detection by ZTF at 3 sigma (int)
min_rising_points, min_data_points: int
Parameters from fink_sn_activelearning.git
rising_criteria: str
How to compute derivatives: ewma (default), or diff.
model: Spark DataFrame Column, optional
Path to the trained model. Default is None, in which case the default
model `data/models/default-model.obj` is loaded.
Expand Down Expand Up @@ -139,16 +148,26 @@ def rfscore_sigmoid_full(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist, model=N
+----------------+-----+
<BLANKLINE>
# Note that we can also specify a model
# We can also specify fink_sn_activelearning parameters
>>> args = [F.col(i) for i in what_prefix]
>>> args += [F.col('cdsxmatch'), F.col('candidate.ndethist')]
>>> args += [F.lit(model_path_sigmoid)]
>>> args += [F.lit(2), F.lit(4), F.lit('ewma')]
>>> df = df.withColumn('pIa', rfscore_sigmoid_full(*args))
>>> df.filter(df['pIa'] > 0.5).count()
6
>>> df.agg({"pIa": "max"}).collect()[0][0] < 1.0
# We can also specify a different model
>>> args = [F.col(i) for i in what_prefix]
>>> args += [F.col('cdsxmatch'), F.col('candidate.ndethist')]
>>> args += [F.lit(1), F.lit(3), F.lit('diff')]
>>> args += [F.lit(model_path_al_loop)]
>>> df = df.withColumn('pIaAL', rfscore_sigmoid_full(*args))
>>> df.filter(df['pIaAL'] > 0.5).count()
5
>>> df.agg({"pIaAL": "max"}).collect()[0][0] < 1.0
True
"""
mask = apply_selection_cuts_ztf(magpsf, ndethist, cdsxmatch)
Expand All @@ -171,7 +190,12 @@ def rfscore_sigmoid_full(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist, model=N
flag = []
for id in np.unique(pdf['SNID']):
pdf_sub = pdf[pdf['SNID'] == id]
features = get_sigmoid_features_dev(pdf_sub)
features = get_sigmoid_features_dev(
pdf_sub,
min_rising_points=min_rising_points.values[0],
min_data_points=min_data_points.values[0],
rising_criteria=rising_criteria.values[0]
)
if (features[0] == 0) or (features[6] == 0):
flag.append(False)
else:
Expand All @@ -183,14 +207,22 @@ def rfscore_sigmoid_full(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist, model=N
# Make predictions
probabilities = clf.predict_proba(test_features)

# pIa = 0.0 for objects that do not
# have both features non-zero.
probabilities[~flag] = [1.0, 0.0]

# Take only probabilities to be Ia
to_return = np.zeros(len(jd), dtype=float)
to_return[mask] = probabilities.T[1]

return pd.Series(to_return)

@pandas_udf(StringType(), PandasUDFType.SCALAR)
def extract_features_rf_snia(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist) -> pd.Series:
def extract_features_rf_snia(
jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist,
min_rising_points=pd.Series([2]),
min_data_points=pd.Series([4]),
rising_criteria=pd.Series(['ewma'])) -> pd.Series:
""" Return the features used by the RF classifier.
There are 12 features. Order is:
Expand All @@ -209,6 +241,10 @@ def extract_features_rf_snia(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist) ->
Type of object found in Simbad (string)
ndethist: Spark DataFrame Column
Column containing the number of detection by ZTF at 3 sigma (int)
min_rising_points, min_data_points: int
Parameters from fink_sn_activelearning.git
rising_criteria: str
How to compute derivatives: ewma (default), or diff.
Returns
----------
Expand Down Expand Up @@ -259,7 +295,12 @@ def extract_features_rf_snia(jd, fid, magpsf, sigmapsf, cdsxmatch, ndethist) ->
test_features = []
for id in np.unique(pdf['SNID']):
pdf_sub = pdf[pdf['SNID'] == id]
features = get_sigmoid_features_dev(pdf_sub)
features = get_sigmoid_features_dev(
pdf_sub,
min_rising_points=min_rising_points.values[0],
min_data_points=min_data_points.values[0],
rising_criteria=rising_criteria.values[0]
)
test_features.append(features)

to_return_features = np.zeros((len(jd), len(RF_FEATURE_NAMES)), dtype=float)
Expand Down Expand Up @@ -414,5 +455,8 @@ def rfscore_sigmoid_elasticc(
model_path_sigmoid = '{}/data/models/default-model_sigmoid.obj'.format(path)
globs["model_path_sigmoid"] = model_path_sigmoid

model_path_al_loop = '{}/data/models/for_al_loop/model_20231009.pkl'.format(path)
globs["model_path_al_loop"] = model_path_al_loop

# Run the test suite
spark_unit_tests(globs)

0 comments on commit 86a63f2

Please sign in to comment.