diff --git a/src/baskerville/dataset.py b/src/baskerville/dataset.py index 785e8cc..c8360c3 100644 --- a/src/baskerville/dataset.py +++ b/src/baskerville/dataset.py @@ -377,7 +377,7 @@ def untransform_preds(preds, targets_df, unscale=False, unclip=True): preds (np.array): Untransformed predictions LxT. """ # clip soft - if unclip : + if unclip: cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0) preds_unclip = cs - 1 + (preds - cs + 1) ** 2 preds = np.where(preds > cs, preds_unclip, preds) @@ -409,7 +409,7 @@ def untransform_preds1(preds, targets_df, unscale=False, unclip=True): preds = preds / scale # clip soft - if unclip : + if unclip: cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0) preds_unclip = cs + (preds - cs) ** 2 preds = np.where(preds > cs, preds_unclip, preds)