Skip to content

Commit

Permalink
New version for AL loop
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienPeloton committed Nov 18, 2024
1 parent 842d842 commit 5010cd6
Showing 1 changed file with 8 additions and 108 deletions.
116 changes: 8 additions & 108 deletions bin/active_learning_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import numpy as np

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.functions import lit, collect_list

from fink_utils.spark.utils import extend_lc_with_upper_limits

from fink_utils.spark.utils import concat_col
from fink_utils.xmatch.simbad import return_list_of_eg_host
Expand Down Expand Up @@ -105,111 +103,13 @@ def main():
for colname in to_expand:
df = concat_col(df, colname, prefix=prefix)

# Use last limiting magnitude for feature extraction
# Explode each column with `posexplode_outer` to include null values, generating the same `pos` index
df_cfid = df.select("objectId", F.posexplode_outer("cfid").alias("pos", "cfid_exp"))
df_cmagpsf = df.select(
"objectId", F.posexplode_outer("cmagpsf").alias("pos", "cmagpsf_exp")
)
df_cjd = df.select("objectId", F.posexplode_outer("cjd").alias("pos", "cjd_exp"))
df_csigmapsf = df.select(
"objectId", F.posexplode_outer("csigmapsf").alias("pos", "csigmapsf_exp")
)
df_cdiffmaglim = df.select(
"objectId", F.posexplode_outer("cdiffmaglim").alias("pos", "cdiffmaglim_exp")
)

# GET DETECTIONS
# Join all exploded columns on `objectId` and `pos` to keep alignment
df_dets_exploded = (
df_cfid.join(df_cmagpsf, ["objectId", "pos"], "outer")
.join(df_cjd, ["objectId", "pos"], "outer")
.join(df_csigmapsf, ["objectId", "pos"], "outer")
.drop("pos") # Drop `pos` if you don't need it in the final result
)
# Get only valid detections
df_dets_exploded = df_dets_exploded.dropDuplicates()
df_dets_exploded = df_dets_exploded.filter(df_dets_exploded.cmagpsf_exp.isNotNull())
# get minimum time and its corresponding magnitude per filter
window_spec = Window.partitionBy("objectId", "cfid_exp").orderBy("cjd_exp")
df_detection_min_cjd = (
df_dets_exploded.withColumn("min_cjd_exp", F.first("cjd_exp").over(window_spec))
.withColumn(
"corresponding_cmagpsf_exp", F.first("cmagpsf_exp").over(window_spec)
)
.groupBy("objectId", "cfid_exp")
.agg(
F.min("min_cjd_exp").alias("min_cjd_exp"),
F.first("corresponding_cmagpsf_exp").alias("corresponding_cmagpsf_exp"),
)
)

# GET LIMITS
df_lims_exploded = (
df_cfid.join(df_cdiffmaglim, ["objectId", "pos"], "outer")
.join(df_cjd, ["objectId", "pos"], "outer")
.drop("pos") # Drop `pos` if you don't need it in the final result
)
df_lims_exploded = df_lims_exploded.dropDuplicates()
# Filter to find last limit fainter than detection
df_filtered = df_lims_exploded.join(
df_detection_min_cjd, on=["objectId", "cfid_exp"], how="inner"
)
df_filtered = df_filtered.filter(
(df_filtered.cjd_exp < df_filtered.min_cjd_exp)
& (df_filtered.cdiffmaglim_exp > df_filtered.corresponding_cmagpsf_exp)
)
df_filtered = df_filtered.select(
"objectId", "cfid_exp", "cjd_exp", "cdiffmaglim_exp"
)
# Define a window specification partitioned by 'objectId' and 'cfid_exp', ordered by 'cjd_exp' descending
window_spec = Window.partitionBy("objectId", "cfid_exp").orderBy(F.desc("cjd_exp"))
# Add a row number to each row within the partition
df_filtered_top_cjd = df_filtered.withColumn(
"row_num", F.row_number().over(window_spec)
)
df_filtered_top_cjd = df_filtered_top_cjd.filter(F.col("row_num") == 1).drop(
"row_num"
)
# Refactor column names
df_limits = df_filtered_top_cjd.withColumnRenamed("cdiffmaglim_exp", "cmagpsf_exp")
# add column for error, putting a large error 0.2
df_limits = df_limits.withColumn("csigmapsf_exp", lit(0.2))

# Append detections and limits
df_appended = df_dets_exploded.unionByName(df_limits)
df_aggregated = df_appended.groupBy("objectId").agg(
collect_list("cfid_exp").alias("cfid_aggregated"),
collect_list("cjd_exp").alias("cjd_aggregated"),
collect_list("cmagpsf_exp").alias("cmagpsf_aggregated"),
collect_list("csigmapsf_exp").alias("csigmapsf_aggregated"),
)
# Join the aggregated result back to the original DataFrame on "objectId"
df_joined = df.join(df_aggregated, on="objectId", how="left")

df = (
df_joined.select(
"*", # Select all original columns
"cfid_aggregated",
"cjd_aggregated",
"cmagpsf_aggregated",
"csigmapsf_aggregated", # Add the new aggregated columns
)
.drop(
"cfid_exp",
"cjd_exp",
"cmagpsf_exp",
"csigmapsf_exp", # Drop the original exploded columns
"cfid",
"cjd",
"cmagpsf",
"csigmapsf", # Drop the original columns that will be replaced
)
.withColumnRenamed("cfid_aggregated", "cfid")
.withColumnRenamed("cjd_aggregated", "cjd")
.withColumnRenamed("cmagpsf_aggregated", "cmagpsf")
.withColumnRenamed("csigmapsf_aggregated", "csigmapsf")
# Add the last upper limits per band if it exists
df = df.withColumn(
"ext",
extend_lc_with_upper_limits("cmagpsf", "csigmapsf", "cfid", "cdiffmaglim"),
)
df = df.withColumn("cmagpsf_ext", df["ext"].getItem("cmagpsf_ext"))
df = df.withColumn("csigmapsf_ext", df["ext"].getItem("csigmapsf_ext"))

# Add classification
cols = [
Expand All @@ -234,7 +134,7 @@ def main():
model = curdir + "/data/models/for_al_loop/model_20240821.pkl"

# Run SN classification using AL model
rfscore_args = ["cjd", "cfid", "cmagpsf", "csigmapsf"]
rfscore_args = ["cjd", "cfid", "cmagpsf_ext", "csigmapsf_ext"]
rfscore_args += [F.col("cdsxmatch"), F.col("candidate.ndethist")]
rfscore_args += [F.lit(1), F.lit(3), F.lit("diff")]
rfscore_args += [F.lit(model)]
Expand Down

0 comments on commit 5010cd6

Please sign in to comment.