Skip to content

Commit

Permalink
Refactor ppe (arviz-devs#579)
Browse files Browse the repository at this point in the history
* refactor ppe

* lint
  • Loading branch information
aloctavodia authored Nov 1, 2024
1 parent 84d301f commit 551dfcc
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 113 deletions.
31 changes: 21 additions & 10 deletions preliz/internal/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,30 +253,41 @@ def interval_short(params):


def optimize_pymc_model(
fmodel, target, draws, prior, initial_guess, bounds, var_info, p_model, rng
fmodel,
target,
num_draws,
bounds,
initial_guess,
prior,
preliz_model,
transformed_var_info,
rng,
):
for _ in range(400):
for idx in range(401):
# can we sample systematically from these and less random?
# This should be more flexible and allow other targets than just
# a preliz distribution
# a PreliZ distribution
if isinstance(target, list):
obs = get_weighted_rvs(target, draws, rng)
obs = get_weighted_rvs(target, num_draws, rng)
else:
obs = target.rvs(draws, random_state=rng)
obs = target.rvs(num_draws, random_state=rng)
result = minimize(
fmodel,
initial_guess,
tol=0.001,
method="SLSQP",
args=(obs, var_info, p_model),
args=(obs, transformed_var_info, preliz_model),
bounds=bounds,
)

optimal_params = result.x
# To help minimize the effect of priors
# We don't save the first result and insteas we use it as the initial guess
# for the next optimization
# Updating the initial guess also helps to provides more spread samples
initial_guess = optimal_params

for key, param in zip(prior.keys(), optimal_params):
prior[key].append(param)
if idx:
for key, param in zip(prior.keys(), optimal_params):
prior[key].append(param)

# convert to numpy arrays
for key, value in prior.items():
Expand Down
2 changes: 1 addition & 1 deletion preliz/internal/predictive_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .distribution_helper import get_distributions


def back_fitting(model, subset, new_families=True):
def back_fitting_ppa(model, subset, new_families=True):
"""
Use MLE to fit a subset of the prior samples to the marginal prior distributions
"""
Expand Down
27 changes: 17 additions & 10 deletions preliz/ppls/agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from preliz.distributions import Gamma, Normal, HalfNormal
from preliz.unidimensional.mle import mle
from preliz.ppls.pymc_io import get_model_information, write_pymc_string
from preliz.ppls.bambi_io import get_bmb_model_information, write_bambi_string
from preliz.ppls.bambi_io import get_pymc_model, write_bambi_string

_log = logging.getLogger("preliz")

Expand Down Expand Up @@ -41,10 +41,23 @@ def posterior_to_prior(model, idata, alternative=None, engine="auto"):
"""
_log.info(""""This is an experimental method under development, use with caution.""")
engine = get_engine(model) if engine == "auto" else engine

if engine == "bambi":
_, _, model_info, _, var_info2, *_ = get_bmb_model_information(model)
model = get_pymc_model(model)

_, _, preliz_model, _, untransformed_var_info, *_ = get_model_information(model)

new_priors = back_fitting_idata(idata, preliz_model, alternative)

if engine == "bambi":
new_model = write_bambi_string(new_priors, untransformed_var_info)
elif engine == "pymc":
_, _, model_info, _, var_info2, *_ = get_model_information(model)
new_model = write_pymc_string(new_priors, untransformed_var_info)

return new_model


def back_fitting_idata(idata, model_info, alternative):
new_priors = {}
posterior = idata.posterior.stack(sample=("chain", "draw"))

Expand All @@ -66,10 +79,4 @@ def posterior_to_prior(model, idata, alternative=None, engine="auto"):

idx, _ = mle(dists, posterior[var].values, plot=False)
new_priors[var] = dists[idx[0]]

if engine == "bambi":
new_model = write_bambi_string(new_priors, var_info2)
elif engine == "pymc":
new_model = write_pymc_string(new_priors, var_info2)

return new_model
return new_priors
12 changes: 6 additions & 6 deletions preliz/ppls/bambi_io.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
from preliz.ppls.pymc_io import get_model_information
"""Functions to communicate with Bambi."""


def get_bmb_model_information(model):
def get_pymc_model(model):
if not model.built:
model.build()
pymc_model = model.backend.model
return get_model_information(pymc_model)
return pymc_model


def write_bambi_string(new_priors, var_info):
"""
Return a string with the new priors for the Bambi model.
So the user can copy and paste, ideally with none to minimal changes.
"""
header = "{"
header = "{\n"
for key, value in new_priors.items():
dist_name, dist_params = repr(value).split("(")
dist_params = dist_params.rstrip(")")
size = var_info[key][1]
if size > 1:
header += f'"{key}" : bmb.Prior("{dist_name}", {dist_params}, shape={size}), '
header += f'"{key}" : bmb.Prior("{dist_name}", {dist_params}, shape={size}),\n'
else:
header += f'"{key}" : bmb.Prior("{dist_name}", {dist_params}), '
header += f'"{key}" : bmb.Prior("{dist_name}", {dist_params}),\n'

header = header.rstrip(", ") + "}"
return header
67 changes: 45 additions & 22 deletions preliz/ppls/pymc_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,29 @@
from preliz.internal.distribution_helper import get_distributions


def backfitting(prior, p_model, var_info2):
def back_fitting_pymc(prior, preliz_model, untransformed_var_info):
"""
Fit the samples from prior into user provided model's prior.
from the perspective of ppe "prior" is actually an approximated posterior
but from the users perspective is its prior.
We need to "backfitted" because we can not use arbitrary samples as priors.
We need to "backfit" because we can not use arbitrary samples as priors.
We need probability distributions.
"""
new_priors = {}
for key, size_inf in var_info2.items():
for key, size_inf in untransformed_var_info.items():
if not size_inf[2]:
size = size_inf[1]
if size > 1:
params = []
for i in range(size):
value = prior[f"{key}__{i}"]
dist = p_model[key]
dist = preliz_model[key]
dist._fit_mle(value)
params.append(dist.params)
dist._parametrization(*[np.array(x) for x in zip(*params)])
else:
value = prior[key]
dist = p_model[key]
dist = preliz_model[key]
dist._fit_mle(value)

new_priors[key] = dist
Expand Down Expand Up @@ -81,7 +81,7 @@ def get_pymc_to_preliz():
return pymc_to_preliz


def get_guess(model, free_rvs):
def get_initial_guess(model, free_rvs):
"""
Get initial guess for optimization routine.
"""
Expand All @@ -104,17 +104,32 @@ def get_guess(model, free_rvs):

def get_model_information(model): # pylint: disable=too-many-locals
"""
Get information from the PyMC model.
This needs some love. We even have a variable named var_info,
and another one var_info2!
Get information from a PyMC model.
Parameters
----------
model : a PyMC model
Returns
-------
bounds : a list of tuples with the support of each marginal distribution in the model
prior : a dictionary with a key for each marginal distribution in the model and an empty
list as value. This will be filled with the samples from a backfitting procedure.
preliz_model : a dictionary with a key for each marginal distribution in the model and the
corresponding PreliZ distribution as value
transformed_var_info : a dictionary with a key for each transformed variable in the model
and a tuple with the shape, size and the indexes of the non-constant parents as value
untransformed_var_info : same as `transformed_var_info` but the keys are untransformed
variable names
num_draws : the number of observed samples
free_rvs : a list with the free random variables in the model
"""

bounds = []
prior = {}
p_model = {}
var_info = {}
var_info2 = {}
preliz_model = {}
transformed_var_info = {}
untransformed_var_info = {}
free_rvs = []
pymc_to_preliz = get_pymc_to_preliz()
rvs_to_values = model.rvs_to_values
Expand All @@ -128,13 +143,13 @@ def get_model_information(model): # pylint: disable=too-many-locals
r_v.owner.op.name if r_v.owner.op.name else str(r_v.owner.op).split("RV", 1)[0].lower()
)
dist = copy(pymc_to_preliz[name])
p_model[r_v.name] = dist
preliz_model[r_v.name] = dist
if nc_parents:
idxs = [free_rvs.index(var_) for var_ in nc_parents]
# the keys are the name of the (transformed) variable
var_info[rvs_to_values[r_v].name] = (shape, size, idxs)
transformed_var_info[rvs_to_values[r_v].name] = (shape, size, idxs)
# the keys are the name of the (untransformed) variable
var_info2[r_v.name] = (shape, size, idxs)
untransformed_var_info[r_v.name] = (shape, size, idxs)
else:
free_rvs.append(r_v)

Expand All @@ -147,13 +162,21 @@ def get_model_information(model): # pylint: disable=too-many-locals
prior[r_v.name] = []

# the keys are the name of the (transformed) variable
var_info[rvs_to_values[r_v].name] = (shape, size, nc_parents)
transformed_var_info[rvs_to_values[r_v].name] = (shape, size, nc_parents)
# the keys are the name of the (untransformed) variable
var_info2[r_v.name] = (shape, size, nc_parents)

draws = model.observed_RVs[0].eval().size

return bounds, prior, p_model, var_info, var_info2, draws, free_rvs
untransformed_var_info[r_v.name] = (shape, size, nc_parents)

num_draws = model.observed_RVs[0].eval().size

return (
bounds,
prior,
preliz_model,
transformed_var_info,
untransformed_var_info,
num_draws,
free_rvs,
)


def write_pymc_string(new_priors, var_info):
Expand Down
4 changes: 2 additions & 2 deletions preliz/predictive/ppa.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
plot_pp_mean,
)
from ..internal.parser import get_prior_pp_samples, from_preliz, from_bambi
from ..internal.predictive_helper import back_fitting, select_prior_samples
from ..internal.predictive_helper import back_fitting_ppa, select_prior_samples
from ..distributions import Normal
from ..distributions.distributions import Distribution

Expand Down Expand Up @@ -386,7 +386,7 @@ def on_return_prior(self):
if len(selected) > 4:
subsample = select_prior_samples(selected, self.prior_samples, self.model)

string, _ = back_fitting(self.model, subsample, new_families=False)
string, _ = back_fitting_ppa(self.model, subsample, new_families=False)

self.fig.clf()
plt.text(0.05, 0.5, string, fontsize=14)
Expand Down
Loading

0 comments on commit 551dfcc

Please sign in to comment.