Skip to content

Commit

Permalink
Return full vector for SNN (elasticc version) (#360)
Browse files Browse the repository at this point in the history
* Test new return type for SNN

* PEP8 and pandas syntax

* Restore full test suite
  • Loading branch information
JulienPeloton authored Dec 10, 2023
1 parent 429bc9a commit 07397aa
Showing 1 changed file with 6 additions and 15 deletions.
21 changes: 6 additions & 15 deletions fink_science/snn/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,19 +386,17 @@ def snn_broad_elasticc(
>>> args += [F.lit('elasticc_broad')]
>>> df = df.withColumn('preds', snn_broad_elasticc(*args))
>>> df = df.withColumn('snn_class', F.col('preds').getItem(0).astype('int'))
>>> df = df.withColumn('snn_max_prob', F.col('preds').getItem(1))
>>> pdf = df.select('preds').toPandas()
>>> df.filter(df['snn_class'] == 0).count()
# 11 objects have been classified as class 0
>>> np.sum(pdf['preds'].apply(lambda x: np.argmax(x) == 0))
11
"""
# No a priori cuts
mask = np.ones(len(diaSourceId), dtype=bool)

if len(midPointTai[mask]) == 0:
snn_class = np.ones(len(midPointTai), dtype=float) * -1
snn_max_prob = np.zeros(len(midPointTai), dtype=float)
return pd.Series([[i, j] for i, j in zip(snn_class, snn_max_prob)])
return pd.Series([[0.0 for i in range(5)] for j in range(len(diaSourceId))])

# Conversion to FLUXCAL
fac = 10**(-(31.4 - 27.5) / 2.5)
Expand Down Expand Up @@ -451,19 +449,12 @@ def snn_broad_elasticc(
preds_df = reformat_to_df(pred_probs, ids=ids)
preds_df.index = preds_df.SNID

# Take only probabilities to be Ia
snn_class = np.ones(len(midPointTai), dtype=float) * -1
snn_max_prob = np.zeros(len(midPointTai), dtype=float)

all_preds = preds_df.reindex([str(i) for i in diaSourceId[mask].values])

cols = ['prob_class{}'.format(i) for i in range(5)]
all_preds[['snn_class', 'snn_max_prob']] = all_preds[cols].apply(lambda x: extract_max_prob(x), axis=1, result_type="expand")
snn_class[mask] = all_preds.snn_class.values
snn_max_prob[mask] = all_preds.snn_max_prob.values
all_preds['all'] = all_preds[cols].values.tolist()

# return main class and associated probability
return pd.Series([[i, j] for i, j in zip(snn_class, snn_max_prob)])
return all_preds['all']


if __name__ == "__main__":
Expand Down

0 comments on commit 07397aa

Please sign in to comment.