diff --git a/bin/active_learning_loop.py b/bin/active_learning_loop.py index fbf199e7..7cb1fb57 100644 --- a/bin/active_learning_loop.py +++ b/bin/active_learning_loop.py @@ -21,6 +21,9 @@ import numpy as np from pyspark.sql import functions as F +from pyspark.sql.window import Window +from pyspark.sql.functions import lit, collect_list + from fink_utils.spark.utils import concat_col from fink_utils.xmatch.simbad import return_list_of_eg_host @@ -95,13 +98,119 @@ def main(): df = load_parquet_files(path) # Retrieve time-series information - to_expand = ["jd", "fid", "magpsf", "sigmapsf"] + to_expand = ["jd", "fid", "magpsf", "sigmapsf", "diffmaglim", "sigmapsf"] # Append temp columns with historical + current measurements prefix = "c" for colname in to_expand: df = concat_col(df, colname, prefix=prefix) + # Use last limiting magnitude for feature extraction + # Explode each column with `posexplode_outer` to include null values, generating the same `pos` index + df_cfid = df.select("objectId", F.posexplode_outer("cfid").alias("pos", "cfid_exp")) + df_cmagpsf = df.select( + "objectId", F.posexplode_outer("cmagpsf").alias("pos", "cmagpsf_exp") + ) + df_cjd = df.select("objectId", F.posexplode_outer("cjd").alias("pos", "cjd_exp")) + df_csigmapsf = df.select( + "objectId", F.posexplode_outer("csigmapsf").alias("pos", "csigmapsf_exp") + ) + df_cdiffmaglim = df.select( + "objectId", F.posexplode_outer("cdiffmaglim").alias("pos", "cdiffmaglim_exp") + ) + + # GET DETECTIONS + # Join all exploded columns on `objectId` and `pos` to keep alignment + df_dets_exploded = ( + df_cfid.join(df_cmagpsf, ["objectId", "pos"], "outer") + .join(df_cjd, ["objectId", "pos"], "outer") + .join(df_csigmapsf, ["objectId", "pos"], "outer") + .drop("pos") # Drop `pos` if you don't need it in the final result + ) + # Get only valid detections + df_dets_exploded = df_dets_exploded.dropDuplicates() + df_dets_exploded = df_dets_exploded.filter(df_dets_exploded.cmagpsf_exp.isNotNull()) + # get minimum time and its corresponding magnitude per filter + window_spec = Window.partitionBy("objectId", "cfid_exp").orderBy("cjd_exp") + df_detection_min_cjd = ( + df_dets_exploded.withColumn("min_cjd_exp", F.first("cjd_exp").over(window_spec)) + .withColumn( + "corresponding_cmagpsf_exp", F.first("cmagpsf_exp").over(window_spec) + ) + .groupBy("objectId", "cfid_exp") + .agg( + F.min("min_cjd_exp").alias("min_cjd_exp"), + F.first("corresponding_cmagpsf_exp").alias("corresponding_cmagpsf_exp"), + ) + ) + + # GET LIMITS + df_lims_exploded = ( + df_cfid.join(df_cdiffmaglim, ["objectId", "pos"], "outer") + .join(df_cjd, ["objectId", "pos"], "outer") + .drop("pos") # Drop `pos` if you don't need it in the final result + ) + df_lims_exploded = df_lims_exploded.dropDuplicates() + # Filter to find last limit fainter than detection + df_filtered = df_lims_exploded.join( + df_detection_min_cjd, on=["objectId", "cfid_exp"], how="inner" + ) + df_filtered = df_filtered.filter( + (df_filtered.cjd_exp < df_filtered.min_cjd_exp) + & (df_filtered.cdiffmaglim_exp > df_filtered.corresponding_cmagpsf_exp) + ) + df_filtered = df_filtered.select( + "objectId", "cfid_exp", "cjd_exp", "cdiffmaglim_exp" + ) + # Define a window specification partitioned by 'objectId' and 'cfid_exp', ordered by 'cjd_exp' descending + window_spec = Window.partitionBy("objectId", "cfid_exp").orderBy(F.desc("cjd_exp")) + # Add a row number to each row within the partition + df_filtered_top_cjd = df_filtered.withColumn( + "row_num", F.row_number().over(window_spec) + ) + df_filtered_top_cjd = df_filtered_top_cjd.filter(F.col("row_num") == 1).drop( + "row_num" + ) + # Refactor column names + df_limits = df_filtered_top_cjd.withColumnRenamed("cdiffmaglim_exp", "cmagpsf_exp") + # add column for error, putting a large error 0.2 + df_limits = df_limits.withColumn("csigmapsf_exp", lit(0.2)) + + # Append detections and limits + df_appended = df_dets_exploded.unionByName(df_limits) + df_aggregated = df_appended.groupBy("objectId").agg( + collect_list("cfid_exp").alias("cfid_aggregated"), + collect_list("cjd_exp").alias("cjd_aggregated"), + collect_list("cmagpsf_exp").alias("cmagpsf_aggregated"), + collect_list("csigmapsf_exp").alias("csigmapsf_aggregated"), + ) + # Join the aggregated result back to the original DataFrame on "objectId" + df_joined = df.join(df_aggregated, on="objectId", how="left") + + df = ( + df_joined.select( + "*", # Select all original columns + "cfid_aggregated", + "cjd_aggregated", + "cmagpsf_aggregated", + "csigmapsf_aggregated", # Add the new aggregated columns + ) + .drop( + "cfid_exp", + "cjd_exp", + "cmagpsf_exp", + "csigmapsf_exp", # Drop the original exploded columns + "cfid", + "cjd", + "cmagpsf", + "csigmapsf", # Drop the original columns that will be replaced + ) + .withColumnRenamed("cfid_aggregated", "cfid") + .withColumnRenamed("cjd_aggregated", "cjd") + .withColumnRenamed("cmagpsf_aggregated", "cmagpsf") + .withColumnRenamed("csigmapsf_aggregated", "csigmapsf") + ) + # Add classification cols = [ "cdsxmatch",