Skip to content
This repository has been archived by the owner on May 10, 2024. It is now read-only.

Commit

Permalink
elasticc metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
ignacioreyes committed Sep 2, 2022
1 parent ba6ddfc commit 20500d4
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 17 deletions.
191 changes: 180 additions & 11 deletions lc_classifier/classifier/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pickle5 as pickle
import wget
from imblearn.ensemble import BalancedRandomForestClassifier as RandomForestClassifier
from imblearn.under_sampling import RandomUnderSampler
from lc_classifier.classifier.preprocessing import FeaturePreprocessor
from lc_classifier.classifier.preprocessing import intersect_oids_in_dataframes
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -390,7 +391,8 @@ def predict_in_pipeline(self, input_features: pd.DataFrame) -> dict:
class ElasticcRandomForest(HierarchicalRandomForest):
def __init__(self,
taxonomy_dictionary,
non_used_features=None,
feature_list_dict: dict,
sampling_strategy=None,
n_trees=500,
n_jobs=1,
verbose: bool = False):
Expand All @@ -406,6 +408,13 @@ def __init__(self,
min_samples_leaf = 1
min_samples_split = 2

if sampling_strategy is None:
sampling_strategy = {}
sampling_strategy['Top'] = 'auto'
sampling_strategy['Stochastic'] = 'auto'
sampling_strategy['Periodic'] = 'auto'
sampling_strategy['Transient'] = 'auto'

# imblearn uses a weight mask, not slicing
max_samples = None # 10000
bootstrap = False
Expand All @@ -416,6 +425,7 @@ def __init__(self,
max_features=max_features,
min_samples_leaf=min_samples_leaf,
min_samples_split=min_samples_split,
sampling_strategy=sampling_strategy['Top'],
bootstrap=bootstrap,
max_samples=max_samples,
n_jobs=n_jobs,
Expand All @@ -428,6 +438,7 @@ def __init__(self,
max_features=max_features,
min_samples_leaf=min_samples_leaf,
min_samples_split=min_samples_split,
sampling_strategy=sampling_strategy['Stochastic'],
bootstrap=bootstrap,
max_samples=max_samples,
n_jobs=n_jobs,
Expand All @@ -440,37 +451,195 @@ def __init__(self,
max_features=max_features,
min_samples_leaf=min_samples_leaf,
min_samples_split=min_samples_split,
sampling_strategy=sampling_strategy['Periodic'],
bootstrap=bootstrap,
max_samples=max_samples,
n_jobs=n_jobs,
verbose=verbose_number
)

self.transient_classifier = RandomForestClassifier(
n_estimators=n_trees,
n_estimators=300, # n_trees,
max_depth=max_depth,
max_features=max_features,
min_samples_leaf=min_samples_leaf,
min_samples_split=min_samples_split,
min_impurity_decrease=0.0, # 0.00004,
sampling_strategy=sampling_strategy['Transient'],
bootstrap=bootstrap,
max_samples=max_samples,
n_jobs=n_jobs,
verbose=verbose_number
)

self.feature_preprocessor = FeaturePreprocessor(
non_used_features=non_used_features
)
self.feature_preprocessor = FeaturePreprocessor()

self.feature_list = None
# each tree has its own features
self.feature_list_dict = feature_list_dict

self.taxonomy_dictionary = taxonomy_dictionary

self.inverted_dictionary = invert_dictionary(self.taxonomy_dictionary)
self.pickles = {
"features_list": "features_RF_model.pkl",
"top_rf": "top_level_BRF_model.pkl",
"periodic_rf": "periodic_level_BRF_model.pkl",
"stochastic_rf": "stochastic_level_BRF_model.pkl",
"transient_rf": "transient_level_BRF_model.pkl",
"feature_list_dict": "features_HRF_model.pkl",
"top_rf": "top_level_HRF_model.pkl",
"periodic_rf": "periodic_level_HRF_model.pkl",
"stochastic_rf": "stochastic_level_HRF_model.pkl",
"transient_rf": "transient_level_HRF_model.pkl",
}

def check_missing_features(self, available_features):
available_features = set(available_features)
required_features = set(np.concatenate(
list(self.feature_list_dict.values())))

missing_features = required_features - available_features
if len(missing_features) != 0:
raise ValueError(f'missing required features: {missing_features}')

def fit(self, samples: pd.DataFrame, labels: pd.DataFrame) -> None:
labels = labels.copy()

# Check that the received labels are in the taxonomy
feeded_labels = labels.classALeRCE.unique()
expected_labels = self.inverted_dictionary.keys()

for label in feeded_labels:
if label not in expected_labels:
raise Exception(f"{label} is not in the taxonomy dictionary")

# Create top class
labels["top_class"] = labels["classALeRCE"].map(self.inverted_dictionary)

# Preprocessing
samples = self.feature_preprocessor.preprocess_features(samples)
samples = self.feature_preprocessor.remove_duplicates(samples)
samples, labels = intersect_oids_in_dataframes(samples, labels)

self.check_missing_features(samples.columns)

# Train top classifier
if self.verbose:
print("training top classifier")

rus = RandomUnderSampler()
selected_top_snids, _ = rus.fit_resample(
samples.index.values.reshape(-1, 1),
labels['classALeRCE'])
selected_top_snids = selected_top_snids.flatten()

self.top_classifier.fit(
samples.loc[selected_top_snids][self.feature_list_dict['top']].values,
labels.loc[selected_top_snids]["top_class"].values)

# Train specialized classifiers
if self.verbose:
print("training stochastic classifier")
is_stochastic = labels["top_class"] == "Stochastic"
self.stochastic_classifier.fit(
samples[is_stochastic][self.feature_list_dict['stochastic']].values,
labels[is_stochastic]["classALeRCE"].values)

if self.verbose:
print("training periodic classifier")
is_periodic = labels["top_class"] == "Periodic"
self.periodic_classifier.fit(
samples[is_periodic][self.feature_list_dict['periodic']].values,
labels[is_periodic]["classALeRCE"].values)

if self.verbose:
print("training transient classifier")
is_transient = labels["top_class"] == "Transient"
self.transient_classifier.fit(
samples[is_transient][self.feature_list_dict['transient']].values,
labels[is_transient]["classALeRCE"].values)

def save_model(self, directory: str) -> None:
Path(directory).mkdir(parents=True, exist_ok=True)
with open(os.path.join(directory, self.pickles["top_rf"]), "wb") as f:
pickle.dump(self.top_classifier, f, pickle.HIGHEST_PROTOCOL)

with open(os.path.join(directory, self.pickles["stochastic_rf"]), "wb") as f:
pickle.dump(self.stochastic_classifier, f, pickle.HIGHEST_PROTOCOL)

with open(os.path.join(directory, self.pickles["periodic_rf"]), "wb") as f:
pickle.dump(self.periodic_classifier, f, pickle.HIGHEST_PROTOCOL)

with open(os.path.join(directory, self.pickles["transient_rf"]), "wb") as f:
pickle.dump(self.transient_classifier, f, pickle.HIGHEST_PROTOCOL)

with open(os.path.join(directory, self.pickles["feature_list_dict"]), "wb") as f:
pickle.dump(self.feature_list_dict, f, pickle.HIGHEST_PROTOCOL)

def load_model(self, directory: str) -> None:
with open(os.path.join(directory, self.pickles["top_rf"]), 'rb') as f:
self.top_classifier = pickle.load(f)

with open(os.path.join(directory, self.pickles["stochastic_rf"]), 'rb') as f:
self.stochastic_classifier = pickle.load(f)

with open(os.path.join(directory, self.pickles["periodic_rf"]), 'rb') as f:
self.periodic_classifier = pickle.load(f)

with open(os.path.join(directory, self.pickles["transient_rf"]), 'rb') as f:
self.transient_classifier = pickle.load(f)

with open(os.path.join(directory, self.pickles["feature_list_dict"]), 'rb') as f:
self.feature_list_dict = pickle.load(f)

self.check_loaded_models()

def check_loaded_models(self):
assert set(self.top_classifier.classes_.tolist()) \
== set(self.taxonomy_dictionary.keys())

assert set(self.stochastic_classifier.classes_.tolist()) \
== set(self.taxonomy_dictionary["Stochastic"])

assert set(self.transient_classifier.classes_.tolist()) \
== set(self.taxonomy_dictionary["Transient"])

assert set(self.periodic_classifier.classes_.tolist()) \
== set(self.taxonomy_dictionary["Periodic"])

def predict_proba(self, samples: pd.DataFrame) -> pd.DataFrame:
self.check_missing_features(samples.columns)

samples = self.feature_preprocessor.preprocess_features(samples)

top_probs = self.top_classifier.predict_proba(
samples[self.feature_list_dict['top']].values)

stochastic_probs = self.stochastic_classifier.predict_proba(
samples[self.feature_list_dict['stochastic']].values)
periodic_probs = self.periodic_classifier.predict_proba(
samples[self.feature_list_dict['periodic']].values)
transient_probs = self.transient_classifier.predict_proba(
samples[self.feature_list_dict['transient']].values)

stochastic_index = self.top_classifier.classes_.tolist().index("Stochastic")
periodic_index = self.top_classifier.classes_.tolist().index("Periodic")
transient_index = self.top_classifier.classes_.tolist().index("Transient")

stochastic_probs = stochastic_probs * top_probs[:, stochastic_index].reshape(
[-1, 1]
)
periodic_probs = periodic_probs * top_probs[:, periodic_index].reshape([-1, 1])
transient_probs = transient_probs * top_probs[:, transient_index].reshape(
[-1, 1]
)

# This line must have the same order as in get_list_of_classes()
final_probs = np.concatenate(
[stochastic_probs, periodic_probs, transient_probs],
axis=1
)

df = pd.DataFrame(
data=final_probs,
index=samples.index,
columns=self.get_list_of_classes()
)

df.index.name = samples.index.name
return df
5 changes: 3 additions & 2 deletions lc_classifier/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .extractors.gp_drw_extractor import GPDRWExtractor
from .extractors.sn_features_phase_ii import SNFeaturesPhaseIIExtractor
from .extractors.sn_parametric_model_computer import SPMExtractorPhaseII

from .extractors.elasticc_metadata_extractor import ElasticcMetadataExtractor

from .custom.ztf_feature_extractor import ZTFFeatureExtractor, ZTFForcedPhotometryFeatureExtractor
from .custom.elasticc_feature_extractor import ElasticcFeatureExtractor
Expand Down Expand Up @@ -58,5 +58,6 @@
'FeatureExtractorComposer',
'SNFeaturesPhaseIIExtractor',
'SPMExtractorPhaseII',
'ElasticcFeatureExtractor'
'ElasticcFeatureExtractor',
'ElasticcMetadataExtractor'
]
7 changes: 3 additions & 4 deletions lc_classifier/features/custom/elasticc_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lc_classifier.features import GPDRWExtractor
from lc_classifier.features import SNFeaturesPhaseIIExtractor
from lc_classifier.features.extractors.sn_parametric_model_computer import SPMExtractorElasticc
from lc_classifier.features import ElasticcMetadataExtractor

from ..core.base import FeatureExtractor
from ..core.base import FeatureExtractorComposer
Expand All @@ -23,9 +24,6 @@ class ElasticcFeatureExtractor(FeatureExtractor):
def __init__(self):
self.bands = ['u', 'g', 'r', 'i', 'z', 'Y']

# input: metadata
# self.gal_extractor = GalacticCoordinatesExtractor(from_metadata=True)

magnitude_extractors = [
# input: apparent magnitude
IQRExtractor(self.bands),
Expand All @@ -39,6 +37,7 @@ def __init__(self):

flux_extractors = [
# input: difference flux
ElasticcMetadataExtractor(),
ElasticcColorFeatureExtractor(self.bands),
MHPSFluxExtractor(self.bands),
SNFeaturesPhaseIIExtractor(self.bands),
Expand Down Expand Up @@ -95,7 +94,7 @@ def _compute_features(self, detections, **kwargs):
-------
"""
required = [] # ['metadata']
required = ['metadata']
for key in required:
if key not in kwargs:
raise Exception(f"{key} argument is missing")
Expand Down
40 changes: 40 additions & 0 deletions lc_classifier/features/extractors/elasticc_metadata_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pandas as pd
from typing import Tuple
from functools import lru_cache
from ..core.base import FeatureExtractor


class ElasticcMetadataExtractor(FeatureExtractor):
@lru_cache(1)
def get_features_keys(self) -> Tuple[str, ...]:
return 'redshift_helio', 'mwebv'

@lru_cache(1)
def get_required_keys(self) -> Tuple[str, ...]:
# not very useful
return tuple()

def _compute_features(self, detections, **kwargs):
return self._compute_features_from_df_groupby(
detections.groupby(level=0),
**kwargs)

def _compute_features_from_df_groupby(
self, detections, **kwargs) -> pd.DataFrame:
columns = self.get_features_keys()
metadata = kwargs['metadata']

def aux_function(oid_detections, **kwargs):
oid = oid_detections.index.values[0]
metadata_lightcurve = metadata.loc[oid]
redshift_helio = metadata_lightcurve['REDSHIFT_HELIO']
mwebv = metadata_lightcurve['MWEBV']

out = pd.Series(
data=[redshift_helio, mwebv],
index=columns)
return out

features = detections.apply(aux_function)
features.index.name = 'oid'
return features

0 comments on commit 20500d4

Please sign in to comment.