Skip to content

Commit

Permalink
Test new return type for SNN
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienPeloton committed Dec 10, 2023
1 parent 429bc9a commit 815cd1e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
27 changes: 14 additions & 13 deletions fink_science/snn/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,19 +386,17 @@ def snn_broad_elasticc(
>>> args += [F.lit('elasticc_broad')]
>>> df = df.withColumn('preds', snn_broad_elasticc(*args))
>>> df = df.withColumn('snn_class', F.col('preds').getItem(0).astype('int'))
>>> df = df.withColumn('snn_max_prob', F.col('preds').getItem(1))
>>> pdf = df.select('preds').toPandas()
>>> df.filter(df['snn_class'] == 0).count()
# 11 objects have been classified as class 0
>>> np.sum(pdf.apply(lambda x: np.argmax(x) == 0))
11
"""
# No a priori cuts
mask = np.ones(len(diaSourceId), dtype=bool)

if len(midPointTai[mask]) == 0:
snn_class = np.ones(len(midPointTai), dtype=float) * -1
snn_max_prob = np.zeros(len(midPointTai), dtype=float)
return pd.Series([[i, j] for i, j in zip(snn_class, snn_max_prob)])
return pd.Series([[0.0 for i in range(5)] for j in range(len(diaSourceId))])

# Conversion to FLUXCAL
fac = 10**(-(31.4 - 27.5) / 2.5)
Expand Down Expand Up @@ -452,19 +450,22 @@ def snn_broad_elasticc(
preds_df.index = preds_df.SNID

# Take only probabilities to be Ia
snn_class = np.ones(len(midPointTai), dtype=float) * -1
snn_max_prob = np.zeros(len(midPointTai), dtype=float)
# snn_class = np.ones(len(midPointTai), dtype=float) * -1
# snn_max_prob = np.zeros(len(midPointTai), dtype=float)

all_preds = preds_df.reindex([str(i) for i in diaSourceId[mask].values])

cols = ['prob_class{}'.format(i) for i in range(5)]
all_preds[['snn_class', 'snn_max_prob']] = all_preds[cols].apply(lambda x: extract_max_prob(x), axis=1, result_type="expand")
snn_class[mask] = all_preds.snn_class.values
snn_max_prob[mask] = all_preds.snn_max_prob.values
all_preds['all'] = all_preds[cols].values.tolist()

# all_preds[['snn_class', 'snn_max_prob']] = all_preds[cols].apply(lambda x: extract_max_prob(x), axis=1, result_type="expand")
# snn_class[mask] = all_preds.snn_class.values
# snn_max_prob[mask] = all_preds.snn_max_prob.values

# return main class and associated probability
return pd.Series([[i, j] for i, j in zip(snn_class, snn_max_prob)])
# # return main class and associated probability
# return pd.Series([[i, j] for i, j in zip(snn_class, snn_max_prob)])

return all_preds['all']

if __name__ == "__main__":
""" Execute the test suite """
Expand Down
4 changes: 4 additions & 0 deletions run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ done
export PYTHONPATH="${SPARK_HOME}/python/test_coverage:$PYTHONPATH"
export COVERAGE_PROCESS_START="${ROOTPATH}/.coveragerc"

coverage run --source=${ROOTPATH} --rcfile ${ROOTPATH}/.coveragerc fink_science/snn/processor.py

dasdlsa

# Run the test suite on the utilities
for filename in fink_science/*.py
do
Expand Down

0 comments on commit 815cd1e

Please sign in to comment.