diff --git a/lc_classifier/classifier/models.py b/lc_classifier/classifier/models.py index c6b65e1..f2d54fc 100644 --- a/lc_classifier/classifier/models.py +++ b/lc_classifier/classifier/models.py @@ -5,6 +5,7 @@ import pickle5 as pickle import wget from imblearn.ensemble import BalancedRandomForestClassifier as RandomForestClassifier +from imblearn.under_sampling import RandomUnderSampler from lc_classifier.classifier.preprocessing import FeaturePreprocessor from lc_classifier.classifier.preprocessing import intersect_oids_in_dataframes from abc import ABC, abstractmethod @@ -390,7 +391,8 @@ def predict_in_pipeline(self, input_features: pd.DataFrame) -> dict: class ElasticcRandomForest(HierarchicalRandomForest): def __init__(self, taxonomy_dictionary, - non_used_features=None, + feature_list_dict: dict, + sampling_strategy=None, n_trees=500, n_jobs=1, verbose: bool = False): @@ -406,6 +408,13 @@ def __init__(self, min_samples_leaf = 1 min_samples_split = 2 + if sampling_strategy is None: + sampling_strategy = {} + sampling_strategy['Top'] = 'auto' + sampling_strategy['Stochastic'] = 'auto' + sampling_strategy['Periodic'] = 'auto' + sampling_strategy['Transient'] = 'auto' + # imblearn uses a weight mask, not slicing max_samples = None # 10000 bootstrap = False @@ -416,6 +425,7 @@ def __init__(self, max_features=max_features, min_samples_leaf=min_samples_leaf, min_samples_split=min_samples_split, + sampling_strategy=sampling_strategy['Top'], bootstrap=bootstrap, max_samples=max_samples, n_jobs=n_jobs, @@ -428,6 +438,7 @@ def __init__(self, max_features=max_features, min_samples_leaf=min_samples_leaf, min_samples_split=min_samples_split, + sampling_strategy=sampling_strategy['Stochastic'], bootstrap=bootstrap, max_samples=max_samples, n_jobs=n_jobs, @@ -440,6 +451,7 @@ def __init__(self, max_features=max_features, min_samples_leaf=min_samples_leaf, min_samples_split=min_samples_split, + sampling_strategy=sampling_strategy['Periodic'], bootstrap=bootstrap, max_samples=max_samples, n_jobs=n_jobs, @@ -447,30 +459,187 @@ def __init__(self, ) self.transient_classifier = RandomForestClassifier( - n_estimators=n_trees, + n_estimators=300, # n_trees, max_depth=max_depth, max_features=max_features, min_samples_leaf=min_samples_leaf, min_samples_split=min_samples_split, + min_impurity_decrease=0.0, # 0.00004, + sampling_strategy=sampling_strategy['Transient'], bootstrap=bootstrap, max_samples=max_samples, n_jobs=n_jobs, verbose=verbose_number ) - self.feature_preprocessor = FeaturePreprocessor( - non_used_features=non_used_features - ) + self.feature_preprocessor = FeaturePreprocessor() - self.feature_list = None + # each tree has its own features + self.feature_list_dict = feature_list_dict self.taxonomy_dictionary = taxonomy_dictionary self.inverted_dictionary = invert_dictionary(self.taxonomy_dictionary) self.pickles = { - "features_list": "features_RF_model.pkl", - "top_rf": "top_level_BRF_model.pkl", - "periodic_rf": "periodic_level_BRF_model.pkl", - "stochastic_rf": "stochastic_level_BRF_model.pkl", - "transient_rf": "transient_level_BRF_model.pkl", + "feature_list_dict": "features_HRF_model.pkl", + "top_rf": "top_level_HRF_model.pkl", + "periodic_rf": "periodic_level_HRF_model.pkl", + "stochastic_rf": "stochastic_level_HRF_model.pkl", + "transient_rf": "transient_level_HRF_model.pkl", } + + def check_missing_features(self, available_features): + available_features = set(available_features) + required_features = set(np.concatenate( + list(self.feature_list_dict.values()))) + + missing_features = required_features - available_features + if len(missing_features) != 0: + raise ValueError(f'missing required features: {missing_features}') + + def fit(self, samples: pd.DataFrame, labels: pd.DataFrame) -> None: + labels = labels.copy() + + # Check that the received labels are in the taxonomy + feeded_labels = labels.classALeRCE.unique() + expected_labels = self.inverted_dictionary.keys() + + for label in feeded_labels: + if label not in expected_labels: + raise Exception(f"{label} is not in the taxonomy dictionary") + + # Create top class + labels["top_class"] = labels["classALeRCE"].map(self.inverted_dictionary) + + # Preprocessing + samples = self.feature_preprocessor.preprocess_features(samples) + samples = self.feature_preprocessor.remove_duplicates(samples) + samples, labels = intersect_oids_in_dataframes(samples, labels) + + self.check_missing_features(samples.columns) + + # Train top classifier + if self.verbose: + print("training top classifier") + + rus = RandomUnderSampler() + selected_top_snids, _ = rus.fit_resample( + samples.index.values.reshape(-1, 1), + labels['classALeRCE']) + selected_top_snids = selected_top_snids.flatten() + + self.top_classifier.fit( + samples.loc[selected_top_snids][self.feature_list_dict['top']].values, + labels.loc[selected_top_snids]["top_class"].values) + + # Train specialized classifiers + if self.verbose: + print("training stochastic classifier") + is_stochastic = labels["top_class"] == "Stochastic" + self.stochastic_classifier.fit( + samples[is_stochastic][self.feature_list_dict['stochastic']].values, + labels[is_stochastic]["classALeRCE"].values) + + if self.verbose: + print("training periodic classifier") + is_periodic = labels["top_class"] == "Periodic" + self.periodic_classifier.fit( + samples[is_periodic][self.feature_list_dict['periodic']].values, + labels[is_periodic]["classALeRCE"].values) + + if self.verbose: + print("training transient classifier") + is_transient = labels["top_class"] == "Transient" + self.transient_classifier.fit( + samples[is_transient][self.feature_list_dict['transient']].values, + labels[is_transient]["classALeRCE"].values) + + def save_model(self, directory: str) -> None: + Path(directory).mkdir(parents=True, exist_ok=True) + with open(os.path.join(directory, self.pickles["top_rf"]), "wb") as f: + pickle.dump(self.top_classifier, f, pickle.HIGHEST_PROTOCOL) + + with open(os.path.join(directory, self.pickles["stochastic_rf"]), "wb") as f: + pickle.dump(self.stochastic_classifier, f, pickle.HIGHEST_PROTOCOL) + + with open(os.path.join(directory, self.pickles["periodic_rf"]), "wb") as f: + pickle.dump(self.periodic_classifier, f, pickle.HIGHEST_PROTOCOL) + + with open(os.path.join(directory, self.pickles["transient_rf"]), "wb") as f: + pickle.dump(self.transient_classifier, f, pickle.HIGHEST_PROTOCOL) + + with open(os.path.join(directory, self.pickles["feature_list_dict"]), "wb") as f: + pickle.dump(self.feature_list_dict, f, pickle.HIGHEST_PROTOCOL) + + def load_model(self, directory: str) -> None: + with open(os.path.join(directory, self.pickles["top_rf"]), 'rb') as f: + self.top_classifier = pickle.load(f) + + with open(os.path.join(directory, self.pickles["stochastic_rf"]), 'rb') as f: + self.stochastic_classifier = pickle.load(f) + + with open(os.path.join(directory, self.pickles["periodic_rf"]), 'rb') as f: + self.periodic_classifier = pickle.load(f) + + with open(os.path.join(directory, self.pickles["transient_rf"]), 'rb') as f: + self.transient_classifier = pickle.load(f) + + with open(os.path.join(directory, self.pickles["feature_list_dict"]), 'rb') as f: + self.feature_list_dict = pickle.load(f) + + self.check_loaded_models() + + def check_loaded_models(self): + assert set(self.top_classifier.classes_.tolist()) \ + == set(self.taxonomy_dictionary.keys()) + + assert set(self.stochastic_classifier.classes_.tolist()) \ + == set(self.taxonomy_dictionary["Stochastic"]) + + assert set(self.transient_classifier.classes_.tolist()) \ + == set(self.taxonomy_dictionary["Transient"]) + + assert set(self.periodic_classifier.classes_.tolist()) \ + == set(self.taxonomy_dictionary["Periodic"]) + + def predict_proba(self, samples: pd.DataFrame) -> pd.DataFrame: + self.check_missing_features(samples.columns) + + samples = self.feature_preprocessor.preprocess_features(samples) + + top_probs = self.top_classifier.predict_proba( + samples[self.feature_list_dict['top']].values) + + stochastic_probs = self.stochastic_classifier.predict_proba( + samples[self.feature_list_dict['stochastic']].values) + periodic_probs = self.periodic_classifier.predict_proba( + samples[self.feature_list_dict['periodic']].values) + transient_probs = self.transient_classifier.predict_proba( + samples[self.feature_list_dict['transient']].values) + + stochastic_index = self.top_classifier.classes_.tolist().index("Stochastic") + periodic_index = self.top_classifier.classes_.tolist().index("Periodic") + transient_index = self.top_classifier.classes_.tolist().index("Transient") + + stochastic_probs = stochastic_probs * top_probs[:, stochastic_index].reshape( + [-1, 1] + ) + periodic_probs = periodic_probs * top_probs[:, periodic_index].reshape([-1, 1]) + transient_probs = transient_probs * top_probs[:, transient_index].reshape( + [-1, 1] + ) + + # This line must have the same order as in get_list_of_classes() + final_probs = np.concatenate( + [stochastic_probs, periodic_probs, transient_probs], + axis=1 + ) + + df = pd.DataFrame( + data=final_probs, + index=samples.index, + columns=self.get_list_of_classes() + ) + + df.index.name = samples.index.name + return df diff --git a/lc_classifier/features/__init__.py b/lc_classifier/features/__init__.py index a8f41ab..1c5f4bb 100644 --- a/lc_classifier/features/__init__.py +++ b/lc_classifier/features/__init__.py @@ -19,7 +19,7 @@ from .extractors.gp_drw_extractor import GPDRWExtractor from .extractors.sn_features_phase_ii import SNFeaturesPhaseIIExtractor from .extractors.sn_parametric_model_computer import SPMExtractorPhaseII - +from .extractors.elasticc_metadata_extractor import ElasticcMetadataExtractor from .custom.ztf_feature_extractor import ZTFFeatureExtractor, ZTFForcedPhotometryFeatureExtractor from .custom.elasticc_feature_extractor import ElasticcFeatureExtractor @@ -58,5 +58,6 @@ 'FeatureExtractorComposer', 'SNFeaturesPhaseIIExtractor', 'SPMExtractorPhaseII', - 'ElasticcFeatureExtractor' + 'ElasticcFeatureExtractor', + 'ElasticcMetadataExtractor' ] diff --git a/lc_classifier/features/custom/elasticc_feature_extractor.py b/lc_classifier/features/custom/elasticc_feature_extractor.py index ce7fc95..e82be83 100644 --- a/lc_classifier/features/custom/elasticc_feature_extractor.py +++ b/lc_classifier/features/custom/elasticc_feature_extractor.py @@ -11,6 +11,7 @@ from lc_classifier.features import GPDRWExtractor from lc_classifier.features import SNFeaturesPhaseIIExtractor from lc_classifier.features.extractors.sn_parametric_model_computer import SPMExtractorElasticc +from lc_classifier.features import ElasticcMetadataExtractor from ..core.base import FeatureExtractor from ..core.base import FeatureExtractorComposer @@ -23,9 +24,6 @@ class ElasticcFeatureExtractor(FeatureExtractor): def __init__(self): self.bands = ['u', 'g', 'r', 'i', 'z', 'Y'] - # input: metadata - # self.gal_extractor = GalacticCoordinatesExtractor(from_metadata=True) - magnitude_extractors = [ # input: apparent magnitude IQRExtractor(self.bands), @@ -39,6 +37,7 @@ def __init__(self): flux_extractors = [ # input: difference flux + ElasticcMetadataExtractor(), ElasticcColorFeatureExtractor(self.bands), MHPSFluxExtractor(self.bands), SNFeaturesPhaseIIExtractor(self.bands), @@ -95,7 +94,7 @@ def _compute_features(self, detections, **kwargs): ------- """ - required = [] # ['metadata'] + required = ['metadata'] for key in required: if key not in kwargs: raise Exception(f"{key} argument is missing") diff --git a/lc_classifier/features/extractors/elasticc_metadata_extractor.py b/lc_classifier/features/extractors/elasticc_metadata_extractor.py new file mode 100644 index 0000000..91d9932 --- /dev/null +++ b/lc_classifier/features/extractors/elasticc_metadata_extractor.py @@ -0,0 +1,40 @@ +import pandas as pd +from typing import Tuple +from functools import lru_cache +from ..core.base import FeatureExtractor + + +class ElasticcMetadataExtractor(FeatureExtractor): + @lru_cache(1) + def get_features_keys(self) -> Tuple[str, ...]: + return 'redshift_helio', 'mwebv' + + @lru_cache(1) + def get_required_keys(self) -> Tuple[str, ...]: + # not very useful + return tuple() + + def _compute_features(self, detections, **kwargs): + return self._compute_features_from_df_groupby( + detections.groupby(level=0), + **kwargs) + + def _compute_features_from_df_groupby( + self, detections, **kwargs) -> pd.DataFrame: + columns = self.get_features_keys() + metadata = kwargs['metadata'] + + def aux_function(oid_detections, **kwargs): + oid = oid_detections.index.values[0] + metadata_lightcurve = metadata.loc[oid] + redshift_helio = metadata_lightcurve['REDSHIFT_HELIO'] + mwebv = metadata_lightcurve['MWEBV'] + + out = pd.Series( + data=[redshift_helio, mwebv], + index=columns) + return out + + features = detections.apply(aux_function) + features.index.name = 'oid' + return features