Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mask replacement of nan and negative mag errors #18

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions src/rail/estimation/algos/gpz.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
import qp


# set of magnitude errors that will replace values that are negative or np.nan
default_err_repl = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]

def _prepare_data(data_dict, bands, err_bands, nondet_val, maglims, logflag):

def _prepare_data(data_dict, bands, err_bands, nondet_val, maglims, logflag, repl_err_vals):
"""Put data in 2D np array expected by GPz.
For some reason they like to take the log of the magnitude errors, so
have that as a boolean option. Also replace nondetect vals for each
Expand All @@ -21,14 +24,17 @@ def _prepare_data(data_dict, bands, err_bands, nondet_val, maglims, logflag):
numbands = len(bands)
totrows = len(data_dict[bands[0]])
data = np.empty([totrows, 2 * numbands])
for i, (band, eband, lim) in enumerate(zip(bands, err_bands, maglims.values())):
for i, (band, eband, lim, rplval) in enumerate(zip(bands, err_bands, maglims.values(), repl_err_vals)):
data[:, i] = data_dict[band]
mask = np.bitwise_or(np.isclose(data_dict[band], nondet_val), np.isnan(data_dict[band]))
data[:, i][mask] = lim
errband = data_dict[eband]
emask = np.bitwise_or(errband <= 0., np.isnan(errband))
errband[emask] = rplval
if logflag:
data[:, numbands + i] = np.log(data_dict[eband])
data[:, numbands + i] = np.log(errband)
else: # pragma: no cover
data[:, numbands + i] = data_dict[eband]
data[:, numbands + i] = errband
data[:, numbands + i][mask] = 1.0
return data

Expand Down Expand Up @@ -63,7 +69,8 @@ class GPzInformer(CatInformer):
pca_decorrelate=Param(bool, True, msg="if True, decorrelate data using PCA as preprocessing stage"),
max_iter=Param(int, 200, msg="max number of iterations"),
max_attempt=Param(int, 100, msg="max iterations if no progress on validation"),
log_errors=Param(bool, True, msg="if true, take log of magnitude errors")
log_errors=Param(bool, True, msg="if true, take log of magnitude errors"),
replace_error_vals=Param(list, default_err_repl, msg="list of values to replace negative and nan mag err values")
)

def __init__(self, args, comm=None):
Expand All @@ -81,9 +88,14 @@ def run(self):
else: # pragma: no cover
training_data = self.get_data('input')

# check that lengths of bands, err_bands, and replace_error_vals match
if not np.logical_and(len(self.config.bands) == len(self.config.err_bands),
len(self.config.err_bands) == len(self.config.replace_error_vals)): # pragma: no cover
raise ValueError("lengths of bands, err_bands, and replace_error_vals do not match!")

input_array = _prepare_data(training_data, self.config.bands, self.config.err_bands,
self.config.nondetect_val, self.config.mag_limits,
self.config.log_errors)
self.config.log_errors, self.config.replace_error_vals)

sz = np.expand_dims(training_data[self.config.redshift_col], -1)
# need permutation mask to define training vs validation
Expand Down Expand Up @@ -128,19 +140,25 @@ class GPzEstimator(CatEstimator):
bands=SHARED_PARAMS,
err_bands=SHARED_PARAMS,
ref_band=SHARED_PARAMS,
log_errors=Param(bool, True, msg="if true, take log of magnitude errors"))
log_errors=Param(bool, True, msg="if true, take log of magnitude errors"),
replace_error_vals=Param(list, default_err_repl, msg="list of values to replace negative and nan mag err values")
)

def __init__(self, args, comm=None):
""" Constructor:
Do CatEstimator specific initialization """
CatEstimator.__init__(self, args, comm=comm)
self.zgrid = None
# check that lengths of bands, err_bands, and replace_error_vals match
if not np.logical_and(len(self.config.bands) == len(self.config.err_bands),
len(self.config.err_bands) == len(self.config.replace_error_vals)): # pragma: no cover
raise ValueError("lengths of bands, err_bands, and replace_error_vals do not match!")

def _process_chunk(self, start, end, data, first):
print(f"Process {self.rank} estimating GPz PZ PDF for rows {start:,} - {end:,}")
test_array = _prepare_data(data, self.config.bands, self.config.err_bands,
self.config.nondetect_val, self.config.mag_limits,
self.config.log_errors)
self.config.log_errors, self.config.replace_error_vals)

mu, totalV, modelV, noiseV, _ = self.model.predict(test_array)
ens = qp.Ensemble(qp.stats.norm, data=dict(loc=mu, scale=totalV))
Expand Down
Loading