From 5010cd631c3cfd581ee226aea0726ab726811781 Mon Sep 17 00:00:00 2001 From: JulienPeloton Date: Mon, 18 Nov 2024 13:47:53 +0100 Subject: [PATCH] New version for AL loop --- bin/active_learning_loop.py | 116 +++--------------------------------- 1 file changed, 8 insertions(+), 108 deletions(-) diff --git a/bin/active_learning_loop.py b/bin/active_learning_loop.py index 7cb1fb57..4572f858 100644 --- a/bin/active_learning_loop.py +++ b/bin/active_learning_loop.py @@ -21,9 +21,7 @@ import numpy as np from pyspark.sql import functions as F -from pyspark.sql.window import Window -from pyspark.sql.functions import lit, collect_list - +from fink_utils.spark.utils import extend_lc_with_upper_limits from fink_utils.spark.utils import concat_col from fink_utils.xmatch.simbad import return_list_of_eg_host @@ -105,111 +103,13 @@ def main(): for colname in to_expand: df = concat_col(df, colname, prefix=prefix) - # Use last limiting magnitude for feature extraction - # Explode each column with `posexplode_outer` to include null values, generating the same `pos` index - df_cfid = df.select("objectId", F.posexplode_outer("cfid").alias("pos", "cfid_exp")) - df_cmagpsf = df.select( - "objectId", F.posexplode_outer("cmagpsf").alias("pos", "cmagpsf_exp") - ) - df_cjd = df.select("objectId", F.posexplode_outer("cjd").alias("pos", "cjd_exp")) - df_csigmapsf = df.select( - "objectId", F.posexplode_outer("csigmapsf").alias("pos", "csigmapsf_exp") - ) - df_cdiffmaglim = df.select( - "objectId", F.posexplode_outer("cdiffmaglim").alias("pos", "cdiffmaglim_exp") - ) - - # GET DETECTIONS - # Join all exploded columns on `objectId` and `pos` to keep alignment - df_dets_exploded = ( - df_cfid.join(df_cmagpsf, ["objectId", "pos"], "outer") - .join(df_cjd, ["objectId", "pos"], "outer") - .join(df_csigmapsf, ["objectId", "pos"], "outer") - .drop("pos") # Drop `pos` if you don't need it in the final result - ) - # Get only valid detections - df_dets_exploded = df_dets_exploded.dropDuplicates() - df_dets_exploded = df_dets_exploded.filter(df_dets_exploded.cmagpsf_exp.isNotNull()) - # get minimum time and its corresponding magnitude per filter - window_spec = Window.partitionBy("objectId", "cfid_exp").orderBy("cjd_exp") - df_detection_min_cjd = ( - df_dets_exploded.withColumn("min_cjd_exp", F.first("cjd_exp").over(window_spec)) - .withColumn( - "corresponding_cmagpsf_exp", F.first("cmagpsf_exp").over(window_spec) - ) - .groupBy("objectId", "cfid_exp") - .agg( - F.min("min_cjd_exp").alias("min_cjd_exp"), - F.first("corresponding_cmagpsf_exp").alias("corresponding_cmagpsf_exp"), - ) - ) - - # GET LIMITS - df_lims_exploded = ( - df_cfid.join(df_cdiffmaglim, ["objectId", "pos"], "outer") - .join(df_cjd, ["objectId", "pos"], "outer") - .drop("pos") # Drop `pos` if you don't need it in the final result - ) - df_lims_exploded = df_lims_exploded.dropDuplicates() - # Filter to find last limit fainter than detection - df_filtered = df_lims_exploded.join( - df_detection_min_cjd, on=["objectId", "cfid_exp"], how="inner" - ) - df_filtered = df_filtered.filter( - (df_filtered.cjd_exp < df_filtered.min_cjd_exp) - & (df_filtered.cdiffmaglim_exp > df_filtered.corresponding_cmagpsf_exp) - ) - df_filtered = df_filtered.select( - "objectId", "cfid_exp", "cjd_exp", "cdiffmaglim_exp" - ) - # Define a window specification partitioned by 'objectId' and 'cfid_exp', ordered by 'cjd_exp' descending - window_spec = Window.partitionBy("objectId", "cfid_exp").orderBy(F.desc("cjd_exp")) - # Add a row number to each row within the partition - df_filtered_top_cjd = df_filtered.withColumn( - "row_num", F.row_number().over(window_spec) - ) - df_filtered_top_cjd = df_filtered_top_cjd.filter(F.col("row_num") == 1).drop( - "row_num" - ) - # Refactor column names - df_limits = df_filtered_top_cjd.withColumnRenamed("cdiffmaglim_exp", "cmagpsf_exp") - # add column for error, putting a large error 0.2 - df_limits = df_limits.withColumn("csigmapsf_exp", lit(0.2)) - - # Append detections and limits - df_appended = df_dets_exploded.unionByName(df_limits) - df_aggregated = df_appended.groupBy("objectId").agg( - collect_list("cfid_exp").alias("cfid_aggregated"), - collect_list("cjd_exp").alias("cjd_aggregated"), - collect_list("cmagpsf_exp").alias("cmagpsf_aggregated"), - collect_list("csigmapsf_exp").alias("csigmapsf_aggregated"), - ) - # Join the aggregated result back to the original DataFrame on "objectId" - df_joined = df.join(df_aggregated, on="objectId", how="left") - - df = ( - df_joined.select( - "*", # Select all original columns - "cfid_aggregated", - "cjd_aggregated", - "cmagpsf_aggregated", - "csigmapsf_aggregated", # Add the new aggregated columns - ) - .drop( - "cfid_exp", - "cjd_exp", - "cmagpsf_exp", - "csigmapsf_exp", # Drop the original exploded columns - "cfid", - "cjd", - "cmagpsf", - "csigmapsf", # Drop the original columns that will be replaced - ) - .withColumnRenamed("cfid_aggregated", "cfid") - .withColumnRenamed("cjd_aggregated", "cjd") - .withColumnRenamed("cmagpsf_aggregated", "cmagpsf") - .withColumnRenamed("csigmapsf_aggregated", "csigmapsf") + # Add the last upper limits per band if it exists + df = df.withColumn( + "ext", + extend_lc_with_upper_limits("cmagpsf", "csigmapsf", "cfid", "cdiffmaglim"), ) + df = df.withColumn("cmagpsf_ext", df["ext"].getItem("cmagpsf_ext")) + df = df.withColumn("csigmapsf_ext", df["ext"].getItem("csigmapsf_ext")) # Add classification cols = [ @@ -234,7 +134,7 @@ def main(): model = curdir + "/data/models/for_al_loop/model_20240821.pkl" # Run SN classification using AL model - rfscore_args = ["cjd", "cfid", "cmagpsf", "csigmapsf"] + rfscore_args = ["cjd", "cfid", "cmagpsf_ext", "csigmapsf_ext"] rfscore_args += [F.col("cdsxmatch"), F.col("candidate.ndethist")] rfscore_args += [F.lit(1), F.lit(3), F.lit("diff")] rfscore_args += [F.lit(model)]