diff --git a/fink_science/anomaly_detection/processor.py b/fink_science/anomaly_detection/processor.py index 37c29b5a..1431331d 100644 --- a/fink_science/anomaly_detection/processor.py +++ b/fink_science/anomaly_detection/processor.py @@ -14,13 +14,17 @@ # limitations under the License. import logging import os -import pickle import zipfile from pyspark.sql.functions import udf from pyspark.sql.types import DoubleType import pandas as pd +import numpy as np + +from onnx import load +import onnxruntime as rt + from fink_science import __file__ from fink_science.tester import spark_unit_tests @@ -43,96 +47,133 @@ def __init__(self, forest_g, forest_r) -> None: self.forest_g = forest_g def anomaly_score(self, data_g, data_r): - scores_g = self.forest_g.score_samples(data_g) - scores_r = self.forest_r.score_samples(data_r) - return (scores_g + scores_r) / 2 + scores_g = self.forest_g.run(None, {"X": data_g.values.astype(np.float32)}) + scores_r = self.forest_r.run(None, {"X": data_r.values.astype(np.float32)}) + return (scores_g[-1] + scores_r[-1]) / 2 path = os.path.dirname(os.path.abspath(__file__)) model_path = f"{path}/data/models/anomaly_detection" -g_model_path = f"{model_path}/forest_g.pickle" -r_model_path = f"{model_path}/forest_r.pickle" +g_model_path = f"{model_path}/forest_g.onnx" +r_model_path = f"{model_path}/forest_r.onnx" +g_model_path_AAD = f"{model_path}/forest_g_AAD.onnx" +r_model_path_AAD = f"{model_path}/forest_r_AAD.onnx" if not (os.path.exists(r_model_path) and os.path.exists(g_model_path)): # unzip in a tmp place tmp_path = '/tmp' - g_model_path = f"{tmp_path}/forest_g.pickle" - r_model_path = f"{tmp_path}/forest_r.pickle" + g_model_path = f"{tmp_path}/forest_g.onnx" + r_model_path = f"{tmp_path}/forest_r.onnx" + g_model_path_AAD = f"{tmp_path}/forest_g_AAD.onnx" + r_model_path_AAD = f"{tmp_path}/forest_r_AAD.onnx" # check it does not exist to avoid concurrent write if not (os.path.exists(r_model_path) and os.path.exists(g_model_path)): with zipfile.ZipFile(f"{model_path}/anomaly_detection_forest.zip", 'r') as zip_ref: zip_ref.extractall(tmp_path) + if not (os.path.exists(g_model_path_AAD) and os.path.exists(r_model_path_AAD)): + with zipfile.ZipFile(f"{model_path}/anomaly_detection_forest_AAD.zip", 'r') as zip_ref: + zip_ref.extractall(tmp_path) -with open(r_model_path, 'rb') as forest_file: - forest_r = pickle.load(forest_file) -with open(g_model_path, 'rb') as forest_file: - forest_g = pickle.load(forest_file) -r_means = pd.read_csv(f"{model_path}/r_means.csv", header=None, index_col=0, squeeze=True) -g_means = pd.read_csv(f"{model_path}/g_means.csv", header=None, index_col=0, squeeze=True) - -model = TwoBandModel(forest_g, forest_r) - - -@udf(returnType=DoubleType()) -def anomaly_score(lc_features) -> float: - """ Returns anomaly score for an observation - Parameters - ---------- - lc_features: Spark Map - Dict of dicts of floats. Keys of first dict - filters (fid), keys of inner dicts - names of features. +class WrapInferenceSession: + """ + The class is an additional wrapper over InferenceSession + to solve the pyspark serialisation problem - Returns - ---------- - out: float - Anomaly score + https://github.com/microsoft/onnxruntime/pull/800#issuecomment-844326099 + """ + def __init__(self, onnx_bytes): + self.sess = rt.InferenceSession(onnx_bytes.SerializeToString()) + self.onnx_bytes = onnx_bytes - Examples - --------- - >>> from fink_utils.spark.utils import concat_col - >>> from pyspark.sql import functions as F - >>> from fink_science.ad_features.processor import extract_features_ad + def run(self, *args): + return self.sess.run(*args) - >>> df = spark.read.load(ztf_alert_sample) + def __getstate__(self): + return {'onnx_bytes': self.onnx_bytes} - # Required alert columns, concatenated with historical data - >>> what = ['magpsf', 'jd', 'sigmapsf', 'fid', 'distnr', 'magnr', 'sigmagnr', 'isdiffpos'] - >>> prefix = 'c' - >>> what_prefix = [prefix + i for i in what] - >>> for colname in what: - ... df = concat_col(df, colname, prefix=prefix) + def __setstate__(self, values): + self.onnx_bytes = values['onnx_bytes'] + self.sess = rt.InferenceSession(self.onnx_bytes.SerializeToString()) - >>> cols = ['cmagpsf', 'cjd', 'csigmapsf', 'cfid', 'objectId', 'cdistnr', 'cmagnr', 'csigmagnr', 'cisdiffpos'] - >>> df = df.withColumn('lc_features', extract_features_ad(*cols)) - >>> df = df.withColumn("anomaly_score", anomaly_score("lc_features")) - >>> df.filter(df["anomaly_score"] < -0.5).count() - 7 +forest_r = WrapInferenceSession(load(r_model_path)) +forest_g = WrapInferenceSession(load(g_model_path)) +forest_r_AAD = WrapInferenceSession(load(r_model_path_AAD)) +forest_g_AAD = WrapInferenceSession(load(g_model_path_AAD)) - >>> df.filter(df["anomaly_score"] == 0).count() - 84 - """ +r_means = pd.read_csv(f"{model_path}/r_means.csv", header=None, index_col=0, squeeze=True) +g_means = pd.read_csv(f"{model_path}/g_means.csv", header=None, index_col=0, squeeze=True) - if ( - lc_features is None - or len(lc_features) != 2 # noqa: W503 (https://www.flake8rules.com/rules/W503.html, https://www.flake8rules.com/rules/W504.html) - or any(map( # noqa: W503 - lambda fs: (fs is None or len(fs) == 0), - lc_features.values() - )) - ): - return 0.0 - if any(map(lambda fid: fid not in lc_features, (1, 2))): - logger.exception(f"Unsupported 'lc_features' format in '{__file__}/{anomaly_score.__name__}'") - - data_r, data_g = ( - pd.DataFrame.from_dict({k: [v] for k, v in lc_features[i].items()})[MODEL_COLUMNS] - for i in (1, 2) - ) - for data, means in ((data_r, r_means), (data_g, g_means)): - for col in data.columns[data.isna().any()]: - data[col].fillna(means[col], inplace=True) - return model.anomaly_score(data_r, data_g)[0].item() +model = TwoBandModel(forest_g, forest_r) +model_AAD = TwoBandModel(forest_g_AAD, forest_r_AAD) + + +def anomaly_score(lc_features, model_type='AADForest'): + @udf(returnType=DoubleType()) + def anomaly_score(lc_features) -> float: + """ Returns anomaly score for an observation + + Parameters + ---------- + lc_features: Spark Map + Dict of dicts of floats. Keys of first dict - filters (fid), keys of inner dicts - names of features. + + Returns + ---------- + out: float + Anomaly score + + Examples + --------- + >>> from fink_utils.spark.utils import concat_col + >>> from pyspark.sql import functions as F + >>> from fink_science.ad_features.processor import extract_features_ad + + >>> df = spark.read.load(ztf_alert_sample) + + # Required alert columns, concatenated with historical data + >>> what = ['magpsf', 'jd', 'sigmapsf', 'fid', 'distnr', 'magnr', 'sigmagnr', 'isdiffpos'] + >>> prefix = 'c' + >>> what_prefix = [prefix + i for i in what] + >>> for colname in what: + ... df = concat_col(df, colname, prefix=prefix) + + >>> cols = ['cmagpsf', 'cjd', 'csigmapsf', 'cfid', 'objectId', 'cdistnr', 'cmagnr', 'csigmagnr', 'cisdiffpos'] + >>> df = df.withColumn('lc_features', extract_features_ad(*cols)) + >>> df = df.withColumn("anomaly_score", anomaly_score("lc_features")) + + >>> df.filter(df["anomaly_score"] < -0.5).count() + 7 + + >>> df.filter(df["anomaly_score"] == 0).count() + 84 + + """ + + if ( + lc_features is None + or len(lc_features) != 2 # noqa: W503 (https://www.flake8rules.com/rules/W503.html, https://www.flake8rules.com/rules/W504.html) + or any(map( # noqa: W503 + lambda fs: (fs is None or len(fs) == 0), + lc_features.values() + )) + ): + return 0.0 + if any(map(lambda fid: fid not in lc_features, (1, 2))): + logger.exception(f"Unsupported 'lc_features' format in '{__file__}/{anomaly_score.__name__}'") + + data_r, data_g = ( + pd.DataFrame.from_dict({k: [v] for k, v in lc_features[i].items()})[MODEL_COLUMNS] + for i in (1, 2) + ) + for data, means in ((data_r, r_means), (data_g, g_means)): + for col in data.columns[data.isna().any()]: + data[col].fillna(means[col], inplace=True) + if model_type == 'AADForest': + return model_AAD.anomaly_score(data_r, data_g)[0].item() + return model.anomaly_score(data_r, data_g)[0].item() + return anomaly_score(lc_features) if __name__ == "__main__": diff --git a/fink_science/data/models/anomaly_detection/anomaly_detection_forest.zip b/fink_science/data/models/anomaly_detection/anomaly_detection_forest.zip index af2ec97e..fa6f562a 100644 Binary files a/fink_science/data/models/anomaly_detection/anomaly_detection_forest.zip and b/fink_science/data/models/anomaly_detection/anomaly_detection_forest.zip differ diff --git a/fink_science/data/models/anomaly_detection/g_means.csv b/fink_science/data/models/anomaly_detection/g_means.csv index ff6223c4..d3902ecc 100644 --- a/fink_science/data/models/anomaly_detection/g_means.csv +++ b/fink_science/data/models/anomaly_detection/g_means.csv @@ -1,18 +1,19 @@ +,0 amplitude,0.21948567063093466 anderson_darling_normal,0.5738338537540897 -beyond_1_std,0.2975957082645236 +beyond_1_std,0.29759570826452364 chi2,37.704414101216535 cusum,0.3407494402523421 -kurtosis,0.29150516293433526 -linear_fit_slope,0.0008679869820436599 -linear_fit_slope_sigma,0.002219643796867781 -linear_trend_noise,0.1395519301423175 -linear_trend_sigma,0.006530832951936918 +kurtosis,0.2915051629343353 +linear_fit_slope,0.0008679869820436663 +linear_fit_slope_sigma,0.002219643796867831 +linear_trend_noise,0.13955193014231754 +linear_trend_sigma,0.006530832951936966 magnitude_percentage_ratio_20_10,0.7096091060287604 -magnitude_percentage_ratio_40_5,0.21345309681339728 +magnitude_percentage_ratio_40_5,0.2134530968133973 maximum_slope,20.197335141851806 median,16.09982754358941 -median_absolute_deviation,0.07874299095497293 +median_absolute_deviation,0.07874299095497296 median_buffer_range_percentage_10,0.226612517162576 skew,0.06777187504700471 stetson_K,0.853534164617388 diff --git a/fink_science/data/models/anomaly_detection/r_means.csv b/fink_science/data/models/anomaly_detection/r_means.csv index ec6f4ba7..4efa95b3 100644 --- a/fink_science/data/models/anomaly_detection/r_means.csv +++ b/fink_science/data/models/anomaly_detection/r_means.csv @@ -1,18 +1,19 @@ -amplitude,0.27761581012726816 +,0 +amplitude,0.2776158101272682 anderson_darling_normal,0.6324932481221309 -beyond_1_std,0.2984203432923638 +beyond_1_std,0.29842034329236383 chi2,47.98632140376363 cusum,0.3353034770643174 kurtosis,0.19552357158753803 -linear_fit_slope,0.0004677950689235035 -linear_fit_slope_sigma,0.004115135468353454 -linear_trend_noise,0.17699726764662013 -linear_trend_sigma,0.008118738252789526 +linear_fit_slope,0.00046779506892350695 +linear_fit_slope_sigma,0.004115135468353504 +linear_trend_noise,0.17699726764662016 +linear_trend_sigma,0.008118738252789576 magnitude_percentage_ratio_20_10,0.7110312348018119 -magnitude_percentage_ratio_40_5,0.2119754721716743 +magnitude_percentage_ratio_40_5,0.21197547217167437 maximum_slope,27.23340912206943 median,17.306199665913482 -median_absolute_deviation,0.09875090359331291 -median_buffer_range_percentage_10,0.22553694294088789 +median_absolute_deviation,0.09875090359331296 +median_buffer_range_percentage_10,0.22553694294088797 skew,0.022823198234573093 stetson_K,0.8580286805997064