From a3a4320201c67ed03dc4abc0e6998e69c6d531ac Mon Sep 17 00:00:00 2001 From: Johannes Linder Date: Fri, 3 May 2024 13:19:44 -0700 Subject: [PATCH] Added option to ignore soft-clip when undoing transforms. --- src/baskerville/dataset.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/baskerville/dataset.py b/src/baskerville/dataset.py index e127e42..785e8cc 100644 --- a/src/baskerville/dataset.py +++ b/src/baskerville/dataset.py @@ -366,7 +366,7 @@ def targets_prep_strand(targets_df): return targets_strand_df -def untransform_preds(preds, targets_df, unscale=False): +def untransform_preds(preds, targets_df, unscale=False, unclip=True): """Undo the squashing transformations performed for the tasks. Args: @@ -377,9 +377,10 @@ def untransform_preds(preds, targets_df, unscale=False): preds (np.array): Untransformed predictions LxT. """ # clip soft - cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0) - preds_unclip = cs - 1 + (preds - cs + 1) ** 2 - preds = np.where(preds > cs, preds_unclip, preds) + if unclip : + cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0) + preds_unclip = cs - 1 + (preds - cs + 1) ** 2 + preds = np.where(preds > cs, preds_unclip, preds) # sqrt sqrt_mask = np.array([ss.find("_sqrt") != -1 for ss in targets_df.sum_stat]) @@ -393,7 +394,7 @@ def untransform_preds(preds, targets_df, unscale=False): return preds -def untransform_preds1(preds, targets_df, unscale=False): +def untransform_preds1(preds, targets_df, unscale=False, unclip=True): """Undo the squashing transformations performed for the tasks. Args: @@ -408,9 +409,10 @@ def untransform_preds1(preds, targets_df, unscale=False): preds = preds / scale # clip soft - cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0) - preds_unclip = cs + (preds - cs) ** 2 - preds = np.where(preds > cs, preds_unclip, preds) + if unclip : + cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0) + preds_unclip = cs + (preds - cs) ** 2 + preds = np.where(preds > cs, preds_unclip, preds) # ** 0.75 sqrt_mask = np.array([ss.find("_sqrt") != -1 for ss in targets_df.sum_stat])