Skip to content

Commit

Permalink
adding last limit per filter
Browse files Browse the repository at this point in the history
  • Loading branch information
anaismoller committed Nov 13, 2024
1 parent f42f439 commit 842d842
Showing 1 changed file with 110 additions and 1 deletion.
111 changes: 110 additions & 1 deletion bin/active_learning_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import numpy as np

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.functions import lit, collect_list


from fink_utils.spark.utils import concat_col
from fink_utils.xmatch.simbad import return_list_of_eg_host
Expand Down Expand Up @@ -95,13 +98,119 @@ def main():
df = load_parquet_files(path)

# Retrieve time-series information
to_expand = ["jd", "fid", "magpsf", "sigmapsf"]
to_expand = ["jd", "fid", "magpsf", "sigmapsf", "diffmaglim", "sigmapsf"]

# Append temp columns with historical + current measurements
prefix = "c"
for colname in to_expand:
df = concat_col(df, colname, prefix=prefix)

# Use last limiting magnitude for feature extraction
# Explode each column with `posexplode_outer` to include null values, generating the same `pos` index
df_cfid = df.select("objectId", F.posexplode_outer("cfid").alias("pos", "cfid_exp"))
df_cmagpsf = df.select(
"objectId", F.posexplode_outer("cmagpsf").alias("pos", "cmagpsf_exp")
)
df_cjd = df.select("objectId", F.posexplode_outer("cjd").alias("pos", "cjd_exp"))
df_csigmapsf = df.select(
"objectId", F.posexplode_outer("csigmapsf").alias("pos", "csigmapsf_exp")
)
df_cdiffmaglim = df.select(
"objectId", F.posexplode_outer("cdiffmaglim").alias("pos", "cdiffmaglim_exp")
)

# GET DETECTIONS
# Join all exploded columns on `objectId` and `pos` to keep alignment
df_dets_exploded = (
df_cfid.join(df_cmagpsf, ["objectId", "pos"], "outer")
.join(df_cjd, ["objectId", "pos"], "outer")
.join(df_csigmapsf, ["objectId", "pos"], "outer")
.drop("pos") # Drop `pos` if you don't need it in the final result
)
# Get only valid detections
df_dets_exploded = df_dets_exploded.dropDuplicates()
df_dets_exploded = df_dets_exploded.filter(df_dets_exploded.cmagpsf_exp.isNotNull())
# get minimum time and its corresponding magnitude per filter
window_spec = Window.partitionBy("objectId", "cfid_exp").orderBy("cjd_exp")
df_detection_min_cjd = (
df_dets_exploded.withColumn("min_cjd_exp", F.first("cjd_exp").over(window_spec))
.withColumn(
"corresponding_cmagpsf_exp", F.first("cmagpsf_exp").over(window_spec)
)
.groupBy("objectId", "cfid_exp")
.agg(
F.min("min_cjd_exp").alias("min_cjd_exp"),
F.first("corresponding_cmagpsf_exp").alias("corresponding_cmagpsf_exp"),
)
)

# GET LIMITS
df_lims_exploded = (
df_cfid.join(df_cdiffmaglim, ["objectId", "pos"], "outer")
.join(df_cjd, ["objectId", "pos"], "outer")
.drop("pos") # Drop `pos` if you don't need it in the final result
)
df_lims_exploded = df_lims_exploded.dropDuplicates()
# Filter to find last limit fainter than detection
df_filtered = df_lims_exploded.join(
df_detection_min_cjd, on=["objectId", "cfid_exp"], how="inner"
)
df_filtered = df_filtered.filter(
(df_filtered.cjd_exp < df_filtered.min_cjd_exp)
& (df_filtered.cdiffmaglim_exp > df_filtered.corresponding_cmagpsf_exp)
)
df_filtered = df_filtered.select(
"objectId", "cfid_exp", "cjd_exp", "cdiffmaglim_exp"
)
# Define a window specification partitioned by 'objectId' and 'cfid_exp', ordered by 'cjd_exp' descending
window_spec = Window.partitionBy("objectId", "cfid_exp").orderBy(F.desc("cjd_exp"))
# Add a row number to each row within the partition
df_filtered_top_cjd = df_filtered.withColumn(
"row_num", F.row_number().over(window_spec)
)
df_filtered_top_cjd = df_filtered_top_cjd.filter(F.col("row_num") == 1).drop(
"row_num"
)
# Refactor column names
df_limits = df_filtered_top_cjd.withColumnRenamed("cdiffmaglim_exp", "cmagpsf_exp")
# add column for error, putting a large error 0.2
df_limits = df_limits.withColumn("csigmapsf_exp", lit(0.2))

# Append detections and limits
df_appended = df_dets_exploded.unionByName(df_limits)
df_aggregated = df_appended.groupBy("objectId").agg(
collect_list("cfid_exp").alias("cfid_aggregated"),
collect_list("cjd_exp").alias("cjd_aggregated"),
collect_list("cmagpsf_exp").alias("cmagpsf_aggregated"),
collect_list("csigmapsf_exp").alias("csigmapsf_aggregated"),
)
# Join the aggregated result back to the original DataFrame on "objectId"
df_joined = df.join(df_aggregated, on="objectId", how="left")

df = (
df_joined.select(
"*", # Select all original columns
"cfid_aggregated",
"cjd_aggregated",
"cmagpsf_aggregated",
"csigmapsf_aggregated", # Add the new aggregated columns
)
.drop(
"cfid_exp",
"cjd_exp",
"cmagpsf_exp",
"csigmapsf_exp", # Drop the original exploded columns
"cfid",
"cjd",
"cmagpsf",
"csigmapsf", # Drop the original columns that will be replaced
)
.withColumnRenamed("cfid_aggregated", "cfid")
.withColumnRenamed("cjd_aggregated", "cjd")
.withColumnRenamed("cmagpsf_aggregated", "cmagpsf")
.withColumnRenamed("csigmapsf_aggregated", "csigmapsf")
)

# Add classification
cols = [
"cdsxmatch",
Expand Down

0 comments on commit 842d842

Please sign in to comment.