Skip to content

Commit

Permalink
Merge pull request #519 from alercebroker/add_rb_procstatus_cuts
Browse files Browse the repository at this point in the history
Added rb and procstatus cuts in pipeline
  • Loading branch information
ale-munozarancibia authored Jan 10, 2025
2 parents 43cb5a4 + 3919061 commit 0fe95b2
Show file tree
Hide file tree
Showing 13 changed files with 232 additions and 46 deletions.
2 changes: 1 addition & 1 deletion feature_step/features/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
):

super().__init__(config=config, **step_args)
self.lightcurve_preprocessor = ZTFLightcurvePreprocessor()
self.lightcurve_preprocessor = ZTFLightcurvePreprocessor(drop_bogus=True)
self.feature_extractor = ZTFFeatureExtractor()

scribe_class = get_class(self.config["SCRIBE_PRODUCER_CONFIG"]["CLASS"])
Expand Down
28 changes: 28 additions & 0 deletions feature_step/features/utils/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,29 @@
from typing import List, Dict, Optional


def get_bogus_flags_for_each_detection(detections: List[Dict]):
# for each detection, it looks for the real-bogus score (available only for
# detections) and procstatus flag (available only for forced
# photometry epochs)

keys = ["rb", "procstatus"]

bogus_flags = []
for detection in detections:
value = []
for key in keys:
if key in detection["extra_fields"].keys():
value.append(detection["extra_fields"][key])
else:
value.append(None)
bogus_flags.append(value)

bogus_flags = pd.DataFrame(bogus_flags, columns=keys)
bogus_flags["procstatus"] = bogus_flags["procstatus"].astype(str)

return bogus_flags


def get_reference_for_each_detection(detections: List[Dict]):
# for each detection, it looks what is the reference id
# and how far away it is
Expand Down Expand Up @@ -84,6 +107,11 @@ def detections_to_astro_object(
)
a = pd.concat([a, reference_for_each_detection], axis=1)

bogus_flags_for_each_detection: pd.DataFrame = get_bogus_flags_for_each_detection(
detections
)
a = pd.concat([a, bogus_flags_for_each_detection], axis=1)

a = a[(a["mag"] != 100) | (a["e_mag"] != 100)].copy()
a.rename(
columns={"mag_corr": "brightness", "e_mag_corr_ext": "e_brightness"},
Expand Down
15 changes: 13 additions & 2 deletions feature_step/tests/unittest/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ def test_period_consistency(self):

assert multiband_period_scribe == multiband_period_message

def test_drop_bogus(self):
messages = [
spm_messages[2]
] # after drop bogus detections, it ends with 0 detections
result_messages = self.step.execute(messages)

nan_features = [
(x, y) for (x, y) in result_messages[0]["features"].items() if np.isnan(y)
]
self.assertEqual(len(nan_features), len(result_messages[0]["features"]))

@mock.patch("features.step.FeatureStep._get_sql_references")
def test_read_empty_reference_from_db(self, _get_sql_ref):
_get_sql_ref.return_value = []
Expand All @@ -134,10 +145,10 @@ def test_read_empty_reference_from_db(self, _get_sql_ref):
def test_references_are_in_messages_only(self, _get_sql_ref):
_get_sql_ref.return_value = None

messages = [spm_messages[2]]
messages = [spm_messages[3]]
result_messages = self.step.execute(messages)

self.assertEqual(2, len(result_messages[0]["reference"]))
self.assertEqual(4, len(result_messages[0]["reference"]))

@mock.patch("features.step.FeatureStep._get_sql_references")
def test_references_some_are_in_db(self, _get_sql_ref):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def create_detections_dto(messages: List[dict]) -> pd.DataFrame:
pd.DataFrame.from_records(msg["detections"]) for msg in messages
]
detections = pd.concat(detections)

if len(detections) == 0:
return pd.DataFrame()

detections.drop_duplicates(["candid", "oid"], inplace=True)
detections = detections.set_index("oid")
detections["extra_fields"] = parse_extra_fields(detections)
Expand Down
4 changes: 2 additions & 2 deletions lc_classification_step/models_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def mbappe_params(model_class: str):
return {
"PARAMS": {
"model_path": os.getenv("MODEL_PATH"),
"features_quantiles_path": os.getenv("FEATURE_QUANTILES_PATH"),
"metadata_quantiles_path": os.getenv("METADATA_QUANTILES_PATH"),
"quantiles_dir": os.getenv("QUANTILES_PATH"),
"config_dir": os.getenv("CONFIG_PATH"),
"mapper": os.getenv("MAPPER_CLASS"),
},
"CLASS": model_class,
Expand Down
66 changes: 46 additions & 20 deletions lc_classification_step/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,44 @@ def func(
return func


def add_fields_to_message(
message,
topic,
bands,
force_empty_features,
force_missing_features,
n_forced: int,
):
message = message.copy()

for i, det in enumerate(message["detections"]):
det["oid"] = message["oid"]
det["candid"] = str(random.randint(0, 100000))
det["extra_fields"] = generate_extra_fields()
det["fid"] = random.choice(bands)
if i < n_forced:
det["forced"] = True
det["extra_fields"]["procstatus"] = random.choice(
["0", "55", "56", "57", "61"]
)
else:
det["forced"] = False
det["extra_fields"]["rb"] = random.uniform(0, 1)

message["detections"][0]["new"] = True
message["detections"][0]["has_stamp"] = True
if topic == "features_ztf":
message["features"] = features_ztf(
force_empty_features, force_missing_features
)
elif topic == "features_elasticc":
message["features"] = features_elasticc(
force_empty_features, force_missing_features
)

return message


def _produce_messages(
topic,
SCHEMA,
Expand All @@ -396,26 +434,14 @@ def _produce_messages(
producer.set_key_field("oid")

for message in messages:
for i, det in enumerate(message["detections"]):
det["oid"] = message["oid"]
det["candid"] = str(random.randint(0, 100000))
det["extra_fields"] = generate_extra_fields()
det["fid"] = random.choice(BANDS)
if i < n_forced:
det["forced"] = True
else:
det["forced"] = False

message["detections"][0]["new"] = True
message["detections"][0]["has_stamp"] = True
if topic == "features_ztf":
message["features"] = features_ztf(
force_empty_features, force_missing_features
)
elif topic == "features_elasticc":
message["features"] = features_elasticc(
force_empty_features, force_missing_features
)
message = add_fields_to_message(
message,
topic,
BANDS,
force_empty_features,
force_missing_features,
n_forced,
)
producer.produce(message)


Expand Down
72 changes: 54 additions & 18 deletions lc_classification_step/tests/integration/test_step_mbappe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@ def test_step_mbappe_result(
"alerce_classifiers.mbappe.model.MbappeClassifier",
{
"MODEL_PATH": os.getenv("TEST_MBAPPE_MODEL_PATH"),
"FEATURE_QUANTILES_PATH": os.getenv(
"TEST_MBAPPE_FEATURES_QUANTILES_PATH"
),
"METADATA_QUANTILES_PATH": os.getenv(
"TEST_MBAPPE_METADATA_QUANTILES_PATH"
),
"QUANTILES_PATH": os.getenv("TEST_MBAPPE_QUANTILES_PATH"),
"CONFIG_PATH": os.getenv("TEST_MBAPPE_CONFIG_PATH"),
"MAPPER_CLASS": "alerce_classifiers.mbappe.mapper.MbappeMapper",
},
)
Expand Down Expand Up @@ -69,12 +65,8 @@ def test_step_mbappe_no_features_result(
"alerce_classifiers.mbappe.model.MbappeClassifier",
{
"MODEL_PATH": os.getenv("TEST_MBAPPE_MODEL_PATH"),
"FEATURE_QUANTILES_PATH": os.getenv(
"TEST_MBAPPE_FEATURES_QUANTILES_PATH"
),
"METADATA_QUANTILES_PATH": os.getenv(
"TEST_MBAPPE_METADATA_QUANTILES_PATH"
),
"QUANTILES_PATH": os.getenv("TEST_MBAPPE_QUANTILES_PATH"),
"CONFIG_PATH": os.getenv("TEST_MBAPPE_CONFIG_PATH"),
"MAPPER_CLASS": "alerce_classifiers.mbappe.mapper.MbappeMapper",
},
)
Expand Down Expand Up @@ -111,12 +103,8 @@ def test_step_mbappe_min_detections(
"alerce_classifiers.mbappe.model.MbappeClassifier",
{
"MODEL_PATH": os.getenv("TEST_MBAPPE_MODEL_PATH"),
"FEATURE_QUANTILES_PATH": os.getenv(
"TEST_MBAPPE_FEATURES_QUANTILES_PATH"
),
"METADATA_QUANTILES_PATH": os.getenv(
"TEST_MBAPPE_METADATA_QUANTILES_PATH"
),
"QUANTILES_PATH": os.getenv("TEST_MBAPPE_QUANTILES_PATH"),
"CONFIG_PATH": os.getenv("TEST_MBAPPE_CONFIG_PATH"),
"MAPPER_CLASS": "alerce_classifiers.mbappe.mapper.MbappeMapper",
"MIN_DETECTIONS": "6",
},
Expand All @@ -138,3 +126,51 @@ def test_step_mbappe_min_detections(
for message in kconsumer.consume():
assert_ztf_object_is_correct(message)
kconsumer.commit()


@pytest.mark.ztf
def test_step_mbappe_no_detections(
env_variables_mbappe,
):
env_variables_mbappe(
"mbappe",
"alerce_classifiers.mbappe.model.MbappeClassifier",
{
"MODEL_PATH": os.getenv("TEST_MBAPPE_MODEL_PATH"),
"QUANTILES_PATH": os.getenv("TEST_MBAPPE_QUANTILES_PATH"),
"CONFIG_PATH": os.getenv("TEST_MBAPPE_CONFIG_PATH"),
"MAPPER_CLASS": "alerce_classifiers.mbappe.mapper.MbappeMapper",
},
)

from settings import config
from .conftest import INPUT_SCHEMA_PATH, add_fields_to_message
from fastavro.utils import generate_many
from fastavro.schema import load_schema
import random

random.seed(42)
schema = load_schema(str(INPUT_SCHEMA_PATH))
messages = generate_many(schema, 2)
messages = list(messages)

for message in messages:
message = add_fields_to_message(
message,
"features_mbappe",
["g", "r"],
False,
False,
5,
)
for det in message["detections"]:
if not det["forced"]:
det["extra_fields"]["rb"] = 0.1

step = LateClassifier(config=config())

output, result_messages, features = step.execute(messages)
probabilities = output.probabilities

assert len(probabilities) == 0
assert len(features) == 0
39 changes: 39 additions & 0 deletions lc_classifier/lc_classifier/features/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,45 @@ def preprocess_batch(self, astro_objects: List[AstroObject], progress_bar=False)
self.preprocess_single_object(astro_object)


def discard_bogus_detections(detections: List[Dict]) -> list[dict]:
RB_THRESHOLD = 0.55

filtered_detections = []

for det in detections:
bogus = False

if "extra_fields" in det.keys():
rb = (
det["extra_fields"]["rb"]
if "rb" in det["extra_fields"].keys()
else None
)
procstatus = (
det["extra_fields"]["procstatus"]
if "procstatus" in det["extra_fields"].keys()
else None
)
else:
rb = det["rb"] if "rb" in det.keys() else None
procstatus = det["procstatus"] if "procstatus" in det.keys() else None

mask_rb = rb is not None and not det["forced"] and (rb < RB_THRESHOLD)
mask_procstatus = (
procstatus is not None
and det["forced"]
and (procstatus != "0")
and (procstatus != "57")
)
if mask_rb or mask_procstatus:
bogus = True

if not bogus:
filtered_detections.append(det)

return filtered_detections


def empty_normal_dataframe() -> pd.DataFrame:
df = pd.DataFrame(columns=["name", "value", "fid", "sid", "version"])
return df
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def compute_features_single_object(self, astro_object: AstroObject):
sg_score = metadata[metadata["name"] == "sgscore1"]["value"].values[0]
dist_nr = metadata[metadata["name"] == "distpsnr1"]["value"].values[0]

if sg_score < 0 or dist_nr < 0:
if sg_score < 0 or dist_nr < 0 or len(astro_object.detections) == 0:
sg_score = np.nan
dist_nr = np.nan

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ def compute_features_single_object(self, astro_object: AstroObject):
detections = detections[detections["unit"] == self.unit]
detections = detections.sort_values("mjd")
forced_photometry = astro_object.forced_photometry
first_detection = detections.iloc[0]
first_detection_mjd = first_detection["mjd"]

if len(detections) > 0:
first_detection = detections.iloc[0]
first_detection_mjd = first_detection["mjd"]
else:
first_detection_mjd = np.nan

features = []
for band in self.bands:
Expand Down
11 changes: 11 additions & 0 deletions lc_classifier/lc_classifier/features/extractors/spm_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,17 @@ def compute_features_single_object(self, astro_object: AstroObject):

observations = self.get_observations(astro_object)

if len(observations) == 0:
parameters = []
chis = []
for band in self.bands:
parameters.append([np.nan] * 6)
chis.append(np.nan)

self.parameters = np.concatenate(parameters, axis=0)
self.chis = np.array(chis)
return

times = observations["mjd"].values
flux = observations["brightness"].values
e_flux = observations["e_brightness"].values
Expand Down
Loading

0 comments on commit 0fe95b2

Please sign in to comment.