diff --git a/feature_step/features/step.py b/feature_step/features/step.py index bfc73edc7..29345f785 100644 --- a/feature_step/features/step.py +++ b/feature_step/features/step.py @@ -43,7 +43,7 @@ def __init__( ): super().__init__(config=config, **step_args) - self.lightcurve_preprocessor = ZTFLightcurvePreprocessor() + self.lightcurve_preprocessor = ZTFLightcurvePreprocessor(drop_bogus=True) self.feature_extractor = ZTFFeatureExtractor() scribe_class = get_class(self.config["SCRIBE_PRODUCER_CONFIG"]["CLASS"]) diff --git a/feature_step/features/utils/parsers.py b/feature_step/features/utils/parsers.py index 824ab0840..f16ad3c4e 100644 --- a/feature_step/features/utils/parsers.py +++ b/feature_step/features/utils/parsers.py @@ -6,6 +6,29 @@ from typing import List, Dict, Optional +def get_bogus_flags_for_each_detection(detections: List[Dict]): + # for each detection, it looks for the real-bogus score (available only for + # detections) and procstatus flag (available only for forced + # photometry epochs) + + keys = ["rb", "procstatus"] + + bogus_flags = [] + for detection in detections: + value = [] + for key in keys: + if key in detection["extra_fields"].keys(): + value.append(detection["extra_fields"][key]) + else: + value.append(None) + bogus_flags.append(value) + + bogus_flags = pd.DataFrame(bogus_flags, columns=keys) + bogus_flags["procstatus"] = bogus_flags["procstatus"].astype(str) + + return bogus_flags + + def get_reference_for_each_detection(detections: List[Dict]): # for each detection, it looks what is the reference id # and how far away it is @@ -84,6 +107,11 @@ def detections_to_astro_object( ) a = pd.concat([a, reference_for_each_detection], axis=1) + bogus_flags_for_each_detection: pd.DataFrame = get_bogus_flags_for_each_detection( + detections + ) + a = pd.concat([a, bogus_flags_for_each_detection], axis=1) + a = a[(a["mag"] != 100) | (a["e_mag"] != 100)].copy() a.rename( columns={"mag_corr": "brightness", "e_mag_corr_ext": "e_brightness"}, diff --git a/feature_step/tests/unittest/test_step.py b/feature_step/tests/unittest/test_step.py index eec20dc7a..8c402b7dc 100644 --- a/feature_step/tests/unittest/test_step.py +++ b/feature_step/tests/unittest/test_step.py @@ -119,6 +119,17 @@ def test_period_consistency(self): assert multiband_period_scribe == multiband_period_message + def test_drop_bogus(self): + messages = [ + spm_messages[2] + ] # after drop bogus detections, it ends with 0 detections + result_messages = self.step.execute(messages) + + nan_features = [ + (x, y) for (x, y) in result_messages[0]["features"].items() if np.isnan(y) + ] + self.assertEqual(len(nan_features), len(result_messages[0]["features"])) + @mock.patch("features.step.FeatureStep._get_sql_references") def test_read_empty_reference_from_db(self, _get_sql_ref): _get_sql_ref.return_value = [] @@ -134,10 +145,10 @@ def test_read_empty_reference_from_db(self, _get_sql_ref): def test_references_are_in_messages_only(self, _get_sql_ref): _get_sql_ref.return_value = None - messages = [spm_messages[2]] + messages = [spm_messages[3]] result_messages = self.step.execute(messages) - self.assertEqual(2, len(result_messages[0]["reference"])) + self.assertEqual(4, len(result_messages[0]["reference"])) @mock.patch("features.step.FeatureStep._get_sql_references") def test_references_some_are_in_db(self, _get_sql_ref): diff --git a/lc_classification_step/lc_classification/core/parsers/input_dto.py b/lc_classification_step/lc_classification/core/parsers/input_dto.py index f4eddaa1f..4ffc525b7 100644 --- a/lc_classification_step/lc_classification/core/parsers/input_dto.py +++ b/lc_classification_step/lc_classification/core/parsers/input_dto.py @@ -52,6 +52,10 @@ def create_detections_dto(messages: List[dict]) -> pd.DataFrame: pd.DataFrame.from_records(msg["detections"]) for msg in messages ] detections = pd.concat(detections) + + if len(detections) == 0: + return pd.DataFrame() + detections.drop_duplicates(["candid", "oid"], inplace=True) detections = detections.set_index("oid") detections["extra_fields"] = parse_extra_fields(detections) diff --git a/lc_classification_step/models_settings.py b/lc_classification_step/models_settings.py index c6720acda..92e96bd65 100644 --- a/lc_classification_step/models_settings.py +++ b/lc_classification_step/models_settings.py @@ -72,8 +72,8 @@ def mbappe_params(model_class: str): return { "PARAMS": { "model_path": os.getenv("MODEL_PATH"), - "features_quantiles_path": os.getenv("FEATURE_QUANTILES_PATH"), - "metadata_quantiles_path": os.getenv("METADATA_QUANTILES_PATH"), + "quantiles_dir": os.getenv("QUANTILES_PATH"), + "config_dir": os.getenv("CONFIG_PATH"), "mapper": os.getenv("MAPPER_CLASS"), }, "CLASS": model_class, diff --git a/lc_classification_step/tests/integration/conftest.py b/lc_classification_step/tests/integration/conftest.py index ee80d65dc..98b46b875 100644 --- a/lc_classification_step/tests/integration/conftest.py +++ b/lc_classification_step/tests/integration/conftest.py @@ -375,6 +375,44 @@ def func( return func +def add_fields_to_message( + message, + topic, + bands, + force_empty_features, + force_missing_features, + n_forced: int, +): + message = message.copy() + + for i, det in enumerate(message["detections"]): + det["oid"] = message["oid"] + det["candid"] = str(random.randint(0, 100000)) + det["extra_fields"] = generate_extra_fields() + det["fid"] = random.choice(bands) + if i < n_forced: + det["forced"] = True + det["extra_fields"]["procstatus"] = random.choice( + ["0", "55", "56", "57", "61"] + ) + else: + det["forced"] = False + det["extra_fields"]["rb"] = random.uniform(0, 1) + + message["detections"][0]["new"] = True + message["detections"][0]["has_stamp"] = True + if topic == "features_ztf": + message["features"] = features_ztf( + force_empty_features, force_missing_features + ) + elif topic == "features_elasticc": + message["features"] = features_elasticc( + force_empty_features, force_missing_features + ) + + return message + + def _produce_messages( topic, SCHEMA, @@ -396,26 +434,14 @@ def _produce_messages( producer.set_key_field("oid") for message in messages: - for i, det in enumerate(message["detections"]): - det["oid"] = message["oid"] - det["candid"] = str(random.randint(0, 100000)) - det["extra_fields"] = generate_extra_fields() - det["fid"] = random.choice(BANDS) - if i < n_forced: - det["forced"] = True - else: - det["forced"] = False - - message["detections"][0]["new"] = True - message["detections"][0]["has_stamp"] = True - if topic == "features_ztf": - message["features"] = features_ztf( - force_empty_features, force_missing_features - ) - elif topic == "features_elasticc": - message["features"] = features_elasticc( - force_empty_features, force_missing_features - ) + message = add_fields_to_message( + message, + topic, + BANDS, + force_empty_features, + force_missing_features, + n_forced, + ) producer.produce(message) diff --git a/lc_classification_step/tests/integration/test_step_mbappe.py b/lc_classification_step/tests/integration/test_step_mbappe.py index 8a9d4349c..a18d51d0d 100644 --- a/lc_classification_step/tests/integration/test_step_mbappe.py +++ b/lc_classification_step/tests/integration/test_step_mbappe.py @@ -27,12 +27,8 @@ def test_step_mbappe_result( "alerce_classifiers.mbappe.model.MbappeClassifier", { "MODEL_PATH": os.getenv("TEST_MBAPPE_MODEL_PATH"), - "FEATURE_QUANTILES_PATH": os.getenv( - "TEST_MBAPPE_FEATURES_QUANTILES_PATH" - ), - "METADATA_QUANTILES_PATH": os.getenv( - "TEST_MBAPPE_METADATA_QUANTILES_PATH" - ), + "QUANTILES_PATH": os.getenv("TEST_MBAPPE_QUANTILES_PATH"), + "CONFIG_PATH": os.getenv("TEST_MBAPPE_CONFIG_PATH"), "MAPPER_CLASS": "alerce_classifiers.mbappe.mapper.MbappeMapper", }, ) @@ -69,12 +65,8 @@ def test_step_mbappe_no_features_result( "alerce_classifiers.mbappe.model.MbappeClassifier", { "MODEL_PATH": os.getenv("TEST_MBAPPE_MODEL_PATH"), - "FEATURE_QUANTILES_PATH": os.getenv( - "TEST_MBAPPE_FEATURES_QUANTILES_PATH" - ), - "METADATA_QUANTILES_PATH": os.getenv( - "TEST_MBAPPE_METADATA_QUANTILES_PATH" - ), + "QUANTILES_PATH": os.getenv("TEST_MBAPPE_QUANTILES_PATH"), + "CONFIG_PATH": os.getenv("TEST_MBAPPE_CONFIG_PATH"), "MAPPER_CLASS": "alerce_classifiers.mbappe.mapper.MbappeMapper", }, ) @@ -111,12 +103,8 @@ def test_step_mbappe_min_detections( "alerce_classifiers.mbappe.model.MbappeClassifier", { "MODEL_PATH": os.getenv("TEST_MBAPPE_MODEL_PATH"), - "FEATURE_QUANTILES_PATH": os.getenv( - "TEST_MBAPPE_FEATURES_QUANTILES_PATH" - ), - "METADATA_QUANTILES_PATH": os.getenv( - "TEST_MBAPPE_METADATA_QUANTILES_PATH" - ), + "QUANTILES_PATH": os.getenv("TEST_MBAPPE_QUANTILES_PATH"), + "CONFIG_PATH": os.getenv("TEST_MBAPPE_CONFIG_PATH"), "MAPPER_CLASS": "alerce_classifiers.mbappe.mapper.MbappeMapper", "MIN_DETECTIONS": "6", }, @@ -138,3 +126,51 @@ def test_step_mbappe_min_detections( for message in kconsumer.consume(): assert_ztf_object_is_correct(message) kconsumer.commit() + + +@pytest.mark.ztf +def test_step_mbappe_no_detections( + env_variables_mbappe, +): + env_variables_mbappe( + "mbappe", + "alerce_classifiers.mbappe.model.MbappeClassifier", + { + "MODEL_PATH": os.getenv("TEST_MBAPPE_MODEL_PATH"), + "QUANTILES_PATH": os.getenv("TEST_MBAPPE_QUANTILES_PATH"), + "CONFIG_PATH": os.getenv("TEST_MBAPPE_CONFIG_PATH"), + "MAPPER_CLASS": "alerce_classifiers.mbappe.mapper.MbappeMapper", + }, + ) + + from settings import config + from .conftest import INPUT_SCHEMA_PATH, add_fields_to_message + from fastavro.utils import generate_many + from fastavro.schema import load_schema + import random + + random.seed(42) + schema = load_schema(str(INPUT_SCHEMA_PATH)) + messages = generate_many(schema, 2) + messages = list(messages) + + for message in messages: + message = add_fields_to_message( + message, + "features_mbappe", + ["g", "r"], + False, + False, + 5, + ) + for det in message["detections"]: + if not det["forced"]: + det["extra_fields"]["rb"] = 0.1 + + step = LateClassifier(config=config()) + + output, result_messages, features = step.execute(messages) + probabilities = output.probabilities + + assert len(probabilities) == 0 + assert len(features) == 0 diff --git a/lc_classification_step/tests/unit/test_step_anomaly.py b/lc_classification_step/tests/unit/legacy_test_step_anomaly.py similarity index 100% rename from lc_classification_step/tests/unit/test_step_anomaly.py rename to lc_classification_step/tests/unit/legacy_test_step_anomaly.py diff --git a/lc_classifier/lc_classifier/features/core/base.py b/lc_classifier/lc_classifier/features/core/base.py index c5b7b975a..cf6284258 100644 --- a/lc_classifier/lc_classifier/features/core/base.py +++ b/lc_classifier/lc_classifier/features/core/base.py @@ -104,6 +104,45 @@ def preprocess_batch(self, astro_objects: List[AstroObject], progress_bar=False) self.preprocess_single_object(astro_object) +def discard_bogus_detections(detections: List[Dict]) -> list[dict]: + RB_THRESHOLD = 0.55 + + filtered_detections = [] + + for det in detections: + bogus = False + + if "extra_fields" in det.keys(): + rb = ( + det["extra_fields"]["rb"] + if "rb" in det["extra_fields"].keys() + else None + ) + procstatus = ( + det["extra_fields"]["procstatus"] + if "procstatus" in det["extra_fields"].keys() + else None + ) + else: + rb = det["rb"] if "rb" in det.keys() else None + procstatus = det["procstatus"] if "procstatus" in det.keys() else None + + mask_rb = rb is not None and not det["forced"] and (rb < RB_THRESHOLD) + mask_procstatus = ( + procstatus is not None + and det["forced"] + and (procstatus != "0") + and (procstatus != "57") + ) + if mask_rb or mask_procstatus: + bogus = True + + if not bogus: + filtered_detections.append(det) + + return filtered_detections + + def empty_normal_dataframe() -> pd.DataFrame: df = pd.DataFrame(columns=["name", "value", "fid", "sid", "version"]) return df diff --git a/lc_classifier/lc_classifier/features/extractors/panstarrs_feature_extractor.py b/lc_classifier/lc_classifier/features/extractors/panstarrs_feature_extractor.py index 565622338..eaf0e404a 100644 --- a/lc_classifier/lc_classifier/features/extractors/panstarrs_feature_extractor.py +++ b/lc_classifier/lc_classifier/features/extractors/panstarrs_feature_extractor.py @@ -34,7 +34,7 @@ def compute_features_single_object(self, astro_object: AstroObject): sg_score = metadata[metadata["name"] == "sgscore1"]["value"].values[0] dist_nr = metadata[metadata["name"] == "distpsnr1"]["value"].values[0] - if sg_score < 0 or dist_nr < 0: + if sg_score < 0 or dist_nr < 0 or len(astro_object.detections) == 0: sg_score = np.nan dist_nr = np.nan diff --git a/lc_classifier/lc_classifier/features/extractors/sn_extractor.py b/lc_classifier/lc_classifier/features/extractors/sn_extractor.py index dcf0b0943..610b1728a 100644 --- a/lc_classifier/lc_classifier/features/extractors/sn_extractor.py +++ b/lc_classifier/lc_classifier/features/extractors/sn_extractor.py @@ -22,8 +22,12 @@ def compute_features_single_object(self, astro_object: AstroObject): detections = detections[detections["unit"] == self.unit] detections = detections.sort_values("mjd") forced_photometry = astro_object.forced_photometry - first_detection = detections.iloc[0] - first_detection_mjd = first_detection["mjd"] + + if len(detections) > 0: + first_detection = detections.iloc[0] + first_detection_mjd = first_detection["mjd"] + else: + first_detection_mjd = np.nan features = [] for band in self.bands: diff --git a/lc_classifier/lc_classifier/features/extractors/spm_extractor.py b/lc_classifier/lc_classifier/features/extractors/spm_extractor.py index a03cc66e9..485d41e71 100644 --- a/lc_classifier/lc_classifier/features/extractors/spm_extractor.py +++ b/lc_classifier/lc_classifier/features/extractors/spm_extractor.py @@ -126,6 +126,17 @@ def compute_features_single_object(self, astro_object: AstroObject): observations = self.get_observations(astro_object) + if len(observations) == 0: + parameters = [] + chis = [] + for band in self.bands: + parameters.append([np.nan] * 6) + chis.append(np.nan) + + self.parameters = np.concatenate(parameters, axis=0) + self.chis = np.array(chis) + return + times = observations["mjd"].values flux = observations["brightness"].values e_flux = observations["e_brightness"].values diff --git a/lc_classifier/lc_classifier/features/preprocess/ztf.py b/lc_classifier/lc_classifier/features/preprocess/ztf.py index 7cc3a86ad..75d2809b5 100644 --- a/lc_classifier/lc_classifier/features/preprocess/ztf.py +++ b/lc_classifier/lc_classifier/features/preprocess/ztf.py @@ -7,15 +7,22 @@ from astropy import units as u from ..core.base import LightcurvePreprocessor, AstroObject +from ..core.base import discard_bogus_detections class ZTFLightcurvePreprocessor(LightcurvePreprocessor): + def __init__(self, drop_bogus: bool = False): + self.drop_bogus = drop_bogus + def preprocess_single_object(self, astro_object: AstroObject): self._helio_time_correction(astro_object) self.drop_absurd_detections(astro_object) # TODO: does error need a np.maximum(error, 0.01) ? # the factor depends on the units + if self.drop_bogus: + self.drop_bogus_detections(astro_object) + def _helio_time_correction(self, astro_object: AstroObject): detections = astro_object.detections ra_deg, dec_deg = detections[["ra", "dec"]].mean().values @@ -69,6 +76,26 @@ def drop_absurd(table): astro_object.detections = drop_absurd(astro_object.detections) astro_object.forced_photometry = drop_absurd(astro_object.forced_photometry) + def drop_bogus_detections(self, astro_object: AstroObject): + def drop_bogus_dets(table): + keys = table.keys() + table = table.to_dict("records") + table = discard_bogus_detections(table) + table = pd.DataFrame.from_records(table) + if len(table) == 0: + table = pd.DataFrame(columns=keys) + return table + + astro_object.detections = drop_bogus_dets(astro_object.detections) + if len(astro_object.detections) == 0: + astro_object.forced_photometry = pd.DataFrame( + columns=astro_object.forced_photometry.keys() + ) + else: + astro_object.forced_photometry = drop_bogus_dets( + astro_object.forced_photometry + ) + class ShortenPreprocessor(LightcurvePreprocessor): def __init__(self, n_days: float):